The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .devcontainer
    ├── devcontainer.json
    └── scripts
    │   ├── postStart.sh
    │   └── runItOnGo.sh
├── .dockerignore
├── .gitattributes
├── .github
    ├── FUNDING.yml
    ├── dependabot.yml
    └── workflows
    │   ├── build-bin.yml
    │   ├── close_stale.yml
    │   ├── codeql-analysis.yml
    │   ├── development-docker.yml
    │   ├── docs.yml
    │   ├── fork-sync.yml
    │   ├── goreleaser-test.yml
    │   ├── makefile.yml
    │   ├── mirror-repo.yml
    │   └── release.yml
├── .gitignore
├── .golangci.yml
├── .goreleaser.yml
├── .vscode
    ├── launch.json
    ├── settings.json
    └── tasks.json
├── CODE_OF_CONDUCT.md
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── api
    ├── api_client.gen.go
    ├── api_interface_impl.go
    ├── api_interface_impl_test.go
    ├── api_server.gen.go
    ├── api_suite_test.go
    ├── api_types.gen.go
    ├── client.cfg.yaml
    ├── server.cfg.yaml
    └── types.cfg.yaml
├── cache
    ├── expirationcache
    │   ├── cache_interface.go
    │   ├── expiration_cache.go
    │   ├── expiration_cache_suite_test.go
    │   ├── expiration_cache_test.go
    │   ├── prefetching_cache.go
    │   └── prefetching_cache_test.go
    └── stringcache
    │   ├── chained_grouped_cache.go
    │   ├── chained_grouped_cache_test.go
    │   ├── grouped_cache_interface.go
    │   ├── in_memory_grouped_cache.go
    │   ├── in_memory_grouped_cache_test.go
    │   ├── string_cache_suite_test.go
    │   ├── string_caches.go
    │   ├── string_caches_benchmark_test.go
    │   └── string_caches_test.go
├── cmd
    ├── blocking.go
    ├── blocking_test.go
    ├── cache.go
    ├── cache_test.go
    ├── cmd_suite_test.go
    ├── healthcheck.go
    ├── healthcheck_test.go
    ├── lists.go
    ├── lists_test.go
    ├── query.go
    ├── query_test.go
    ├── root.go
    ├── root_test.go
    ├── serve.go
    ├── serve_test.go
    ├── validate.go
    ├── validate_test.go
    ├── version.go
    └── version_test.go
├── codecov.yml
├── config
    ├── blocking.go
    ├── blocking_test.go
    ├── bytes_source.go
    ├── bytes_source_enum.go
    ├── caching.go
    ├── caching_test.go
    ├── client_lookup.go
    ├── client_lookup_test.go
    ├── conditional_upstream.go
    ├── conditional_upstream_test.go
    ├── config.go
    ├── config_enum.go
    ├── config_suite_test.go
    ├── config_test.go
    ├── custom_dns.go
    ├── custom_dns_test.go
    ├── duration.go
    ├── duration_test.go
    ├── ecs.go
    ├── ecs_test.go
    ├── filtering.go
    ├── filtering_test.go
    ├── hosts_file.go
    ├── hosts_file_test.go
    ├── metrics.go
    ├── metrics_test.go
    ├── migration
    │   └── migration.go
    ├── qtype_set.go
    ├── qtype_set_test.go
    ├── query_log.go
    ├── query_log_test.go
    ├── redis.go
    ├── redis_test.go
    ├── rewriter.go
    ├── rewriter_test.go
    ├── sudn.go
    ├── sudn_test.go
    ├── upstream.go
    ├── upstreams.go
    └── upstreams_test.go
├── docs
    ├── additional_information.md
    ├── api
    │   └── openapi.yaml
    ├── blocky-grafana.json
    ├── blocky-query-grafana-postgres.json
    ├── blocky-query-grafana.json
    ├── blocky.svg
    ├── config.yml
    ├── configuration.md
    ├── embed.go
    ├── fb_dns_config.png
    ├── grafana-dashboard.png
    ├── grafana-query-dashboard.png
    ├── includes
    │   └── abbreviations.md
    ├── index.md
    ├── installation.md
    ├── interfaces.md
    ├── network_configuration.md
    ├── prometheus_grafana.md
    └── rapidoc.html
├── e2e
    ├── basic_test.go
    ├── blocking_test.go
    ├── containers.go
    ├── custom_dns_test.go
    ├── e2e_suite_test.go
    ├── helper.go
    ├── metrics_test.go
    ├── querylog_test.go
    ├── redis_test.go
    └── upstream_test.go
├── evt
    └── events.go
├── go.mod
├── go.sum
├── helpertest
    ├── data
    │   ├── oisd-big-plain.txt
    │   └── oisd-big-wildcard.txt
    ├── helper.go
    ├── http.go
    ├── mock_call_sequence.go
    └── tmpdata.go
├── lists
    ├── downloader.go
    ├── downloader_test.go
    ├── list_cache.go
    ├── list_cache_benchmark_test.go
    ├── list_cache_enum.go
    ├── list_cache_test.go
    ├── list_suite_test.go
    ├── parsers
    │   ├── adapt.go
    │   ├── filtererrors.go
    │   ├── filtererrors_test.go
    │   ├── hosts.go
    │   ├── hosts_test.go
    │   ├── lines.go
    │   ├── lines_test.go
    │   ├── parser.go
    │   ├── parser_test.go
    │   └── parsers_suite_test.go
    └── sourcereader.go
├── log
    ├── context.go
    ├── logger.go
    ├── logger_enum.go
    └── mock_entry.go
├── main.go
├── main_static.go
├── metrics
    ├── metrics.go
    └── metrics_event_publisher.go
├── mkdocs.yml
├── model
    ├── models.go
    └── models_enum.go
├── querylog
    ├── database_writer.go
    ├── database_writer_test.go
    ├── file_writer.go
    ├── file_writer_test.go
    ├── logger_writer.go
    ├── logger_writer_test.go
    ├── none_writer.go
    ├── none_writer_test.go
    ├── querylog_suite_test.go
    └── writer.go
├── redis
    ├── redis.go
    ├── redis_suite_test.go
    └── redis_test.go
├── resolver
    ├── blocking_resolver.go
    ├── blocking_resolver_test.go
    ├── bootstrap.go
    ├── bootstrap_test.go
    ├── caching_resolver.go
    ├── caching_resolver_test.go
    ├── client_names_resolver.go
    ├── client_names_resolver_test.go
    ├── conditional_upstream_resolver.go
    ├── conditional_upstream_resolver_test.go
    ├── custom_dns_resolver.go
    ├── custom_dns_resolver_test.go
    ├── ecs_resolver.go
    ├── ecs_resolver_test.go
    ├── ede_resolver.go
    ├── ede_resolver_test.go
    ├── filtering_resolver.go
    ├── filtering_resolver_test.go
    ├── fqdn_only_resolver.go
    ├── fqdn_only_resolver_test.go
    ├── hosts_file_resolver.go
    ├── hosts_file_resolver_test.go
    ├── metrics_resolver.go
    ├── metrics_resolver_test.go
    ├── mock_udp_upstream_server.go
    ├── mocks_test.go
    ├── noop_resolver.go
    ├── noop_resolver_test.go
    ├── parallel_best_resolver.go
    ├── parallel_best_resolver_test.go
    ├── query_logging_resolver.go
    ├── query_logging_resolver_test.go
    ├── resolver.go
    ├── resolver_suite_test.go
    ├── resolver_test.go
    ├── rewriter_resolver.go
    ├── rewriter_resolver_test.go
    ├── strict_resolver.go
    ├── strict_resolver_test.go
    ├── sudn_resolver.go
    ├── sudn_resolver_test.go
    ├── upstream_resolver.go
    ├── upstream_resolver_test.go
    ├── upstream_tree_resolver.go
    └── upstream_tree_resolver_test.go
├── server
    ├── http.go
    ├── server.go
    ├── server_config_trigger.go
    ├── server_config_trigger_windows.go
    ├── server_endpoints.go
    ├── server_suite_test.go
    └── server_test.go
├── trie
    ├── split.go
    ├── split_test.go
    ├── trie.go
    ├── trie_suite_test.go
    └── trie_test.go
├── util
    ├── arpa.go
    ├── arpa_test.go
    ├── buildinfo.go
    ├── common.go
    ├── common_test.go
    ├── context.go
    ├── context_test.go
    ├── edns0.go
    ├── edns0_test.go
    ├── http.go
    ├── http_test.go
    ├── tls.go
    ├── tls_test.go
    └── util_suite_test.go
└── web
    ├── index.go
    ├── index.html
    └── static
        ├── rapidoc-min.js
        └── rapidoc.html


/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "name": "blocky development",
 3 |   "image": "mcr.microsoft.com/devcontainers/base:ubuntu-22.04",
 4 |   "features": {
 5 |     "ghcr.io/devcontainers/features/go:1": {},
 6 |     "ghcr.io/jungaretti/features/make:1": {},
 7 |     "ghcr.io/devcontainers/features/docker-in-docker:2": {
 8 |       "dockerDashComposeVersion": "v2"
 9 |     },
10 |     "ghcr.io/devcontainers/features/python:1": {},
11 |     "ghcr.io/devcontainers/features/github-cli:1": {},
12 |     "ghcr.io/rocker-org/devcontainer-features/apt-packages:1": {
13 |       "packages": "dnsutils "
14 |     }
15 |   },
16 |   "remoteEnv": {
17 |     "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}",
18 |     "WORKSPACE_FOLDER": "${containerWorkspaceFolder}",
19 |     "GENERATE_LCOV": "true"
20 |   },
21 |   "customizations": {
22 |     "vscode": {
23 |       "extensions": [
24 |         "golang.go",
25 |         "esbenp.prettier-vscode",
26 |         "yzhang.markdown-all-in-one",
27 |         "joselitofilho.ginkgotestexplorer",
28 |         "fsevenm.run-it-on",
29 |         "markis.code-coverage",
30 |         "tooltitudeteam.tooltitude",
31 |         "GitHub.vscode-github-actions"
32 |       ],
33 |       "settings": {
34 |         "go.lintFlags": [
35 |           "--config=${containerWorkspaceFolder}/.golangci.yml",
36 |           "--fast"
37 |         ],
38 |         "go.alternateTools": {
39 |           "go-langserver": "gopls"
40 |         },
41 |         "markiscodecoverage.searchCriteria": "**/*.lcov",
42 |         "runItOn": {
43 |           "commands": [
44 |             {
45 |               "match": "\\.go
quot;,
46 |               "cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}"
47 |             }
48 |           ]
49 |         },
50 |         "[go]": {
51 |           "editor.defaultFormatter": "golang.go"
52 |         },
53 |         "[yaml][json][jsonc][github-actions-workflow]": {
54 |           "editor.defaultFormatter": "esbenp.prettier-vscode"
55 |         },
56 |         "[markdown]": {
57 |           "editor.defaultFormatter": "yzhang.markdown-all-in-one"
58 |         }
59 |       }
60 |     }
61 |   },
62 |   "mounts": [
63 |     "type=bind,readonly,source=/etc/localtime,target=/usr/share/host/localtime",
64 |     "type=bind,readonly,source=/etc/timezone,target=/usr/share/host/timezone",
65 |     "type=volume,source=blocky-pkg_cache,target=/go/pkg"
66 |   ],
67 |   "postCreateCommand": "sudo chmod +x .devcontainer/scripts/*.sh",
68 |   "postStartCommand": "sh .devcontainer/scripts/postStart.sh"
69 | }
70 | 


--------------------------------------------------------------------------------
/.devcontainer/scripts/postStart.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash -e
 2 | 
 3 | echo "Setting up go environment..."
 4 | # Use the host's timezone and time
 5 | sudo ln -sf /usr/share/host/localtime /etc/localtime
 6 | sudo ln -sf /usr/share/host/timezone /etc/timezone
 7 | # Change permission on pkg volume
 8 | sudo chown -R vscode:golang /go/pkg
 9 | echo ""
10 | 
11 | echo "Downloading Go modules..."
12 | go mod download -x
13 | echo ""
14 | 
15 | echo "Tidying Go modules..."
16 | go mod tidy -x
17 | echo ""
18 | 
19 | echo "Installing Go tools..."
20 | echo "  - ginkgo"
21 | go install github.com/onsi/ginkgo/v2/ginkgo@latest
22 | echo "  - golangci-lint"
23 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
24 | echo "  - gofumpt"
25 | go install mvdan.cc/gofumpt@latest
26 | echo "  - gcov2lcov"
27 | go install github.com/jandelgado/gcov2lcov@latest


--------------------------------------------------------------------------------
/.devcontainer/scripts/runItOnGo.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash -e
 2 | 
 3 | FOLDER_PATH=$1
 4 | if [ -z "${FOLDER_PATH}" ]; then
 5 |   FOLDER_PATH=$PWD
 6 | fi
 7 | 
 8 | BASE_PATH=$2
 9 | if [ -z "${BASE_PATH}" ]; then
10 |   BASE_PATH=$WORKSPACE_FOLDER
11 | fi
12 | 
13 | if [ "$FOLDER_PATH" = "$BASE_PATH" ]; then
14 |   echo "Skipping lcov creation for base path"
15 |   exit 1
16 | fi
17 | 
18 | FOLDER_NAME=${FOLDER_PATH#"$BASE_PATH/"}
19 | WORK_NAME="$(echo "$FOLDER_NAME" | sed 's/\//-/g')"
20 | WORK_FILE_NAME="$WORK_NAME.ginkgo"
21 | WORK_FILE_PATH="/tmp/$WORK_FILE_NAME"
22 | OUTPUT_FOLDER="$BASE_PATH/coverage"
23 | OUTPUT_FILE_PATH="$OUTPUT_FOLDER/$WORK_NAME.lcov"
24 | 
25 | 
26 | mkdir -p "$OUTPUT_FOLDER"
27 | 
28 | echo "-- Start $FOLDER_NAME ($(date '+%T')) --"
29 | 
30 | TIMEFORMAT=' - Ginkgo tests finished in: %R seconds'
31 | time ginkgo --label-filter="!e2e" --keep-going --timeout=5m --output-dir=/tmp --coverprofile="$WORK_FILE_NAME" --covermode=atomic --cover -r -p "$FOLDER_PATH" || true
32 | 
33 | TIMEFORMAT=' - lcov convert finished in: %R seconds'
34 | time gcov2lcov -infile="$WORK_FILE_PATH" -outfile="$OUTPUT_FILE_PATH" || true
35 | 
36 | TIMEFORMAT=' - cleanup finished in: %R seconds'
37 | time rm "$WORK_FILE_PATH" || true
38 | 
39 | echo "-- Finished $FOLDER_NAME ($(date '+%T')) --"


--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
 1 | bin
 2 | dist
 3 | site
 4 | node_modules
 5 | .git
 6 | .idea
 7 | .github
 8 | .vscode
 9 | .gitignore
10 | *.md
11 | LICENSE
12 | vendor
13 | e2e/
14 | .devcontainer/
15 | coverage.txt
16 | coverage/
17 | 


--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | * 		text=auto eol=lf


--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: 0xERR0R
2 | ko_fi: 0xerr0r
3 | custom: ["paypal.me/spx01"]
4 | 


--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
 1 | version: 2
 2 | updates:
 3 | - package-ecosystem: gomod
 4 |   directory: "/"
 5 |   schedule:
 6 |     interval: daily
 7 |   open-pull-requests-limit: 10
 8 |   assignees:
 9 |   - 0xERR0R
10 | 
11 | - package-ecosystem: github-actions
12 |   directory: "/"
13 |   schedule:
14 |     interval: daily
15 | 


--------------------------------------------------------------------------------
/.github/workflows/close_stale.yml:
--------------------------------------------------------------------------------
 1 | name: Close stale issues and PRs
 2 | 
 3 | on:
 4 |   schedule:
 5 |     - cron: "0 4 * * *"
 6 | 
 7 | concurrency:
 8 |   group: ${{ github.workflow }}-${{ github.ref }}
 9 | 
10 | jobs:
11 |   stale:
12 |     runs-on: ubuntu-latest
13 |     if: github.repository_owner == '0xERR0R'
14 |     permissions:
15 |       issues: write
16 |       pull-requests: write
17 |     steps:
18 |       - uses: actions/stale@v9
19 |         with:
20 |           stale-issue-message: "This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days."
21 |           stale-pr-message: "This PR is stale because it has been open 45 days with no activity. Remove stale label or comment or this will be closed in 10 days."
22 |           close-issue-message: "This issue was closed because it has been stalled for 5 days with no activity."
23 |           close-pr-message: "This PR was closed because it has been stalled for 10 days with no activity."
24 |           days-before-issue-stale: 90
25 |           days-before-pr-stale: 45
26 |           days-before-issue-close: 5
27 |           days-before-pr-close: 10
28 |           exempt-all-milestones: true
29 |           operations-per-run: 60
30 | 


--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
 1 | name: CodeQL
 2 | 
 3 | on:
 4 |   push:
 5 |     branches:
 6 |       - main
 7 |   pull_request:
 8 |     branches:
 9 |       - main
10 |   schedule:
11 |     - cron: "33 15 * * 1"
12 | 
13 | concurrency:
14 |   group: ${{ github.workflow }}-${{ github.ref }}
15 | 
16 | jobs:
17 |   analyze:
18 |     name: Analyze
19 |     runs-on: ubuntu-latest
20 |     steps:
21 |       - name: Checkout repository
22 |         uses: actions/checkout@v4
23 |         with:
24 |           fetch-depth: 0
25 | 
26 |       - name: Setup Golang
27 |         uses: actions/setup-go@v5
28 |         with:
29 |           go-version-file: go.mod
30 | 
31 |       - name: Initialize CodeQL
32 |         uses: github/codeql-action/init@v3
33 |         with:
34 |           languages: go
35 | 
36 |       - name: Build with Makefile
37 |         run: make build
38 | 
39 |       - name: Perform CodeQL Analysis
40 |         uses: github/codeql-action/analyze@v3
41 | 


--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
 1 | name: docs
 2 | 
 3 | on:
 4 |   push:
 5 |     tags:
 6 |       - v*
 7 | 
 8 |     branches:
 9 |       - '**'
10 |     paths:
11 |       - .github/workflows/**
12 |       - mkdocs.yml
13 |       - docs/**
14 | 
15 | concurrency:
16 |   group: ${{ github.workflow }}
17 | 
18 | jobs:
19 |   deploy:
20 |     runs-on: ubuntu-latest
21 |     if: ${{ github.event.repository.has_pages && (github.repository_owner != '0xERR0R' || github.ref_type == 'tag' || github.ref_name == 'main') }}
22 |     steps:
23 |       - uses: actions/checkout@v4
24 |         with:
25 |           fetch-depth: 0
26 | 
27 |       - uses: actions/setup-python@v5
28 |         with:
29 |           python-version: 3.x
30 | 
31 |       - name: install tools
32 |         run: pip install mkdocs-material mike
33 | 
34 |       - name: Setup doc deploy
35 |         run: |
36 |           git config --local user.email "github-actions[bot]@users.noreply.github.com"
37 |           git config --local user.name "github-actions[bot]"
38 | 
39 |       - name: Deploy version
40 |         run: |
41 |           VERSION="$(sed 's:/:-:g' <<< "$GITHUB_REF_NAME")"
42 |           if [[ ${{github.ref}} =~ ^refs/tags/ ]]; then
43 |             EXTRA_ALIAS=latest
44 |           fi
45 |           mike deploy --push --update-aliases "$VERSION" $EXTRA_ALIAS
46 |           tr '[:upper:]' '[:lower:]' <<< "https://${{github.repository_owner}}.github.io/${{github.event.repository.name}}/$VERSION/" >> "$GITHUB_STEP_SUMMARY"
47 | 


--------------------------------------------------------------------------------
/.github/workflows/fork-sync.yml:
--------------------------------------------------------------------------------
 1 | name: Sync Fork
 2 | 
 3 | on:
 4 |   schedule:
 5 |     - cron: "*/30 * * * *"
 6 |   workflow_dispatch:
 7 | 
 8 | concurrency:
 9 |   group: ${{ github.workflow }}
10 |   cancel-in-progress: true
11 | 
12 | jobs:
13 |   sync:
14 |     name: Sync with Upstream
15 |     runs-on: ubuntu-latest
16 |     if: github.repository_owner != '0xERR0R'
17 |     steps:
18 |       - name: Enabled Check
19 |         id: check
20 |         shell: bash
21 |         run: |
22 |           if [[ "${{ secrets.FORK_SYNC_TOKEN }}" != "" ]]; then
23 |             echo "enabled=1" >> $GITHUB_OUTPUT
24 | 
25 |             echo "Workflow is enabled"
26 |           else
27 |             echo "enabled=0" >> $GITHUB_OUTPUT
28 | 
29 |             (
30 |               echo 'Workflow is disabled (create `FORK_SYNC_TOKEN` secret with repo write permission to enable it)'
31 |               echo
32 |               echo 'Alternatively, you can disable it for your repo from the web UI:'
33 |               echo 'https://docs.github.com/en/actions/using-workflows/disabling-and-enabling-a-workflow'
34 |             ) | tee "$GITHUB_STEP_SUMMARY"
35 |           fi
36 | 
37 |       - name: Sync
38 |         if: ${{ steps.check.outputs.enabled == 1 }}
39 |         env:
40 |           GH_TOKEN: ${{ secrets.FORK_SYNC_TOKEN }}
41 |         shell: bash
42 |         run: |
43 |           gh repo sync ${{ github.repository }} -b main
44 | 


--------------------------------------------------------------------------------
/.github/workflows/makefile.yml:
--------------------------------------------------------------------------------
 1 | name: Makefile
 2 | 
 3 | on:
 4 |   push:
 5 |     paths:
 6 |       - .github/workflows/makefile.yml
 7 |       - Dockerfile
 8 |       - Makefile
 9 |       - "**.go"
10 |       - "go.*"
11 |       - "helpertest/data/**"
12 |   pull_request:
13 | 
14 | permissions:
15 |   security-events: write
16 |   actions: read
17 |   contents: read
18 | 
19 | env:
20 |   GINKGO_PROCS: --procs=1
21 | 
22 | jobs:
23 |   make:
24 |     name: make
25 |     runs-on: ubuntu-latest
26 |     strategy:
27 |       matrix:
28 |         include:
29 |           - make: build
30 |             go: true
31 |             docker: false
32 |           - make: test
33 |             go: true
34 |             docker: false
35 |           - make: race
36 |             go: true
37 |             docker: false
38 |           - make: docker-build
39 |             go: false
40 |             docker: true
41 |           - make: e2e-test
42 |             go: true
43 |             docker: true
44 |           - make: goreleaser
45 |             go: false
46 |             docker: false
47 |           - make: lint
48 |             go: true
49 |             docker: false
50 | 
51 |     steps:
52 |       - name: Check out code into the Go module directory
53 |         uses: actions/checkout@v4
54 | 
55 |       - name: Setup Golang
56 |         uses: actions/setup-go@v5
57 |         if: matrix.go == true
58 |         with:
59 |           go-version-file: go.mod
60 | 
61 |       - name: Download dependencies
62 |         run: go mod download
63 |         if: matrix.go == true
64 | 
65 |       - name: Set up Docker Buildx
66 |         uses: docker/setup-buildx-action@v3
67 |         if: matrix.docker == true
68 | 
69 |       - name: make ${{ matrix.make }}
70 |         run: make ${{ matrix.make }}
71 |         if: matrix.make != 'goreleaser'
72 |         env:
73 |           GO_SKIP_GENERATE: 1
74 | 
75 |       - name: Upload results to codecov
76 |         uses: codecov/codecov-action@v5
77 |         if: matrix.make == 'test' && github.repository_owner == '0xERR0R'
78 | 
79 |       - name: Check GoReleaser configuration
80 |         uses: goreleaser/goreleaser-action@v5
81 |         if: matrix.make == 'goreleaser'
82 |         with:
83 |           args: check
84 | 


--------------------------------------------------------------------------------
/.github/workflows/mirror-repo.yml:
--------------------------------------------------------------------------------
 1 | name: mirror git repo
 2 | 
 3 | on:
 4 |   push:
 5 |     branches:
 6 |       - main
 7 | 
 8 | concurrency:
 9 |   group: ${{ github.workflow }}-${{ github.ref }}
10 | 
11 | jobs:
12 |   mirror:
13 |     runs-on: ubuntu-latest
14 |     if: github.repository_owner == '0xERR0R'
15 |     steps:
16 |       - uses: actions/checkout@v4
17 |         with:
18 |           fetch-depth: 0
19 | 
20 |       - uses: yesolutions/mirror-action@master
21 |         with:
22 |           REMOTE: "https://codeberg.org/0xERR0R/blocky.git"
23 |           GIT_USERNAME: 0xERR0R
24 |           GIT_PASSWORD: ${{ secrets.CODEBERG_TOKEN }}
25 | 


--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
 1 | name: Release
 2 | 
 3 | on:
 4 |   push:
 5 |     tags:
 6 |       - v*
 7 | 
 8 | concurrency:
 9 |   group: ${{ github.workflow }}-${{ github.ref }}
10 | 
11 | jobs:
12 |   build:
13 |     runs-on: ubuntu-latest
14 |     if: github.repository_owner == '0xERR0R'
15 |     steps:
16 |       - name: Checkout
17 |         uses: actions/checkout@v4
18 |         with:
19 |           fetch-depth: 0
20 | 
21 |       - name: Set up Go
22 |         uses: actions/setup-go@v5
23 |         with:
24 |           go-version-file: go.mod
25 |         id: go
26 | 
27 |       - name: Build
28 |         run: make build
29 | 
30 |       - name: Test
31 |         run: make test
32 | 
33 |       - name: Docker meta
34 |         id: docker_meta
35 |         uses: crazy-max/ghaction-docker-meta@v5
36 |         with:
37 |           images: spx01/blocky,ghcr.io/0xerr0r/blocky
38 | 
39 |       - name: Set up QEMU
40 |         uses: docker/setup-qemu-action@v3
41 |         with:
42 |           platforms: arm,arm64
43 | 
44 |       - name: Set up Docker Buildx
45 |         uses: docker/setup-buildx-action@v3
46 | 
47 |       - name: Login to GitHub Container Registry
48 |         uses: docker/login-action@v3
49 |         with:
50 |           registry: ghcr.io
51 |           username: ${{ github.repository_owner }}
52 |           password: ${{ secrets.CR_PAT }}
53 | 
54 |       - name: Login to DockerHub
55 |         uses: docker/login-action@v3
56 |         with:
57 |           username: ${{ secrets.DOCKER_USERNAME }}
58 |           password: ${{ secrets.DOCKER_TOKEN }}
59 | 
60 |       - name: Populate build variables
61 |         id: get_vars
62 |         shell: bash
63 |         run: |
64 |           VERSION=$(git describe --always --tags)
65 |           echo "version=${VERSION}" >> $GITHUB_OUTPUT
66 |           echo "VERSION: ${VERSION}"
67 | 
68 |           BUILD_TIME=$(date --iso-8601=seconds)
69 |           echo "build_time=${BUILD_TIME}" >> $GITHUB_OUTPUT
70 |           echo "BUILD_TIME: ${BUILD_TIME}"
71 | 
72 |           DOC_PATH=${VERSION%%-*}
73 |           echo "doc_path=${DOC_PATH}" >> $GITHUB_OUTPUT
74 |           echo "DOC_PATH: ${DOC_PATH}"
75 | 
76 |       - name: Build and push
77 |         uses: docker/build-push-action@v6
78 |         with:
79 |           context: .
80 |           platforms: linux/amd64,linux/arm/v6,linux/arm/v7,linux/arm64
81 |           push: ${{ github.event_name != 'pull_request' }}
82 |           tags: ${{ steps.docker_meta.outputs.tags }}
83 |           labels: ${{ steps.docker_meta.outputs.labels }}
84 |           build-args: |
85 |             VERSION=${{ steps.get_vars.outputs.version }}
86 |             BUILD_TIME=${{ steps.get_vars.outputs.build_time }}
87 |             DOC_PATH=${{ steps.get_vars.outputs.doc_path }}
88 | 
89 | 
90 |       - name: Run GoReleaser
91 |         uses: goreleaser/goreleaser-action@v5
92 |         with:
93 |           version: latest
94 |           args: release --clean
95 |         env:
96 |           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
97 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
 1 | .idea/
 2 | *.iml
 3 | *.test
 4 | /*.pem
 5 | bin/
 6 | dist/
 7 | docs/docs.go
 8 | site/
 9 | config.yml
10 | todo.txt
11 | !docs/config.yml
12 | node_modules
13 | package-lock.json
14 | vendor/
15 | coverage.html
16 | coverage.txt
17 | coverage/
18 | blocky
19 | 


--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
  1 | linters:
  2 |   enable:
  3 |     - asciicheck
  4 |     - bidichk
  5 |     - bodyclose
  6 |     - dogsled
  7 |     - dupl
  8 |     - durationcheck
  9 |     - errcheck
 10 |     - errchkjson
 11 |     - errorlint
 12 |     - exhaustive
 13 |     - funlen
 14 |     - gochecknoglobals
 15 |     - gochecknoinits
 16 |     - gocognit
 17 |     - goconst
 18 |     - gocritic
 19 |     - gocyclo
 20 |     - godox
 21 |     - gofmt
 22 |     - goimports
 23 |     - mnd
 24 |     - gomodguard
 25 |     - gosimple
 26 |     - govet
 27 |     - grouper
 28 |     - importas
 29 |     - ineffassign
 30 |     - lll
 31 |     - makezero
 32 |     - misspell
 33 |     - nakedret
 34 |     - nestif
 35 |     - nilerr
 36 |     - nilnil
 37 |     - nlreturn
 38 |     - nolintlint
 39 |     - nosprintfhostport
 40 |     - prealloc
 41 |     - predeclared
 42 |     - revive
 43 |     - sqlclosecheck
 44 |     - staticcheck
 45 |     - stylecheck
 46 |     - typecheck
 47 |     - unconvert
 48 |     - unparam
 49 |     - unused
 50 |     - usestdlibvars
 51 |     - wastedassign
 52 |     - whitespace
 53 |     - ginkgolinter
 54 |     - noctx
 55 |     - containedctx
 56 |     - contextcheck
 57 |   disable:
 58 |     - forbidigo
 59 |     - gosmopolitan
 60 |     - gosec
 61 |     - recvcheck
 62 |   disable-all: false
 63 |   presets:
 64 |     - bugs
 65 |     - unused
 66 |   fast: false
 67 | 
 68 | linters-settings:
 69 |   mnd:
 70 |     ignored-numbers:
 71 |       - "0666"
 72 |       - "2"
 73 |       - "5"
 74 |   ginkgolinter:
 75 |     forbid-focus-container: true
 76 |   stylecheck:
 77 |     # Whietlist dot imports for test packages.
 78 |     dot-import-whitelist:
 79 |       - "github.com/onsi/ginkgo/v2"
 80 |       - "github.com/onsi/gomega"
 81 |       - "github.com/0xERR0R/blocky/config/migration"
 82 |       - "github.com/0xERR0R/blocky/helpertest"
 83 |   revive:
 84 |     rules:
 85 |       - name: dot-imports
 86 |         disabled: true # prefer stylecheck since it's more configurable
 87 | 
 88 | issues:
 89 |   exclude-rules:
 90 |     # Exclude some linters from running on tests files.
 91 |     - path: _test\.go
 92 |       linters:
 93 |         - dupl
 94 |         - funlen
 95 |         - gochecknoinits
 96 |         - gochecknoglobals
 97 |         - gosec
 98 |     - path: _test\.go
 99 |       linters:
100 |         - staticcheck
101 |       text: "SA1012:"
102 | 


--------------------------------------------------------------------------------
/.goreleaser.yml:
--------------------------------------------------------------------------------
 1 | project_name: blocky
 2 | 
 3 | before:
 4 |   hooks:
 5 |     - go mod tidy
 6 | builds:
 7 |   - goos:
 8 |       - linux
 9 |       - windows
10 |       - freebsd
11 |       - netbsd
12 |       - openbsd
13 |       - darwin
14 |     goarch:
15 |       - amd64
16 |       - arm
17 |       - arm64
18 |     goarm:
19 |       - 6
20 |       - 7
21 |     ignore:
22 |       - goos: windows
23 |         goarch: arm
24 |       - goos: windows
25 |         goarch: arm64
26 |     ldflags:
27 |       - -w
28 |       - -s
29 |       - -X github.com/0xERR0R/blocky/util.Version=v{{.Version}}
30 |       - -X github.com/0xERR0R/blocky/util.BuildTime={{time "20060102-150405"}}
31 |       - -X github.com/0xERR0R/blocky/util.Architecture={{.Arch}}{{.Arm}}
32 |     env:
33 |       - CGO_ENABLED=0
34 | release:
35 |   draft: true
36 | archives:
37 |   - format_overrides:
38 |       - goos: windows
39 |         format: zip
40 |     name_template: >-
41 |       {{ .ProjectName }}_v
42 |       {{- .Version }}_
43 |       {{- title .Os }}_
44 |       {{- if eq .Arch "amd64" }}x86_64
45 |       {{- else if eq .Arch "386" }}i386
46 |       {{- else }}{{ .Arch }}{{ end }}
47 |       {{- if .Arm }}v{{ .Arm }}{{ end }}
48 | 
49 | snapshot:
50 |   name_template: "{{ .Version }}-{{.ShortCommit}}"
51 | checksum:
52 |   name_template: "{{ .ProjectName }}_checksums.txt"
53 | changelog:
54 |   use: github
55 |   sort: asc
56 |   filters:
57 |     exclude:
58 |       - '^docs:'
59 |       - '^chore:'
60 |       - '^test:'
61 | 


--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "version": "0.2.0",
 3 |   "configurations": [
 4 |     {
 5 |       "name": "Launch Package",
 6 |       "type": "go",
 7 |       "request": "launch",
 8 |       "mode": "auto",
 9 |       "program": "${fileDirname}"
10 |     }
11 |   ]
12 | }
13 | 


--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "editor.tabSize": 2,
 3 |   "editor.insertSpaces": true,
 4 |   "editor.detectIndentation": false,
 5 |   "editor.formatOnSave": true,
 6 |   "editor.formatOnPaste": true,
 7 |   "editor.codeActionsOnSave": {
 8 |     "source.organizeImports": "explicit",
 9 |     "source.fixAll": "explicit"
10 |   },
11 |   "editor.rulers": [120],
12 |   "go.showWelcome": false,
13 |   "go.survey.prompt": false,
14 |   "go.useLanguageServer": true,
15 |   "go.formatTool": "gofumpt",
16 |   "go.lintTool": "golangci-lint",
17 |   "go.lintOnSave": "workspace",
18 |   "gopls": {
19 |     "ui.semanticTokens": true,
20 |     "formatting.gofumpt": true,
21 |     "build.standaloneTags": ["ignore", "tools"]
22 |   }
23 | }
24 | 


--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "version": "2.0.0",
 3 |   "tasks": [
 4 |     {
 5 |       "label": "Test",
 6 |       "type": "shell",
 7 |       "command": "make test",
 8 |       "group": {
 9 |         "kind": "test",
10 |         "isDefault": true
11 |       },
12 |       "presentation": {
13 |         "reveal": "always"
14 |       }
15 |     },
16 |     {
17 |       "label": "Race",
18 |       "type": "shell",
19 |       "command": "make race",
20 |       "group": {
21 |         "kind": "test",
22 |         "isDefault": false
23 |       },
24 |       "presentation": {
25 |         "reveal": "always"
26 |       }
27 |     },
28 |     {
29 |       "label": "e2e - Test",
30 |       "type": "shell",
31 |       "command": "make e2e-test",
32 |       "group": {
33 |         "kind": "test",
34 |         "isDefault": false
35 |       },
36 |       "presentation": {
37 |         "reveal": "always"
38 |       }
39 |     },
40 |     {
41 |       "label": "Build",
42 |       "type": "shell",
43 |       "command": "make build",
44 |       "group": {
45 |         "group": {
46 |           "kind": "build",
47 |           "isDefault": true
48 |         },
49 |         "isDefault": true
50 |       },
51 |       "presentation": {
52 |         "reveal": "always"
53 |       }
54 |     },
55 |     {
56 |       "label": "FMT",
57 |       "type": "shell",
58 |       "command": "make fmt",
59 |       "group": "none",
60 |       "presentation": {
61 |         "reveal": "always"
62 |       }
63 |     },
64 |     {
65 |       "label": "Tidy",
66 |       "type": "shell",
67 |       "command": "go mod tidy",
68 |       "group": "none",
69 |       "presentation": {
70 |         "reveal": "always"
71 |       }
72 |     }
73 |   ]
74 | }
75 | 


--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
 1 | # Contributor Covenant Code of Conduct
 2 | 
 3 | ## Our Pledge
 4 | 
 5 | In the interest of fostering an open and welcoming environment, we as
 6 | contributors and maintainers pledge to making participation in our project and
 7 | our community a harassment-free experience for everyone, regardless of age, body
 8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
 9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 | 
12 | ## Our Standards
13 | 
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 | 
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 | 
23 | Examples of unacceptable behavior by participants include:
24 | 
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 |  advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 |  address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 |  professional setting
33 | 
34 | ## Our Responsibilities
35 | 
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 | 
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 | 
46 | ## Scope
47 | 
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 | 
55 | ## Enforcement
56 | 
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by creating an issue. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 | 
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 | 
68 | ## Attribution
69 | 
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 | 
73 | [homepage]: https://www.contributor-covenant.org
74 | 
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 | 


--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
 1 | # syntax=docker/dockerfile:1
 2 | 
 3 | # ----------- stage: build
 4 | FROM golang:alpine AS build
 5 | RUN apk add --no-cache make coreutils libcap
 6 | 
 7 | # required arguments
 8 | ARG VERSION
 9 | ARG BUILD_TIME
10 | 
11 | COPY . .
12 | # setup go
13 | ENV GO_SKIP_GENERATE=1\
14 |   GO_BUILD_FLAGS="-tags static -v " \
15 |   BIN_USER=100\
16 |   BIN_AUTOCAB=1 \
17 |   BIN_OUT_DIR="/bin"
18 | 
19 | RUN make build
20 | 
21 | # ----------- stage: final
22 | FROM scratch
23 | 
24 | ARG VERSION
25 | ARG BUILD_TIME
26 | ARG DOC_PATH
27 | 
28 | LABEL org.opencontainers.image.title="blocky" \
29 |   org.opencontainers.image.vendor="0xERR0R" \
30 |   org.opencontainers.image.licenses="Apache-2.0" \
31 |   org.opencontainers.image.version="${VERSION}" \
32 |   org.opencontainers.image.created="${BUILD_TIME}" \
33 |   org.opencontainers.image.description="Fast and lightweight DNS proxy as ad-blocker for local network with many features" \
34 |   org.opencontainers.image.url="https://github.com/0xERR0R/blocky#readme" \
35 |   org.opencontainers.image.source="https://github.com/0xERR0R/blocky" \
36 |   org.opencontainers.image.documentation="https://0xerr0r.github.io/blocky/${DOC_PATH}/"
37 | 
38 | 
39 | 
40 | USER 100
41 | WORKDIR /app
42 | 
43 | COPY --from=build /bin/blocky /app/blocky
44 | 
45 | ENV BLOCKY_CONFIG_FILE=/app/config.yml
46 | 
47 | ENTRYPOINT ["/app/blocky"]
48 | 
49 | HEALTHCHECK --start-period=1m --timeout=3s CMD ["/app/blocky", "healthcheck"]
50 | 


--------------------------------------------------------------------------------
/api/api_suite_test.go:
--------------------------------------------------------------------------------
 1 | package api_test
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | func init() {
12 | 	log.Silence()
13 | }
14 | 
15 | func TestResolver(t *testing.T) {
16 | 	RegisterFailHandler(Fail)
17 | 	RunSpecs(t, "API Suite")
18 | }
19 | 


--------------------------------------------------------------------------------
/api/api_types.gen.go:
--------------------------------------------------------------------------------
 1 | // Package api provides primitives to interact with the openapi HTTP API.
 2 | //
 3 | // Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.4.1 DO NOT EDIT.
 4 | package api
 5 | 
 6 | // ApiBlockingStatus defines model for api.BlockingStatus.
 7 | type ApiBlockingStatus struct {
 8 | 	// AutoEnableInSec If blocking is temporary disabled: amount of seconds until blocking will be enabled
 9 | 	AutoEnableInSec *int `json:"autoEnableInSec,omitempty"`
10 | 
11 | 	// DisabledGroups Disabled group names
12 | 	DisabledGroups *[]string `json:"disabledGroups,omitempty"`
13 | 
14 | 	// Enabled True if blocking is enabled
15 | 	Enabled bool `json:"enabled"`
16 | }
17 | 
18 | // ApiQueryRequest defines model for api.QueryRequest.
19 | type ApiQueryRequest struct {
20 | 	// Query query for DNS request
21 | 	Query string `json:"query"`
22 | 
23 | 	// Type request type (A, AAAA, ...)
24 | 	Type string `json:"type"`
25 | }
26 | 
27 | // ApiQueryResult defines model for api.QueryResult.
28 | type ApiQueryResult struct {
29 | 	// Reason blocky reason for resolution
30 | 	Reason string `json:"reason"`
31 | 
32 | 	// Response actual DNS response
33 | 	Response string `json:"response"`
34 | 
35 | 	// ResponseType response type (CACHED, BLOCKED, ...)
36 | 	ResponseType string `json:"responseType"`
37 | 
38 | 	// ReturnCode DNS return code (NOERROR, NXDOMAIN, ...)
39 | 	ReturnCode string `json:"returnCode"`
40 | }
41 | 
42 | // DisableBlockingParams defines parameters for DisableBlocking.
43 | type DisableBlockingParams struct {
44 | 	// Duration duration of blocking (Example: 300s, 5m, 1h, 5m30s)
45 | 	Duration *string `form:"duration,omitempty" json:"duration,omitempty"`
46 | 
47 | 	// Groups groups to disable (comma separated). If empty, disable all groups
48 | 	Groups *string `form:"groups,omitempty" json:"groups,omitempty"`
49 | }
50 | 
51 | // QueryJSONRequestBody defines body for Query for application/json ContentType.
52 | type QueryJSONRequestBody = ApiQueryRequest
53 | 


--------------------------------------------------------------------------------
/api/client.cfg.yaml:
--------------------------------------------------------------------------------
1 | package: api
2 | generate:
3 |   client: true
4 | output: api_client.gen.go


--------------------------------------------------------------------------------
/api/server.cfg.yaml:
--------------------------------------------------------------------------------
1 | package: api
2 | generate:
3 |   chi-server: true
4 |   strict-server: true
5 |   embedded-spec: false
6 | output: api_server.gen.go


--------------------------------------------------------------------------------
/api/types.cfg.yaml:
--------------------------------------------------------------------------------
1 | package: api
2 | generate:
3 |   models: true
4 | output: api_types.gen.go


--------------------------------------------------------------------------------
/cache/expirationcache/cache_interface.go:
--------------------------------------------------------------------------------
 1 | package expirationcache
 2 | 
 3 | import "time"
 4 | 
 5 | type ExpiringCache[T any] interface {
 6 | 	// Put adds the value to the cache unter the passed key with expiration. If expiration <=0, entry will NOT be cached
 7 | 	Put(key string, val *T, expiration time.Duration)
 8 | 
 9 | 	// Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil
10 | 	Get(key string) (val *T, expiration time.Duration)
11 | 
12 | 	// TotalCount returns the total count of valid (not expired) elements
13 | 	TotalCount() int
14 | 
15 | 	// Clear removes all cache entries
16 | 	Clear()
17 | }
18 | 


--------------------------------------------------------------------------------
/cache/expirationcache/expiration_cache_suite_test.go:
--------------------------------------------------------------------------------
 1 | package expirationcache_test
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestCache(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "Expiration cache suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/cache/stringcache/chained_grouped_cache.go:
--------------------------------------------------------------------------------
 1 | package stringcache
 2 | 
 3 | import (
 4 | 	"sort"
 5 | 
 6 | 	"golang.org/x/exp/maps"
 7 | )
 8 | 
 9 | type ChainedGroupedCache struct {
10 | 	caches []GroupedStringCache
11 | }
12 | 
13 | func NewChainedGroupedCache(caches ...GroupedStringCache) *ChainedGroupedCache {
14 | 	return &ChainedGroupedCache{
15 | 		caches: caches,
16 | 	}
17 | }
18 | 
19 | func (c *ChainedGroupedCache) ElementCount(group string) int {
20 | 	sum := 0
21 | 	for _, cache := range c.caches {
22 | 		sum += cache.ElementCount(group)
23 | 	}
24 | 
25 | 	return sum
26 | }
27 | 
28 | func (c *ChainedGroupedCache) Contains(searchString string, groups []string) []string {
29 | 	groupMatchedMap := make(map[string]struct{}, len(groups))
30 | 
31 | 	for _, cache := range c.caches {
32 | 		for _, group := range cache.Contains(searchString, groups) {
33 | 			groupMatchedMap[group] = struct{}{}
34 | 		}
35 | 	}
36 | 
37 | 	matchedGroups := maps.Keys(groupMatchedMap)
38 | 
39 | 	sort.Strings(matchedGroups)
40 | 
41 | 	return matchedGroups
42 | }
43 | 
44 | func (c *ChainedGroupedCache) Refresh(group string) GroupFactory {
45 | 	cacheFactories := make([]GroupFactory, len(c.caches))
46 | 	for i, cache := range c.caches {
47 | 		cacheFactories[i] = cache.Refresh(group)
48 | 	}
49 | 
50 | 	return &chainedGroupFactory{
51 | 		cacheFactories: cacheFactories,
52 | 	}
53 | }
54 | 
55 | type chainedGroupFactory struct {
56 | 	cacheFactories []GroupFactory
57 | }
58 | 
59 | func (c *chainedGroupFactory) AddEntry(entry string) bool {
60 | 	for _, factory := range c.cacheFactories {
61 | 		if factory.AddEntry(entry) {
62 | 			return true
63 | 		}
64 | 	}
65 | 
66 | 	return false
67 | }
68 | 
69 | func (c *chainedGroupFactory) Count() int {
70 | 	var cnt int
71 | 	for _, factory := range c.cacheFactories {
72 | 		cnt += factory.Count()
73 | 	}
74 | 
75 | 	return cnt
76 | }
77 | 
78 | func (c *chainedGroupFactory) Finish() {
79 | 	for _, factory := range c.cacheFactories {
80 | 		factory.Finish()
81 | 	}
82 | }
83 | 


--------------------------------------------------------------------------------
/cache/stringcache/grouped_cache_interface.go:
--------------------------------------------------------------------------------
 1 | package stringcache
 2 | 
 3 | type GroupedStringCache interface {
 4 | 	// Contains checks if one or more groups in the cache contains the search string.
 5 | 	// Returns group(s) containing the string or empty slice if string was not found
 6 | 	Contains(searchString string, groups []string) []string
 7 | 
 8 | 	// Refresh creates new factory for the group to be refreshed.
 9 | 	// Calling Finish on the factory will perform the group refresh.
10 | 	Refresh(group string) GroupFactory
11 | 
12 | 	// ElementCount returns the amount of elements in the group
13 | 	ElementCount(group string) int
14 | }
15 | 
16 | type GroupFactory interface {
17 | 	// AddEntry adds a new string to the factory to be added later to the cache groups.
18 | 	AddEntry(entry string) bool
19 | 
20 | 	// Count returns amount of processed string in the factory
21 | 	Count() int
22 | 
23 | 	// Finish replaces the group in cache with factory's content
24 | 	Finish()
25 | }
26 | 


--------------------------------------------------------------------------------
/cache/stringcache/in_memory_grouped_cache.go:
--------------------------------------------------------------------------------
 1 | package stringcache
 2 | 
 3 | import "sync"
 4 | 
 5 | type stringCacheFactoryFn func() cacheFactory
 6 | 
 7 | type InMemoryGroupedCache struct {
 8 | 	caches    map[string]stringCache
 9 | 	lock      sync.RWMutex
10 | 	factoryFn stringCacheFactoryFn
11 | }
12 | 
13 | func NewInMemoryGroupedStringCache() *InMemoryGroupedCache {
14 | 	return &InMemoryGroupedCache{
15 | 		caches:    make(map[string]stringCache),
16 | 		factoryFn: newStringCacheFactory,
17 | 	}
18 | }
19 | 
20 | func NewInMemoryGroupedRegexCache() *InMemoryGroupedCache {
21 | 	return &InMemoryGroupedCache{
22 | 		caches:    make(map[string]stringCache),
23 | 		factoryFn: newRegexCacheFactory,
24 | 	}
25 | }
26 | 
27 | func NewInMemoryGroupedWildcardCache() *InMemoryGroupedCache {
28 | 	return &InMemoryGroupedCache{
29 | 		caches:    make(map[string]stringCache),
30 | 		factoryFn: newWildcardCacheFactory,
31 | 	}
32 | }
33 | 
34 | func (c *InMemoryGroupedCache) ElementCount(group string) int {
35 | 	c.lock.RLock()
36 | 	cache, found := c.caches[group]
37 | 	c.lock.RUnlock()
38 | 
39 | 	if !found {
40 | 		return 0
41 | 	}
42 | 
43 | 	return cache.elementCount()
44 | }
45 | 
46 | func (c *InMemoryGroupedCache) Contains(searchString string, groups []string) []string {
47 | 	var result []string
48 | 
49 | 	for _, group := range groups {
50 | 		c.lock.RLock()
51 | 		cache, found := c.caches[group]
52 | 		c.lock.RUnlock()
53 | 
54 | 		if found && cache.contains(searchString) {
55 | 			result = append(result, group)
56 | 		}
57 | 	}
58 | 
59 | 	return result
60 | }
61 | 
62 | func (c *InMemoryGroupedCache) Refresh(group string) GroupFactory {
63 | 	return &inMemoryGroupFactory{
64 | 		factory: c.factoryFn(),
65 | 		finishFn: func(sc stringCache) {
66 | 			c.lock.Lock()
67 | 			defer c.lock.Unlock()
68 | 
69 | 			if sc != nil {
70 | 				c.caches[group] = sc
71 | 			} else {
72 | 				delete(c.caches, group)
73 | 			}
74 | 		},
75 | 	}
76 | }
77 | 
78 | type inMemoryGroupFactory struct {
79 | 	factory  cacheFactory
80 | 	finishFn func(stringCache)
81 | }
82 | 
83 | func (c *inMemoryGroupFactory) AddEntry(entry string) bool {
84 | 	return c.factory.addEntry(entry)
85 | }
86 | 
87 | func (c *inMemoryGroupFactory) Count() int {
88 | 	return c.factory.count()
89 | }
90 | 
91 | func (c *inMemoryGroupFactory) Finish() {
92 | 	sc := c.factory.create()
93 | 	c.finishFn(sc)
94 | }
95 | 


--------------------------------------------------------------------------------
/cache/stringcache/string_cache_suite_test.go:
--------------------------------------------------------------------------------
 1 | package stringcache_test
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestCache(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "String cache suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/cmd/cache.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"fmt"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/api"
 8 | 	"github.com/spf13/cobra"
 9 | )
10 | 
11 | func newCacheCommand() *cobra.Command {
12 | 	c := &cobra.Command{
13 | 		Use:               "cache",
14 | 		Short:             "Performs cache operations",
15 | 		PersistentPreRunE: initConfigPreRun,
16 | 	}
17 | 	c.AddCommand(&cobra.Command{
18 | 		Use:     "flush",
19 | 		Args:    cobra.NoArgs,
20 | 		Aliases: []string{"clear"},
21 | 		Short:   "Flush cache",
22 | 		RunE:    flushCache,
23 | 	})
24 | 
25 | 	return c
26 | }
27 | 
28 | func flushCache(_ *cobra.Command, _ []string) error {
29 | 	client, err := api.NewClientWithResponses(apiURL())
30 | 	if err != nil {
31 | 		return fmt.Errorf("can't create client: %w", err)
32 | 	}
33 | 
34 | 	resp, err := client.CacheFlushWithResponse(context.Background())
35 | 	if err != nil {
36 | 		return fmt.Errorf("can't execute %w", err)
37 | 	}
38 | 
39 | 	return printOkOrError(resp, string(resp.Body))
40 | }
41 | 


--------------------------------------------------------------------------------
/cmd/cache_test.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"net/http"
 5 | 	"net/http/httptest"
 6 | 
 7 | 	"github.com/sirupsen/logrus/hooks/test"
 8 | 
 9 | 	"github.com/0xERR0R/blocky/log"
10 | 
11 | 	. "github.com/onsi/ginkgo/v2"
12 | 	. "github.com/onsi/gomega"
13 | )
14 | 
15 | var _ = Describe("Cache command", func() {
16 | 	var (
17 | 		ts         *httptest.Server
18 | 		mockFn     func(w http.ResponseWriter, _ *http.Request)
19 | 		loggerHook *test.Hook
20 | 	)
21 | 	JustBeforeEach(func() {
22 | 		ts = testHTTPAPIServer(mockFn)
23 | 	})
24 | 	JustAfterEach(func() {
25 | 		ts.Close()
26 | 	})
27 | 	BeforeEach(func() {
28 | 		mockFn = func(w http.ResponseWriter, _ *http.Request) {}
29 | 		loggerHook = test.NewGlobal()
30 | 		log.Log().AddHook(loggerHook)
31 | 	})
32 | 	AfterEach(func() {
33 | 		loggerHook.Reset()
34 | 	})
35 | 	Describe("flush cache", func() {
36 | 		When("flush cache is called via REST", func() {
37 | 			It("should flush caches", func() {
38 | 				Expect(flushCache(newCacheCommand(), []string{})).Should(Succeed())
39 | 				Expect(loggerHook.LastEntry().Message).Should(Equal("OK"))
40 | 			})
41 | 		})
42 | 		When("Wrong url is used", func() {
43 | 			It("Should end with error", func() {
44 | 				apiPort = 0
45 | 				err := flushCache(newCacheCommand(), []string{})
46 | 				Expect(err).Should(HaveOccurred())
47 | 				Expect(err.Error()).Should(ContainSubstring("connection refused"))
48 | 			})
49 | 		})
50 | 	})
51 | })
52 | 


--------------------------------------------------------------------------------
/cmd/cmd_suite_test.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestCmd(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "Command Suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/cmd/healthcheck.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"net"
 6 | 
 7 | 	"github.com/miekg/dns"
 8 | 	"github.com/spf13/cobra"
 9 | )
10 | 
11 | const (
12 | 	defaultDNSPort   = 53
13 | 	defaultIPAddress = "127.0.0.1"
14 | )
15 | 
16 | func NewHealthcheckCommand() *cobra.Command {
17 | 	c := &cobra.Command{
18 | 		Use:   "healthcheck",
19 | 		Short: "performs healthcheck",
20 | 		RunE:  healthcheck,
21 | 	}
22 | 
23 | 	c.Flags().Uint16P("port", "p", defaultDNSPort, "blocky port")
24 | 	c.Flags().StringP("bindip", "b", defaultIPAddress, "blocky host binding ip address")
25 | 
26 | 	return c
27 | }
28 | 
29 | func healthcheck(cmd *cobra.Command, args []string) error {
30 | 	_ = args
31 | 	port, _ := cmd.Flags().GetUint16("port")
32 | 	bindIP, _ := cmd.Flags().GetString("bindip")
33 | 
34 | 	c := new(dns.Client)
35 | 	c.Net = "tcp"
36 | 	m := new(dns.Msg)
37 | 	m.SetQuestion("healthcheck.blocky.", dns.TypeA)
38 | 
39 | 	_, _, err := c.Exchange(m, net.JoinHostPort(bindIP, fmt.Sprintf("%d", port)))
40 | 
41 | 	if err == nil {
42 | 		fmt.Println("OK")
43 | 	} else {
44 | 		fmt.Println("NOT OK")
45 | 	}
46 | 
47 | 	return err
48 | }
49 | 


--------------------------------------------------------------------------------
/cmd/healthcheck_test.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/helpertest"
 7 | 	"github.com/miekg/dns"
 8 | 
 9 | 	. "github.com/onsi/ginkgo/v2"
10 | 	. "github.com/onsi/gomega"
11 | )
12 | 
13 | var _ = Describe("Healthcheck command", func() {
14 | 	Describe("Call healthcheck command", func() {
15 | 		It("should fail", func() {
16 | 			c := NewHealthcheckCommand()
17 | 			c.SetArgs([]string{"-p", "533"})
18 | 
19 | 			err := c.Execute()
20 | 
21 | 			Expect(err).Should(HaveOccurred())
22 | 		})
23 | 
24 | 		It("should fail", func() {
25 | 			c := NewHealthcheckCommand()
26 | 			c.SetArgs([]string{"-b", "127.0.2.9"})
27 | 
28 | 			err := c.Execute()
29 | 
30 | 			Expect(err).Should(HaveOccurred())
31 | 		})
32 | 
33 | 		It("should succeed", func() {
34 | 			ip := "127.0.0.1"
35 | 			hostPort := helpertest.GetHostPort(ip, 65100)
36 | 			port := helpertest.GetStringPort(65100)
37 | 			srv := createMockServer(hostPort)
38 | 			go func() {
39 | 				defer GinkgoRecover()
40 | 				err := srv.ListenAndServe()
41 | 				Expect(err).Should(Succeed())
42 | 			}()
43 | 
44 | 			Eventually(func() error {
45 | 				c := NewHealthcheckCommand()
46 | 				c.SetArgs([]string{"-p", port, "-b", ip})
47 | 
48 | 				return c.Execute()
49 | 			}, "1s").Should(Succeed())
50 | 		})
51 | 	})
52 | })
53 | 
54 | func createMockServer(hostPort string) *dns.Server {
55 | 	res := &dns.Server{
56 | 		Addr:    hostPort,
57 | 		Net:     "tcp",
58 | 		Handler: dns.NewServeMux(),
59 | 		NotifyStartedFunc: func() {
60 | 			fmt.Printf("Mock healthcheck server is up: %s\n", hostPort)
61 | 		},
62 | 	}
63 | 
64 | 	th := res.Handler.(*dns.ServeMux)
65 | 	th.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, request *dns.Msg) {
66 | 		resp := new(dns.Msg)
67 | 		resp.SetReply(request)
68 | 		resp.Rcode = dns.RcodeSuccess
69 | 
70 | 		err := w.WriteMsg(resp)
71 | 		Expect(err).Should(Succeed())
72 | 	})
73 | 
74 | 	DeferCleanup(res.Shutdown)
75 | 
76 | 	return res
77 | }
78 | 


--------------------------------------------------------------------------------
/cmd/lists.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"fmt"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/api"
 8 | 	"github.com/spf13/cobra"
 9 | )
10 | 
11 | // NewListsCommand creates new command instance
12 | func NewListsCommand() *cobra.Command {
13 | 	c := &cobra.Command{
14 | 		Use:               "lists",
15 | 		Short:             "lists operations",
16 | 		PersistentPreRunE: initConfigPreRun,
17 | 	}
18 | 
19 | 	c.AddCommand(newRefreshCommand())
20 | 
21 | 	return c
22 | }
23 | 
24 | func newRefreshCommand() *cobra.Command {
25 | 	return &cobra.Command{
26 | 		Use:   "refresh",
27 | 		Short: "refreshes all lists",
28 | 		RunE:  refreshList,
29 | 	}
30 | }
31 | 
32 | func refreshList(_ *cobra.Command, _ []string) error {
33 | 	client, err := api.NewClientWithResponses(apiURL())
34 | 	if err != nil {
35 | 		return fmt.Errorf("can't create client: %w", err)
36 | 	}
37 | 
38 | 	resp, err := client.ListRefreshWithResponse(context.Background())
39 | 	if err != nil {
40 | 		return fmt.Errorf("can't execute %w", err)
41 | 	}
42 | 
43 | 	return printOkOrError(resp, string(resp.Body))
44 | }
45 | 


--------------------------------------------------------------------------------
/cmd/lists_test.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"net/http"
 5 | 	"net/http/httptest"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 	"github.com/sirupsen/logrus/hooks/test"
 9 | 	"github.com/spf13/cobra"
10 | 
11 | 	. "github.com/onsi/ginkgo/v2"
12 | 	. "github.com/onsi/gomega"
13 | )
14 | 
15 | var _ = Describe("Lists command", func() {
16 | 	var (
17 | 		ts         *httptest.Server
18 | 		mockFn     func(w http.ResponseWriter, _ *http.Request)
19 | 		loggerHook *test.Hook
20 | 		c          *cobra.Command
21 | 		err        error
22 | 	)
23 | 	JustBeforeEach(func() {
24 | 		ts = testHTTPAPIServer(mockFn)
25 | 	})
26 | 	JustAfterEach(func() {
27 | 		ts.Close()
28 | 	})
29 | 	BeforeEach(func() {
30 | 		mockFn = func(w http.ResponseWriter, _ *http.Request) {}
31 | 		loggerHook = test.NewGlobal()
32 | 		log.Log().AddHook(loggerHook)
33 | 	})
34 | 	AfterEach(func() {
35 | 		loggerHook.Reset()
36 | 	})
37 | 	Describe("Call list refresh command", func() {
38 | 		When("list refresh is executed", func() {
39 | 			BeforeEach(func() {
40 | 				c = NewListsCommand()
41 | 				c.SetArgs([]string{"refresh"})
42 | 			})
43 | 			It("should print result", func() {
44 | 				err = c.Execute()
45 | 				Expect(err).Should(Succeed())
46 | 
47 | 				Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("OK"))
48 | 			})
49 | 		})
50 | 		When("Server returns 500", func() {
51 | 			BeforeEach(func() {
52 | 				c = newRefreshCommand()
53 | 				c.SetArgs(make([]string, 0))
54 | 				mockFn = func(w http.ResponseWriter, _ *http.Request) {
55 | 					w.WriteHeader(http.StatusInternalServerError)
56 | 				}
57 | 			})
58 | 			It("should end with error", func() {
59 | 				err = c.Execute()
60 | 				Expect(err).Should(HaveOccurred())
61 | 				Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error"))
62 | 			})
63 | 		})
64 | 		When("Url is wrong", func() {
65 | 			BeforeEach(func() {
66 | 				c = newRefreshCommand()
67 | 				c.SetArgs(make([]string, 0))
68 | 			})
69 | 			It("should end with error", func() {
70 | 				apiPort = 0
71 | 				err = c.Execute()
72 | 				Expect(err).Should(HaveOccurred())
73 | 				Expect(err.Error()).Should(ContainSubstring("connection refused"))
74 | 			})
75 | 		})
76 | 	})
77 | })
78 | 


--------------------------------------------------------------------------------
/cmd/query.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"fmt"
 6 | 	"net/http"
 7 | 
 8 | 	"github.com/0xERR0R/blocky/api"
 9 | 	"github.com/0xERR0R/blocky/log"
10 | 	"github.com/miekg/dns"
11 | 	"github.com/spf13/cobra"
12 | )
13 | 
14 | // NewQueryCommand creates new command instance
15 | func NewQueryCommand() *cobra.Command {
16 | 	c := &cobra.Command{
17 | 		Use:               "query <domain>",
18 | 		Args:              cobra.ExactArgs(1),
19 | 		Short:             "performs DNS query",
20 | 		RunE:              query,
21 | 		PersistentPreRunE: initConfigPreRun,
22 | 	}
23 | 
24 | 	c.Flags().StringP("type", "t", "A", "query type (A, AAAA, ...)")
25 | 
26 | 	return c
27 | }
28 | 
29 | func query(cmd *cobra.Command, args []string) error {
30 | 	typeFlag, _ := cmd.Flags().GetString("type")
31 | 	qType := dns.StringToType[typeFlag]
32 | 
33 | 	if qType == dns.TypeNone {
34 | 		return fmt.Errorf("unknown query type '%s'", typeFlag)
35 | 	}
36 | 
37 | 	client, err := api.NewClientWithResponses(apiURL())
38 | 	if err != nil {
39 | 		return fmt.Errorf("can't create client: %w", err)
40 | 	}
41 | 
42 | 	req := api.ApiQueryRequest{
43 | 		Query: args[0],
44 | 		Type:  typeFlag,
45 | 	}
46 | 
47 | 	resp, err := client.QueryWithResponse(context.Background(), req)
48 | 	if err != nil {
49 | 		return fmt.Errorf("can't execute %w", err)
50 | 	}
51 | 
52 | 	if resp.StatusCode() != http.StatusOK {
53 | 		return fmt.Errorf("response NOK, %s %s", resp.Status(), string(resp.Body))
54 | 	}
55 | 
56 | 	log.Log().Infof("Query result for '%s' (%s):", req.Query, req.Type)
57 | 	log.Log().Infof("\treason:        %20s", resp.JSON200.Reason)
58 | 	log.Log().Infof("\tresponse type: %20s", resp.JSON200.ResponseType)
59 | 	log.Log().Infof("\tresponse:      %20s", resp.JSON200.Response)
60 | 	log.Log().Infof("\treturn code:   %20s", resp.JSON200.ReturnCode)
61 | 
62 | 	return nil
63 | }
64 | 


--------------------------------------------------------------------------------
/cmd/query_test.go:
--------------------------------------------------------------------------------
  1 | package cmd
  2 | 
  3 | import (
  4 | 	"encoding/json"
  5 | 	"net/http"
  6 | 	"net/http/httptest"
  7 | 
  8 | 	"github.com/0xERR0R/blocky/log"
  9 | 	"github.com/sirupsen/logrus/hooks/test"
 10 | 
 11 | 	"github.com/0xERR0R/blocky/api"
 12 | 
 13 | 	. "github.com/onsi/ginkgo/v2"
 14 | 	. "github.com/onsi/gomega"
 15 | )
 16 | 
 17 | var _ = Describe("Blocking command", func() {
 18 | 	var (
 19 | 		ts         *httptest.Server
 20 | 		mockFn     func(w http.ResponseWriter, _ *http.Request)
 21 | 		loggerHook *test.Hook
 22 | 	)
 23 | 	JustBeforeEach(func() {
 24 | 		ts = testHTTPAPIServer(mockFn)
 25 | 	})
 26 | 	JustAfterEach(func() {
 27 | 		ts.Close()
 28 | 	})
 29 | 	BeforeEach(func() {
 30 | 		mockFn = func(w http.ResponseWriter, _ *http.Request) {}
 31 | 		loggerHook = test.NewGlobal()
 32 | 		log.Log().AddHook(loggerHook)
 33 | 	})
 34 | 	AfterEach(func() {
 35 | 		loggerHook.Reset()
 36 | 	})
 37 | 	Describe("Call query command", func() {
 38 | 		BeforeEach(func() {
 39 | 			mockFn = func(w http.ResponseWriter, _ *http.Request) {
 40 | 				response, err := json.Marshal(api.ApiQueryResult{
 41 | 					Reason:       "Reason",
 42 | 					ResponseType: "Type",
 43 | 					Response:     "Response",
 44 | 					ReturnCode:   "NOERROR",
 45 | 				})
 46 | 				Expect(err).Should(Succeed())
 47 | 
 48 | 				_, err = w.Write(response)
 49 | 				Expect(err).Should(Succeed())
 50 | 			}
 51 | 		})
 52 | 		When("query command is called via REST", func() {
 53 | 			BeforeEach(func() {
 54 | 				mockFn = func(w http.ResponseWriter, _ *http.Request) {
 55 | 					w.Header().Add("Content-Type", "application/json")
 56 | 					response, err := json.Marshal(api.ApiQueryResult{
 57 | 						Reason:       "Reason",
 58 | 						ResponseType: "Type",
 59 | 						Response:     "Response",
 60 | 						ReturnCode:   "NOERROR",
 61 | 					})
 62 | 					Expect(err).Should(Succeed())
 63 | 
 64 | 					_, err = w.Write(response)
 65 | 					Expect(err).Should(Succeed())
 66 | 				}
 67 | 			})
 68 | 			It("should print result", func() {
 69 | 				Expect(query(NewQueryCommand(), []string{"google.de"})).Should(Succeed())
 70 | 				Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("NOERROR"))
 71 | 			})
 72 | 		})
 73 | 		When("Server returns 500", func() {
 74 | 			BeforeEach(func() {
 75 | 				mockFn = func(w http.ResponseWriter, _ *http.Request) {
 76 | 					w.WriteHeader(http.StatusInternalServerError)
 77 | 				}
 78 | 			})
 79 | 			It("should end with error", func() {
 80 | 				err := query(NewQueryCommand(), []string{"google.de"})
 81 | 				Expect(err).Should(HaveOccurred())
 82 | 				Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error"))
 83 | 			})
 84 | 		})
 85 | 		When("Type is wrong", func() {
 86 | 			It("should end with error", func() {
 87 | 				command := NewQueryCommand()
 88 | 				command.SetArgs([]string{"--type", "X", "google.de"})
 89 | 				err := command.Execute()
 90 | 				Expect(err).Should(HaveOccurred())
 91 | 				Expect(err.Error()).Should(ContainSubstring("unknown query type 'X'"))
 92 | 			})
 93 | 		})
 94 | 		When("Url is wrong", func() {
 95 | 			It("should end with error", func() {
 96 | 				apiPort = 0
 97 | 				err := query(NewQueryCommand(), []string{"google.de"})
 98 | 				Expect(err).Should(HaveOccurred())
 99 | 				Expect(err.Error()).Should(ContainSubstring("connection refused"))
100 | 			})
101 | 		})
102 | 	})
103 | })
104 | 


--------------------------------------------------------------------------------
/cmd/serve.go:
--------------------------------------------------------------------------------
  1 | package cmd
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"fmt"
  6 | 	"os"
  7 | 	"os/signal"
  8 | 	"syscall"
  9 | 
 10 | 	"github.com/0xERR0R/blocky/config"
 11 | 	"github.com/0xERR0R/blocky/evt"
 12 | 	"github.com/0xERR0R/blocky/log"
 13 | 	"github.com/0xERR0R/blocky/server"
 14 | 	"github.com/0xERR0R/blocky/util"
 15 | 
 16 | 	"github.com/spf13/cobra"
 17 | )
 18 | 
 19 | //nolint:gochecknoglobals
 20 | var (
 21 | 	done              = make(chan bool, 1)
 22 | 	isConfigMandatory = true
 23 | 	signals           = make(chan os.Signal, 1)
 24 | )
 25 | 
 26 | func newServeCommand() *cobra.Command {
 27 | 	return &cobra.Command{
 28 | 		Use:               "serve",
 29 | 		Args:              cobra.NoArgs,
 30 | 		Short:             "start blocky DNS server (default command)",
 31 | 		RunE:              startServer,
 32 | 		PersistentPreRunE: initConfigPreRun,
 33 | 		SilenceUsage:      true,
 34 | 	}
 35 | }
 36 | 
 37 | func startServer(_ *cobra.Command, _ []string) error {
 38 | 	printBanner()
 39 | 
 40 | 	cfg, err := config.LoadConfig(configPath, isConfigMandatory)
 41 | 	if err != nil {
 42 | 		return fmt.Errorf("unable to load configuration: %w", err)
 43 | 	}
 44 | 
 45 | 	log.Configure(&cfg.Log)
 46 | 
 47 | 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 48 | 
 49 | 	ctx, cancelFn := context.WithCancel(context.Background())
 50 | 	defer cancelFn()
 51 | 
 52 | 	srv, err := server.NewServer(ctx, cfg)
 53 | 	if err != nil {
 54 | 		return fmt.Errorf("can't start server: %w", err)
 55 | 	}
 56 | 
 57 | 	const errChanSize = 10
 58 | 	errChan := make(chan error, errChanSize)
 59 | 
 60 | 	srv.Start(ctx, errChan)
 61 | 
 62 | 	var terminationErr error
 63 | 
 64 | 	go func() {
 65 | 		select {
 66 | 		case <-signals:
 67 | 			log.Log().Infof("Terminating...")
 68 | 			util.LogOnError(ctx, "can't stop server: ", srv.Stop(ctx))
 69 | 			done <- true
 70 | 
 71 | 		case err := <-errChan:
 72 | 			log.Log().Error("server start failed: ", err)
 73 | 			terminationErr = err
 74 | 			done <- true
 75 | 		}
 76 | 	}()
 77 | 
 78 | 	evt.Bus().Publish(evt.ApplicationStarted, util.Version, util.BuildTime)
 79 | 	<-done
 80 | 
 81 | 	return terminationErr
 82 | }
 83 | 
 84 | func printBanner() {
 85 | 	log.Log().Info("_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/")
 86 | 	log.Log().Info("_/                                                              _/")
 87 | 	log.Log().Info("_/                                                              _/")
 88 | 	log.Log().Info("_/       _/        _/                      _/                   _/")
 89 | 	log.Log().Info("_/      _/_/_/    _/    _/_/      _/_/_/  _/  _/    _/    _/    _/")
 90 | 	log.Log().Info("_/     _/    _/  _/  _/    _/  _/        _/_/      _/    _/     _/")
 91 | 	log.Log().Info("_/    _/    _/  _/  _/    _/  _/        _/  _/    _/    _/      _/")
 92 | 	log.Log().Info("_/   _/_/_/    _/    _/_/      _/_/_/  _/    _/    _/_/_/       _/")
 93 | 	log.Log().Info("_/                                                    _/        _/")
 94 | 	log.Log().Info("_/                                               _/_/           _/")
 95 | 	log.Log().Info("_/                                                              _/")
 96 | 	log.Log().Info("_/                                                              _/")
 97 | 	log.Log().Infof("_/  Version: %-18s Build time: %-18s  _/", util.Version, util.BuildTime)
 98 | 	log.Log().Info("_/                                                              _/")
 99 | 	log.Log().Info("_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/")
100 | }
101 | 


--------------------------------------------------------------------------------
/cmd/validate.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"errors"
 5 | 	"os"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 
 9 | 	"github.com/spf13/cobra"
10 | )
11 | 
12 | // NewValidateCommand creates new command instance
13 | func NewValidateCommand() *cobra.Command {
14 | 	return &cobra.Command{
15 | 		Use:   "validate",
16 | 		Args:  cobra.NoArgs,
17 | 		Short: "Validates the configuration",
18 | 		RunE:  validateConfiguration,
19 | 	}
20 | }
21 | 
22 | func validateConfiguration(_ *cobra.Command, _ []string) error {
23 | 	log.Log().Infof("Validating configuration file: %s", configPath)
24 | 
25 | 	_, err := os.Stat(configPath)
26 | 	if err != nil && errors.Is(err, os.ErrNotExist) {
27 | 		return errors.New("configuration path does not exist")
28 | 	}
29 | 
30 | 	err = initConfig()
31 | 	if err != nil {
32 | 		return err
33 | 	}
34 | 
35 | 	log.Log().Info("Configuration is valid")
36 | 
37 | 	return nil
38 | }
39 | 


--------------------------------------------------------------------------------
/cmd/validate_test.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"github.com/0xERR0R/blocky/helpertest"
 5 | 	. "github.com/onsi/ginkgo/v2"
 6 | 	. "github.com/onsi/gomega"
 7 | )
 8 | 
 9 | var _ = Describe("Validate command", func() {
10 | 	var tmpDir *helpertest.TmpFolder
11 | 	BeforeEach(func() {
12 | 		tmpDir = helpertest.NewTmpFolder("config")
13 | 	})
14 | 	When("Validate is called with not existing configuration file", func() {
15 | 		It("should terminate with error", func() {
16 | 			c := NewRootCommand()
17 | 			c.SetArgs([]string{"validate", "--config", "/notexisting/path.yaml"})
18 | 
19 | 			Expect(c.Execute()).Should(HaveOccurred())
20 | 		})
21 | 	})
22 | 
23 | 	When("Validate is called with existing valid configuration file", func() {
24 | 		It("should terminate without error", func() {
25 | 			cfgFile := tmpDir.CreateStringFile("config.yaml",
26 | 				"upstreams:",
27 | 				"  groups:",
28 | 				"    default:",
29 | 				"      - 1.1.1.1")
30 | 
31 | 			c := NewRootCommand()
32 | 			c.SetArgs([]string{"validate", "--config", cfgFile.Path})
33 | 
34 | 			Expect(c.Execute()).Should(Succeed())
35 | 		})
36 | 	})
37 | 
38 | 	When("Validate is called with existing invalid configuration file", func() {
39 | 		It("should terminate with error", func() {
40 | 			cfgFile := tmpDir.CreateStringFile("config.yaml",
41 | 				"upstreams:",
42 | 				"  groups:",
43 | 				"    default:",
44 | 				"      - 1.broken file")
45 | 
46 | 			c := NewRootCommand()
47 | 			c.SetArgs([]string{"validate", "--config", cfgFile.Path})
48 | 
49 | 			Expect(c.Execute()).Should(HaveOccurred())
50 | 		})
51 | 	})
52 | })
53 | 


--------------------------------------------------------------------------------
/cmd/version.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/util"
 7 | 	"github.com/spf13/cobra"
 8 | )
 9 | 
10 | // NewVersionCommand creates new command instance
11 | func NewVersionCommand() *cobra.Command {
12 | 	return &cobra.Command{
13 | 		Use:   "version",
14 | 		Args:  cobra.NoArgs,
15 | 		Short: "Print the version number of blocky",
16 | 		Run:   printVersion,
17 | 	}
18 | }
19 | 
20 | func printVersion(_ *cobra.Command, _ []string) {
21 | 	fmt.Println("blocky")
22 | 	fmt.Printf("Version: %s\n", util.Version)
23 | 	fmt.Printf("Build time: %s\n", util.BuildTime)
24 | 	fmt.Printf("Architecture: %s\n", util.Architecture)
25 | }
26 | 


--------------------------------------------------------------------------------
/cmd/version_test.go:
--------------------------------------------------------------------------------
 1 | package cmd
 2 | 
 3 | import (
 4 | 	. "github.com/onsi/ginkgo/v2"
 5 | 	. "github.com/onsi/gomega"
 6 | )
 7 | 
 8 | var _ = Describe("Version command", func() {
 9 | 	When("Version command is called", func() {
10 | 		It("should execute without error", func() {
11 | 			c := NewVersionCommand()
12 | 			c.SetArgs(make([]string, 0))
13 | 			err := c.Execute()
14 | 			Expect(err).Should(Succeed())
15 | 		})
16 | 	})
17 | })
18 | 


--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
 1 | coverage:
 2 |   status:
 3 |     project:
 4 |       default:
 5 |         target: auto
 6 |         threshold: 80%
 7 |     patch:
 8 |       default:
 9 |         # basic
10 |         target: auto
11 |         threshold: 0%
12 |         base: auto 
13 |         only_pulls: true
14 | ignore:
15 |   - "**/mock_*"
16 |   - "**/*_enum.go"
17 |   - "**/*.gen.go"
18 |   - "e2e/*.go"
19 | 


--------------------------------------------------------------------------------
/config/blocking_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"time"
 5 | 
 6 | 	"github.com/creasty/defaults"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | var _ = Describe("BlockingConfig", func() {
12 | 	var cfg Blocking
13 | 
14 | 	suiteBeforeEach()
15 | 
16 | 	BeforeEach(func() {
17 | 		cfg = Blocking{
18 | 			BlockType: "ZEROIP",
19 | 			BlockTTL:  Duration(time.Minute),
20 | 			Denylists: map[string][]BytesSource{
21 | 				"gr1": NewBytesSources("/a/file/path"),
22 | 			},
23 | 			ClientGroupsBlock: map[string][]string{
24 | 				"default": {"gr1"},
25 | 			},
26 | 		}
27 | 	})
28 | 
29 | 	Describe("IsEnabled", func() {
30 | 		It("should be false by default", func() {
31 | 			cfg := Blocking{}
32 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
33 | 
34 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
35 | 		})
36 | 
37 | 		When("enabled", func() {
38 | 			It("should be true", func() {
39 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
40 | 			})
41 | 		})
42 | 
43 | 		When("disabled", func() {
44 | 			It("should be false", func() {
45 | 				cfg := Blocking{
46 | 					BlockTTL: Duration(-1),
47 | 				}
48 | 
49 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
50 | 			})
51 | 		})
52 | 	})
53 | 
54 | 	Describe("LogConfig", func() {
55 | 		It("should log configuration", func() {
56 | 			cfg.LogConfig(logger)
57 | 
58 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
59 | 			Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:"))
60 | 			Expect(hook.Messages).Should(ContainElement(Equal("blockType = ZEROIP")))
61 | 		})
62 | 	})
63 | 
64 | 	Describe("migrate", func() {
65 | 		It("should copy values", func() {
66 | 			cfg, err := WithDefaults[Blocking]()
67 | 			Expect(err).Should(Succeed())
68 | 
69 | 			cfg.Deprecated.BlackLists = &map[string][]BytesSource{
70 | 				"deny-group": NewBytesSources("/deny.txt"),
71 | 			}
72 | 			cfg.Deprecated.WhiteLists = &map[string][]BytesSource{
73 | 				"allow-group": NewBytesSources("/allow.txt"),
74 | 			}
75 | 
76 | 			migrated := cfg.migrate(logger)
77 | 			Expect(migrated).Should(BeTrue())
78 | 
79 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
80 | 			Expect(hook.Messages).Should(ContainElements(
81 | 				ContainSubstring("blocking.allowlists"),
82 | 				ContainSubstring("blocking.denylists"),
83 | 			))
84 | 
85 | 			Expect(cfg.Allowlists).Should(Equal(*cfg.Deprecated.WhiteLists))
86 | 			Expect(cfg.Denylists).Should(Equal(*cfg.Deprecated.BlackLists))
87 | 		})
88 | 	})
89 | })
90 | 


--------------------------------------------------------------------------------
/config/bytes_source.go:
--------------------------------------------------------------------------------
  1 | //go:generate go tool go-enum -f=$GOFILE --marshal --names --values
  2 | package config
  3 | 
  4 | import (
  5 | 	"fmt"
  6 | 	"strings"
  7 | )
  8 | 
  9 | const maxTextSourceDisplayLen = 12
 10 | 
 11 | // var BytesSourceNone = BytesSource{}
 12 | 
 13 | // BytesSourceType supported BytesSource types. ENUM(
 14 | // text=1 // Inline YAML block.
 15 | // http   // HTTP(S).
 16 | // file   // Local file.
 17 | // )
 18 | type BytesSourceType uint16
 19 | 
 20 | type BytesSource struct {
 21 | 	Type BytesSourceType
 22 | 	From string
 23 | }
 24 | 
 25 | func (s BytesSource) String() string {
 26 | 	switch s.Type {
 27 | 	case BytesSourceTypeText:
 28 | 		break
 29 | 
 30 | 	case BytesSourceTypeHttp:
 31 | 		return s.From
 32 | 
 33 | 	case BytesSourceTypeFile:
 34 | 		return fmt.Sprintf("file://%s", s.From)
 35 | 
 36 | 	default:
 37 | 		return fmt.Sprintf("unknown source (%s: %s)", s.Type, s.From)
 38 | 	}
 39 | 
 40 | 	text := s.From
 41 | 	truncated := false
 42 | 
 43 | 	if idx := strings.IndexRune(text, '\n'); idx != -1 {
 44 | 		text = text[:idx]           // first line only
 45 | 		truncated = idx < len(text) // don't count removing last char
 46 | 	}
 47 | 
 48 | 	if len(text) > maxTextSourceDisplayLen { // truncate
 49 | 		text = text[:maxTextSourceDisplayLen]
 50 | 		truncated = true
 51 | 	}
 52 | 
 53 | 	if truncated {
 54 | 		return fmt.Sprintf("%s...", text[:maxTextSourceDisplayLen])
 55 | 	}
 56 | 
 57 | 	return text
 58 | }
 59 | 
 60 | // UnmarshalText implements `encoding.TextUnmarshaler`.
 61 | func (s *BytesSource) UnmarshalText(data []byte) error {
 62 | 	source := string(data)
 63 | 
 64 | 	switch {
 65 | 	// Inline definition in YAML (with literal style Block Scalar)
 66 | 	case strings.ContainsAny(source, "\n"):
 67 | 		*s = BytesSource{Type: BytesSourceTypeText, From: source}
 68 | 
 69 | 	// HTTP(S)
 70 | 	case strings.HasPrefix(source, "http"):
 71 | 		*s = BytesSource{Type: BytesSourceTypeHttp, From: source}
 72 | 
 73 | 	// Probably path to a local file
 74 | 	default:
 75 | 		*s = BytesSource{Type: BytesSourceTypeFile, From: strings.TrimPrefix(source, "file://")}
 76 | 	}
 77 | 
 78 | 	return nil
 79 | }
 80 | 
 81 | func newBytesSource(source string) BytesSource {
 82 | 	var res BytesSource
 83 | 
 84 | 	// UnmarshalText never returns an error
 85 | 	_ = res.UnmarshalText([]byte(source))
 86 | 
 87 | 	return res
 88 | }
 89 | 
 90 | func NewBytesSources(sources ...string) []BytesSource {
 91 | 	res := make([]BytesSource, 0, len(sources))
 92 | 
 93 | 	for _, source := range sources {
 94 | 		res = append(res, newBytesSource(source))
 95 | 	}
 96 | 
 97 | 	return res
 98 | }
 99 | 
100 | func TextBytesSource(lines ...string) BytesSource {
101 | 	return BytesSource{Type: BytesSourceTypeText, From: inlineList(lines...)}
102 | }
103 | 
104 | func inlineList(lines ...string) string {
105 | 	res := strings.Join(lines, "\n")
106 | 
107 | 	// ensure at least one line ending so it's parsed as an inline block
108 | 	res += "\n"
109 | 
110 | 	return res
111 | }
112 | 


--------------------------------------------------------------------------------
/config/bytes_source_enum.go:
--------------------------------------------------------------------------------
  1 | // Code generated by go-enum DO NOT EDIT.
  2 | // Version:
  3 | // Revision:
  4 | // Build Date:
  5 | // Built By:
  6 | 
  7 | package config
  8 | 
  9 | import (
 10 | 	"fmt"
 11 | 	"strings"
 12 | )
 13 | 
 14 | const (
 15 | 	// BytesSourceTypeText is a BytesSourceType of type Text.
 16 | 	// Inline YAML block.
 17 | 	BytesSourceTypeText BytesSourceType = iota + 1
 18 | 	// BytesSourceTypeHttp is a BytesSourceType of type Http.
 19 | 	// HTTP(S).
 20 | 	BytesSourceTypeHttp
 21 | 	// BytesSourceTypeFile is a BytesSourceType of type File.
 22 | 	// Local file.
 23 | 	BytesSourceTypeFile
 24 | )
 25 | 
 26 | var ErrInvalidBytesSourceType = fmt.Errorf("not a valid BytesSourceType, try [%s]", strings.Join(_BytesSourceTypeNames, ", "))
 27 | 
 28 | const _BytesSourceTypeName = "texthttpfile"
 29 | 
 30 | var _BytesSourceTypeNames = []string{
 31 | 	_BytesSourceTypeName[0:4],
 32 | 	_BytesSourceTypeName[4:8],
 33 | 	_BytesSourceTypeName[8:12],
 34 | }
 35 | 
 36 | // BytesSourceTypeNames returns a list of possible string values of BytesSourceType.
 37 | func BytesSourceTypeNames() []string {
 38 | 	tmp := make([]string, len(_BytesSourceTypeNames))
 39 | 	copy(tmp, _BytesSourceTypeNames)
 40 | 	return tmp
 41 | }
 42 | 
 43 | // BytesSourceTypeValues returns a list of the values for BytesSourceType
 44 | func BytesSourceTypeValues() []BytesSourceType {
 45 | 	return []BytesSourceType{
 46 | 		BytesSourceTypeText,
 47 | 		BytesSourceTypeHttp,
 48 | 		BytesSourceTypeFile,
 49 | 	}
 50 | }
 51 | 
 52 | var _BytesSourceTypeMap = map[BytesSourceType]string{
 53 | 	BytesSourceTypeText: _BytesSourceTypeName[0:4],
 54 | 	BytesSourceTypeHttp: _BytesSourceTypeName[4:8],
 55 | 	BytesSourceTypeFile: _BytesSourceTypeName[8:12],
 56 | }
 57 | 
 58 | // String implements the Stringer interface.
 59 | func (x BytesSourceType) String() string {
 60 | 	if str, ok := _BytesSourceTypeMap[x]; ok {
 61 | 		return str
 62 | 	}
 63 | 	return fmt.Sprintf("BytesSourceType(%d)", x)
 64 | }
 65 | 
 66 | // IsValid provides a quick way to determine if the typed value is
 67 | // part of the allowed enumerated values
 68 | func (x BytesSourceType) IsValid() bool {
 69 | 	_, ok := _BytesSourceTypeMap[x]
 70 | 	return ok
 71 | }
 72 | 
 73 | var _BytesSourceTypeValue = map[string]BytesSourceType{
 74 | 	_BytesSourceTypeName[0:4]:  BytesSourceTypeText,
 75 | 	_BytesSourceTypeName[4:8]:  BytesSourceTypeHttp,
 76 | 	_BytesSourceTypeName[8:12]: BytesSourceTypeFile,
 77 | }
 78 | 
 79 | // ParseBytesSourceType attempts to convert a string to a BytesSourceType.
 80 | func ParseBytesSourceType(name string) (BytesSourceType, error) {
 81 | 	if x, ok := _BytesSourceTypeValue[name]; ok {
 82 | 		return x, nil
 83 | 	}
 84 | 	return BytesSourceType(0), fmt.Errorf("%s is %w", name, ErrInvalidBytesSourceType)
 85 | }
 86 | 
 87 | // MarshalText implements the text marshaller method.
 88 | func (x BytesSourceType) MarshalText() ([]byte, error) {
 89 | 	return []byte(x.String()), nil
 90 | }
 91 | 
 92 | // UnmarshalText implements the text unmarshaller method.
 93 | func (x *BytesSourceType) UnmarshalText(text []byte) error {
 94 | 	name := string(text)
 95 | 	tmp, err := ParseBytesSourceType(name)
 96 | 	if err != nil {
 97 | 		return err
 98 | 	}
 99 | 	*x = tmp
100 | 	return nil
101 | }
102 | 


--------------------------------------------------------------------------------
/config/caching.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"time"
 5 | 
 6 | 	"github.com/sirupsen/logrus"
 7 | )
 8 | 
 9 | // Caching configuration for domain caching
10 | type Caching struct {
11 | 	MinCachingTime        Duration `yaml:"minTime"`
12 | 	MaxCachingTime        Duration `yaml:"maxTime"`
13 | 	CacheTimeNegative     Duration `yaml:"cacheTimeNegative" default:"30m"`
14 | 	MaxItemsCount         int      `yaml:"maxItemsCount"`
15 | 	Prefetching           bool     `yaml:"prefetching"`
16 | 	PrefetchExpires       Duration `yaml:"prefetchExpires" default:"2h"`
17 | 	PrefetchThreshold     int      `yaml:"prefetchThreshold" default:"5"`
18 | 	PrefetchMaxItemsCount int      `yaml:"prefetchMaxItemsCount"`
19 | 	Exclude               []string `yaml:"exclude"`
20 | }
21 | 
22 | // IsEnabled implements `config.Configurable`.
23 | func (c *Caching) IsEnabled() bool {
24 | 	return c.MaxCachingTime.IsAtLeastZero()
25 | }
26 | 
27 | // LogConfig implements `config.Configurable`.
28 | func (c *Caching) LogConfig(logger *logrus.Entry) {
29 | 	logger.Infof("minTime = %s", c.MinCachingTime)
30 | 	logger.Infof("maxTime = %s", c.MaxCachingTime)
31 | 	logger.Infof("cacheTimeNegative = %s", c.CacheTimeNegative)
32 | 	logger.Infof("exclude:")
33 | 	for _, val := range c.Exclude {
34 | 		logger.Infof("- %v", val)
35 | 	}
36 | 
37 | 	if c.Prefetching {
38 | 		logger.Infof("prefetching:")
39 | 		logger.Infof("  expires   = %s", c.PrefetchExpires)
40 | 		logger.Infof("  threshold = %d", c.PrefetchThreshold)
41 | 		logger.Infof("  maxItems  = %d", c.PrefetchMaxItemsCount)
42 | 	} else {
43 | 		logger.Debug("prefetching: disabled")
44 | 	}
45 | }
46 | 
47 | func (c *Caching) EnablePrefetch() {
48 | 	const day = Duration(24 * time.Hour)
49 | 
50 | 	if !c.IsEnabled() {
51 | 		// make sure resolver gets enabled
52 | 		c.MaxCachingTime = day
53 | 	}
54 | 
55 | 	c.Prefetching = true
56 | 	c.PrefetchThreshold = 0
57 | }
58 | 


--------------------------------------------------------------------------------
/config/caching_test.go:
--------------------------------------------------------------------------------
  1 | package config
  2 | 
  3 | import (
  4 | 	"time"
  5 | 
  6 | 	"github.com/creasty/defaults"
  7 | 	. "github.com/onsi/ginkgo/v2"
  8 | 	. "github.com/onsi/gomega"
  9 | )
 10 | 
 11 | var _ = Describe("CachingConfig", func() {
 12 | 	var cfg Caching
 13 | 
 14 | 	suiteBeforeEach()
 15 | 
 16 | 	BeforeEach(func() {
 17 | 		cfg = Caching{
 18 | 			MaxCachingTime: Duration(time.Hour),
 19 | 		}
 20 | 	})
 21 | 
 22 | 	Describe("IsEnabled", func() {
 23 | 		It("should be true by default", func() {
 24 | 			cfg := Caching{}
 25 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
 26 | 
 27 | 			Expect(cfg.IsEnabled()).Should(BeTrue())
 28 | 		})
 29 | 
 30 | 		When("the config is disabled", func() {
 31 | 			BeforeEach(func() {
 32 | 				cfg = Caching{
 33 | 					MaxCachingTime: Duration(time.Hour * -1),
 34 | 				}
 35 | 			})
 36 | 			It("should be false", func() {
 37 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
 38 | 			})
 39 | 		})
 40 | 
 41 | 		When("the config is enabled", func() {
 42 | 			It("should be true", func() {
 43 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
 44 | 			})
 45 | 		})
 46 | 
 47 | 		When("the config is disabled", func() {
 48 | 			It("should be false", func() {
 49 | 				cfg := Caching{
 50 | 					MaxCachingTime: Duration(-1),
 51 | 				}
 52 | 
 53 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
 54 | 			})
 55 | 		})
 56 | 	})
 57 | 
 58 | 	Describe("LogConfig", func() {
 59 | 		When("prefetching is enabled", func() {
 60 | 			BeforeEach(func() {
 61 | 				cfg = Caching{
 62 | 					Prefetching: true,
 63 | 				}
 64 | 			})
 65 | 
 66 | 			It("should return configuration", func() {
 67 | 				cfg.LogConfig(logger)
 68 | 
 69 | 				Expect(hook.Calls).ShouldNot(BeEmpty())
 70 | 				Expect(hook.Messages).Should(ContainElement(ContainSubstring("prefetching:")))
 71 | 			})
 72 | 		})
 73 | 		When("has any settings", func() {
 74 | 			BeforeEach(func() {
 75 | 				cfg = Caching{}
 76 | 			})
 77 | 			It("should return Exclude", func() {
 78 | 				cfg.LogConfig(logger)
 79 | 
 80 | 				Expect(hook.Calls).ShouldNot(BeEmpty())
 81 | 				Expect(hook.Messages).Should(ContainElement(ContainSubstring("exclude:")))
 82 | 			})
 83 | 		})
 84 | 		When("Exclude any settings", func() {
 85 | 			BeforeEach(func() {
 86 | 				cfg = Caching{Exclude: []string{"local"}}
 87 | 			})
 88 | 			It("should return Exclude and a list with values in it", func() {
 89 | 				cfg.LogConfig(logger)
 90 | 
 91 | 				Expect(hook.Calls).ShouldNot(BeEmpty())
 92 | 				Expect(hook.Messages).Should(ContainElement(ContainSubstring("exclude:")))
 93 | 				Expect(hook.Messages).Should(ContainElement(ContainSubstring("- local")))
 94 | 			})
 95 | 		})
 96 | 	})
 97 | 
 98 | 	Describe("EnablePrefetch", func() {
 99 | 		When("prefetching is enabled", func() {
100 | 			BeforeEach(func() {
101 | 				cfg = Caching{}
102 | 			})
103 | 
104 | 			It("should return configuration", func() {
105 | 				cfg.EnablePrefetch()
106 | 
107 | 				Expect(cfg.Prefetching).Should(BeTrue())
108 | 				Expect(cfg.PrefetchThreshold).Should(Equal(0))
109 | 				Expect(cfg.MaxCachingTime).Should(BeZero())
110 | 			})
111 | 		})
112 | 	})
113 | 
114 | 	Describe("Exclude", func() {
115 | 		It("should be empty by default", func() {
116 | 			cfg := Caching{}
117 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
118 | 
119 | 			Expect(cfg.Exclude).Should(BeEmpty())
120 | 		})
121 | 	})
122 | })
123 | 


--------------------------------------------------------------------------------
/config/client_lookup.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"net"
 5 | 
 6 | 	"github.com/sirupsen/logrus"
 7 | )
 8 | 
 9 | // ClientLookup configuration for the client lookup
10 | type ClientLookup struct {
11 | 	ClientnameIPMapping map[string][]net.IP `yaml:"clients"`
12 | 	Upstream            Upstream            `yaml:"upstream"`
13 | 	SingleNameOrder     []uint              `yaml:"singleNameOrder"`
14 | }
15 | 
16 | // IsEnabled implements `config.Configurable`.
17 | func (c *ClientLookup) IsEnabled() bool {
18 | 	return !c.Upstream.IsDefault() || len(c.ClientnameIPMapping) != 0
19 | }
20 | 
21 | // LogConfig implements `config.Configurable`.
22 | func (c *ClientLookup) LogConfig(logger *logrus.Entry) {
23 | 	if !c.Upstream.IsDefault() {
24 | 		logger.Infof("upstream = %s", c.Upstream)
25 | 	}
26 | 
27 | 	logger.Infof("singleNameOrder = %v", c.SingleNameOrder)
28 | 
29 | 	if len(c.ClientnameIPMapping) > 0 {
30 | 		logger.Infof("client IP mapping:")
31 | 
32 | 		for k, v := range c.ClientnameIPMapping {
33 | 			logger.Infof("  %s = %s", k, v)
34 | 		}
35 | 	}
36 | }
37 | 


--------------------------------------------------------------------------------
/config/client_lookup_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"net"
 5 | 
 6 | 	"github.com/creasty/defaults"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | var _ = Describe("ClientLookupConfig", func() {
12 | 	var cfg ClientLookup
13 | 
14 | 	suiteBeforeEach()
15 | 
16 | 	BeforeEach(func() {
17 | 		cfg = ClientLookup{
18 | 			Upstream:        Upstream{Net: NetProtocolTcpUdp, Host: "host"},
19 | 			SingleNameOrder: []uint{1, 2},
20 | 			ClientnameIPMapping: map[string][]net.IP{
21 | 				"client8": {net.ParseIP("1.2.3.5")},
22 | 			},
23 | 		}
24 | 	})
25 | 
26 | 	Describe("IsEnabled", func() {
27 | 		It("should be false by default", func() {
28 | 			cfg = ClientLookup{}
29 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
30 | 
31 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
32 | 		})
33 | 
34 | 		When("enabled", func() {
35 | 			It("should be true", func() {
36 | 				By("upstream", func() {
37 | 					cfg := ClientLookup{
38 | 						Upstream:            Upstream{Net: NetProtocolTcpUdp, Host: "host"},
39 | 						ClientnameIPMapping: nil,
40 | 					}
41 | 
42 | 					Expect(cfg.IsEnabled()).Should(BeTrue())
43 | 				})
44 | 
45 | 				By("mapping", func() {
46 | 					cfg := ClientLookup{
47 | 						ClientnameIPMapping: map[string][]net.IP{
48 | 							"client8": {net.ParseIP("1.2.3.5")},
49 | 						},
50 | 					}
51 | 
52 | 					Expect(cfg.IsEnabled()).Should(BeTrue())
53 | 				})
54 | 			})
55 | 		})
56 | 	})
57 | 
58 | 	Describe("LogConfig", func() {
59 | 		It("should log configuration", func() {
60 | 			cfg.LogConfig(logger)
61 | 
62 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
63 | 			Expect(hook.Messages).Should(ContainElement(ContainSubstring("client IP mapping:")))
64 | 		})
65 | 	})
66 | })
67 | 


--------------------------------------------------------------------------------
/config/conditional_upstream.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"strings"
 6 | 
 7 | 	"github.com/sirupsen/logrus"
 8 | )
 9 | 
10 | // ConditionalUpstream conditional upstream configuration
11 | type ConditionalUpstream struct {
12 | 	RewriterConfig `yaml:",inline"`
13 | 	Mapping        ConditionalUpstreamMapping `yaml:"mapping"`
14 | }
15 | 
16 | // ConditionalUpstreamMapping mapping for conditional configuration
17 | type ConditionalUpstreamMapping struct {
18 | 	Upstreams map[string][]Upstream
19 | }
20 | 
21 | // IsEnabled implements `config.Configurable`.
22 | func (c *ConditionalUpstream) IsEnabled() bool {
23 | 	return len(c.Mapping.Upstreams) != 0
24 | }
25 | 
26 | // LogConfig implements `config.Configurable`.
27 | func (c *ConditionalUpstream) LogConfig(logger *logrus.Entry) {
28 | 	for key, val := range c.Mapping.Upstreams {
29 | 		logger.Infof("%s = %v", key, val)
30 | 	}
31 | }
32 | 
33 | // UnmarshalYAML implements `yaml.Unmarshaler`.
34 | func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
35 | 	var input map[string]string
36 | 	if err := unmarshal(&input); err != nil {
37 | 		return err
38 | 	}
39 | 
40 | 	result := make(map[string][]Upstream, len(input))
41 | 
42 | 	for k, v := range input {
43 | 		var upstreams []Upstream
44 | 
45 | 		for _, part := range strings.Split(v, ",") {
46 | 			upstream, err := ParseUpstream(strings.TrimSpace(part))
47 | 			if err != nil {
48 | 				return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err)
49 | 			}
50 | 
51 | 			upstreams = append(upstreams, upstream)
52 | 		}
53 | 
54 | 		result[k] = upstreams
55 | 	}
56 | 
57 | 	c.Upstreams = result
58 | 
59 | 	return nil
60 | }
61 | 


--------------------------------------------------------------------------------
/config/conditional_upstream_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"errors"
 5 | 
 6 | 	"github.com/creasty/defaults"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | var _ = Describe("ConditionalUpstreamConfig", func() {
12 | 	var cfg ConditionalUpstream
13 | 
14 | 	suiteBeforeEach()
15 | 
16 | 	BeforeEach(func() {
17 | 		cfg = ConditionalUpstream{
18 | 			Mapping: ConditionalUpstreamMapping{
19 | 				Upstreams: map[string][]Upstream{
20 | 					"fritz.box": {Upstream{Net: NetProtocolTcpUdp, Host: "fbTest"}},
21 | 					"other.box": {Upstream{Net: NetProtocolTcpUdp, Host: "otherTest"}},
22 | 					".":         {Upstream{Net: NetProtocolTcpUdp, Host: "dotTest"}},
23 | 				},
24 | 			},
25 | 		}
26 | 	})
27 | 
28 | 	Describe("IsEnabled", func() {
29 | 		It("should be false by default", func() {
30 | 			cfg := ConditionalUpstream{}
31 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
32 | 
33 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
34 | 		})
35 | 
36 | 		When("enabled", func() {
37 | 			It("should be true", func() {
38 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
39 | 			})
40 | 		})
41 | 
42 | 		When("disabled", func() {
43 | 			It("should be false", func() {
44 | 				cfg := ConditionalUpstream{
45 | 					Mapping: ConditionalUpstreamMapping{Upstreams: map[string][]Upstream{}},
46 | 				}
47 | 
48 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
49 | 			})
50 | 		})
51 | 	})
52 | 
53 | 	Describe("LogConfig", func() {
54 | 		It("should log configuration", func() {
55 | 			cfg.LogConfig(logger)
56 | 
57 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
58 | 			Expect(hook.Messages).Should(ContainElement(ContainSubstring("fritz.box = ")))
59 | 		})
60 | 	})
61 | 
62 | 	Describe("UnmarshalYAML", func() {
63 | 		It("Should parse config as map", func() {
64 | 			c := &ConditionalUpstreamMapping{}
65 | 			err := c.UnmarshalYAML(func(i interface{}) error {
66 | 				*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
67 | 
68 | 				return nil
69 | 			})
70 | 			Expect(err).Should(Succeed())
71 | 			Expect(c.Upstreams).Should(HaveLen(1))
72 | 			Expect(c.Upstreams["key"]).Should(HaveLen(1))
73 | 			Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{
74 | 				Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53,
75 | 			}))
76 | 		})
77 | 
78 | 		It("should fail if wrong YAML format", func() {
79 | 			c := &ConditionalUpstreamMapping{}
80 | 			err := c.UnmarshalYAML(func(i interface{}) error {
81 | 				return errors.New("some err")
82 | 			})
83 | 			Expect(err).Should(HaveOccurred())
84 | 			Expect(err).Should(MatchError("some err"))
85 | 		})
86 | 	})
87 | })
88 | 


--------------------------------------------------------------------------------
/config/config_suite_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | 	"github.com/sirupsen/logrus"
10 | )
11 | 
12 | var (
13 | 	logger *logrus.Entry
14 | 	hook   *log.MockLoggerHook
15 | )
16 | 
17 | func init() {
18 | 	log.Silence()
19 | }
20 | 
21 | func TestConfig(t *testing.T) {
22 | 	RegisterFailHandler(Fail)
23 | 	RunSpecs(t, "Config Suite")
24 | }
25 | 
26 | func suiteBeforeEach() {
27 | 	BeforeEach(func() {
28 | 		logger, hook = log.NewMockEntry()
29 | 	})
30 | }
31 | 


--------------------------------------------------------------------------------
/config/custom_dns.go:
--------------------------------------------------------------------------------
  1 | package config
  2 | 
  3 | import (
  4 | 	"fmt"
  5 | 	"net"
  6 | 	"strings"
  7 | 
  8 | 	"github.com/miekg/dns"
  9 | 	"github.com/sirupsen/logrus"
 10 | )
 11 | 
 12 | // CustomDNS custom DNS configuration
 13 | type CustomDNS struct {
 14 | 	RewriterConfig      `yaml:",inline"`
 15 | 	CustomTTL           Duration         `yaml:"customTTL" default:"1h"`
 16 | 	Mapping             CustomDNSMapping `yaml:"mapping"`
 17 | 	Zone                ZoneFileDNS      `yaml:"zone" default:""`
 18 | 	FilterUnmappedTypes bool             `yaml:"filterUnmappedTypes" default:"true"`
 19 | }
 20 | 
 21 | type (
 22 | 	CustomDNSMapping map[string]CustomDNSEntries
 23 | 	CustomDNSEntries []dns.RR
 24 | 
 25 | 	ZoneFileDNS struct {
 26 | 		RRs        CustomDNSMapping
 27 | 		configPath string
 28 | 	}
 29 | )
 30 | 
 31 | func (z *ZoneFileDNS) UnmarshalYAML(unmarshal func(interface{}) error) error {
 32 | 	var input string
 33 | 	if err := unmarshal(&input); err != nil {
 34 | 		return err
 35 | 	}
 36 | 
 37 | 	result := make(CustomDNSMapping)
 38 | 
 39 | 	zoneParser := dns.NewZoneParser(strings.NewReader(input), "", z.configPath)
 40 | 	zoneParser.SetIncludeAllowed(true)
 41 | 
 42 | 	for {
 43 | 		zoneRR, ok := zoneParser.Next()
 44 | 
 45 | 		if !ok {
 46 | 			if zoneParser.Err() != nil {
 47 | 				return zoneParser.Err()
 48 | 			}
 49 | 
 50 | 			// Done
 51 | 			break
 52 | 		}
 53 | 
 54 | 		domain := zoneRR.Header().Name
 55 | 
 56 | 		if _, ok := result[domain]; !ok {
 57 | 			result[domain] = make(CustomDNSEntries, 0, 1)
 58 | 		}
 59 | 
 60 | 		result[domain] = append(result[domain], zoneRR)
 61 | 	}
 62 | 
 63 | 	z.RRs = result
 64 | 
 65 | 	return nil
 66 | }
 67 | 
 68 | func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error {
 69 | 	var input string
 70 | 	if err := unmarshal(&input); err != nil {
 71 | 		return err
 72 | 	}
 73 | 
 74 | 	parts := strings.Split(input, ",")
 75 | 	result := make(CustomDNSEntries, len(parts))
 76 | 
 77 | 	for i, part := range parts {
 78 | 		rr, err := configToRR(strings.TrimSpace(part))
 79 | 		if err != nil {
 80 | 			return err
 81 | 		}
 82 | 
 83 | 		result[i] = rr
 84 | 	}
 85 | 
 86 | 	*c = result
 87 | 
 88 | 	return nil
 89 | }
 90 | 
 91 | // IsEnabled implements `config.Configurable`.
 92 | func (c *CustomDNS) IsEnabled() bool {
 93 | 	return len(c.Mapping) != 0
 94 | }
 95 | 
 96 | // LogConfig implements `config.Configurable`.
 97 | func (c *CustomDNS) LogConfig(logger *logrus.Entry) {
 98 | 	logger.Debugf("TTL = %s", c.CustomTTL)
 99 | 	logger.Debugf("filterUnmappedTypes = %t", c.FilterUnmappedTypes)
100 | 
101 | 	logger.Info("mapping:")
102 | 
103 | 	for key, val := range c.Mapping {
104 | 		logger.Infof("  %s = %s", key, val)
105 | 	}
106 | }
107 | 
108 | func configToRR(ipStr string) (dns.RR, error) {
109 | 	ip := net.ParseIP(ipStr)
110 | 	if ip == nil {
111 | 		return nil, fmt.Errorf("invalid IP address '%s'", ipStr)
112 | 	}
113 | 
114 | 	if ip.To4() != nil {
115 | 		a := new(dns.A)
116 | 		a.A = ip
117 | 
118 | 		return a, nil
119 | 	}
120 | 
121 | 	aaaa := new(dns.AAAA)
122 | 	aaaa.AAAA = ip
123 | 
124 | 	return aaaa, nil
125 | }
126 | 


--------------------------------------------------------------------------------
/config/duration.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"strconv"
 5 | 	"time"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 	"github.com/hako/durafmt"
 9 | )
10 | 
11 | // Duration is a wrapper for time.Duration to support yaml unmarshalling
12 | type Duration time.Duration
13 | 
14 | // ToDuration converts Duration to time.Duration
15 | func (c Duration) ToDuration() time.Duration {
16 | 	return time.Duration(c)
17 | }
18 | 
19 | // IsAboveZero returns true if duration is strictly greater than zero.
20 | func (c Duration) IsAboveZero() bool {
21 | 	return c.ToDuration() > 0
22 | }
23 | 
24 | // IsAtLeastZero returns true if duration is greater or equal to zero.
25 | func (c Duration) IsAtLeastZero() bool {
26 | 	return c.ToDuration() >= 0
27 | }
28 | 
29 | // Seconds returns duration in seconds
30 | func (c Duration) Seconds() float64 {
31 | 	return c.ToDuration().Seconds()
32 | }
33 | 
34 | // SecondsU32 returns duration in seconds as uint32
35 | func (c Duration) SecondsU32() uint32 {
36 | 	return uint32(c.Seconds())
37 | }
38 | 
39 | // String implements `fmt.Stringer`
40 | func (c Duration) String() string {
41 | 	return durafmt.Parse(c.ToDuration()).String()
42 | }
43 | 
44 | // UnmarshalText implements `encoding.TextUnmarshaler`.
45 | func (c *Duration) UnmarshalText(data []byte) error {
46 | 	input := string(data)
47 | 
48 | 	if minutes, err := strconv.Atoi(input); err == nil {
49 | 		// number without unit: use minutes to ensure back compatibility
50 | 		*c = Duration(time.Duration(minutes) * time.Minute)
51 | 
52 | 		log.Log().Warnf("Setting a duration without a unit is deprecated. Please use '%s min' instead.", input)
53 | 
54 | 		return nil
55 | 	}
56 | 
57 | 	duration, err := time.ParseDuration(input)
58 | 	if err == nil {
59 | 		*c = Duration(duration)
60 | 
61 | 		return nil
62 | 	}
63 | 
64 | 	return err
65 | }
66 | 


--------------------------------------------------------------------------------
/config/duration_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"time"
 5 | 
 6 | 	. "github.com/onsi/ginkgo/v2"
 7 | 	. "github.com/onsi/gomega"
 8 | )
 9 | 
10 | var _ = Describe("Duration", func() {
11 | 	var d Duration
12 | 
13 | 	BeforeEach(func() {
14 | 		var zero Duration
15 | 
16 | 		d = zero
17 | 	})
18 | 
19 | 	Describe("UnmarshalText", func() {
20 | 		It("should parse duration with unit", func() {
21 | 			err := d.UnmarshalText([]byte("1m20s"))
22 | 			Expect(err).Should(Succeed())
23 | 			Expect(d).Should(Equal(Duration(80 * time.Second)))
24 | 			Expect(d.String()).Should(Equal("1 minute 20 seconds"))
25 | 		})
26 | 
27 | 		It("should fail if duration is in wrong format", func() {
28 | 			err := d.UnmarshalText([]byte("wrong"))
29 | 			Expect(err).Should(HaveOccurred())
30 | 			Expect(err).Should(MatchError("time: invalid duration \"wrong\""))
31 | 		})
32 | 	})
33 | 
34 | 	Describe("IsAboveZero", func() {
35 | 		It("should be false for zero", func() {
36 | 			Expect(d.IsAboveZero()).Should(BeFalse())
37 | 			Expect(Duration(0).IsAboveZero()).Should(BeFalse())
38 | 		})
39 | 
40 | 		It("should be false for negative", func() {
41 | 			Expect(Duration(-1).IsAboveZero()).Should(BeFalse())
42 | 		})
43 | 
44 | 		It("should be true for positive", func() {
45 | 			Expect(Duration(1).IsAboveZero()).Should(BeTrue())
46 | 		})
47 | 	})
48 | 
49 | 	Describe("SecondsU32", func() {
50 | 		It("should return the seconds", func() {
51 | 			Expect(Duration(time.Minute).SecondsU32()).Should(Equal(uint32(60)))
52 | 			Expect(Duration(time.Hour).SecondsU32()).Should(Equal(uint32(3600)))
53 | 		})
54 | 	})
55 | })
56 | 


--------------------------------------------------------------------------------
/config/ecs.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"strconv"
 6 | 
 7 | 	"github.com/sirupsen/logrus"
 8 | )
 9 | 
10 | const (
11 | 	ecsIpv4MaskMax = uint8(32)
12 | 	ecsIpv6MaskMax = uint8(128)
13 | )
14 | 
15 | // ECSv4Mask is the subnet mask to be added as EDNS0 option for IPv4
16 | type ECSv4Mask uint8
17 | 
18 | // UnmarshalText implements the encoding.TextUnmarshaler interface
19 | func (x *ECSv4Mask) UnmarshalText(text []byte) error {
20 | 	res, err := unmarshalInternal(text, ecsIpv4MaskMax, "IPv4")
21 | 	if err != nil {
22 | 		return err
23 | 	}
24 | 
25 | 	*x = ECSv4Mask(res)
26 | 
27 | 	return nil
28 | }
29 | 
30 | // ECSv6Mask is the subnet mask to be added as EDNS0 option for IPv6
31 | type ECSv6Mask uint8
32 | 
33 | // UnmarshalText implements the encoding.TextUnmarshaler interface
34 | func (x *ECSv6Mask) UnmarshalText(text []byte) error {
35 | 	res, err := unmarshalInternal(text, ecsIpv6MaskMax, "IPv6")
36 | 	if err != nil {
37 | 		return err
38 | 	}
39 | 
40 | 	*x = ECSv6Mask(res)
41 | 
42 | 	return nil
43 | }
44 | 
45 | // ECS is the configuration of the ECS resolver
46 | type ECS struct {
47 | 	UseAsClient bool      `yaml:"useAsClient" default:"false"`
48 | 	Forward     bool      `yaml:"forward" default:"false"`
49 | 	IPv4Mask    ECSv4Mask `yaml:"ipv4Mask" default:"0"`
50 | 	IPv6Mask    ECSv6Mask `yaml:"ipv6Mask" default:"0"`
51 | }
52 | 
53 | // IsEnabled returns true if the ECS resolver is enabled
54 | func (c *ECS) IsEnabled() bool {
55 | 	return c.UseAsClient || c.Forward || c.IPv4Mask > 0 || c.IPv6Mask > 0
56 | }
57 | 
58 | // LogConfig logs the configuration
59 | func (c *ECS) LogConfig(logger *logrus.Entry) {
60 | 	logger.Infof("Use as client = %t", c.UseAsClient)
61 | 	logger.Infof("Forward       = %t", c.Forward)
62 | 	logger.Infof("IPv4 netmask  = %d", c.IPv4Mask)
63 | 	logger.Infof("IPv6 netmask  = %d", c.IPv6Mask)
64 | }
65 | 
66 | // unmarshalInternal unmarshals the subnet mask from the given text and checks if the value is valid
67 | // it is used by the UnmarshalText methods of ECSv4Mask and ECSv6Mask
68 | func unmarshalInternal(text []byte, maxvalue uint8, name string) (uint8, error) {
69 | 	strVal := string(text)
70 | 
71 | 	uiVal, err := strconv.ParseUint(strVal, 10, 8)
72 | 	if err != nil {
73 | 		return 0, err
74 | 	}
75 | 
76 | 	if uiVal > uint64(maxvalue) {
77 | 		return 0, fmt.Errorf("mask value (%s) is too large for %s(max: %d)", strVal, name, maxvalue)
78 | 	}
79 | 
80 | 	return uint8(uiVal), nil
81 | }
82 | 


--------------------------------------------------------------------------------
/config/filtering.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/sirupsen/logrus"
 5 | )
 6 | 
 7 | type Filtering struct {
 8 | 	QueryTypes QTypeSet `yaml:"queryTypes"`
 9 | }
10 | 
11 | // IsEnabled implements `config.Configurable`.
12 | func (c *Filtering) IsEnabled() bool {
13 | 	return len(c.QueryTypes) != 0
14 | }
15 | 
16 | // LogConfig implements `config.Configurable`.
17 | func (c *Filtering) LogConfig(logger *logrus.Entry) {
18 | 	logger.Info("query types:")
19 | 
20 | 	for qType := range c.QueryTypes {
21 | 		logger.Infof("  - %s", qType)
22 | 	}
23 | }
24 | 


--------------------------------------------------------------------------------
/config/filtering_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	. "github.com/0xERR0R/blocky/helpertest"
 5 | 
 6 | 	"github.com/creasty/defaults"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | var _ = Describe("FilteringConfig", func() {
12 | 	var cfg Filtering
13 | 
14 | 	suiteBeforeEach()
15 | 
16 | 	BeforeEach(func() {
17 | 		cfg = Filtering{
18 | 			QueryTypes: NewQTypeSet(AAAA, MX),
19 | 		}
20 | 	})
21 | 
22 | 	Describe("IsEnabled", func() {
23 | 		It("should be false by default", func() {
24 | 			cfg := Filtering{}
25 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
26 | 
27 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
28 | 		})
29 | 
30 | 		When("enabled", func() {
31 | 			It("should be true", func() {
32 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
33 | 			})
34 | 		})
35 | 
36 | 		When("disabled", func() {
37 | 			It("should be false", func() {
38 | 				cfg := Filtering{}
39 | 
40 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
41 | 			})
42 | 		})
43 | 	})
44 | 
45 | 	Describe("LogConfig", func() {
46 | 		It("should log configuration", func() {
47 | 			cfg.LogConfig(logger)
48 | 
49 | 			Expect(hook.Calls).Should(HaveLen(3))
50 | 			Expect(hook.Messages).Should(ContainElements(
51 | 				ContainSubstring("query types:"),
52 | 				ContainSubstring("  - AAAA"),
53 | 				ContainSubstring("  - MX"),
54 | 			))
55 | 		})
56 | 	})
57 | })
58 | 


--------------------------------------------------------------------------------
/config/hosts_file.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	. "github.com/0xERR0R/blocky/config/migration"
 5 | 	"github.com/0xERR0R/blocky/log"
 6 | 	"github.com/sirupsen/logrus"
 7 | )
 8 | 
 9 | type HostsFile struct {
10 | 	Sources        []BytesSource `yaml:"sources"`
11 | 	HostsTTL       Duration      `yaml:"hostsTTL" default:"1h"`
12 | 	FilterLoopback bool          `yaml:"filterLoopback"`
13 | 	Loading        SourceLoading `yaml:"loading"`
14 | 
15 | 	// Deprecated options
16 | 	Deprecated struct {
17 | 		RefreshPeriod *Duration    `yaml:"refreshPeriod"`
18 | 		Filepath      *BytesSource `yaml:"filePath"`
19 | 	} `yaml:",inline"`
20 | }
21 | 
22 | func (c *HostsFile) migrate(logger *logrus.Entry) bool {
23 | 	return Migrate(logger, "hostsFile", c.Deprecated, map[string]Migrator{
24 | 		"refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)),
25 | 		"filePath": Apply(To("sources", c), func(value BytesSource) {
26 | 			c.Sources = append(c.Sources, value)
27 | 		}),
28 | 	})
29 | }
30 | 
31 | // IsEnabled implements `config.Configurable`.
32 | func (c *HostsFile) IsEnabled() bool {
33 | 	return len(c.Sources) != 0
34 | }
35 | 
36 | // LogConfig implements `config.Configurable`.
37 | func (c *HostsFile) LogConfig(logger *logrus.Entry) {
38 | 	logger.Infof("TTL: %s", c.HostsTTL)
39 | 	logger.Infof("filter loopback addresses: %t", c.FilterLoopback)
40 | 
41 | 	logger.Info("loading:")
42 | 	log.WithIndent(logger, "  ", c.Loading.LogConfig)
43 | 
44 | 	logger.Info("sources:")
45 | 
46 | 	for _, source := range c.Sources {
47 | 		logger.Infof("  - %s", source)
48 | 	}
49 | }
50 | 


--------------------------------------------------------------------------------
/config/hosts_file_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"time"
 5 | 
 6 | 	"github.com/creasty/defaults"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | var _ = Describe("HostsFileConfig", func() {
12 | 	var cfg HostsFile
13 | 
14 | 	suiteBeforeEach()
15 | 
16 | 	BeforeEach(func() {
17 | 		cfg = HostsFile{
18 | 			Sources: append(
19 | 				NewBytesSources("/a/file/path"),
20 | 				TextBytesSource("127.0.0.1 localhost"),
21 | 			),
22 | 			HostsTTL:       Duration(29 * time.Minute),
23 | 			Loading:        SourceLoading{RefreshPeriod: Duration(30 * time.Minute)},
24 | 			FilterLoopback: true,
25 | 		}
26 | 	})
27 | 
28 | 	Describe("IsEnabled", func() {
29 | 		It("should be false by default", func() {
30 | 			cfg := HostsFile{}
31 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
32 | 
33 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
34 | 		})
35 | 
36 | 		When("enabled", func() {
37 | 			It("should be true", func() {
38 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
39 | 			})
40 | 		})
41 | 
42 | 		When("disabled", func() {
43 | 			It("should be false", func() {
44 | 				cfg := HostsFile{}
45 | 
46 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
47 | 			})
48 | 		})
49 | 	})
50 | 
51 | 	Describe("LogConfig", func() {
52 | 		It("should log configuration", func() {
53 | 			cfg.LogConfig(logger)
54 | 
55 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
56 | 			Expect(hook.Messages).Should(ContainElements(
57 | 				ContainSubstring("- file:///a/file/path"),
58 | 				ContainSubstring("- 127.0.0.1 lo..."),
59 | 			))
60 | 		})
61 | 	})
62 | 
63 | 	Describe("migrate", func() {
64 | 		It("should", func() {
65 | 			cfg, err := WithDefaults[HostsFile]()
66 | 			Expect(err).Should(Succeed())
67 | 
68 | 			cfg.Deprecated.Filepath = ptrOf(newBytesSource("/a/file/path"))
69 | 			cfg.Deprecated.RefreshPeriod = ptrOf(Duration(time.Hour))
70 | 
71 | 			migrated := cfg.migrate(logger)
72 | 			Expect(migrated).Should(BeTrue())
73 | 
74 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
75 | 			Expect(hook.Messages).Should(ContainElements(
76 | 				ContainSubstring("hostsFile.loading.refreshPeriod"),
77 | 				ContainSubstring("hostsFile.sources"),
78 | 			))
79 | 
80 | 			Expect(cfg.Sources).Should(Equal([]BytesSource{*cfg.Deprecated.Filepath}))
81 | 			Expect(cfg.Loading.RefreshPeriod).Should(Equal(*cfg.Deprecated.RefreshPeriod))
82 | 		})
83 | 	})
84 | })
85 | 


--------------------------------------------------------------------------------
/config/metrics.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import "github.com/sirupsen/logrus"
 4 | 
 5 | // Metrics contains the config values for prometheus
 6 | type Metrics struct {
 7 | 	Enable bool   `yaml:"enable" default:"false"`
 8 | 	Path   string `yaml:"path" default:"/metrics"`
 9 | }
10 | 
11 | // IsEnabled implements `config.Configurable`.
12 | func (c *Metrics) IsEnabled() bool {
13 | 	return c.Enable
14 | }
15 | 
16 | // LogConfig implements `config.Configurable`.
17 | func (c *Metrics) LogConfig(logger *logrus.Entry) {
18 | 	logger.Infof("url path: %s", c.Path)
19 | }
20 | 


--------------------------------------------------------------------------------
/config/metrics_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/creasty/defaults"
 5 | 	. "github.com/onsi/ginkgo/v2"
 6 | 	. "github.com/onsi/gomega"
 7 | )
 8 | 
 9 | var _ = Describe("MetricsConfig", func() {
10 | 	var cfg Metrics
11 | 
12 | 	suiteBeforeEach()
13 | 
14 | 	BeforeEach(func() {
15 | 		cfg = Metrics{
16 | 			Enable: true,
17 | 			Path:   "/custom/path",
18 | 		}
19 | 	})
20 | 
21 | 	Describe("IsEnabled", func() {
22 | 		It("should be false by default", func() {
23 | 			cfg := Metrics{}
24 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
25 | 
26 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
27 | 		})
28 | 
29 | 		When("enabled", func() {
30 | 			It("should be true", func() {
31 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
32 | 			})
33 | 		})
34 | 
35 | 		When("disabled", func() {
36 | 			It("should be false", func() {
37 | 				cfg := Metrics{}
38 | 
39 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
40 | 			})
41 | 		})
42 | 	})
43 | 
44 | 	Describe("LogConfig", func() {
45 | 		It("should log configuration", func() {
46 | 			cfg.LogConfig(logger)
47 | 
48 | 			Expect(hook.Calls).Should(HaveLen(1))
49 | 			Expect(hook.Messages).Should(ContainElement(ContainSubstring("url path: /custom/path")))
50 | 		})
51 | 	})
52 | })
53 | 


--------------------------------------------------------------------------------
/config/qtype_set.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"sort"
 6 | 	"strings"
 7 | 
 8 | 	"github.com/miekg/dns"
 9 | 	"golang.org/x/exp/maps"
10 | )
11 | 
12 | type QTypeSet map[QType]struct{}
13 | 
14 | func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
15 | 	s := make(QTypeSet, len(qTypes))
16 | 
17 | 	for _, qType := range qTypes {
18 | 		s.Insert(qType)
19 | 	}
20 | 
21 | 	return s
22 | }
23 | 
24 | func (s QTypeSet) Contains(qType dns.Type) bool {
25 | 	_, found := s[QType(qType)]
26 | 
27 | 	return found
28 | }
29 | 
30 | func (s *QTypeSet) Insert(qType dns.Type) {
31 | 	if *s == nil {
32 | 		*s = make(QTypeSet, 1)
33 | 	}
34 | 
35 | 	(*s)[QType(qType)] = struct{}{}
36 | }
37 | 
38 | func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error {
39 | 	var input []QType
40 | 	if err := unmarshal(&input); err != nil {
41 | 		return err
42 | 	}
43 | 
44 | 	*s = make(QTypeSet, len(input))
45 | 
46 | 	for _, qType := range input {
47 | 		(*s)[qType] = struct{}{}
48 | 	}
49 | 
50 | 	return nil
51 | }
52 | 
53 | type QType dns.Type
54 | 
55 | func (c QType) String() string {
56 | 	return dns.Type(c).String()
57 | }
58 | 
59 | // UnmarshalText implements `encoding.TextUnmarshaler`.
60 | func (c *QType) UnmarshalText(data []byte) error {
61 | 	input := string(data)
62 | 
63 | 	t, found := dns.StringToType[input]
64 | 	if !found {
65 | 		types := maps.Keys(dns.StringToType)
66 | 
67 | 		sort.Strings(types)
68 | 
69 | 		return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'",
70 | 			input, strings.Join(types, ", "))
71 | 	}
72 | 
73 | 	*c = QType(t)
74 | 
75 | 	return nil
76 | }
77 | 


--------------------------------------------------------------------------------
/config/qtype_set_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/miekg/dns"
 5 | 	. "github.com/onsi/ginkgo/v2"
 6 | 	. "github.com/onsi/gomega"
 7 | )
 8 | 
 9 | var _ = Describe("QTypeSet", func() {
10 | 	Describe("NewQTypeSet", func() {
11 | 		It("should insert given qTypes", func() {
12 | 			set := NewQTypeSet(dns.Type(dns.TypeA))
13 | 			Expect(set).Should(HaveKey(QType(dns.TypeA)))
14 | 			Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue())
15 | 
16 | 			Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
17 | 			Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
18 | 		})
19 | 	})
20 | 
21 | 	Describe("Insert", func() {
22 | 		It("should insert given qTypes", func() {
23 | 			set := NewQTypeSet()
24 | 
25 | 			Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA)))
26 | 			Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue())
27 | 
28 | 			set.Insert(dns.Type(dns.TypeAAAA))
29 | 
30 | 			Expect(set).Should(HaveKey(QType(dns.TypeAAAA)))
31 | 			Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue())
32 | 		})
33 | 	})
34 | })
35 | 
36 | var _ = Describe("QType", func() {
37 | 	Describe("UnmarshalText", func() {
38 | 		It("Should parse existing DNS type as string", func() {
39 | 			t := QType(0)
40 | 			err := t.UnmarshalText([]byte("AAAA"))
41 | 			Expect(err).Should(Succeed())
42 | 			Expect(t).Should(Equal(QType(dns.TypeAAAA)))
43 | 			Expect(t.String()).Should(Equal("AAAA"))
44 | 		})
45 | 
46 | 		It("should fail if DNS type does not exist", func() {
47 | 			t := QType(0)
48 | 			err := t.UnmarshalText([]byte("WRONGTYPE"))
49 | 			Expect(err).Should(HaveOccurred())
50 | 			Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'"))
51 | 		})
52 | 
53 | 		It("should fail if wrong YAML format", func() {
54 | 			d := QType(0)
55 | 			err := d.UnmarshalText([]byte("some err"))
56 | 			Expect(err).Should(HaveOccurred())
57 | 			Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'some err'"))
58 | 		})
59 | 	})
60 | })
61 | 


--------------------------------------------------------------------------------
/config/query_log.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"net/url"
 5 | 	"strings"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 	"github.com/sirupsen/logrus"
 9 | )
10 | 
11 | // QueryLog configuration for the query logging
12 | type QueryLog struct {
13 | 	Target           string          `yaml:"target"`
14 | 	Type             QueryLogType    `yaml:"type"`
15 | 	LogRetentionDays uint64          `yaml:"logRetentionDays"`
16 | 	CreationAttempts int             `yaml:"creationAttempts" default:"3"`
17 | 	CreationCooldown Duration        `yaml:"creationCooldown" default:"2s"`
18 | 	Fields           []QueryLogField `yaml:"fields"`
19 | 	FlushInterval    Duration        `yaml:"flushInterval" default:"30s"`
20 | 	Ignore           QueryLogIgnore  `yaml:"ignore"`
21 | }
22 | 
23 | type QueryLogIgnore struct {
24 | 	SUDN bool `yaml:"sudn" default:"false"`
25 | }
26 | 
27 | // SetDefaults implements `defaults.Setter`.
28 | func (c *QueryLog) SetDefaults() {
29 | 	// Since the default depends on the enum values, set it dynamically
30 | 	// to avoid having to repeat the values in the annotation.
31 | 	c.Fields = QueryLogFieldValues()
32 | }
33 | 
34 | // IsEnabled implements `config.Configurable`.
35 | func (c *QueryLog) IsEnabled() bool {
36 | 	return c.Type != QueryLogTypeNone
37 | }
38 | 
39 | // LogConfig implements `config.Configurable`.
40 | func (c *QueryLog) LogConfig(logger *logrus.Entry) {
41 | 	logger.Infof("type: %s", c.Type)
42 | 
43 | 	if c.Target != "" {
44 | 		logger.Infof("target: %s", c.censoredTarget())
45 | 	}
46 | 
47 | 	logger.Infof("logRetentionDays: %d", c.LogRetentionDays)
48 | 	logger.Debugf("creationAttempts: %d", c.CreationAttempts)
49 | 	logger.Debugf("creationCooldown: %s", c.CreationCooldown)
50 | 	logger.Infof("flushInterval: %s", c.FlushInterval)
51 | 	logger.Infof("fields: %s", c.Fields)
52 | 
53 | 	logger.Infof("ignore:")
54 | 	log.WithIndent(logger, "  ", func(e *logrus.Entry) {
55 | 		logger.Infof("sudn: %t", c.Ignore.SUDN)
56 | 	})
57 | }
58 | 
59 | func (c *QueryLog) censoredTarget() string {
60 | 	// Make sure there's a scheme, otherwise the user is parsed as the scheme
61 | 	targetStr := c.Target
62 | 	if !strings.Contains(targetStr, "://") {
63 | 		targetStr = c.Type.String() + "://" + targetStr
64 | 	}
65 | 
66 | 	target, err := url.Parse(targetStr)
67 | 	if err != nil {
68 | 		return c.Target
69 | 	}
70 | 
71 | 	pass, ok := target.User.Password()
72 | 	if !ok {
73 | 		return c.Target
74 | 	}
75 | 
76 | 	return strings.ReplaceAll(c.Target, pass, secretObfuscator)
77 | }
78 | 


--------------------------------------------------------------------------------
/config/query_log_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"time"
 5 | 
 6 | 	"github.com/creasty/defaults"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | var _ = Describe("QueryLogConfig", func() {
12 | 	var cfg QueryLog
13 | 
14 | 	suiteBeforeEach()
15 | 
16 | 	BeforeEach(func() {
17 | 		cfg = QueryLog{
18 | 			Target:           "/dev/null",
19 | 			Type:             QueryLogTypeCsvClient,
20 | 			LogRetentionDays: 0,
21 | 			CreationAttempts: 1,
22 | 			CreationCooldown: Duration(time.Millisecond),
23 | 		}
24 | 	})
25 | 
26 | 	Describe("IsEnabled", func() {
27 | 		It("should be true by default", func() {
28 | 			cfg := QueryLog{}
29 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
30 | 
31 | 			Expect(cfg.IsEnabled()).Should(BeTrue())
32 | 		})
33 | 
34 | 		When("enabled", func() {
35 | 			It("should be true", func() {
36 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
37 | 			})
38 | 		})
39 | 
40 | 		When("disabled", func() {
41 | 			It("should be false", func() {
42 | 				cfg := QueryLog{
43 | 					Type: QueryLogTypeNone,
44 | 				}
45 | 
46 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
47 | 			})
48 | 		})
49 | 	})
50 | 
51 | 	Describe("LogConfig", func() {
52 | 		It("should log configuration", func() {
53 | 			cfg.LogConfig(logger)
54 | 
55 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
56 | 			Expect(hook.Messages).Should(ContainElement(ContainSubstring("logRetentionDays:")))
57 | 			Expect(hook.Messages).Should(ContainElement(ContainSubstring("sudn:")))
58 | 		})
59 | 
60 | 		DescribeTable("secret censoring", func(target string) {
61 | 			cfg.Type = QueryLogTypeMysql
62 | 			cfg.Target = target
63 | 
64 | 			cfg.LogConfig(logger)
65 | 
66 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
67 | 			Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("password")))
68 | 		},
69 | 			Entry("without scheme", "user:password@localhost"),
70 | 			Entry("with scheme", "scheme://user:password@localhost"),
71 | 			Entry("no password", "localhost"),
72 | 			Entry("not a URL", "invalid!://"),
73 | 		)
74 | 	})
75 | 
76 | 	Describe("SetDefaults", func() {
77 | 		It("should log configuration", func() {
78 | 			cfg := QueryLog{}
79 | 			Expect(cfg.Fields).Should(BeEmpty())
80 | 
81 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
82 | 
83 | 			Expect(cfg.Fields).ShouldNot(BeEmpty())
84 | 		})
85 | 	})
86 | })
87 | 


--------------------------------------------------------------------------------
/config/redis.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/sirupsen/logrus"
 5 | )
 6 | 
 7 | // Redis configuration for the redis connection
 8 | type Redis struct {
 9 | 	Address            string   `yaml:"address"`
10 | 	Username           string   `yaml:"username" default:""`
11 | 	Password           string   `yaml:"password" default:""`
12 | 	Database           int      `yaml:"database" default:"0"`
13 | 	Required           bool     `yaml:"required" default:"false"`
14 | 	ConnectionAttempts int      `yaml:"connectionAttempts" default:"3"`
15 | 	ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"`
16 | 	SentinelUsername   string   `yaml:"sentinelUsername" default:""`
17 | 	SentinelPassword   string   `yaml:"sentinelPassword" default:""`
18 | 	SentinelAddresses  []string `yaml:"sentinelAddresses"`
19 | }
20 | 
21 | // IsEnabled implements `config.Configurable`
22 | func (c *Redis) IsEnabled() bool {
23 | 	return c.Address != ""
24 | }
25 | 
26 | // LogConfig implements `config.Configurable`
27 | func (c *Redis) LogConfig(logger *logrus.Entry) {
28 | 	if len(c.SentinelAddresses) == 0 {
29 | 		logger.Info("address: ", c.Address)
30 | 	}
31 | 
32 | 	logger.Info("username: ", c.Username)
33 | 	logger.Info("password: ", secretObfuscator)
34 | 	logger.Info("database: ", c.Database)
35 | 	logger.Info("required: ", c.Required)
36 | 	logger.Info("connectionAttempts: ", c.ConnectionAttempts)
37 | 	logger.Info("connectionCooldown: ", c.ConnectionCooldown)
38 | 
39 | 	if len(c.SentinelAddresses) > 0 {
40 | 		logger.Info("sentinel:")
41 | 		logger.Info("  master: ", c.Address)
42 | 		logger.Info("  username: ", c.SentinelUsername)
43 | 		logger.Info("  password: ", secretObfuscator)
44 | 		logger.Info("  addresses:")
45 | 
46 | 		for _, addr := range c.SentinelAddresses {
47 | 			logger.Info("    - ", addr)
48 | 		}
49 | 	}
50 | }
51 | 


--------------------------------------------------------------------------------
/config/redis_test.go:
--------------------------------------------------------------------------------
  1 | package config
  2 | 
  3 | import (
  4 | 	"github.com/0xERR0R/blocky/log"
  5 | 	"github.com/creasty/defaults"
  6 | 	. "github.com/onsi/ginkgo/v2"
  7 | 	. "github.com/onsi/gomega"
  8 | )
  9 | 
 10 | var _ = Describe("Redis", func() {
 11 | 	var (
 12 | 		c   Redis
 13 | 		err error
 14 | 	)
 15 | 
 16 | 	suiteBeforeEach()
 17 | 
 18 | 	BeforeEach(func() {
 19 | 		err = defaults.Set(&c)
 20 | 		Expect(err).Should(Succeed())
 21 | 	})
 22 | 
 23 | 	Describe("IsEnabled", func() {
 24 | 		When("all fields are default", func() {
 25 | 			It("should be disabled", func() {
 26 | 				Expect(c.IsEnabled()).Should(BeFalse())
 27 | 			})
 28 | 		})
 29 | 
 30 | 		When("Address is set", func() {
 31 | 			BeforeEach(func() {
 32 | 				c.Address = "localhost:6379"
 33 | 			})
 34 | 
 35 | 			It("should be enabled", func() {
 36 | 				Expect(c.IsEnabled()).Should(BeTrue())
 37 | 			})
 38 | 		})
 39 | 	})
 40 | 
 41 | 	Describe("LogConfig", func() {
 42 | 		BeforeEach(func() {
 43 | 			logger, hook = log.NewMockEntry()
 44 | 		})
 45 | 
 46 | 		When("all fields are default", func() {
 47 | 			It("should log default values", func() {
 48 | 				c.LogConfig(logger)
 49 | 
 50 | 				Expect(hook.Messages).Should(
 51 | 					SatisfyAll(ContainElement(ContainSubstring("address: ")),
 52 | 						ContainElement(ContainSubstring("username: ")),
 53 | 						ContainElement(ContainSubstring("password: ")),
 54 | 						ContainElement(ContainSubstring("database: ")),
 55 | 						ContainElement(ContainSubstring("required: ")),
 56 | 						ContainElement(ContainSubstring("connectionAttempts: ")),
 57 | 						ContainElement(ContainSubstring("connectionCooldown: "))))
 58 | 			})
 59 | 		})
 60 | 
 61 | 		When("Address is set", func() {
 62 | 			BeforeEach(func() {
 63 | 				c.Address = "localhost:6379"
 64 | 			})
 65 | 
 66 | 			It("should log address", func() {
 67 | 				c.LogConfig(logger)
 68 | 
 69 | 				Expect(hook.Messages).Should(ContainElement(ContainSubstring("address: localhost:6379")))
 70 | 			})
 71 | 		})
 72 | 
 73 | 		When("SentinelAddresses is set", func() {
 74 | 			BeforeEach(func() {
 75 | 				c.SentinelAddresses = []string{"localhost:26379", "localhost:26380"}
 76 | 			})
 77 | 
 78 | 			It("should log sentinel addresses", func() {
 79 | 				c.LogConfig(logger)
 80 | 
 81 | 				Expect(hook.Messages).Should(
 82 | 					SatisfyAll(
 83 | 						ContainElement(ContainSubstring("sentinel:")),
 84 | 						ContainElement(ContainSubstring("  addresses:")),
 85 | 						ContainElement(ContainSubstring("  - localhost:26379")),
 86 | 						ContainElement(ContainSubstring("  - localhost:26380"))))
 87 | 			})
 88 | 		})
 89 | 
 90 | 		const secretValue = "secret-value"
 91 | 
 92 | 		It("should not log the password", func() {
 93 | 			c.Password = secretValue
 94 | 			c.LogConfig(logger)
 95 | 
 96 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
 97 | 			Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring(secretValue)))
 98 | 		})
 99 | 
100 | 		It("should not log the sentinel password", func() {
101 | 			c.SentinelPassword = secretValue
102 | 			c.LogConfig(logger)
103 | 
104 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
105 | 			Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring(secretValue)))
106 | 		})
107 | 	})
108 | })
109 | 


--------------------------------------------------------------------------------
/config/rewriter.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/sirupsen/logrus"
 5 | )
 6 | 
 7 | // RewriterConfig custom DNS configuration
 8 | type RewriterConfig struct {
 9 | 	Rewrite          map[string]string `yaml:"rewrite"`
10 | 	FallbackUpstream bool              `yaml:"fallbackUpstream" default:"false"`
11 | }
12 | 
13 | // IsEnabled implements `config.Configurable`.
14 | func (c *RewriterConfig) IsEnabled() bool {
15 | 	return len(c.Rewrite) != 0
16 | }
17 | 
18 | // LogConfig implements `config.Configurable`.
19 | func (c *RewriterConfig) LogConfig(logger *logrus.Entry) {
20 | 	logger.Infof("fallbackUpstream = %t", c.FallbackUpstream)
21 | 
22 | 	logger.Info("rules:")
23 | 
24 | 	for key, val := range c.Rewrite {
25 | 		logger.Infof("  %s = %s", key, val)
26 | 	}
27 | }
28 | 


--------------------------------------------------------------------------------
/config/rewriter_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/creasty/defaults"
 5 | 	. "github.com/onsi/ginkgo/v2"
 6 | 	. "github.com/onsi/gomega"
 7 | )
 8 | 
 9 | var _ = Describe("RewriterConfig", func() {
10 | 	var cfg RewriterConfig
11 | 
12 | 	suiteBeforeEach()
13 | 
14 | 	BeforeEach(func() {
15 | 		cfg = RewriterConfig{
16 | 			Rewrite: map[string]string{
17 | 				"original1": "rewritten1",
18 | 				"original2": "rewritten2",
19 | 			},
20 | 		}
21 | 	})
22 | 
23 | 	Describe("IsEnabled", func() {
24 | 		It("should be false by default", func() {
25 | 			cfg := RewriterConfig{}
26 | 			Expect(defaults.Set(&cfg)).Should(Succeed())
27 | 
28 | 			Expect(cfg.IsEnabled()).Should(BeFalse())
29 | 		})
30 | 
31 | 		When("enabled", func() {
32 | 			It("should be true", func() {
33 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
34 | 			})
35 | 		})
36 | 
37 | 		When("disabled", func() {
38 | 			It("should be false", func() {
39 | 				cfg := RewriterConfig{}
40 | 
41 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
42 | 			})
43 | 		})
44 | 	})
45 | 
46 | 	Describe("LogConfig", func() {
47 | 		It("should log configuration", func() {
48 | 			cfg.LogConfig(logger)
49 | 
50 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
51 | 			Expect(hook.Messages).Should(ContainElements(
52 | 				ContainSubstring("rules:"),
53 | 				ContainSubstring("original2 ="),
54 | 			))
55 | 		})
56 | 	})
57 | })
58 | 


--------------------------------------------------------------------------------
/config/sudn.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/sirupsen/logrus"
 5 | )
 6 | 
 7 | // SUDN configuration for Special Use Domain Names
 8 | type SUDN struct {
 9 | 	// These are "recommended for private use" but not mandatory.
10 | 	// If a user wishes to use one, it will most likely be via conditional
11 | 	// upstream or custom DNS, which come before SUDN in the resolver chain.
12 | 	// Thus defaulting to `true` and returning NXDOMAIN here should not conflict.
13 | 	RFC6762AppendixG bool `yaml:"rfc6762-appendixG" default:"true"`
14 | 	Enable           bool `yaml:"enable" default:"true"`
15 | }
16 | 
17 | // IsEnabled implements `config.Configurable`.
18 | func (c *SUDN) IsEnabled() bool {
19 | 	return c.Enable
20 | }
21 | 
22 | // LogConfig implements `config.Configurable`.
23 | func (c *SUDN) LogConfig(logger *logrus.Entry) {
24 | 	logger.Debugf("rfc6762-appendixG = %v", c.RFC6762AppendixG)
25 | }
26 | 


--------------------------------------------------------------------------------
/config/sudn_test.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	. "github.com/onsi/ginkgo/v2"
 5 | 	. "github.com/onsi/gomega"
 6 | )
 7 | 
 8 | var _ = Describe("SUDNConfig", func() {
 9 | 	var cfg SUDN
10 | 
11 | 	suiteBeforeEach()
12 | 
13 | 	BeforeEach(func() {
14 | 		var err error
15 | 
16 | 		cfg, err = WithDefaults[SUDN]()
17 | 		Expect(err).Should(Succeed())
18 | 	})
19 | 
20 | 	Describe("IsEnabled", func() {
21 | 		It("should be true by default", func() {
22 | 			Expect(cfg.IsEnabled()).Should(BeTrue())
23 | 		})
24 | 
25 | 		When("enabled", func() {
26 | 			It("should be true", func() {
27 | 				cfg := SUDN{
28 | 					Enable: true,
29 | 				}
30 | 				Expect(cfg.IsEnabled()).Should(BeTrue())
31 | 			})
32 | 		})
33 | 
34 | 		When("disabled", func() {
35 | 			It("should be false", func() {
36 | 				cfg := SUDN{
37 | 					Enable: false,
38 | 				}
39 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
40 | 			})
41 | 		})
42 | 	})
43 | 
44 | 	Describe("LogConfig", func() {
45 | 		It("should log configuration", func() {
46 | 			cfg.LogConfig(logger)
47 | 
48 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
49 | 			Expect(hook.Messages).Should(ContainElement(ContainSubstring("rfc6762-appendixG = true")))
50 | 		})
51 | 	})
52 | })
53 | 


--------------------------------------------------------------------------------
/config/upstreams.go:
--------------------------------------------------------------------------------
 1 | package config
 2 | 
 3 | import (
 4 | 	"github.com/0xERR0R/blocky/log"
 5 | 	"github.com/sirupsen/logrus"
 6 | )
 7 | 
 8 | const UpstreamDefaultCfgName = "default"
 9 | 
10 | // Upstreams upstream servers configuration
11 | type Upstreams struct {
12 | 	Init      Init             `yaml:"init"`
13 | 	Timeout   Duration         `yaml:"timeout" default:"2s"` // always > 0
14 | 	Groups    UpstreamGroups   `yaml:"groups"`
15 | 	Strategy  UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
16 | 	UserAgent string           `yaml:"userAgent"`
17 | }
18 | 
19 | type UpstreamGroups map[string][]Upstream
20 | 
21 | func (c *Upstreams) validate(logger *logrus.Entry) {
22 | 	defaults := mustDefault[Upstreams]()
23 | 
24 | 	if !c.Timeout.IsAboveZero() {
25 | 		logger.Warnf("upstreams.timeout <= 0, setting to %s", defaults.Timeout)
26 | 		c.Timeout = defaults.Timeout
27 | 	}
28 | }
29 | 
30 | // IsEnabled implements `config.Configurable`.
31 | func (c *Upstreams) IsEnabled() bool {
32 | 	return len(c.Groups) != 0
33 | }
34 | 
35 | // LogConfig implements `config.Configurable`.
36 | func (c *Upstreams) LogConfig(logger *logrus.Entry) {
37 | 	logger.Info("init:")
38 | 	log.WithIndent(logger, "  ", c.Init.LogConfig)
39 | 
40 | 	logger.Info("timeout: ", c.Timeout)
41 | 	logger.Info("strategy: ", c.Strategy)
42 | 	logger.Info("groups:")
43 | 
44 | 	for name, upstreams := range c.Groups {
45 | 		logger.Infof("  %s:", name)
46 | 
47 | 		for _, upstream := range upstreams {
48 | 			logger.Infof("    - %s", upstream)
49 | 		}
50 | 	}
51 | }
52 | 
53 | // UpstreamGroup represents the config for one group (upstream branch)
54 | type UpstreamGroup struct {
55 | 	Upstreams
56 | 
57 | 	Name string // group name
58 | }
59 | 
60 | // NewUpstreamGroup creates an UpstreamGroup with the given name and upstreams.
61 | //
62 | // The upstreams from `cfg.Groups` are ignored.
63 | func NewUpstreamGroup(name string, cfg Upstreams, upstreams []Upstream) UpstreamGroup {
64 | 	group := UpstreamGroup{
65 | 		Name:      name,
66 | 		Upstreams: cfg,
67 | 	}
68 | 
69 | 	group.Groups = UpstreamGroups{name: upstreams}
70 | 
71 | 	return group
72 | }
73 | 
74 | func (c *UpstreamGroup) GroupUpstreams() []Upstream {
75 | 	return c.Groups[c.Name]
76 | }
77 | 
78 | // IsEnabled implements `config.Configurable`.
79 | func (c *UpstreamGroup) IsEnabled() bool {
80 | 	return len(c.GroupUpstreams()) != 0
81 | }
82 | 
83 | // LogConfig implements `config.Configurable`.
84 | func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) {
85 | 	logger.Info("group: ", c.Name)
86 | 	logger.Info("upstreams:")
87 | 
88 | 	for _, upstream := range c.GroupUpstreams() {
89 | 		logger.Infof("  - %s", upstream)
90 | 	}
91 | }
92 | 


--------------------------------------------------------------------------------
/config/upstreams_test.go:
--------------------------------------------------------------------------------
  1 | package config
  2 | 
  3 | import (
  4 | 	"time"
  5 | 
  6 | 	"github.com/creasty/defaults"
  7 | 	. "github.com/onsi/ginkgo/v2"
  8 | 	. "github.com/onsi/gomega"
  9 | )
 10 | 
 11 | var _ = Describe("ParallelBestConfig", func() {
 12 | 	suiteBeforeEach()
 13 | 
 14 | 	Context("Upstreams", func() {
 15 | 		var cfg Upstreams
 16 | 
 17 | 		BeforeEach(func() {
 18 | 			cfg = Upstreams{
 19 | 				Timeout: Duration(5 * time.Second),
 20 | 				Groups: UpstreamGroups{
 21 | 					UpstreamDefaultCfgName: {
 22 | 						{Host: "host1"},
 23 | 						{Host: "host2"},
 24 | 					},
 25 | 				},
 26 | 			}
 27 | 		})
 28 | 
 29 | 		Describe("IsEnabled", func() {
 30 | 			It("should be false by default", func() {
 31 | 				cfg := Upstreams{}
 32 | 				Expect(defaults.Set(&cfg)).Should(Succeed())
 33 | 
 34 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
 35 | 			})
 36 | 
 37 | 			When("enabled", func() {
 38 | 				It("should be true", func() {
 39 | 					Expect(cfg.IsEnabled()).Should(BeTrue())
 40 | 				})
 41 | 			})
 42 | 
 43 | 			When("disabled", func() {
 44 | 				It("should be false", func() {
 45 | 					cfg := Upstreams{}
 46 | 
 47 | 					Expect(cfg.IsEnabled()).Should(BeFalse())
 48 | 				})
 49 | 			})
 50 | 		})
 51 | 
 52 | 		Describe("LogConfig", func() {
 53 | 			It("should log configuration", func() {
 54 | 				cfg.LogConfig(logger)
 55 | 
 56 | 				Expect(hook.Calls).ShouldNot(BeEmpty())
 57 | 				Expect(hook.Messages).Should(ContainElements(
 58 | 					ContainSubstring("timeout:"),
 59 | 					ContainSubstring("groups:"),
 60 | 					ContainSubstring(":host2:"),
 61 | 				))
 62 | 			})
 63 | 		})
 64 | 
 65 | 		Describe("validate", func() {
 66 | 			It("should compute defaults", func() {
 67 | 				cfg.Timeout = -1
 68 | 
 69 | 				cfg.validate(logger)
 70 | 
 71 | 				Expect(cfg.Timeout).Should(BeNumerically(">", 0))
 72 | 
 73 | 				Expect(hook.Calls).ShouldNot(BeEmpty())
 74 | 				Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout")))
 75 | 			})
 76 | 
 77 | 			It("should not override valid user values", func() {
 78 | 				cfg.validate(logger)
 79 | 
 80 | 				Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("timeout")))
 81 | 			})
 82 | 		})
 83 | 	})
 84 | 
 85 | 	Context("UpstreamGroupConfig", func() {
 86 | 		var cfg UpstreamGroup
 87 | 
 88 | 		BeforeEach(func() {
 89 | 			upstreamsCfg, err := WithDefaults[Upstreams]()
 90 | 			Expect(err).Should(Succeed())
 91 | 
 92 | 			cfg = NewUpstreamGroup("test", upstreamsCfg, []Upstream{
 93 | 				{Host: "host1"},
 94 | 				{Host: "host2"},
 95 | 			})
 96 | 		})
 97 | 
 98 | 		Describe("IsEnabled", func() {
 99 | 			It("should be false by default", func() {
100 | 				cfg := UpstreamGroup{}
101 | 				Expect(defaults.Set(&cfg)).Should(Succeed())
102 | 
103 | 				Expect(cfg.IsEnabled()).Should(BeFalse())
104 | 			})
105 | 
106 | 			When("enabled", func() {
107 | 				It("should be true", func() {
108 | 					Expect(cfg.IsEnabled()).Should(BeTrue())
109 | 				})
110 | 			})
111 | 
112 | 			When("disabled", func() {
113 | 				It("should be false", func() {
114 | 					cfg := UpstreamGroup{}
115 | 
116 | 					Expect(cfg.IsEnabled()).Should(BeFalse())
117 | 				})
118 | 			})
119 | 		})
120 | 
121 | 		Describe("LogConfig", func() {
122 | 			It("should log configuration", func() {
123 | 				cfg.LogConfig(logger)
124 | 
125 | 				Expect(hook.Calls).ShouldNot(BeEmpty())
126 | 				Expect(hook.Messages).Should(ContainElements(
127 | 					ContainSubstring("group: test"),
128 | 					ContainSubstring("upstreams:"),
129 | 					ContainSubstring(":host1:"),
130 | 					ContainSubstring(":host2:"),
131 | 				))
132 | 			})
133 | 		})
134 | 	})
135 | })
136 | 


--------------------------------------------------------------------------------
/docs/embed.go:
--------------------------------------------------------------------------------
1 | package docs
2 | 
3 | import _ "embed"
4 | 
5 | //go:embed api/openapi.yaml
6 | var OpenAPI string
7 | 


--------------------------------------------------------------------------------
/docs/fb_dns_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0xERR0R/blocky/62610f657e052b9f5dba77bdfd6c2229d558ab6b/docs/fb_dns_config.png


--------------------------------------------------------------------------------
/docs/grafana-dashboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0xERR0R/blocky/62610f657e052b9f5dba77bdfd6c2229d558ab6b/docs/grafana-dashboard.png


--------------------------------------------------------------------------------
/docs/grafana-query-dashboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0xERR0R/blocky/62610f657e052b9f5dba77bdfd6c2229d558ab6b/docs/grafana-query-dashboard.png


--------------------------------------------------------------------------------
/docs/includes/abbreviations.md:
--------------------------------------------------------------------------------
 1 | *[DNS]: Domain Name System
 2 | *[k8s]: Kubernetes
 3 | *[UDP]: User Datagram Protocol
 4 | *[TCP]: Transmission Control Protocol
 5 | *[HTTP]: Hypertext Transfer Protocol
 6 | *[HTTPS]: Hypertext Transfer Protocol Secure
 7 | *[DoH]: DNS-over-HTTPS
 8 | *[DoT]: DNS-over-TLS
 9 | *[DNSSEC]: Domain Name System Security Extensions
10 | *[eDNS]: Extended DNS
11 | *[REST]: Representational State Transfer
12 | *[API]: Application Programming Interface
13 | *[CLI]: Command Line Interface
14 | *[YAML]: YAML Ain't Markup Language
15 | *[Helm]: package manager for Kubernetes
16 | *[CNAME]: Canonical Name
17 | *[CIDR]: Classless Inter-Domain Routing
18 | *[NXDOMAIN]: Non-Existence Domain
19 | *[TTL]: Time-To-Live
20 | *[rDNS]: Reverse DNS
21 | *[SSL]: Secure Sockets Layer
22 | *[CSV]: Comma-separated values
23 | *[SAMBA]: Server Message Block Protocol (Windows Network File System)
24 | *[DHCP]: Dynamic Host Configuration Protocol
25 | *[duration format]: Example: "300ms", "1.5h" or "2h45m". Valid time units are "ns", "us", "ms", "s", "m", "h".
26 | *[regex]: Regular expression


--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
 1 | # Blocky
 2 | 
 3 | <figure>
 4 |   <img src="https://raw.githubusercontent.com/0xERR0R/blocky/main/docs/blocky.svg" width="200" />
 5 | </figure>
 6 | 
 7 | Blocky is a DNS proxy and ad-blocker for the local network written in Go with following features:
 8 | 
 9 | ## Features
10 | 
11 | - **Blocking** - :no_entry: Blocking of DNS queries with external lists (Ad-block, malware) and allowlisting
12 | 
13 |     * Definition of allow/denylists per client group (Kids, Smart home devices, etc.)
14 |     * Periodical reload of external allow/denylists
15 |     * Regex support
16 |     * Blocking of request domain, response CNAME (deep CNAME inspection) and response IP addresses (against IP lists)
17 | 
18 | - **Advanced DNS configuration** - :nerd: not just an ad-blocker
19 | 
20 |     * Custom DNS resolution for certain domain names
21 |     * Conditional forwarding to external DNS server
22 |     * Upstream resolvers can be defined per client group
23 | 
24 | - **Performance** - :rocket: Improves speed and performance in your network
25 | 
26 |     * Customizable caching of DNS answers for queries -> improves DNS resolution speed and reduces amount of external DNS
27 |       queries
28 |     * Prefetching and caching of often used queries
29 |     * Using multiple external resolver simultaneously
30 |     * Low memory footprint
31 | 
32 | - **Various Protocols** - :computer: Supports modern DNS protocols
33 | 
34 |     * DNS over UDP and TCP
35 |     * DNS over HTTPS (aka DoH)
36 |     * DNS over TLS (aka DoT)
37 | 
38 | - **Security and Privacy** - :dark_sunglasses: Secure communication
39 | 
40 |     * Supports modern DNS extensions: DNSSEC, eDNS, ...
41 |     * Free configurable blocking lists - no hidden filtering etc.
42 |     * Provides DoH Endpoint
43 |     * Uses random upstream resolvers from the configuration - increases your privacy through the distribution of your DNS
44 |       traffic over multiple provider
45 |     * Open source development
46 |     * Blocky does **NOT** collect any user data, telemetry, statistics etc.
47 | 
48 | - **Integration** - :notebook_with_decorative_cover: various integration
49 | 
50 |     * [Prometheus](https://prometheus.io/) metrics
51 |     * Prepared [Grafana](https://grafana.com/) dashboards (Prometheus and database)
52 |     * Logging of DNS queries per day / per client in CSV format or MySQL/MariaDB/PostgreSQL/Timescale database - easy to
53 |       analyze
54 |     * Various REST API endpoints
55 |     * CLI tool
56 | 
57 | - **Simple configuration** - :baby: single configuration file in YAML format
58 | 
59 |     * Simple to maintain
60 |     * Simple to backup
61 | 
62 | - **Simple installation/configuration** - :cloud: blocky was designed for simple installation
63 | 
64 |     * Stateless (no database, no temporary files)
65 |     * Docker image with Multi-arch support
66 |     * Single binary
67 |     * Supports x86-64 and ARM architectures -> runs fine on Raspberry PI
68 |     * Community supported Helm chart for k8s deployment
69 | 
70 | 
71 | ## Contribution
72 | 
73 | Issues, feature suggestions and pull requests are welcome! Blocky lives on :material-github:[GitHub](https://github.com/0xERR0R/blocky).
74 | 
75 | --8<-- "docs/includes/abbreviations.md"
76 | 


--------------------------------------------------------------------------------
/docs/interfaces.md:
--------------------------------------------------------------------------------
 1 | # Interfaces
 2 | 
 3 | ## REST API
 4 | 
 5 | 
 6 | ??? abstract "OpenAPI specification"
 7 | 
 8 |     ```yaml
 9 |     --8<-- "docs/api/openapi.yaml"
10 |     ```
11 | 
12 | If http listener is enabled, blocky provides REST API. You can download the [OpenAPI YAML](api/openapi.yaml) interface specification. 
13 | 
14 | You can also browse the interactive API documentation (RapiDoc) documentation [online](rapidoc.html).
15 | 
16 | ## CLI
17 | 
18 | Blocky provides a CLI interface to control. This interface uses internally the REST API.
19 | 
20 | To run the CLI, please ensure, that blocky DNS server is running, then execute `blocky help` for help or
21 | 
22 | - `./blocky blocking enable` to enable blocking
23 | - `./blocky blocking disable` to disable blocking
24 | - `./blocky blocking disable --duration [duration]` to disable blocking for a certain amount of time (30s, 5m, 10m30s,
25 |   ...)
26 | - `./blocky blocking disable --groups ads,othergroup` to disable blocking only for special groups
27 | - `./blocky blocking status` to print current status of blocking
28 | - `./blocky query <domain>` execute DNS query (A) (simple replacement for dig, useful for debug purposes)
29 | - `./blocky query <domain> --type <queryType>` execute DNS query with passed query type (A, AAAA, MX, ...)
30 | - `./blocky lists refresh` reloads all allow/denylists
31 | - `./blocky validate [--config /path/to/config.yaml]` validates configuration file
32 | 
33 | !!! tip 
34 | 
35 |     To run this inside docker run `docker exec blocky ./blocky blocking status`
36 | 
37 | --8<-- "docs/includes/abbreviations.md"
38 | 


--------------------------------------------------------------------------------
/docs/network_configuration.md:
--------------------------------------------------------------------------------
 1 | # Network configuration
 2 | 
 3 | In order, to benefit from all the advantages of blocky like ad-blocking, privacy and speed, it is necessary to use
 4 | blocky as DNS server for your devices. You can configure DNS server on each device manually or use DHCP in your network
 5 | router and push the right settings to your device. With this approach, you will configure blocky only once in your
 6 | router and each device in your network will automatically use blocky as DNS server.
 7 | 
 8 | ## Transparent configuration with DHCP
 9 | 
10 | Let us assume, blocky is installed on a Raspberry PI with fix IP address `192.168.178.2`. Each device which connects to
11 | the router will obtain an IP address and receive the network configuration. The IP address of the Raspberry PI should be
12 | pushed to the device as DNS server.
13 | 
14 | ```
15 | ┌──────────────┐         ┌─────────────────┐
16 | │              │         │ Raspberry PI    │
17 | │  Router      │         │   blocky        │        
18 | │              │         │ 192.168.178.2   │            
19 | └─▲─────┬──────┘         └────▲────────────┘        
20 |   │1    │                     │  3                  
21 |   │     │                     │                         
22 |   │     │                     │ 
23 |   │     │                     │                     
24 |   │     │                     │
25 |   │     │                     │
26 |   │     │                     │
27 |   │     │       ┌─────────────┴──────┐
28 |   │     │   2   │                    │
29 |   │     └───────►  Network device    │
30 |   │             │    Android         │
31 |   └─────────────┤                    │
32 |                 └────────────────────┘
33 | ```
34 | 
35 | **1** - Network device asks the DHCP server (on Router) for the network configuration
36 | 
37 | **2** - Router assigns a free IP address to the device and says "Use 192.168.178.2" as DNS server
38 | 
39 | **3** - Clients makes DNS queries and is happy to use **blocky** :smile:
40 | 
41 | !!! warning
42 | 
43 |     It is necessary to assign the server which runs blocky (e.g. Raspberry PI) a fix IP address.
44 | 
45 | ### Example configuration with FritzBox
46 | 
47 | To configure the DNS server in the FritzBox, please open in the FritzBox web interface:
48 | 
49 | * in navigation menu on the left side: Home Network -> Network
50 | * Network Settings tab on the top
51 | * "IPv4 Configuration" Button at the bottom op the page
52 | * Enter the IP address of blocky under "Local DNS server", see screenshot
53 | 
54 | ![FritzBox DNS configuration](fb_dns_config.png "Logo Title Text 1")
55 | 
56 | --8<-- "docs/includes/abbreviations.md"


--------------------------------------------------------------------------------
/docs/rapidoc.html:
--------------------------------------------------------------------------------
 1 | <!doctype html>
 2 | <html>
 3 | <head>
 4 |   <meta charset="utf-8">
 5 |   <script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script>
 6 | </head>
 7 | <body>
 8 |   <rapi-doc
 9 |     spec-url="api/openapi.yaml"
10 |     theme = "light"
11 | 	  allow-authentication = "false"
12 |     show-header = "false"
13 |     bg-color = "#fdf8ed"
14 |     nav-bg-color = "#3f4d67"
15 |     nav-text-color = "#a9b7d0"
16 |     nav-hover-bg-color = "#333f54"
17 |     nav-hover-text-color = "#fff"
18 |     nav-accent-color = "#f87070"
19 |     primary-color = "#5c7096"
20 |     allow-try = "false"
21 |   > </rapi-doc>
22 | </body>
23 | </html>


--------------------------------------------------------------------------------
/e2e/e2e_suite_test.go:
--------------------------------------------------------------------------------
 1 | package e2e
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"testing"
 6 | 	"time"
 7 | 
 8 | 	"github.com/0xERR0R/blocky/log"
 9 | 	. "github.com/onsi/ginkgo/v2"
10 | 	. "github.com/onsi/gomega"
11 | )
12 | 
13 | func init() {
14 | 	log.Silence()
15 | }
16 | 
17 | func TestLists(t *testing.T) {
18 | 	RegisterFailHandler(Fail)
19 | 	RunSpecs(t, "e2e Suite", Label("e2e"))
20 | }
21 | 
22 | var _ = BeforeSuite(func(ctx context.Context) {
23 | 	SetDefaultEventuallyTimeout(5 * time.Second)
24 | })
25 | 


--------------------------------------------------------------------------------
/evt/events.go:
--------------------------------------------------------------------------------
 1 | package evt
 2 | 
 3 | import (
 4 | 	"github.com/asaskevich/EventBus"
 5 | )
 6 | 
 7 | const (
 8 | 	// BlockingEnabledEvent fires if blocking status will be changed. Parameter: boolean (enabled = true)
 9 | 	BlockingEnabledEvent = "blocking:enabled"
10 | 
11 | 	// BlockingCacheGroupChanged fires, if a list group is changed. Parameter: list type, group name, element count
12 | 	BlockingCacheGroupChanged = "blocking:cachingGroupChanged"
13 | 
14 | 	// CachingDomainPrefetched fires if a domain will be prefetched, Parameter: domain name
15 | 	CachingDomainPrefetched = "caching:prefetched"
16 | 
17 | 	// CachingResultCacheChanged fires if a result cache was changed, Parameter: new cache size
18 | 	CachingResultCacheChanged = "caching:resultCacheChanged"
19 | 
20 | 	// CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name
21 | 	CachingPrefetchCacheHit = "caching:prefetchHit"
22 | 
23 | 	// CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count
24 | 	CachingDomainsToPrefetchCountChanged = "caching:domainsToPrefetchCountChanged"
25 | 
26 | 	// CachingFailedDownloadChanged fires, if a download of a blocking list or hosts file fails
27 | 	CachingFailedDownloadChanged = "caching:failedDownload"
28 | 
29 | 	// ApplicationStarted fires on start of the application. Parameter: version number, build time
30 | 	ApplicationStarted = "application:started"
31 | )
32 | 
33 | //nolint:gochecknoglobals
34 | var evtBus = EventBus.New()
35 | 
36 | // Bus returns the global bus instance
37 | func Bus() EventBus.Bus {
38 | 	return evtBus
39 | }
40 | 


--------------------------------------------------------------------------------
/helpertest/http.go:
--------------------------------------------------------------------------------
 1 | package helpertest
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"net"
 6 | 	"net/http"
 7 | 	"net/url"
 8 | 	"sync/atomic"
 9 | 
10 | 	"github.com/onsi/ginkgo/v2"
11 | )
12 | 
13 | type HTTPProxy struct {
14 | 	Addr          net.Addr
15 | 	requestTarget atomic.Value // string: HTTP Host of latest request
16 | }
17 | 
18 | // TestHTTPProxy returns a new HTTPProxy server.
19 | //
20 | // All requests return http.StatusNotImplemented.
21 | func TestHTTPProxy() *HTTPProxy {
22 | 	proxyListener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
23 | 	if err != nil {
24 | 		ginkgo.Fail(fmt.Sprintf("could not create HTTP proxy listener: %s", err))
25 | 	}
26 | 
27 | 	proxy := &HTTPProxy{
28 | 		Addr: proxyListener.Addr(),
29 | 	}
30 | 
31 | 	proxySrv := http.Server{ //nolint:gosec
32 | 		Addr:    "127.0.0.1:0",
33 | 		Handler: proxy,
34 | 	}
35 | 
36 | 	go func() { _ = proxySrv.Serve(proxyListener) }()
37 | 	ginkgo.DeferCleanup(proxySrv.Close)
38 | 
39 | 	return proxy
40 | }
41 | 
42 | // URL returns the proxy's URL for use by clients.
43 | func (p *HTTPProxy) URL() *url.URL {
44 | 	return &url.URL{
45 | 		Scheme: "http",
46 | 		Host:   p.Addr.String(),
47 | 	}
48 | }
49 | 
50 | // Check ReqURL has the right type signature for http.Transport.Proxy
51 | var _ = http.Transport{Proxy: (*HTTPProxy)(nil).ReqURL}
52 | 
53 | func (p *HTTPProxy) ReqURL(*http.Request) (*url.URL, error) {
54 | 	return p.URL(), nil
55 | }
56 | 
57 | // RequestTarget returns the target of the last request.
58 | func (p *HTTPProxy) RequestTarget() string {
59 | 	val := p.requestTarget.Load()
60 | 	if val == nil {
61 | 		ginkgo.Fail(fmt.Sprintf("http proxy %s received no requests", p.Addr))
62 | 	}
63 | 
64 | 	return val.(string)
65 | }
66 | 
67 | func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
68 | 	p.requestTarget.Store(req.Host)
69 | 
70 | 	w.WriteHeader(http.StatusNotImplemented)
71 | }
72 | 


--------------------------------------------------------------------------------
/helpertest/mock_call_sequence.go:
--------------------------------------------------------------------------------
 1 | package helpertest
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"fmt"
 6 | 	"sync"
 7 | 	"time"
 8 | )
 9 | 
10 | const mockCallTimeout = 2 * time.Second
11 | 
12 | type MockCallSequence[T any] struct {
13 | 	driver    func(chan<- T, chan<- error)
14 | 	res       chan T
15 | 	err       chan error
16 | 	callCount uint
17 | 	initOnce  sync.Once
18 | 	closeOnce sync.Once
19 | }
20 | 
21 | func NewMockCallSequence[T any](driver func(chan<- T, chan<- error)) MockCallSequence[T] {
22 | 	return MockCallSequence[T]{
23 | 		driver: driver,
24 | 	}
25 | }
26 | 
27 | func (m *MockCallSequence[T]) Call() (T, error) {
28 | 	m.callCount++
29 | 
30 | 	m.initOnce.Do(func() {
31 | 		m.res = make(chan T)
32 | 		m.err = make(chan error)
33 | 
34 | 		// This goroutine never stops
35 | 		go func() {
36 | 			defer m.Close()
37 | 
38 | 			m.driver(m.res, m.err)
39 | 		}()
40 | 	})
41 | 
42 | 	ctx, cancel := context.WithTimeout(context.Background(), mockCallTimeout)
43 | 	defer cancel()
44 | 
45 | 	select {
46 | 	case t, ok := <-m.res:
47 | 		if !ok {
48 | 			break
49 | 		}
50 | 
51 | 		return t, nil
52 | 
53 | 	case err, ok := <-m.err:
54 | 		if !ok {
55 | 			break
56 | 		}
57 | 
58 | 		var zero T
59 | 
60 | 		return zero, err
61 | 
62 | 	case <-ctx.Done():
63 | 		panic(fmt.Sprintf("mock call sequence driver timed-out on call %d", m.CallCount()))
64 | 	}
65 | 
66 | 	panic("mock call sequence called after driver returned (or sequence Close was called explicitly)")
67 | }
68 | 
69 | func (m *MockCallSequence[T]) CallCount() uint {
70 | 	return m.callCount
71 | }
72 | 
73 | func (m *MockCallSequence[T]) Close() {
74 | 	m.closeOnce.Do(func() {
75 | 		close(m.res)
76 | 		close(m.err)
77 | 	})
78 | }
79 | 


--------------------------------------------------------------------------------
/helpertest/tmpdata.go:
--------------------------------------------------------------------------------
  1 | package helpertest
  2 | 
  3 | import (
  4 | 	"bufio"
  5 | 	"io/fs"
  6 | 	"os"
  7 | 	"path/filepath"
  8 | 
  9 | 	. "github.com/onsi/ginkgo/v2"
 10 | 	. "github.com/onsi/gomega"
 11 | )
 12 | 
 13 | type TmpFolder struct {
 14 | 	Path   string
 15 | 	prefix string
 16 | }
 17 | 
 18 | type TmpFile struct {
 19 | 	Path   string
 20 | 	Folder *TmpFolder
 21 | }
 22 | 
 23 | func NewTmpFolder(prefix string) *TmpFolder {
 24 | 	ipref := prefix
 25 | 
 26 | 	if len(ipref) == 0 {
 27 | 		ipref = "blocky"
 28 | 	}
 29 | 
 30 | 	path, err := os.MkdirTemp("", ipref)
 31 | 	Expect(err).Should(Succeed())
 32 | 
 33 | 	res := &TmpFolder{
 34 | 		Path:   path,
 35 | 		prefix: ipref,
 36 | 	}
 37 | 
 38 | 	DeferCleanup(res.Clean)
 39 | 
 40 | 	return res
 41 | }
 42 | 
 43 | func (tf *TmpFolder) Clean() error {
 44 | 	if len(tf.Path) > 0 {
 45 | 		return os.RemoveAll(tf.Path)
 46 | 	}
 47 | 
 48 | 	return nil
 49 | }
 50 | 
 51 | func (tf *TmpFolder) CreateSubFolder(name string) *TmpFolder {
 52 | 	var path string
 53 | 
 54 | 	var err error
 55 | 
 56 | 	if len(name) > 0 {
 57 | 		path = filepath.Join(tf.Path, name)
 58 | 		err = os.Mkdir(path, fs.ModePerm)
 59 | 	} else {
 60 | 		path, err = os.MkdirTemp(tf.Path, tf.prefix)
 61 | 	}
 62 | 
 63 | 	Expect(err).Should(Succeed())
 64 | 
 65 | 	res := &TmpFolder{
 66 | 		Path:   path,
 67 | 		prefix: tf.prefix,
 68 | 	}
 69 | 
 70 | 	return res
 71 | }
 72 | 
 73 | func (tf *TmpFolder) CreateEmptyFile(name string) *TmpFile {
 74 | 	f, res := tf.createFile(name)
 75 | 	defer f.Close()
 76 | 
 77 | 	return res
 78 | }
 79 | 
 80 | func (tf *TmpFolder) CreateStringFile(name string, lines ...string) *TmpFile {
 81 | 	f, res := tf.createFile(name)
 82 | 	defer f.Close()
 83 | 
 84 | 	first := true
 85 | 	w := bufio.NewWriter(f)
 86 | 
 87 | 	for _, l := range lines {
 88 | 		if first {
 89 | 			first = false
 90 | 		} else {
 91 | 			_, err := w.WriteString("\n")
 92 | 			Expect(err).Should(Succeed())
 93 | 		}
 94 | 
 95 | 		_, err := w.WriteString(l)
 96 | 		Expect(err).Should(Succeed())
 97 | 	}
 98 | 
 99 | 	w.Flush()
100 | 
101 | 	return res
102 | }
103 | 
104 | func (tf *TmpFolder) JoinPath(name string) string {
105 | 	return filepath.Join(tf.Path, name)
106 | }
107 | 
108 | func (tf *TmpFolder) createFile(name string) (*os.File, *TmpFile) {
109 | 	var (
110 | 		f   *os.File
111 | 		err error
112 | 	)
113 | 
114 | 	if len(name) > 0 {
115 | 		f, err = os.Create(filepath.Join(tf.Path, name))
116 | 	} else {
117 | 		f, err = os.CreateTemp(tf.Path, "temp")
118 | 	}
119 | 
120 | 	Expect(err).Should(Succeed())
121 | 
122 | 	return f, &TmpFile{
123 | 		Path:   f.Name(),
124 | 		Folder: tf,
125 | 	}
126 | }
127 | 


--------------------------------------------------------------------------------
/lists/downloader.go:
--------------------------------------------------------------------------------
  1 | package lists
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"errors"
  6 | 	"fmt"
  7 | 	"io"
  8 | 	"net"
  9 | 	"net/http"
 10 | 
 11 | 	"github.com/0xERR0R/blocky/config"
 12 | 	"github.com/0xERR0R/blocky/evt"
 13 | 	"github.com/avast/retry-go/v4"
 14 | )
 15 | 
 16 | // TransientError represents a temporary error like timeout, network errors...
 17 | type TransientError struct {
 18 | 	inner error
 19 | }
 20 | 
 21 | func (e *TransientError) Error() string {
 22 | 	return fmt.Sprintf("temporary error occurred: %v", e.inner)
 23 | }
 24 | 
 25 | func (e *TransientError) Unwrap() error {
 26 | 	return e.inner
 27 | }
 28 | 
 29 | // FileDownloader is able to download some text file
 30 | type FileDownloader interface {
 31 | 	DownloadFile(ctx context.Context, link string) (io.ReadCloser, error)
 32 | }
 33 | 
 34 | // httpDownloader downloads files via HTTP protocol
 35 | type httpDownloader struct {
 36 | 	cfg config.Downloader
 37 | 
 38 | 	client http.Client
 39 | }
 40 | 
 41 | func NewDownloader(cfg config.Downloader, transport http.RoundTripper) FileDownloader {
 42 | 	return newDownloader(cfg, transport)
 43 | }
 44 | 
 45 | func newDownloader(cfg config.Downloader, transport http.RoundTripper) *httpDownloader {
 46 | 	return &httpDownloader{
 47 | 		cfg: cfg,
 48 | 
 49 | 		client: http.Client{
 50 | 			Transport: transport,
 51 | 			Timeout:   cfg.Timeout.ToDuration(),
 52 | 		},
 53 | 	}
 54 | }
 55 | 
 56 | func (d *httpDownloader) DownloadFile(ctx context.Context, link string) (io.ReadCloser, error) {
 57 | 	var body io.ReadCloser
 58 | 
 59 | 	err := retry.Do(
 60 | 		func() error {
 61 | 			req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil)
 62 | 			if err != nil {
 63 | 				return err
 64 | 			}
 65 | 
 66 | 			resp, httpErr := d.client.Do(req)
 67 | 			if httpErr == nil {
 68 | 				if resp.StatusCode == http.StatusOK {
 69 | 					body = resp.Body
 70 | 
 71 | 					return nil
 72 | 				}
 73 | 
 74 | 				_ = resp.Body.Close()
 75 | 
 76 | 				return fmt.Errorf("got status code %d", resp.StatusCode)
 77 | 			}
 78 | 
 79 | 			var netErr net.Error
 80 | 			if errors.As(httpErr, &netErr) && netErr.Timeout() {
 81 | 				return &TransientError{inner: netErr}
 82 | 			}
 83 | 
 84 | 			return httpErr
 85 | 		},
 86 | 		retry.Attempts(d.cfg.Attempts),
 87 | 		retry.DelayType(retry.FixedDelay),
 88 | 		retry.Delay(d.cfg.Cooldown.ToDuration()),
 89 | 		retry.LastErrorOnly(true),
 90 | 		retry.OnRetry(func(n uint, err error) {
 91 | 			var transientErr *TransientError
 92 | 
 93 | 			var dnsErr *net.DNSError
 94 | 
 95 | 			logger := logger().
 96 | 				WithField("link", link).
 97 | 				WithField("attempt", fmt.Sprintf("%d/%d", n+1, d.cfg.Attempts))
 98 | 
 99 | 			switch {
100 | 			case errors.As(err, &transientErr):
101 | 				logger.Warnf("Temporary network err / Timeout occurred: %s", transientErr)
102 | 			case errors.As(err, &dnsErr):
103 | 				logger.Warnf("Name resolution err: %s", dnsErr.Err)
104 | 			default:
105 | 				logger.Warnf("Can't download file: %s", err)
106 | 			}
107 | 
108 | 			onDownloadError(link)
109 | 		}))
110 | 
111 | 	return body, err
112 | }
113 | 
114 | func onDownloadError(link string) {
115 | 	evt.Bus().Publish(evt.CachingFailedDownloadChanged, link)
116 | }
117 | 


--------------------------------------------------------------------------------
/lists/list_cache_benchmark_test.go:
--------------------------------------------------------------------------------
 1 | package lists
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"testing"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/config"
 8 | )
 9 | 
10 | func BenchmarkRefresh(b *testing.B) {
11 | 	file1, _ := createTestListFile(b.TempDir(), 100000)
12 | 	file2, _ := createTestListFile(b.TempDir(), 150000)
13 | 	file3, _ := createTestListFile(b.TempDir(), 130000)
14 | 	lists := map[string][]config.BytesSource{
15 | 		"gr1": config.NewBytesSources(file1, file2, file3),
16 | 	}
17 | 
18 | 	cfg := config.SourceLoading{
19 | 		Concurrency:   5,
20 | 		RefreshPeriod: config.Duration(-1),
21 | 	}
22 | 	downloader := NewDownloader(config.Downloader{}, nil)
23 | 	cache, _ := NewListCache(context.Background(), ListCacheTypeDenylist, cfg, lists, downloader)
24 | 
25 | 	b.ReportAllocs()
26 | 
27 | 	for n := 0; n < b.N; n++ {
28 | 		_ = cache.Refresh()
29 | 	}
30 | }
31 | 


--------------------------------------------------------------------------------
/lists/list_cache_enum.go:
--------------------------------------------------------------------------------
 1 | // Code generated by go-enum DO NOT EDIT.
 2 | // Version:
 3 | // Revision:
 4 | // Build Date:
 5 | // Built By:
 6 | 
 7 | package lists
 8 | 
 9 | import (
10 | 	"fmt"
11 | 	"strings"
12 | )
13 | 
14 | const (
15 | 	// ListCacheTypeDenylist is a ListCacheType of type Denylist.
16 | 	// is a list with blocked domains
17 | 	ListCacheTypeDenylist ListCacheType = iota
18 | 	// ListCacheTypeAllowlist is a ListCacheType of type Allowlist.
19 | 	// is a list with allowlisted domains / IPs
20 | 	ListCacheTypeAllowlist
21 | )
22 | 
23 | var ErrInvalidListCacheType = fmt.Errorf("not a valid ListCacheType, try [%s]", strings.Join(_ListCacheTypeNames, ", "))
24 | 
25 | const _ListCacheTypeName = "denylistallowlist"
26 | 
27 | var _ListCacheTypeNames = []string{
28 | 	_ListCacheTypeName[0:8],
29 | 	_ListCacheTypeName[8:17],
30 | }
31 | 
32 | // ListCacheTypeNames returns a list of possible string values of ListCacheType.
33 | func ListCacheTypeNames() []string {
34 | 	tmp := make([]string, len(_ListCacheTypeNames))
35 | 	copy(tmp, _ListCacheTypeNames)
36 | 	return tmp
37 | }
38 | 
39 | var _ListCacheTypeMap = map[ListCacheType]string{
40 | 	ListCacheTypeDenylist:  _ListCacheTypeName[0:8],
41 | 	ListCacheTypeAllowlist: _ListCacheTypeName[8:17],
42 | }
43 | 
44 | // String implements the Stringer interface.
45 | func (x ListCacheType) String() string {
46 | 	if str, ok := _ListCacheTypeMap[x]; ok {
47 | 		return str
48 | 	}
49 | 	return fmt.Sprintf("ListCacheType(%d)", x)
50 | }
51 | 
52 | // IsValid provides a quick way to determine if the typed value is
53 | // part of the allowed enumerated values
54 | func (x ListCacheType) IsValid() bool {
55 | 	_, ok := _ListCacheTypeMap[x]
56 | 	return ok
57 | }
58 | 
59 | var _ListCacheTypeValue = map[string]ListCacheType{
60 | 	_ListCacheTypeName[0:8]:  ListCacheTypeDenylist,
61 | 	_ListCacheTypeName[8:17]: ListCacheTypeAllowlist,
62 | }
63 | 
64 | // ParseListCacheType attempts to convert a string to a ListCacheType.
65 | func ParseListCacheType(name string) (ListCacheType, error) {
66 | 	if x, ok := _ListCacheTypeValue[name]; ok {
67 | 		return x, nil
68 | 	}
69 | 	return ListCacheType(0), fmt.Errorf("%s is %w", name, ErrInvalidListCacheType)
70 | }
71 | 
72 | // MarshalText implements the text marshaller method.
73 | func (x ListCacheType) MarshalText() ([]byte, error) {
74 | 	return []byte(x.String()), nil
75 | }
76 | 
77 | // UnmarshalText implements the text unmarshaller method.
78 | func (x *ListCacheType) UnmarshalText(text []byte) error {
79 | 	name := string(text)
80 | 	tmp, err := ParseListCacheType(name)
81 | 	if err != nil {
82 | 		return err
83 | 	}
84 | 	*x = tmp
85 | 	return nil
86 | }
87 | 


--------------------------------------------------------------------------------
/lists/list_suite_test.go:
--------------------------------------------------------------------------------
 1 | package lists
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestLists(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "Lists Suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/lists/parsers/adapt.go:
--------------------------------------------------------------------------------
 1 | package parsers
 2 | 
 3 | import "context"
 4 | 
 5 | // Adapt returns a parser that wraps `inner` converting each parsed value.
 6 | func Adapt[From, To any](inner SeriesParser[From], adapt func(From) To) SeriesParser[To] {
 7 | 	return TryAdapt(inner, func(from From) (To, error) {
 8 | 		return adapt(from), nil
 9 | 	})
10 | }
11 | 
12 | // TryAdapt returns a parser that wraps `inner` and tries to convert each parsed value.
13 | func TryAdapt[From, To any](inner SeriesParser[From], adapt func(From) (To, error)) SeriesParser[To] {
14 | 	return newAdapter(inner, adapt)
15 | }
16 | 
17 | // TryAdaptMethod returns a parser that wraps `inner` and tries to convert each parsed value
18 | // using the given method with pointer receiver of `To`.
19 | func TryAdaptMethod[ToPtr *To, From, To any](
20 | 	inner SeriesParser[From], method func(ToPtr, From) error,
21 | ) SeriesParser[*To] {
22 | 	return TryAdapt(inner, func(from From) (*To, error) {
23 | 		res := new(To)
24 | 
25 | 		err := method(res, from)
26 | 		if err != nil {
27 | 			return nil, err
28 | 		}
29 | 
30 | 		return res, nil
31 | 	})
32 | }
33 | 
34 | type adapter[From, To any] struct {
35 | 	inner SeriesParser[From]
36 | 	adapt func(From) (To, error)
37 | }
38 | 
39 | func newAdapter[From, To any](inner SeriesParser[From], adapt func(From) (To, error)) SeriesParser[To] {
40 | 	return &adapter[From, To]{inner, adapt}
41 | }
42 | 
43 | func (a *adapter[From, To]) Position() string {
44 | 	return a.inner.Position()
45 | }
46 | 
47 | func (a *adapter[From, To]) Next(ctx context.Context) (To, error) {
48 | 	from, err := a.inner.Next(ctx)
49 | 	if err != nil {
50 | 		var zero To
51 | 
52 | 		return zero, err
53 | 	}
54 | 
55 | 	res, err := a.adapt(from)
56 | 	if err != nil {
57 | 		var zero To
58 | 
59 | 		return zero, err
60 | 	}
61 | 
62 | 	return res, nil
63 | }
64 | 


--------------------------------------------------------------------------------
/lists/parsers/filtererrors.go:
--------------------------------------------------------------------------------
 1 | package parsers
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"errors"
 6 | )
 7 | 
 8 | // NoErrorLimit can be used to continue parsing until EOF.
 9 | const NoErrorLimit = -1
10 | 
11 | var ErrTooManyErrors = errors.New("too many parse errors")
12 | 
13 | type FilteredSeriesParser[T any] interface {
14 | 	SeriesParser[T]
15 | 
16 | 	// OnErr registers a callback invoked for each error encountered.
17 | 	OnErr(func(error))
18 | }
19 | 
20 | // AllowErrors returns a parser that wraps `inner` and tries to continue parsing.
21 | //
22 | // After `n` errors, it returns any error `inner` does.
23 | func FilterErrors[T any](inner SeriesParser[T], filter func(error) error) FilteredSeriesParser[T] {
24 | 	return newErrorFilter(inner, filter)
25 | }
26 | 
27 | // AllowErrors returns a parser that wraps `inner` and tries to continue parsing.
28 | //
29 | // After `n` errors, it returns any error `inner` does.
30 | func AllowErrors[T any](inner SeriesParser[T], n int) FilteredSeriesParser[T] {
31 | 	if n == NoErrorLimit {
32 | 		return FilterErrors(inner, func(error) error { return nil })
33 | 	}
34 | 
35 | 	count := 0
36 | 
37 | 	return FilterErrors(inner, func(err error) error {
38 | 		count++
39 | 
40 | 		if count > n {
41 | 			return ErrTooManyErrors
42 | 		}
43 | 
44 | 		return nil
45 | 	})
46 | }
47 | 
48 | type errorFilter[T any] struct {
49 | 	inner  SeriesParser[T]
50 | 	filter func(error) error
51 | }
52 | 
53 | func newErrorFilter[T any](inner SeriesParser[T], filter func(error) error) FilteredSeriesParser[T] {
54 | 	return &errorFilter[T]{inner, filter}
55 | }
56 | 
57 | func (f *errorFilter[T]) OnErr(callback func(error)) {
58 | 	filter := f.filter
59 | 
60 | 	f.filter = func(err error) error {
61 | 		callback(ErrWithPosition(f.inner, err))
62 | 
63 | 		return filter(err)
64 | 	}
65 | }
66 | 
67 | func (f *errorFilter[T]) Position() string {
68 | 	return f.inner.Position()
69 | }
70 | 
71 | func (f *errorFilter[T]) Next(ctx context.Context) (T, error) {
72 | 	var zero T
73 | 
74 | 	for {
75 | 		res, err := f.inner.Next(ctx)
76 | 		if err != nil {
77 | 			if IsNonResumableErr(err) {
78 | 				// bypass the filter, and just propagate the error
79 | 				return zero, err
80 | 			}
81 | 
82 | 			err = f.filter(err)
83 | 			if err != nil {
84 | 				return zero, err
85 | 			}
86 | 
87 | 			continue
88 | 		}
89 | 
90 | 		return res, nil
91 | 	}
92 | }
93 | 


--------------------------------------------------------------------------------
/lists/parsers/lines.go:
--------------------------------------------------------------------------------
 1 | package parsers
 2 | 
 3 | import (
 4 | 	"bufio"
 5 | 	"context"
 6 | 	"encoding"
 7 | 	"fmt"
 8 | 	"io"
 9 | 	"strings"
10 | 	"unicode"
11 | )
12 | 
13 | // Lines splits `r` into a series of lines.
14 | //
15 | // Empty lines are skipped, and comments are stripped.
16 | func Lines(r io.Reader) SeriesParser[string] {
17 | 	return newLines(r)
18 | }
19 | 
20 | // LinesAs returns a parser that parses each line of `r` as a `T`.
21 | func LinesAs[TPtr TextUnmarshaler[T], T any](r io.Reader) SeriesParser[*T] {
22 | 	return UnmarshalEach[TPtr](Lines(r))
23 | }
24 | 
25 | // UnmarshalEach returns a parser that unmarshals each string of `inner` as a `T`.
26 | func UnmarshalEach[TPtr TextUnmarshaler[T], T any](inner SeriesParser[string]) SeriesParser[*T] {
27 | 	stringToBytes := func(s string) []byte {
28 | 		return []byte(s)
29 | 	}
30 | 
31 | 	return TryAdaptMethod(Adapt(inner, stringToBytes), TPtr.UnmarshalText)
32 | }
33 | 
34 | type TextUnmarshaler[T any] interface {
35 | 	encoding.TextUnmarshaler
36 | 	*T
37 | }
38 | 
39 | type lines struct {
40 | 	scanner *bufio.Scanner
41 | 	lineNo  uint
42 | }
43 | 
44 | func newLines(r io.Reader) SeriesParser[string] {
45 | 	scanner := bufio.NewScanner(r)
46 | 	scanner.Split(bufio.ScanLines)
47 | 
48 | 	return &lines{scanner: scanner}
49 | }
50 | 
51 | func (l *lines) Position() string {
52 | 	return fmt.Sprintf("line %d", l.lineNo)
53 | }
54 | 
55 | func (l *lines) Next(ctx context.Context) (string, error) {
56 | 	for {
57 | 		l.lineNo++
58 | 
59 | 		if err := ctx.Err(); err != nil {
60 | 			return "", NewNonResumableError(err)
61 | 		}
62 | 
63 | 		if !l.scanner.Scan() {
64 | 			break
65 | 		}
66 | 
67 | 		text := strings.TrimSpace(l.scanner.Text())
68 | 
69 | 		if len(text) == 0 {
70 | 			continue // empty line
71 | 		}
72 | 
73 | 		if idx := strings.IndexRune(text, '#'); idx != -1 {
74 | 			if idx == 0 {
75 | 				continue // commented line
76 | 			}
77 | 
78 | 			// end of line comment
79 | 			text = text[:idx]
80 | 			text = strings.TrimRightFunc(text, unicode.IsSpace)
81 | 		}
82 | 
83 | 		return text, nil
84 | 	}
85 | 
86 | 	err := l.scanner.Err()
87 | 	if err != nil {
88 | 		// bufio.Scanner does not support continuing after an error
89 | 		return "", NewNonResumableError(err)
90 | 	}
91 | 
92 | 	return "", NewNonResumableError(io.EOF)
93 | }
94 | 


--------------------------------------------------------------------------------
/lists/parsers/parser.go:
--------------------------------------------------------------------------------
 1 | package parsers
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"errors"
 6 | 	"fmt"
 7 | 	"io"
 8 | )
 9 | 
10 | // SeriesParser parses a series of `T`.
11 | type SeriesParser[T any] interface {
12 | 	// Next advances the cursor in the underlying data source,
13 | 	// and returns a `T`, or an error.
14 | 	//
15 | 	// Fatal parse errors, where no more calls to `Next` should
16 | 	// be made are of type `NonResumableError`.
17 | 	// Other errors apply to the item being parsed, and have no
18 | 	// impact on the rest of the series.
19 | 	Next(context.Context) (T, error)
20 | 
21 | 	// Position returns a string that gives an user readable indication
22 | 	// as to where in the parser's underlying data source the cursor is.
23 | 	//
24 | 	// The string should be understandable easily by the user.
25 | 	Position() string
26 | }
27 | 
28 | // ForEach is a helper for consuming a parser.
29 | //
30 | // It stops iteration at the first error encountered.
31 | // If that error is `io.EOF`, `nil` is returned instead.
32 | // Any other error is wrapped with the parser's position using `ErrWithPosition`.
33 | //
34 | // To continue iteration on resumable errors, use with `FilterErrors`.
35 | func ForEach[T any](ctx context.Context, parser SeriesParser[T], callback func(T) error) (rerr error) {
36 | 	defer func() {
37 | 		rerr = ErrWithPosition(parser, rerr)
38 | 	}()
39 | 
40 | 	for {
41 | 		if err := ctx.Err(); err != nil {
42 | 			return err
43 | 		}
44 | 
45 | 		res, err := parser.Next(ctx)
46 | 		if err != nil {
47 | 			if errors.Is(err, io.EOF) {
48 | 				return nil
49 | 			}
50 | 
51 | 			return err
52 | 		}
53 | 
54 | 		err = callback(res)
55 | 		if err != nil {
56 | 			return err
57 | 		}
58 | 	}
59 | }
60 | 
61 | // ErrWithPosition adds the `parser`'s position to the given `err`.
62 | func ErrWithPosition[T any](parser SeriesParser[T], err error) error {
63 | 	if err == nil {
64 | 		return nil
65 | 	}
66 | 
67 | 	return fmt.Errorf("%s: %w", parser.Position(), err)
68 | }
69 | 
70 | // IsNonResumableErr is a helper to check if an error returned by a parser is resumable.
71 | func IsNonResumableErr(err error) bool {
72 | 	var nonResumableError *NonResumableError
73 | 
74 | 	return errors.As(err, &nonResumableError)
75 | }
76 | 
77 | // NonResumableError represents an error from which a parser cannot recover.
78 | type NonResumableError struct {
79 | 	inner error
80 | }
81 | 
82 | // NewNonResumableError creates and returns a new `NonResumableError`.
83 | func NewNonResumableError(inner error) error {
84 | 	return &NonResumableError{inner}
85 | }
86 | 
87 | func (e *NonResumableError) Error() string {
88 | 	return fmt.Sprintf("non resumable parse error: %s", e.inner.Error())
89 | }
90 | 
91 | func (e *NonResumableError) Unwrap() error {
92 | 	return e.inner
93 | }
94 | 


--------------------------------------------------------------------------------
/lists/parsers/parsers_suite_test.go:
--------------------------------------------------------------------------------
 1 | package parsers
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestLists(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "Parsers Suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/lists/sourcereader.go:
--------------------------------------------------------------------------------
 1 | package lists
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"fmt"
 6 | 	"io"
 7 | 	"os"
 8 | 	"strings"
 9 | 
10 | 	"github.com/0xERR0R/blocky/config"
11 | )
12 | 
13 | type SourceOpener interface {
14 | 	fmt.Stringer
15 | 
16 | 	Open(ctx context.Context) (io.ReadCloser, error)
17 | }
18 | 
19 | func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) {
20 | 	switch source.Type {
21 | 	case config.BytesSourceTypeText:
22 | 		return &textOpener{source: source, locInfo: txtLocInfo}, nil
23 | 
24 | 	case config.BytesSourceTypeHttp:
25 | 		return &httpOpener{source: source, downloader: downloader}, nil
26 | 
27 | 	case config.BytesSourceTypeFile:
28 | 		return &fileOpener{source: source}, nil
29 | 	}
30 | 
31 | 	return nil, fmt.Errorf("cannot open %s", source)
32 | }
33 | 
34 | type textOpener struct {
35 | 	source  config.BytesSource
36 | 	locInfo string
37 | }
38 | 
39 | func (o *textOpener) Open(_ context.Context) (io.ReadCloser, error) {
40 | 	return io.NopCloser(strings.NewReader(o.source.From)), nil
41 | }
42 | 
43 | func (o *textOpener) String() string {
44 | 	return fmt.Sprintf("%s: %s", o.locInfo, o.source)
45 | }
46 | 
47 | type httpOpener struct {
48 | 	source     config.BytesSource
49 | 	downloader FileDownloader
50 | }
51 | 
52 | func (o *httpOpener) Open(ctx context.Context) (io.ReadCloser, error) {
53 | 	return o.downloader.DownloadFile(ctx, o.source.From)
54 | }
55 | 
56 | func (o *httpOpener) String() string {
57 | 	return o.source.String()
58 | }
59 | 
60 | type fileOpener struct {
61 | 	source config.BytesSource
62 | }
63 | 
64 | func (o *fileOpener) Open(_ context.Context) (io.ReadCloser, error) {
65 | 	return os.Open(o.source.From)
66 | }
67 | 
68 | func (o *fileOpener) String() string {
69 | 	return o.source.String()
70 | }
71 | 


--------------------------------------------------------------------------------
/log/context.go:
--------------------------------------------------------------------------------
 1 | package log
 2 | 
 3 | import (
 4 | 	"context"
 5 | 
 6 | 	"github.com/sirupsen/logrus"
 7 | )
 8 | 
 9 | type ctxKey struct{}
10 | 
11 | func NewCtx(ctx context.Context, logger *logrus.Entry) (context.Context, *logrus.Entry) {
12 | 	ctx = context.WithValue(ctx, ctxKey{}, logger)
13 | 
14 | 	return ctx, entryWithCtx(ctx, logger)
15 | }
16 | 
17 | func FromCtx(ctx context.Context) *logrus.Entry {
18 | 	logger, ok := ctx.Value(ctxKey{}).(*logrus.Entry)
19 | 	if !ok {
20 | 		// Fallback to the global logger
21 | 		return logrus.NewEntry(Log())
22 | 	}
23 | 
24 | 	// Ensure `logger.Context == ctx`, not always the case since `ctx` could be a child of `logger.Context`
25 | 	return entryWithCtx(ctx, logger)
26 | }
27 | 
28 | func entryWithCtx(ctx context.Context, logger *logrus.Entry) *logrus.Entry {
29 | 	loggerCopy := *logger
30 | 	loggerCopy.Context = ctx
31 | 
32 | 	return &loggerCopy
33 | }
34 | 
35 | func WrapCtx(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) {
36 | 	logger := FromCtx(ctx)
37 | 	logger = wrap(logger)
38 | 
39 | 	return NewCtx(ctx, logger)
40 | }
41 | 
42 | func CtxWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) {
43 | 	return WrapCtx(ctx, func(e *logrus.Entry) *logrus.Entry {
44 | 		return e.WithFields(fields)
45 | 	})
46 | }
47 | 


--------------------------------------------------------------------------------
/log/logger_enum.go:
--------------------------------------------------------------------------------
 1 | // Code generated by go-enum DO NOT EDIT.
 2 | // Version:
 3 | // Revision:
 4 | // Build Date:
 5 | // Built By:
 6 | 
 7 | package log
 8 | 
 9 | import (
10 | 	"fmt"
11 | 	"strings"
12 | )
13 | 
14 | const (
15 | 	// FormatTypeText is a FormatType of type Text.
16 | 	// logging as text
17 | 	FormatTypeText FormatType = iota
18 | 	// FormatTypeJson is a FormatType of type Json.
19 | 	// JSON format
20 | 	FormatTypeJson
21 | )
22 | 
23 | var ErrInvalidFormatType = fmt.Errorf("not a valid FormatType, try [%s]", strings.Join(_FormatTypeNames, ", "))
24 | 
25 | const _FormatTypeName = "textjson"
26 | 
27 | var _FormatTypeNames = []string{
28 | 	_FormatTypeName[0:4],
29 | 	_FormatTypeName[4:8],
30 | }
31 | 
32 | // FormatTypeNames returns a list of possible string values of FormatType.
33 | func FormatTypeNames() []string {
34 | 	tmp := make([]string, len(_FormatTypeNames))
35 | 	copy(tmp, _FormatTypeNames)
36 | 	return tmp
37 | }
38 | 
39 | var _FormatTypeMap = map[FormatType]string{
40 | 	FormatTypeText: _FormatTypeName[0:4],
41 | 	FormatTypeJson: _FormatTypeName[4:8],
42 | }
43 | 
44 | // String implements the Stringer interface.
45 | func (x FormatType) String() string {
46 | 	if str, ok := _FormatTypeMap[x]; ok {
47 | 		return str
48 | 	}
49 | 	return fmt.Sprintf("FormatType(%d)", x)
50 | }
51 | 
52 | // IsValid provides a quick way to determine if the typed value is
53 | // part of the allowed enumerated values
54 | func (x FormatType) IsValid() bool {
55 | 	_, ok := _FormatTypeMap[x]
56 | 	return ok
57 | }
58 | 
59 | var _FormatTypeValue = map[string]FormatType{
60 | 	_FormatTypeName[0:4]: FormatTypeText,
61 | 	_FormatTypeName[4:8]: FormatTypeJson,
62 | }
63 | 
64 | // ParseFormatType attempts to convert a string to a FormatType.
65 | func ParseFormatType(name string) (FormatType, error) {
66 | 	if x, ok := _FormatTypeValue[name]; ok {
67 | 		return x, nil
68 | 	}
69 | 	return FormatType(0), fmt.Errorf("%s is %w", name, ErrInvalidFormatType)
70 | }
71 | 
72 | // MarshalText implements the text marshaller method.
73 | func (x FormatType) MarshalText() ([]byte, error) {
74 | 	return []byte(x.String()), nil
75 | }
76 | 
77 | // UnmarshalText implements the text unmarshaller method.
78 | func (x *FormatType) UnmarshalText(text []byte) error {
79 | 	name := string(text)
80 | 	tmp, err := ParseFormatType(name)
81 | 	if err != nil {
82 | 		return err
83 | 	}
84 | 	*x = tmp
85 | 	return nil
86 | }
87 | 


--------------------------------------------------------------------------------
/log/mock_entry.go:
--------------------------------------------------------------------------------
 1 | package log
 2 | 
 3 | import (
 4 | 	"github.com/sirupsen/logrus"
 5 | 	"github.com/sirupsen/logrus/hooks/test"
 6 | 	"github.com/stretchr/testify/mock"
 7 | )
 8 | 
 9 | func NewMockEntry() (*logrus.Entry, *MockLoggerHook) {
10 | 	logger, _ := test.NewNullLogger()
11 | 	logger.Level = logrus.TraceLevel
12 | 
13 | 	entry := logrus.Entry{Logger: logger}
14 | 	hook := MockLoggerHook{}
15 | 
16 | 	entry.Logger.AddHook(&hook)
17 | 
18 | 	hook.On("Fire", mock.Anything).Return(nil)
19 | 
20 | 	return &entry, &hook
21 | }
22 | 
23 | type MockLoggerHook struct {
24 | 	mock.Mock
25 | 
26 | 	Messages []string
27 | }
28 | 
29 | // Levels implements `logrus.Hook`.
30 | func (h *MockLoggerHook) Levels() []logrus.Level {
31 | 	return logrus.AllLevels
32 | }
33 | 
34 | // Fire implements `logrus.Hook`.
35 | func (h *MockLoggerHook) Fire(entry *logrus.Entry) error {
36 | 	_ = h.Called()
37 | 
38 | 	h.Messages = append(h.Messages, entry.Message)
39 | 
40 | 	return nil
41 | }
42 | 


--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
 1 | package main
 2 | 
 3 | import (
 4 | 	"os"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/cmd"
 7 | )
 8 | 
 9 | func main() {
10 | 	cmd.Execute()
11 | 	os.Exit(0)
12 | }
13 | 


--------------------------------------------------------------------------------
/main_static.go:
--------------------------------------------------------------------------------
 1 | //go:build linux
 2 | // +build linux
 3 | 
 4 | package main
 5 | 
 6 | import (
 7 | 	"os"
 8 | 	"time"
 9 | 	_ "time/tzdata"
10 | 
11 | 	_ "github.com/breml/rootcerts"
12 | 
13 | 	reaper "github.com/ramr/go-reaper"
14 | )
15 | 
16 | //nolint:gochecknoinits
17 | func init() {
18 | 	go reaper.Start(reaper.Config{
19 | 		DisablePid1Check: true,
20 | 	})
21 | 
22 | 	setLocaltime()
23 | }
24 | 
25 | // set localtime to /etc/localtime if available
26 | // or modify the system time with the TZ environment variable if it is provided
27 | func setLocaltime() {
28 | 	// load /etc/localtime without modifying it
29 | 	if lt, err := os.ReadFile("/etc/localtime"); err == nil {
30 | 		if t, err := time.LoadLocationFromTZData("", lt); err == nil {
31 | 			time.Local = t
32 | 
33 | 			return
34 | 		}
35 | 	}
36 | 
37 | 	// use zoneinfo from time/tzdata and set location with the TZ environment variable
38 | 	if tz := os.Getenv("TZ"); tz != "" {
39 | 		if t, err := time.LoadLocation(tz); err == nil {
40 | 			time.Local = t
41 | 
42 | 			return
43 | 		}
44 | 	}
45 | }
46 | 


--------------------------------------------------------------------------------
/metrics/metrics.go:
--------------------------------------------------------------------------------
 1 | package metrics
 2 | 
 3 | import (
 4 | 	"github.com/0xERR0R/blocky/config"
 5 | 
 6 | 	"github.com/go-chi/chi/v5"
 7 | 	"github.com/prometheus/client_golang/prometheus"
 8 | 	"github.com/prometheus/client_golang/prometheus/collectors"
 9 | 	"github.com/prometheus/client_golang/prometheus/promhttp"
10 | )
11 | 
12 | //nolint:gochecknoglobals
13 | var Reg = prometheus.NewRegistry()
14 | 
15 | // RegisterMetric registers prometheus collector
16 | func RegisterMetric(c prometheus.Collector) {
17 | 	_ = Reg.Register(c)
18 | }
19 | 
20 | // Start starts prometheus endpoint
21 | func Start(router *chi.Mux, cfg config.Metrics) {
22 | 	if cfg.Enable {
23 | 		_ = Reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
24 | 		_ = Reg.Register(collectors.NewGoCollector())
25 | 		router.Handle(cfg.Path, promhttp.InstrumentMetricHandler(Reg,
26 | 			promhttp.HandlerFor(Reg, promhttp.HandlerOpts{})))
27 | 	}
28 | }
29 | 


--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
 1 | site_name: blocky
 2 | site_description: blocky Documentation
 3 | theme:
 4 |   name: material
 5 |   palette:
 6 |     primary: teal
 7 |     accent: teal
 8 | extra:
 9 |   version:
10 |     provider: mike
11 |   social:
12 |     - icon: fontawesome/brands/github
13 |       link: https://github.com/0xERR0R/blocky
14 |     - icon: simple/codeberg
15 |       link: https://codeberg.org/0xERR0R/blocky
16 |     - icon: fontawesome/brands/docker
17 |       link: https://hub.docker.com/r/spx01/blocky
18 | repo_url: https://github.com/0xERR0R/blocky
19 | 
20 | markdown_extensions:
21 |   - abbr
22 |   - pymdownx.snippets
23 |   - pymdownx.emoji:
24 |       emoji_index: !!python/name:material.extensions.emoji.twemoji
25 |       emoji_generator: !!python/name:material.extensions.emoji.to_svg
26 |   - pymdownx.highlight
27 |   - pymdownx.superfences
28 |   - admonition
29 |   - pymdownx.details
30 | 
31 | nav:
32 |   - 'Welcome': 'index.md'
33 |   - 'Configuration': 'configuration.md'
34 |   - 'Installation': 'installation.md'
35 |   - 'Prometheus / Grafana': 'prometheus_grafana.md'
36 |   - 'Interfaces': 'interfaces.md'
37 |   - 'Network configuration': 'network_configuration.md'
38 |   - 'Additional information': 'additional_information.md'
39 | 


--------------------------------------------------------------------------------
/model/models.go:
--------------------------------------------------------------------------------
 1 | package model
 2 | 
 3 | //go:generate go tool go-enum -f=$GOFILE --marshal --names
 4 | import (
 5 | 	"net"
 6 | 	"time"
 7 | 
 8 | 	"github.com/miekg/dns"
 9 | )
10 | 
11 | // ResponseType represents the type of the response ENUM(
12 | // RESOLVED // the response was resolved by the external upstream resolver
13 | // CACHED // the response was resolved from cache
14 | // BLOCKED // the query was blocked
15 | // CONDITIONAL // the query was resolved by the conditional upstream resolver
16 | // CUSTOMDNS // the query was resolved by a custom rule
17 | // HOSTSFILE // the query was resolved by looking up the hosts file
18 | // FILTERED // the query was filtered by query type
19 | // NOTFQDN // the query was filtered as it is not fqdn conform
20 | // SPECIAL // the query was resolved by the special use domain name resolver
21 | // )
22 | type ResponseType int
23 | 
24 | func (t ResponseType) ToExtendedErrorCode() uint16 {
25 | 	switch t {
26 | 	case ResponseTypeRESOLVED:
27 | 		return dns.ExtendedErrorCodeOther
28 | 	case ResponseTypeCACHED:
29 | 		return dns.ExtendedErrorCodeCachedError
30 | 	case ResponseTypeCONDITIONAL:
31 | 		return dns.ExtendedErrorCodeForgedAnswer
32 | 	case ResponseTypeCUSTOMDNS:
33 | 		return dns.ExtendedErrorCodeForgedAnswer
34 | 	case ResponseTypeHOSTSFILE:
35 | 		return dns.ExtendedErrorCodeForgedAnswer
36 | 	case ResponseTypeNOTFQDN:
37 | 		return dns.ExtendedErrorCodeBlocked
38 | 	case ResponseTypeBLOCKED:
39 | 		return dns.ExtendedErrorCodeBlocked
40 | 	case ResponseTypeFILTERED:
41 | 		return dns.ExtendedErrorCodeFiltered
42 | 	case ResponseTypeSPECIAL:
43 | 		return dns.ExtendedErrorCodeFiltered
44 | 	default:
45 | 		return dns.ExtendedErrorCodeOther
46 | 	}
47 | }
48 | 
49 | // Response represents the response of a DNS query
50 | type Response struct {
51 | 	Res    *dns.Msg
52 | 	Reason string
53 | 	RType  ResponseType
54 | }
55 | 
56 | // RequestProtocol represents the server protocol ENUM(
57 | // TCP // is the TCP protocol
58 | // UDP // is the UDP protocol
59 | // )
60 | type RequestProtocol uint8
61 | 
62 | // Request represents client's DNS request
63 | type Request struct {
64 | 	ClientIP        net.IP
65 | 	RequestClientID string
66 | 	Protocol        RequestProtocol
67 | 	ClientNames     []string
68 | 	Req             *dns.Msg
69 | 	RequestTS       time.Time
70 | }
71 | 


--------------------------------------------------------------------------------
/querylog/logger_writer.go:
--------------------------------------------------------------------------------
 1 | package querylog
 2 | 
 3 | import (
 4 | 	"reflect"
 5 | 	"strings"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 	"github.com/sirupsen/logrus"
 9 | )
10 | 
11 | const loggerPrefixLoggerWriter = "queryLog"
12 | 
13 | type LoggerWriter struct {
14 | 	logger *logrus.Entry
15 | }
16 | 
17 | func NewLoggerWriter() *LoggerWriter {
18 | 	return &LoggerWriter{logger: log.PrefixedLog(loggerPrefixLoggerWriter)}
19 | }
20 | 
21 | func (d *LoggerWriter) Write(entry *LogEntry) {
22 | 	fields := LogEntryFields(entry)
23 | 
24 | 	d.logger.WithFields(fields).Infof("query resolved")
25 | }
26 | 
27 | func (d *LoggerWriter) CleanUp() {
28 | 	// Nothing to do
29 | }
30 | 
31 | func LogEntryFields(entry *LogEntry) logrus.Fields {
32 | 	return withoutZeroes(logrus.Fields{
33 | 		"client_ip":       entry.ClientIP,
34 | 		"client_names":    strings.Join(entry.ClientNames, "; "),
35 | 		"response_reason": entry.ResponseReason,
36 | 		"response_type":   entry.ResponseType,
37 | 		"response_code":   entry.ResponseCode,
38 | 		"question_name":   entry.QuestionName,
39 | 		"question_type":   entry.QuestionType,
40 | 		"answer":          entry.Answer,
41 | 		"duration_ms":     entry.DurationMs,
42 | 		"instance":        entry.BlockyInstance,
43 | 	})
44 | }
45 | 
46 | func withoutZeroes(fields logrus.Fields) logrus.Fields {
47 | 	for k, v := range fields {
48 | 		if reflect.ValueOf(v).IsZero() {
49 | 			delete(fields, k)
50 | 		}
51 | 	}
52 | 
53 | 	return fields
54 | }
55 | 


--------------------------------------------------------------------------------
/querylog/logger_writer_test.go:
--------------------------------------------------------------------------------
 1 | package querylog
 2 | 
 3 | import (
 4 | 	"time"
 5 | 
 6 | 	"github.com/sirupsen/logrus"
 7 | 	"github.com/sirupsen/logrus/hooks/test"
 8 | 
 9 | 	. "github.com/onsi/gomega"
10 | 
11 | 	. "github.com/onsi/ginkgo/v2"
12 | )
13 | 
14 | var _ = Describe("LoggerWriter", func() {
15 | 	Describe("logger query log", func() {
16 | 		When("New log entry was created", func() {
17 | 			It("should be logged", func() {
18 | 				writer := NewLoggerWriter()
19 | 				logger, hook := test.NewNullLogger()
20 | 				writer.logger = logger.WithField("k", "v")
21 | 
22 | 				writer.Write(&LogEntry{
23 | 					Start:      time.Now(),
24 | 					DurationMs: 20,
25 | 				})
26 | 
27 | 				Expect(hook.Entries).Should(HaveLen(1))
28 | 				Expect(hook.LastEntry().Message).Should(Equal("query resolved"))
29 | 			})
30 | 		})
31 | 		When("Cleanup is called", func() {
32 | 			It("should do nothing", func() {
33 | 				writer := NewLoggerWriter()
34 | 				writer.CleanUp()
35 | 			})
36 | 		})
37 | 	})
38 | 
39 | 	Describe("LogEntryFields", func() {
40 | 		It("should return log fields", func() {
41 | 			entry := LogEntry{
42 | 				ClientIP:     "ip",
43 | 				DurationMs:   100,
44 | 				QuestionType: "qtype",
45 | 				ResponseCode: "rcode",
46 | 			}
47 | 
48 | 			fields := LogEntryFields(&entry)
49 | 
50 | 			Expect(fields).Should(HaveKeyWithValue("client_ip", entry.ClientIP))
51 | 			Expect(fields).Should(HaveKeyWithValue("duration_ms", entry.DurationMs))
52 | 			Expect(fields).Should(HaveKeyWithValue("question_type", entry.QuestionType))
53 | 			Expect(fields).Should(HaveKeyWithValue("response_code", entry.ResponseCode))
54 | 
55 | 			Expect(fields).ShouldNot(HaveKey("client_names"))
56 | 			Expect(fields).ShouldNot(HaveKey("question_name"))
57 | 		})
58 | 	})
59 | 
60 | 	DescribeTable("withoutZeroes",
61 | 		func(value any, isZero bool) {
62 | 			fields := withoutZeroes(logrus.Fields{"a": value})
63 | 
64 | 			if isZero {
65 | 				Expect(fields).Should(BeEmpty())
66 | 			} else {
67 | 				Expect(fields).ShouldNot(BeEmpty())
68 | 			}
69 | 		},
70 | 		Entry("empty string",
71 | 			"",
72 | 			true),
73 | 		Entry("non-empty string",
74 | 			"something",
75 | 			false),
76 | 		Entry("zero int",
77 | 			0,
78 | 			true),
79 | 		Entry("non-zero int",
80 | 			1,
81 | 			false),
82 | 	)
83 | })
84 | 


--------------------------------------------------------------------------------
/querylog/none_writer.go:
--------------------------------------------------------------------------------
 1 | package querylog
 2 | 
 3 | type NoneWriter struct{}
 4 | 
 5 | func NewNoneWriter() *NoneWriter {
 6 | 	return &NoneWriter{}
 7 | }
 8 | 
 9 | func (d *NoneWriter) Write(*LogEntry) {
10 | 	// Nothing to do
11 | }
12 | 
13 | func (d *NoneWriter) CleanUp() {
14 | 	// Nothing to do
15 | }
16 | 


--------------------------------------------------------------------------------
/querylog/none_writer_test.go:
--------------------------------------------------------------------------------
 1 | package querylog
 2 | 
 3 | import (
 4 | 	. "github.com/onsi/ginkgo/v2"
 5 | )
 6 | 
 7 | var _ = Describe("NoneWriter", func() {
 8 | 	Describe("NoneWriter", func() {
 9 | 		When("write is called", func() {
10 | 			It("should do nothing", func() {
11 | 				NewNoneWriter().Write(nil)
12 | 			})
13 | 		})
14 | 		When("cleanUp is called", func() {
15 | 			It("should do nothing", func() {
16 | 				NewNoneWriter().CleanUp()
17 | 			})
18 | 		})
19 | 	})
20 | })
21 | 


--------------------------------------------------------------------------------
/querylog/querylog_suite_test.go:
--------------------------------------------------------------------------------
 1 | package querylog
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestResolver(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "Querylog Suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/querylog/writer.go:
--------------------------------------------------------------------------------
 1 | package querylog
 2 | 
 3 | import (
 4 | 	"time"
 5 | )
 6 | 
 7 | type LogEntry struct {
 8 | 	Start          time.Time
 9 | 	ClientIP       string
10 | 	ClientNames    []string
11 | 	DurationMs     int64
12 | 	ResponseReason string
13 | 	ResponseType   string
14 | 	ResponseCode   string
15 | 	QuestionType   string
16 | 	QuestionName   string
17 | 	Answer         string
18 | 	BlockyInstance string
19 | }
20 | 
21 | type Writer interface {
22 | 	Write(entry *LogEntry)
23 | 	CleanUp()
24 | }
25 | 


--------------------------------------------------------------------------------
/redis/redis_suite_test.go:
--------------------------------------------------------------------------------
 1 | package redis
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"testing"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 	"github.com/go-redis/redis/v8"
 9 | 	. "github.com/onsi/ginkgo/v2"
10 | 	. "github.com/onsi/gomega"
11 | )
12 | 
13 | func init() {
14 | 	log.Silence()
15 | 	redis.SetLogger(NoLogs{})
16 | }
17 | 
18 | func TestRedisClient(t *testing.T) {
19 | 	RegisterFailHandler(Fail)
20 | 	RunSpecs(t, "Redis Suite")
21 | }
22 | 
23 | type NoLogs struct{}
24 | 
25 | func (l NoLogs) Printf(context.Context, string, ...interface{}) {}
26 | 


--------------------------------------------------------------------------------
/resolver/ede_resolver.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/config"
 7 | 	"github.com/0xERR0R/blocky/model"
 8 | 	"github.com/0xERR0R/blocky/util"
 9 | 	"github.com/miekg/dns"
10 | )
11 | 
12 | // A EDEResolver is responsible for adding the reason for the response as EDNS0 option
13 | type EDEResolver struct {
14 | 	configurable[*config.EDE]
15 | 	NextResolver
16 | 	typed
17 | }
18 | 
19 | // NewEDEResolver creates new resolver instance which adds the reason for
20 | // the response as EDNS0 option to the response if it is enabled in the configuration
21 | func NewEDEResolver(cfg config.EDE) *EDEResolver {
22 | 	return &EDEResolver{
23 | 		configurable: withConfig(&cfg),
24 | 		typed:        withType("extended_error_code"),
25 | 	}
26 | }
27 | 
28 | // Resolve adds the reason as EDNS0 option to the response of the next resolver
29 | // if it is enabled in the configuration
30 | func (r *EDEResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
31 | 	if !r.cfg.Enable {
32 | 		return r.next.Resolve(ctx, request)
33 | 	}
34 | 
35 | 	resp, err := r.next.Resolve(ctx, request)
36 | 	if err != nil {
37 | 		return nil, err
38 | 	}
39 | 
40 | 	r.addExtraReasoning(resp)
41 | 
42 | 	return resp, nil
43 | }
44 | 
45 | // addExtraReasoning adds the reason for the response as EDNS0 option
46 | func (r *EDEResolver) addExtraReasoning(res *model.Response) {
47 | 	infocode := res.RType.ToExtendedErrorCode()
48 | 
49 | 	if infocode == dns.ExtendedErrorCodeOther {
50 | 		// dns.ExtendedErrorCodeOther seams broken in some clients
51 | 		return
52 | 	}
53 | 
54 | 	edeOption := new(dns.EDNS0_EDE)
55 | 	edeOption.InfoCode = infocode
56 | 	edeOption.ExtraText = res.Reason
57 | 
58 | 	util.SetEdns0Option(res.Res, edeOption)
59 | }
60 | 


--------------------------------------------------------------------------------
/resolver/filtering_resolver.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/config"
 7 | 	"github.com/0xERR0R/blocky/model"
 8 | 	"github.com/miekg/dns"
 9 | )
10 | 
11 | // FilteringResolver filters DNS queries (for example can drop all AAAA query)
12 | // returns empty ANSWER with NOERROR
13 | type FilteringResolver struct {
14 | 	configurable[*config.Filtering]
15 | 	NextResolver
16 | 	typed
17 | }
18 | 
19 | func NewFilteringResolver(cfg config.Filtering) *FilteringResolver {
20 | 	return &FilteringResolver{
21 | 		configurable: withConfig(&cfg),
22 | 		typed:        withType("filtering"),
23 | 	}
24 | }
25 | 
26 | func (r *FilteringResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
27 | 	qType := request.Req.Question[0].Qtype
28 | 	if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
29 | 		response := new(dns.Msg)
30 | 		response.SetRcode(request.Req, dns.RcodeSuccess)
31 | 
32 | 		return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil
33 | 	}
34 | 
35 | 	return r.next.Resolve(ctx, request)
36 | }
37 | 


--------------------------------------------------------------------------------
/resolver/filtering_resolver_test.go:
--------------------------------------------------------------------------------
  1 | package resolver
  2 | 
  3 | import (
  4 | 	"context"
  5 | 
  6 | 	"github.com/0xERR0R/blocky/config"
  7 | 	. "github.com/0xERR0R/blocky/helpertest"
  8 | 	"github.com/0xERR0R/blocky/log"
  9 | 	. "github.com/0xERR0R/blocky/model"
 10 | 
 11 | 	"github.com/miekg/dns"
 12 | 	. "github.com/onsi/ginkgo/v2"
 13 | 	. "github.com/onsi/gomega"
 14 | 	"github.com/stretchr/testify/mock"
 15 | )
 16 | 
 17 | var _ = Describe("FilteringResolver", func() {
 18 | 	var (
 19 | 		sut        *FilteringResolver
 20 | 		sutConfig  config.Filtering
 21 | 		m          *mockResolver
 22 | 		mockAnswer *dns.Msg
 23 | 
 24 | 		ctx      context.Context
 25 | 		cancelFn context.CancelFunc
 26 | 	)
 27 | 
 28 | 	Describe("Type", func() {
 29 | 		It("follows conventions", func() {
 30 | 			expectValidResolverType(sut)
 31 | 		})
 32 | 	})
 33 | 
 34 | 	BeforeEach(func() {
 35 | 		ctx, cancelFn = context.WithCancel(context.Background())
 36 | 		DeferCleanup(cancelFn)
 37 | 
 38 | 		mockAnswer = new(dns.Msg)
 39 | 	})
 40 | 
 41 | 	JustBeforeEach(func() {
 42 | 		sut = NewFilteringResolver(sutConfig)
 43 | 		m = &mockResolver{}
 44 | 		m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
 45 | 		sut.Next(m)
 46 | 	})
 47 | 
 48 | 	Describe("IsEnabled", func() {
 49 | 		It("is false", func() {
 50 | 			Expect(sut.IsEnabled()).Should(BeFalse())
 51 | 		})
 52 | 	})
 53 | 
 54 | 	Describe("LogConfig", func() {
 55 | 		It("should log something", func() {
 56 | 			logger, hook := log.NewMockEntry()
 57 | 
 58 | 			sut.LogConfig(logger)
 59 | 
 60 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
 61 | 		})
 62 | 	})
 63 | 
 64 | 	When("Filtering query types are defined", func() {
 65 | 		BeforeEach(func() {
 66 | 			sutConfig = config.Filtering{
 67 | 				QueryTypes: config.NewQTypeSet(AAAA, MX),
 68 | 			}
 69 | 		})
 70 | 		It("Should delegate to next resolver if request query has other type", func() {
 71 | 			Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
 72 | 				Should(
 73 | 					SatisfyAll(
 74 | 						HaveNoAnswer(),
 75 | 						HaveResponseType(ResponseTypeRESOLVED),
 76 | 						HaveReturnCode(dns.RcodeSuccess),
 77 | 					))
 78 | 
 79 | 			// delegated to next resolver
 80 | 			Expect(m.Calls).Should(HaveLen(1))
 81 | 		})
 82 | 		It("Should return empty answer for defined query type", func() {
 83 | 			Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
 84 | 				Should(
 85 | 					SatisfyAll(
 86 | 						HaveNoAnswer(),
 87 | 						HaveResponseType(ResponseTypeFILTERED),
 88 | 						HaveReturnCode(dns.RcodeSuccess),
 89 | 					))
 90 | 
 91 | 			// no call of next resolver
 92 | 			Expect(m.Calls).Should(BeZero())
 93 | 		})
 94 | 	})
 95 | 
 96 | 	When("No filtering query types are defined", func() {
 97 | 		BeforeEach(func() {
 98 | 			sutConfig = config.Filtering{}
 99 | 		})
100 | 		It("Should return empty answer without error", func() {
101 | 			Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
102 | 				Should(
103 | 					SatisfyAll(
104 | 						HaveNoAnswer(),
105 | 						HaveResponseType(ResponseTypeRESOLVED),
106 | 						HaveReturnCode(dns.RcodeSuccess),
107 | 					))
108 | 
109 | 			// delegated to next resolver
110 | 			Expect(m.Calls).Should(HaveLen(1))
111 | 		})
112 | 	})
113 | })
114 | 


--------------------------------------------------------------------------------
/resolver/fqdn_only_resolver.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"strings"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/config"
 8 | 	"github.com/0xERR0R/blocky/model"
 9 | 	"github.com/0xERR0R/blocky/util"
10 | 	"github.com/miekg/dns"
11 | )
12 | 
13 | type FQDNOnlyResolver struct {
14 | 	configurable[*config.FQDNOnly]
15 | 	NextResolver
16 | 	typed
17 | }
18 | 
19 | func NewFQDNOnlyResolver(cfg config.FQDNOnly) *FQDNOnlyResolver {
20 | 	return &FQDNOnlyResolver{
21 | 		configurable: withConfig(&cfg),
22 | 		typed:        withType("fqdn_only"),
23 | 	}
24 | }
25 | 
26 | func (r *FQDNOnlyResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
27 | 	if r.IsEnabled() {
28 | 		domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
29 | 		if !strings.Contains(domainFromQuestion, ".") {
30 | 			response := new(dns.Msg)
31 | 			response.Rcode = dns.RcodeNameError
32 | 
33 | 			return &model.Response{Res: response, RType: model.ResponseTypeNOTFQDN, Reason: "NOTFQDN"}, nil
34 | 		}
35 | 	}
36 | 
37 | 	return r.next.Resolve(ctx, request)
38 | }
39 | 


--------------------------------------------------------------------------------
/resolver/metrics_resolver_test.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"errors"
 6 | 
 7 | 	"github.com/0xERR0R/blocky/config"
 8 | 	"github.com/0xERR0R/blocky/log"
 9 | 
10 | 	. "github.com/0xERR0R/blocky/helpertest"
11 | 	. "github.com/0xERR0R/blocky/model"
12 | 
13 | 	"github.com/miekg/dns"
14 | 	. "github.com/onsi/ginkgo/v2"
15 | 	. "github.com/onsi/gomega"
16 | 	"github.com/prometheus/client_golang/prometheus"
17 | 	"github.com/prometheus/client_golang/prometheus/testutil"
18 | 	"github.com/stretchr/testify/mock"
19 | )
20 | 
21 | var _ = Describe("MetricResolver", func() {
22 | 	var (
23 | 		sut *MetricsResolver
24 | 		m   *mockResolver
25 | 
26 | 		ctx      context.Context
27 | 		cancelFn context.CancelFunc
28 | 	)
29 | 
30 | 	Describe("Type", func() {
31 | 		It("follows conventions", func() {
32 | 			expectValidResolverType(sut)
33 | 		})
34 | 	})
35 | 
36 | 	BeforeEach(func() {
37 | 		ctx, cancelFn = context.WithCancel(context.Background())
38 | 		DeferCleanup(cancelFn)
39 | 
40 | 		sut = NewMetricsResolver(config.Metrics{Enable: true})
41 | 		m = &mockResolver{}
42 | 		m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
43 | 		sut.Next(m)
44 | 	})
45 | 
46 | 	Describe("IsEnabled", func() {
47 | 		It("is true", func() {
48 | 			Expect(sut.IsEnabled()).Should(BeTrue())
49 | 		})
50 | 	})
51 | 
52 | 	Describe("LogConfig", func() {
53 | 		It("should log something", func() {
54 | 			logger, hook := log.NewMockEntry()
55 | 
56 | 			sut.LogConfig(logger)
57 | 
58 | 			Expect(hook.Calls).ShouldNot(BeEmpty())
59 | 		})
60 | 	})
61 | 
62 | 	Describe("Recording prometheus metrics", func() {
63 | 		Context("Recording request metrics", func() {
64 | 			When("Request will be performed", func() {
65 | 				It("Should record metrics", func() {
66 | 					Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))).
67 | 						Should(
68 | 							SatisfyAll(
69 | 								HaveResponseType(ResponseTypeRESOLVED),
70 | 								HaveReturnCode(dns.RcodeSuccess),
71 | 							))
72 | 
73 | 					cnt, err := sut.totalQueries.GetMetricWith(prometheus.Labels{"client": "client", "type": "A"})
74 | 					Expect(err).Should(Succeed())
75 | 
76 | 					Expect(testutil.ToFloat64(cnt)).Should(BeNumerically("==", 1))
77 | 					m.AssertExpectations(GinkgoT())
78 | 				})
79 | 			})
80 | 			When("Error occurs while request processing", func() {
81 | 				BeforeEach(func() {
82 | 					m = &mockResolver{}
83 | 					m.On("Resolve", mock.Anything).Return(nil, errors.New("error"))
84 | 					sut.Next(m)
85 | 				})
86 | 				It("Error should be recorded", func() {
87 | 					_, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))
88 | 
89 | 					Expect(err).Should(HaveOccurred())
90 | 
91 | 					Expect(testutil.ToFloat64(sut.totalErrors)).Should(BeNumerically("==", 1))
92 | 				})
93 | 			})
94 | 		})
95 | 	})
96 | })
97 | 


--------------------------------------------------------------------------------
/resolver/noop_resolver.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/model"
 7 | 	"github.com/sirupsen/logrus"
 8 | )
 9 | 
10 | var NoResponse = &model.Response{} //nolint:gochecknoglobals
11 | 
12 | // NoOpResolver is used to finish a resolver branch as created in RewriterResolver
13 | type NoOpResolver struct{}
14 | 
15 | func NewNoOpResolver() *NoOpResolver {
16 | 	return &NoOpResolver{}
17 | }
18 | 
19 | // Type implements `Resolver`.
20 | func (NoOpResolver) Type() string {
21 | 	return "noop"
22 | }
23 | 
24 | // String implements `fmt.Stringer`.
25 | func (r NoOpResolver) String() string {
26 | 	return r.Type()
27 | }
28 | 
29 | // IsEnabled implements `config.Configurable`.
30 | func (NoOpResolver) IsEnabled() bool {
31 | 	return true
32 | }
33 | 
34 | // LogConfig implements `config.Configurable`.
35 | func (NoOpResolver) LogConfig(*logrus.Entry) {
36 | }
37 | 
38 | func (NoOpResolver) Resolve(context.Context, *model.Request) (*model.Response, error) {
39 | 	return NoResponse, nil
40 | }
41 | 


--------------------------------------------------------------------------------
/resolver/noop_resolver_test.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 
 6 | 	. "github.com/0xERR0R/blocky/helpertest"
 7 | 	"github.com/0xERR0R/blocky/log"
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | var _ = Describe("NoOpResolver", func() {
13 | 	var (
14 | 		sut *NoOpResolver
15 | 
16 | 		ctx      context.Context
17 | 		cancelFn context.CancelFunc
18 | 	)
19 | 
20 | 	Describe("Type", func() {
21 | 		It("follows conventions", func() {
22 | 			expectValidResolverType(sut)
23 | 		})
24 | 	})
25 | 
26 | 	BeforeEach(func() {
27 | 		ctx, cancelFn = context.WithCancel(context.Background())
28 | 		DeferCleanup(cancelFn)
29 | 
30 | 		sut = NewNoOpResolver()
31 | 	})
32 | 
33 | 	Describe("Resolving", func() {
34 | 		It("returns no response", func() {
35 | 			resp, err := sut.Resolve(ctx, newRequest("test.tld", A))
36 | 			Expect(err).Should(Succeed())
37 | 			Expect(resp).Should(Equal(NoResponse))
38 | 		})
39 | 	})
40 | 
41 | 	Describe("IsEnabled", func() {
42 | 		It("is true", func() {
43 | 			Expect(sut.IsEnabled()).Should(BeTrue())
44 | 		})
45 | 	})
46 | 
47 | 	Describe("LogConfig", func() {
48 | 		It("should not log anything", func() {
49 | 			logger, hook := log.NewMockEntry()
50 | 
51 | 			sut.LogConfig(logger)
52 | 
53 | 			Expect(hook.Calls).Should(BeEmpty())
54 | 		})
55 | 	})
56 | })
57 | 


--------------------------------------------------------------------------------
/resolver/resolver_suite_test.go:
--------------------------------------------------------------------------------
 1 | package resolver
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"testing"
 6 | 	"time"
 7 | 
 8 | 	"github.com/0xERR0R/blocky/config"
 9 | 	"github.com/0xERR0R/blocky/log"
10 | 	"github.com/go-redis/redis/v8"
11 | 
12 | 	. "github.com/onsi/ginkgo/v2"
13 | 	. "github.com/onsi/gomega"
14 | )
15 | 
16 | const (
17 | 	timeout = 50 * time.Millisecond
18 | )
19 | 
20 | var defaultUpstreamsConfig config.Upstreams
21 | 
22 | func init() {
23 | 	log.Silence()
24 | 	redis.SetLogger(NoLogs{})
25 | 
26 | 	var err error
27 | 
28 | 	defaultUpstreamsConfig, err = config.WithDefaults[config.Upstreams]()
29 | 	if err != nil {
30 | 		panic(err)
31 | 	}
32 | 
33 | 	// Shorter timeout for tests
34 | 	defaultUpstreamsConfig.Timeout = config.Duration(timeout)
35 | }
36 | 
37 | func TestResolver(t *testing.T) {
38 | 	RegisterFailHandler(Fail)
39 | 	RunSpecs(t, "Resolver Suite")
40 | }
41 | 
42 | type NoLogs struct{}
43 | 
44 | func (l NoLogs) Printf(_ context.Context, _ string, _ ...interface{}) {}
45 | 


--------------------------------------------------------------------------------
/resolver/strict_resolver.go:
--------------------------------------------------------------------------------
  1 | package resolver
  2 | 
  3 | import (
  4 | 	"context"
  5 | 	"errors"
  6 | 	"fmt"
  7 | 	"strings"
  8 | 	"sync/atomic"
  9 | 
 10 | 	"github.com/0xERR0R/blocky/config"
 11 | 	"github.com/0xERR0R/blocky/model"
 12 | 	"github.com/0xERR0R/blocky/util"
 13 | 
 14 | 	"github.com/sirupsen/logrus"
 15 | )
 16 | 
 17 | const (
 18 | 	strictResolverType = "strict"
 19 | )
 20 | 
 21 | // StrictResolver delegates the DNS message strictly to the first configured upstream resolver
 22 | // if it can't provide the answer in time the next resolver is used
 23 | type StrictResolver struct {
 24 | 	configurable[*config.UpstreamGroup]
 25 | 	typed
 26 | 
 27 | 	resolvers atomic.Pointer[[]*upstreamResolverStatus]
 28 | }
 29 | 
 30 | // NewStrictResolver creates a new strict resolver instance
 31 | func NewStrictResolver(
 32 | 	ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap,
 33 | ) (*StrictResolver, error) {
 34 | 	r := newStrictResolver(
 35 | 		cfg,
 36 | 		[]Resolver{bootstrap}, // if init strategy is fast, use bootstrap until init finishes
 37 | 	)
 38 | 
 39 | 	return initGroupResolvers(ctx, r, cfg, bootstrap)
 40 | }
 41 | 
 42 | func newStrictResolver(
 43 | 	cfg config.UpstreamGroup, resolvers []Resolver,
 44 | ) *StrictResolver {
 45 | 	r := StrictResolver{
 46 | 		configurable: withConfig(&cfg),
 47 | 		typed:        withType(strictResolverType),
 48 | 	}
 49 | 
 50 | 	r.setResolvers(newUpstreamResolverStatuses(resolvers))
 51 | 
 52 | 	return &r
 53 | }
 54 | 
 55 | func (r *StrictResolver) setResolvers(resolvers []*upstreamResolverStatus) {
 56 | 	r.resolvers.Store(&resolvers)
 57 | }
 58 | 
 59 | func (r *StrictResolver) Name() string {
 60 | 	return r.String()
 61 | }
 62 | 
 63 | func (r *StrictResolver) String() string {
 64 | 	resolvers := *r.resolvers.Load()
 65 | 
 66 | 	upstreams := make([]string, len(resolvers))
 67 | 	for i, s := range resolvers {
 68 | 		upstreams[i] = s.resolver.String()
 69 | 	}
 70 | 
 71 | 	return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(upstreams, ","))
 72 | }
 73 | 
 74 | // Resolve sends the query request in a strict order to the upstream resolvers
 75 | func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
 76 | 	ctx, logger := r.log(ctx)
 77 | 
 78 | 	// start with first resolver
 79 | 	for _, resolver := range *r.resolvers.Load() {
 80 | 		logger.Debugf("using %s as resolver", resolver.resolver)
 81 | 
 82 | 		resp, err := resolver.resolve(ctx, request)
 83 | 		if err != nil {
 84 | 			// log error and try next upstream
 85 | 			logger.WithField("resolver", resolver.resolver).Debug("resolution failed from resolver, cause: ", err)
 86 | 
 87 | 			continue
 88 | 		}
 89 | 
 90 | 		logger.WithFields(logrus.Fields{
 91 | 			"resolver": *resolver,
 92 | 			"answer":   util.AnswerToString(resp.Res.Answer),
 93 | 		}).Debug("using response from resolver")
 94 | 
 95 | 		return resp, nil
 96 | 	}
 97 | 
 98 | 	return nil, errors.New("resolution was not successful, no resolver returned an answer in time")
 99 | }
100 | 


--------------------------------------------------------------------------------
/server/http.go:
--------------------------------------------------------------------------------
 1 | package server
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"net"
 6 | 	"net/http"
 7 | 	"time"
 8 | 
 9 | 	"github.com/0xERR0R/blocky/config"
10 | 	"github.com/go-chi/chi/v5"
11 | 	"github.com/go-chi/cors"
12 | )
13 | 
14 | type httpServer struct {
15 | 	inner http.Server
16 | 
17 | 	name string
18 | }
19 | 
20 | func newHTTPServer(name string, handler http.Handler, cfg *config.Config) *httpServer {
21 | 	var (
22 | 		writeTimeout      = cfg.Blocking.Loading.Downloads.WriteTimeout
23 | 		readTimeout       = cfg.Blocking.Loading.Downloads.ReadTimeout
24 | 		readHeaderTimeout = cfg.Blocking.Loading.Downloads.ReadHeaderTimeout
25 | 	)
26 | 
27 | 	return &httpServer{
28 | 		inner: http.Server{
29 | 			ReadTimeout:       time.Duration(readTimeout),
30 | 			ReadHeaderTimeout: time.Duration(readHeaderTimeout),
31 | 			WriteTimeout:      time.Duration(writeTimeout),
32 | 			Handler:           withCommonMiddleware(handler),
33 | 		},
34 | 
35 | 		name: name,
36 | 	}
37 | }
38 | 
39 | func (s *httpServer) String() string {
40 | 	return s.name
41 | }
42 | 
43 | func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
44 | 	go func() {
45 | 		<-ctx.Done()
46 | 
47 | 		s.inner.Close()
48 | 	}()
49 | 
50 | 	return s.inner.Serve(l)
51 | }
52 | 
53 | func withCommonMiddleware(inner http.Handler) *chi.Mux {
54 | 	// Middleware must be defined before routes, so
55 | 	// create a new router and mount the inner handler
56 | 	mux := chi.NewMux()
57 | 
58 | 	mux.Use(
59 | 		secureHeadersMiddleware,
60 | 		newCORSMiddleware(),
61 | 	)
62 | 
63 | 	mux.Mount("/", inner)
64 | 
65 | 	return mux
66 | }
67 | 
68 | type httpMiddleware = func(http.Handler) http.Handler
69 | 
70 | func secureHeadersMiddleware(next http.Handler) http.Handler {
71 | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
72 | 		if r.TLS != nil {
73 | 			w.Header().Set("strict-transport-security", "max-age=63072000")
74 | 			w.Header().Set("x-frame-options", "DENY")
75 | 			w.Header().Set("x-content-type-options", "nosniff")
76 | 			w.Header().Set("x-xss-protection", "1; mode=block")
77 | 		}
78 | 
79 | 		next.ServeHTTP(w, r)
80 | 	})
81 | }
82 | 
83 | func newCORSMiddleware() httpMiddleware {
84 | 	const corsMaxAge = 5 * time.Minute
85 | 
86 | 	options := cors.Options{
87 | 		AllowCredentials: true,
88 | 		AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
89 | 		AllowedMethods:   []string{"GET", "POST"},
90 | 		AllowedOrigins:   []string{"*"},
91 | 		ExposedHeaders:   []string{"Link"},
92 | 		MaxAge:           int(corsMaxAge.Seconds()),
93 | 	}
94 | 
95 | 	return cors.New(options).Handler
96 | }
97 | 


--------------------------------------------------------------------------------
/server/server_config_trigger.go:
--------------------------------------------------------------------------------
 1 | //go:build !windows
 2 | // +build !windows
 3 | 
 4 | package server
 5 | 
 6 | import (
 7 | 	"context"
 8 | 	"os"
 9 | 	"os/signal"
10 | 	"syscall"
11 | )
12 | 
13 | func registerPrintConfigurationTrigger(ctx context.Context, s *Server) {
14 | 	signals := make(chan os.Signal, 1)
15 | 	signal.Notify(signals, syscall.SIGUSR1)
16 | 
17 | 	go func() {
18 | 		for {
19 | 			select {
20 | 			case <-signals:
21 | 				s.printConfiguration()
22 | 
23 | 			case <-ctx.Done():
24 | 				return
25 | 			}
26 | 		}
27 | 	}()
28 | }
29 | 


--------------------------------------------------------------------------------
/server/server_config_trigger_windows.go:
--------------------------------------------------------------------------------
1 | package server
2 | 
3 | import "context"
4 | 
5 | func registerPrintConfigurationTrigger(ctx context.Context, s *Server) {
6 | }
7 | 


--------------------------------------------------------------------------------
/server/server_suite_test.go:
--------------------------------------------------------------------------------
 1 | package server
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | func init() {
12 | 	log.Silence()
13 | }
14 | 
15 | func TestDNSServer(t *testing.T) {
16 | 	RegisterFailHandler(Fail)
17 | 	RunSpecs(t, "Server Suite")
18 | }
19 | 


--------------------------------------------------------------------------------
/trie/split.go:
--------------------------------------------------------------------------------
 1 | package trie
 2 | 
 3 | import "strings"
 4 | 
 5 | type SplitFunc func(string) (label, rest string)
 6 | 
 7 | // www.example.com -> ("com", "www.example")
 8 | func SplitTLD(domain string) (label, rest string) {
 9 | 	domain = strings.TrimRight(domain, ".")
10 | 
11 | 	idx := strings.LastIndexByte(domain, '.')
12 | 	if idx == -1 {
13 | 		return domain, ""
14 | 	}
15 | 
16 | 	label = domain[idx+1:]
17 | 	rest = domain[:idx]
18 | 
19 | 	return label, rest
20 | }
21 | 


--------------------------------------------------------------------------------
/trie/split_test.go:
--------------------------------------------------------------------------------
 1 | package trie
 2 | 
 3 | import (
 4 | 	. "github.com/onsi/ginkgo/v2"
 5 | 	. "github.com/onsi/gomega"
 6 | )
 7 | 
 8 | var _ = Describe("SpltTLD", func() {
 9 | 	It("should split a tld", func() {
10 | 		key, rest := SplitTLD("www.example.com")
11 | 		Expect(key).Should(Equal("com"))
12 | 		Expect(rest).Should(Equal("www.example"))
13 | 	})
14 | 
15 | 	It("should not split a plain string", func() {
16 | 		key, rest := SplitTLD("example")
17 | 		Expect(key).Should(Equal("example"))
18 | 		Expect(rest).Should(Equal(""))
19 | 	})
20 | 
21 | 	It("should not crash with an empty string", func() {
22 | 		key, rest := SplitTLD("")
23 | 		Expect(key).Should(Equal(""))
24 | 		Expect(rest).Should(Equal(""))
25 | 	})
26 | 
27 | 	It("should ignore trailing dots", func() {
28 | 		key, rest := SplitTLD("www.example.com.")
29 | 		Expect(key).Should(Equal("com"))
30 | 		Expect(rest).Should(Equal("www.example"))
31 | 
32 | 		key, rest = SplitTLD(rest)
33 | 		Expect(key).Should(Equal("example"))
34 | 		Expect(rest).Should(Equal("www"))
35 | 	})
36 | 
37 | 	It("should skip empty parts", func() {
38 | 		key, rest := SplitTLD("www.example..com")
39 | 		Expect(key).Should(Equal("com"))
40 | 		Expect(rest).Should(Equal("www.example."))
41 | 
42 | 		key, rest = SplitTLD(rest)
43 | 		Expect(key).Should(Equal("example"))
44 | 		Expect(rest).Should(Equal("www"))
45 | 	})
46 | })
47 | 


--------------------------------------------------------------------------------
/trie/trie_suite_test.go:
--------------------------------------------------------------------------------
 1 | package trie
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 
 8 | 	. "github.com/onsi/ginkgo/v2"
 9 | 	. "github.com/onsi/gomega"
10 | )
11 | 
12 | func init() {
13 | 	log.Silence()
14 | }
15 | 
16 | func TestTrie(t *testing.T) {
17 | 	RegisterFailHandler(Fail)
18 | 	RunSpecs(t, "Trie Suite")
19 | }
20 | 


--------------------------------------------------------------------------------
/trie/trie_test.go:
--------------------------------------------------------------------------------
  1 | package trie
  2 | 
  3 | import (
  4 | 	. "github.com/onsi/ginkgo/v2"
  5 | 	. "github.com/onsi/gomega"
  6 | )
  7 | 
  8 | var _ = Describe("Trie", func() {
  9 | 	var sut *Trie
 10 | 
 11 | 	BeforeEach(func() {
 12 | 		sut = NewTrie(SplitTLD)
 13 | 	})
 14 | 
 15 | 	Describe("Basic operations", func() {
 16 | 		When("Trie is created", func() {
 17 | 			It("should be empty", func() {
 18 | 				Expect(sut.IsEmpty()).Should(BeTrue())
 19 | 			})
 20 | 
 21 | 			It("should not find domains", func() {
 22 | 				Expect(sut.HasParentOf("example.com")).Should(BeFalse())
 23 | 			})
 24 | 
 25 | 			It("should not insert the empty string", func() {
 26 | 				sut.Insert("")
 27 | 				Expect(sut.HasParentOf("")).Should(BeFalse())
 28 | 			})
 29 | 		})
 30 | 
 31 | 		When("Adding a domain", func() {
 32 | 			var (
 33 | 				domainOkTLD = "com"
 34 | 				domainOk    = "example." + domainOkTLD
 35 | 
 36 | 				domainKo = "example.org"
 37 | 			)
 38 | 
 39 | 			BeforeEach(func() {
 40 | 				Expect(sut.HasParentOf(domainOk)).Should(BeFalse())
 41 | 				sut.Insert(domainOk)
 42 | 				Expect(sut.HasParentOf(domainOk)).Should(BeTrue())
 43 | 			})
 44 | 
 45 | 			AfterEach(func() {
 46 | 				Expect(sut.HasParentOf(domainOk)).Should(BeTrue())
 47 | 			})
 48 | 
 49 | 			It("should be found", func() {})
 50 | 
 51 | 			It("should contain subdomains", func() {
 52 | 				subdomain := "www." + domainOk
 53 | 
 54 | 				Expect(sut.HasParentOf(subdomain)).Should(BeTrue())
 55 | 			})
 56 | 
 57 | 			It("should support inserting subdomains", func() {
 58 | 				subdomain := "www." + domainOk
 59 | 
 60 | 				Expect(sut.HasParentOf(subdomain)).Should(BeTrue())
 61 | 				sut.Insert(subdomain)
 62 | 				Expect(sut.HasParentOf(subdomain)).Should(BeTrue())
 63 | 			})
 64 | 
 65 | 			It("should not find unrelated", func() {
 66 | 				Expect(sut.HasParentOf(domainKo)).Should(BeFalse())
 67 | 			})
 68 | 
 69 | 			It("should not find uninserted parent", func() {
 70 | 				Expect(sut.HasParentOf(domainOkTLD)).Should(BeFalse())
 71 | 			})
 72 | 
 73 | 			It("should not find deep uninserted parent", func() {
 74 | 				sut.Insert("sub.sub.sub.test")
 75 | 
 76 | 				Expect(sut.HasParentOf("sub.sub.test")).Should(BeFalse())
 77 | 			})
 78 | 
 79 | 			It("should find inserted parent", func() {
 80 | 				sut.Insert(domainOkTLD)
 81 | 				Expect(sut.HasParentOf(domainOkTLD)).Should(BeTrue())
 82 | 			})
 83 | 
 84 | 			It("should insert sibling", func() {
 85 | 				sibling := "other." + domainOkTLD
 86 | 
 87 | 				sut.Insert(sibling)
 88 | 				Expect(sut.HasParentOf(sibling)).Should(BeTrue())
 89 | 			})
 90 | 
 91 | 			It("should insert grand-children siblings", func() {
 92 | 				base := "other.com"
 93 | 				abcSub := "abc." + base
 94 | 				xyzSub := "xyz." + base
 95 | 
 96 | 				sut.Insert(abcSub)
 97 | 				Expect(sut.HasParentOf(abcSub)).Should(BeTrue())
 98 | 				Expect(sut.HasParentOf(xyzSub)).Should(BeFalse())
 99 | 				Expect(sut.HasParentOf(base)).Should(BeFalse())
100 | 
101 | 				sut.Insert(xyzSub)
102 | 				Expect(sut.HasParentOf(xyzSub)).Should(BeTrue())
103 | 				Expect(sut.HasParentOf(abcSub)).Should(BeTrue())
104 | 				Expect(sut.HasParentOf(base)).Should(BeFalse())
105 | 			})
106 | 		})
107 | 	})
108 | })
109 | 


--------------------------------------------------------------------------------
/util/arpa.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import (
 4 | 	"errors"
 5 | 	"fmt"
 6 | 	"net"
 7 | 	"strconv"
 8 | 	"strings"
 9 | )
10 | 
11 | const (
12 | 	IPv4PtrSuffix = ".in-addr.arpa."
13 | 	IPv6PtrSuffix = ".ip6.arpa."
14 | 
15 | 	byteBits = 8
16 | )
17 | 
18 | var ErrInvalidArpaAddrLen = errors.New("arpa hostname is not of expected length")
19 | 
20 | func ParseIPFromArpaAddr(arpa string) (net.IP, error) {
21 | 	if strings.HasSuffix(arpa, IPv4PtrSuffix) {
22 | 		return parseIPv4FromArpaAddr(arpa)
23 | 	}
24 | 
25 | 	if strings.HasSuffix(arpa, IPv6PtrSuffix) {
26 | 		return parseIPv6FromArpaAddr(arpa)
27 | 	}
28 | 
29 | 	return nil, fmt.Errorf("invalid arpa hostname: %s", arpa)
30 | }
31 | 
32 | func parseIPv4FromArpaAddr(arpa string) (net.IP, error) {
33 | 	const base10 = 10
34 | 
35 | 	revAddr := strings.TrimSuffix(arpa, IPv4PtrSuffix)
36 | 
37 | 	parts := strings.Split(revAddr, ".")
38 | 	if len(parts) != net.IPv4len {
39 | 		return nil, ErrInvalidArpaAddrLen
40 | 	}
41 | 
42 | 	buf := make([]byte, 0, net.IPv4len)
43 | 
44 | 	// Parse and add each byte, in reverse, to the buffer
45 | 	for i := len(parts) - 1; i >= 0; i-- {
46 | 		part, err := strconv.ParseUint(parts[i], base10, byteBits)
47 | 		if err != nil {
48 | 			return nil, err
49 | 		}
50 | 
51 | 		buf = append(buf, byte(part))
52 | 	}
53 | 
54 | 	return net.IPv4(buf[0], buf[1], buf[2], buf[3]), nil
55 | }
56 | 
57 | func parseIPv6FromArpaAddr(arpa string) (net.IP, error) {
58 | 	const (
59 | 		base16     = 16
60 | 		ipv6Bytes  = 2 * net.IPv6len
61 | 		nibbleBits = byteBits / 2
62 | 	)
63 | 
64 | 	revAddr := strings.TrimSuffix(arpa, IPv6PtrSuffix)
65 | 
66 | 	parts := strings.Split(revAddr, ".")
67 | 	if len(parts) != ipv6Bytes {
68 | 		return nil, ErrInvalidArpaAddrLen
69 | 	}
70 | 
71 | 	buf := make([]byte, 0, net.IPv6len)
72 | 
73 | 	// Parse and add each byte, in reverse, to the buffer
74 | 	for i := len(parts) - 1; i >= 0; i -= 2 {
75 | 		msNibble, err := strconv.ParseUint(parts[i], base16, byteBits)
76 | 		if err != nil {
77 | 			return nil, err
78 | 		}
79 | 
80 | 		lsNibble, err := strconv.ParseUint(parts[i-1], base16, byteBits)
81 | 		if err != nil {
82 | 			return nil, err
83 | 		}
84 | 
85 | 		part := msNibble<<nibbleBits | lsNibble
86 | 
87 | 		buf = append(buf, byte(part))
88 | 	}
89 | 
90 | 	return net.IP(buf), nil
91 | }
92 | 


--------------------------------------------------------------------------------
/util/buildinfo.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | //nolint:gochecknoglobals
 4 | var (
 5 | 	// Version current version number
 6 | 	Version = "undefined"
 7 | 	// BuildTime build time of the binary
 8 | 	BuildTime = "undefined"
 9 | 	// Architecture current CPU architecture
10 | 	Architecture = "undefined"
11 | )
12 | 


--------------------------------------------------------------------------------
/util/context.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import "context"
 4 | 
 5 | // CtxSend sends a value to a channel while the context isn't done.
 6 | // If the message is sent, it returns true.
 7 | // If the context is done or the channel is closed, it returns false.
 8 | func CtxSend[T any](ctx context.Context, ch chan T, val T) (ok bool) {
 9 | 	if ctx == nil || ch == nil || ctx.Err() != nil {
10 | 		ok = false
11 | 
12 | 		return
13 | 	}
14 | 
15 | 	defer func() {
16 | 		if err := recover(); err != nil {
17 | 			ok = false
18 | 		}
19 | 	}()
20 | 
21 | 	select {
22 | 	case <-ctx.Done():
23 | 		ok = false
24 | 	case ch <- val:
25 | 		ok = true
26 | 	}
27 | 
28 | 	return
29 | }
30 | 


--------------------------------------------------------------------------------
/util/edns0.go:
--------------------------------------------------------------------------------
  1 | package util
  2 | 
  3 | import (
  4 | 	"fmt"
  5 | 	"slices"
  6 | 
  7 | 	"github.com/miekg/dns"
  8 | )
  9 | 
 10 | // EDNS0Option is an interface for all EDNS0 options as type constraint for generics.
 11 | type EDNS0Option interface {
 12 | 	*dns.EDNS0_SUBNET | *dns.EDNS0_EDE | *dns.EDNS0_LOCAL | *dns.EDNS0_NSID | *dns.EDNS0_COOKIE | *dns.EDNS0_UL
 13 | 	Option() uint16
 14 | }
 15 | 
 16 | // RemoveEdns0Record removes the OPT record from the Extra section of the given message.
 17 | // If the OPT record is removed, true will be returned.
 18 | func RemoveEdns0Record(msg *dns.Msg) bool {
 19 | 	if msg == nil || msg.IsEdns0() == nil {
 20 | 		return false
 21 | 	}
 22 | 
 23 | 	for i, rr := range msg.Extra {
 24 | 		if rr.Header().Rrtype == dns.TypeOPT {
 25 | 			msg.Extra = slices.Delete(msg.Extra, i, i+1)
 26 | 
 27 | 			return true
 28 | 		}
 29 | 	}
 30 | 
 31 | 	return false
 32 | }
 33 | 
 34 | // GetEdns0Option returns the option with the given code from the OPT record in the
 35 | // Extra section of the given message.
 36 | // If the option is not found, nil will be returned.
 37 | func GetEdns0Option[T EDNS0Option](msg *dns.Msg) T {
 38 | 	if msg == nil {
 39 | 		return nil
 40 | 	}
 41 | 
 42 | 	opt := msg.IsEdns0()
 43 | 	if opt == nil {
 44 | 		return nil
 45 | 	}
 46 | 
 47 | 	var t T
 48 | 
 49 | 	for _, o := range opt.Option {
 50 | 		if o.Option() == t.Option() {
 51 | 			t, ok := o.(T)
 52 | 			if !ok {
 53 | 				panic(fmt.Errorf("dns option with code %d is not of type %T", t.Option(), t))
 54 | 			}
 55 | 
 56 | 			return t
 57 | 		}
 58 | 	}
 59 | 
 60 | 	return nil
 61 | }
 62 | 
 63 | // RemoveEdns0Option removes the option according to the given type from the OPT record
 64 | // in the Extra section of the given message.
 65 | // If there are no more options in the OPT record, the OPT record will be removed.
 66 | // If the option is successfully removed, true will be returned.
 67 | func RemoveEdns0Option[T EDNS0Option](msg *dns.Msg) bool {
 68 | 	if msg == nil {
 69 | 		return false
 70 | 	}
 71 | 
 72 | 	opt := msg.IsEdns0()
 73 | 	if opt == nil {
 74 | 		return false
 75 | 	}
 76 | 
 77 | 	res := false
 78 | 
 79 | 	var t T
 80 | 
 81 | 	for i, o := range opt.Option {
 82 | 		if o.Option() == t.Option() {
 83 | 			opt.Option = slices.Delete(opt.Option, i, i+1)
 84 | 
 85 | 			res = true
 86 | 
 87 | 			break
 88 | 		}
 89 | 	}
 90 | 
 91 | 	if len(opt.Option) == 0 {
 92 | 		RemoveEdns0Record(msg)
 93 | 	}
 94 | 
 95 | 	return res
 96 | }
 97 | 
 98 | // SetEdns0Option adds the given option to the OPT record in the Extra section of the
 99 | // given message.
100 | // If the option already exists, it will be replaced.
101 | // If the option is successfully set, true will be returned.
102 | func SetEdns0Option(msg *dns.Msg, opt dns.EDNS0) bool {
103 | 	if msg == nil || opt == nil {
104 | 		return false
105 | 	}
106 | 
107 | 	optRecord := msg.IsEdns0()
108 | 
109 | 	if optRecord == nil {
110 | 		optRecord = new(dns.OPT)
111 | 		optRecord.Hdr.Name = "."
112 | 		optRecord.Hdr.Rrtype = dns.TypeOPT
113 | 		msg.Extra = append(msg.Extra, optRecord)
114 | 	}
115 | 
116 | 	newOpts := make([]dns.EDNS0, 0, len(optRecord.Option)+1)
117 | 
118 | 	for _, o := range optRecord.Option {
119 | 		if o.Option() != opt.Option() {
120 | 			newOpts = append(newOpts, o)
121 | 		}
122 | 	}
123 | 
124 | 	newOpts = append(newOpts, opt)
125 | 	optRecord.Option = newOpts
126 | 
127 | 	return true
128 | }
129 | 


--------------------------------------------------------------------------------
/util/http.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import (
 4 | 	"fmt"
 5 | 	"net"
 6 | 	"net/http"
 7 | )
 8 | 
 9 | //nolint:gochecknoglobals
10 | var baseTransport *http.Transport
11 | 
12 | //nolint:gochecknoinits
13 | func init() {
14 | 	base, ok := http.DefaultTransport.(*http.Transport)
15 | 	if !ok {
16 | 		panic(fmt.Errorf(
17 | 			"unsupported Go version: http.DefaultTransport is not of type *http.Transport: it is a %T",
18 | 			http.DefaultTransport,
19 | 		))
20 | 	}
21 | 
22 | 	baseTransport = base
23 | }
24 | 
25 | // DefaultHTTPTransport returns a new Transport with the same defaults as net/http.
26 | func DefaultHTTPTransport() *http.Transport {
27 | 	return &http.Transport{
28 | 		DialContext:            baseTransport.DialContext,
29 | 		DialTLSContext:         baseTransport.DialTLSContext,
30 | 		DisableCompression:     baseTransport.DisableCompression,
31 | 		DisableKeepAlives:      baseTransport.DisableKeepAlives,
32 | 		ExpectContinueTimeout:  baseTransport.ExpectContinueTimeout,
33 | 		ForceAttemptHTTP2:      baseTransport.ForceAttemptHTTP2,
34 | 		GetProxyConnectHeader:  baseTransport.GetProxyConnectHeader,
35 | 		IdleConnTimeout:        baseTransport.IdleConnTimeout,
36 | 		MaxConnsPerHost:        baseTransport.MaxConnsPerHost,
37 | 		MaxIdleConns:           baseTransport.MaxIdleConns,
38 | 		MaxIdleConnsPerHost:    baseTransport.MaxConnsPerHost,
39 | 		MaxResponseHeaderBytes: baseTransport.MaxResponseHeaderBytes,
40 | 		OnProxyConnectResponse: baseTransport.OnProxyConnectResponse,
41 | 		Proxy:                  baseTransport.Proxy,
42 | 		ProxyConnectHeader:     baseTransport.ProxyConnectHeader,
43 | 		ReadBufferSize:         baseTransport.ReadBufferSize,
44 | 		ResponseHeaderTimeout:  baseTransport.ResponseHeaderTimeout,
45 | 		TLSClientConfig:        baseTransport.TLSClientConfig,
46 | 		TLSHandshakeTimeout:    baseTransport.TLSHandshakeTimeout,
47 | 		TLSNextProto:           baseTransport.TLSNextProto,
48 | 		WriteBufferSize:        baseTransport.WriteBufferSize,
49 | 	}
50 | }
51 | 
52 | func HTTPClientIP(r *http.Request) net.IP {
53 | 	addr := r.Header.Get("X-FORWARDED-FOR")
54 | 	if addr == "" {
55 | 		addr = r.RemoteAddr
56 | 	}
57 | 
58 | 	ip, _, err := net.SplitHostPort(addr)
59 | 	if err != nil {
60 | 		return net.ParseIP(addr)
61 | 	}
62 | 
63 | 	return net.ParseIP(ip)
64 | }
65 | 


--------------------------------------------------------------------------------
/util/http_test.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import (
 4 | 	"context"
 5 | 	"net"
 6 | 	"net/http"
 7 | 	"net/url"
 8 | 	"reflect"
 9 | 
10 | 	"github.com/google/go-cmp/cmp"
11 | 	"github.com/google/go-cmp/cmp/cmpopts"
12 | 	. "github.com/onsi/ginkgo/v2"
13 | 	. "github.com/onsi/gomega"
14 | )
15 | 
16 | var _ = Describe("HTTP Util", func() {
17 | 	Describe("DefaultHTTPTransport", func() {
18 | 		It("returns a new transport", func() {
19 | 			a := DefaultHTTPTransport()
20 | 			Expect(a).Should(BeIdenticalTo(a))
21 | 
22 | 			b := DefaultHTTPTransport()
23 | 			Expect(a).ShouldNot(BeIdenticalTo(b))
24 | 		})
25 | 
26 | 		It("returns a copy of http.DefaultTransport", func() {
27 | 			Expect(cmp.Diff(
28 | 				DefaultHTTPTransport(), http.DefaultTransport,
29 | 				cmpopts.IgnoreUnexported(http.Transport{}),
30 | 				// Non nil func field comparers
31 | 				cmp.Comparer(cmpAsPtrs[func(context.Context, string, string) (net.Conn, error)]),
32 | 				cmp.Comparer(cmpAsPtrs[func(*http.Request) (*url.URL, error)]),
33 | 			)).Should(BeEmpty())
34 | 		})
35 | 	})
36 | 
37 | 	Describe("HTTPClientIP", func() {
38 | 		It("extracts the IP from RemoteAddr", func() {
39 | 			r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
40 | 			Expect(err).Should(Succeed())
41 | 
42 | 			ip := net.IPv4allrouter
43 | 			r.RemoteAddr = net.JoinHostPort(ip.String(), "78954")
44 | 
45 | 			Expect(HTTPClientIP(r)).Should(Equal(ip))
46 | 		})
47 | 
48 | 		It("extracts the IP from RemoteAddr without a port", func() {
49 | 			r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
50 | 			Expect(err).Should(Succeed())
51 | 
52 | 			ip := net.IPv4allrouter
53 | 			r.RemoteAddr = ip.String()
54 | 
55 | 			Expect(HTTPClientIP(r)).Should(Equal(ip))
56 | 		})
57 | 
58 | 		It("extracts the IP from the X-Forwarded-For header", func() {
59 | 			r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
60 | 			Expect(err).Should(Succeed())
61 | 
62 | 			ip := net.IPv4bcast
63 | 			r.RemoteAddr = ip.String()
64 | 
65 | 			r.Header.Set("X-Forwarded-For", ip.String())
66 | 
67 | 			Expect(HTTPClientIP(r)).Should(Equal(ip))
68 | 		})
69 | 	})
70 | })
71 | 
72 | // Go and cmp don't define func comparisons, besides with nil.
73 | // In practice we can just compare them as pointers.
74 | // See https://github.com/google/go-cmp/issues/162
75 | func cmpAsPtrs[T any](x, y T) bool {
76 | 	return reflect.ValueOf(x).Pointer() == reflect.ValueOf(y).Pointer()
77 | }
78 | 


--------------------------------------------------------------------------------
/util/tls.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import (
 4 | 	"crypto/ecdsa"
 5 | 	"crypto/elliptic"
 6 | 	"crypto/rand"
 7 | 	"crypto/tls"
 8 | 	"crypto/x509"
 9 | 	"crypto/x509/pkix"
10 | 	"fmt"
11 | 	"math/big"
12 | 	"time"
13 | )
14 | 
15 | const (
16 | 	certSerialMaxBits = 128
17 | 	certExpiryYears   = 5
18 | )
19 | 
20 | // TLSGenerateSelfSignedCert returns a new self-signed cert for the given domains.
21 | //
22 | // Being self-signed, no client will trust this certificate.
23 | func TLSGenerateSelfSignedCert(domains []string) (tls.Certificate, error) {
24 | 	serialMax := new(big.Int).Lsh(big.NewInt(1), certSerialMaxBits)
25 | 	serial, err := rand.Int(rand.Reader, serialMax)
26 | 	if err != nil {
27 | 		return tls.Certificate{}, err
28 | 	}
29 | 
30 | 	template := &x509.Certificate{
31 | 		SerialNumber: serial,
32 | 
33 | 		Subject:  pkix.Name{Organization: []string{"Blocky"}},
34 | 		DNSNames: domains,
35 | 
36 | 		NotBefore: time.Now(),
37 | 		NotAfter:  time.Now().AddDate(certExpiryYears, 0, 0),
38 | 
39 | 		KeyUsage:    x509.KeyUsageDigitalSignature,
40 | 		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
41 | 	}
42 | 
43 | 	privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
44 | 	if err != nil {
45 | 		return tls.Certificate{}, fmt.Errorf("unable to generate private key: %w", err)
46 | 	}
47 | 
48 | 	der, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
49 | 	if err != nil {
50 | 		return tls.Certificate{}, fmt.Errorf("cert creation from template failed: %w", err)
51 | 	}
52 | 
53 | 	// Parse the generated DER back into a useable cert
54 | 	// This avoids needing to do it for each TLS handshake (see tls.Certificate.Leaf comment)
55 | 	cert, err := x509.ParseCertificate(der)
56 | 	if err != nil {
57 | 		return tls.Certificate{}, fmt.Errorf("generated cert DER could not be parsed: %w", err)
58 | 	}
59 | 
60 | 	tlsCert := tls.Certificate{
61 | 		Certificate: [][]byte{der},
62 | 		PrivateKey:  privKey,
63 | 		Leaf:        cert,
64 | 	}
65 | 
66 | 	return tlsCert, nil
67 | }
68 | 


--------------------------------------------------------------------------------
/util/tls_test.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import (
 4 | 	"crypto/x509"
 5 | 
 6 | 	. "github.com/onsi/ginkgo/v2"
 7 | 	. "github.com/onsi/gomega"
 8 | )
 9 | 
10 | var _ = Describe("TLS Util", func() {
11 | 	Describe("TLSGenerateSelfSignedCert", func() {
12 | 		It("returns a good value", func() {
13 | 			const domain = "whatever.test.blocky.invalid"
14 | 
15 | 			cert, err := TLSGenerateSelfSignedCert([]string{domain})
16 | 			Expect(err).Should(Succeed())
17 | 
18 | 			Expect(cert.Certificate).ShouldNot(BeEmpty())
19 | 
20 | 			By("having the right Leaf", func() {
21 | 				fromDER, err := x509.ParseCertificate(cert.Certificate[0])
22 | 				Expect(err).Should(Succeed())
23 | 
24 | 				Expect(cert.Leaf).Should(Equal(fromDER))
25 | 			})
26 | 
27 | 			By("being valid as self-signed for server TLS on the given domain", func() {
28 | 				pool := x509.NewCertPool()
29 | 				pool.AddCert(cert.Leaf)
30 | 
31 | 				chain, err := cert.Leaf.Verify(x509.VerifyOptions{
32 | 					DNSName:   domain,
33 | 					Roots:     pool,
34 | 					KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
35 | 				})
36 | 				Expect(err).Should(Succeed())
37 | 				Expect(chain).Should(Equal([][]*x509.Certificate{{cert.Leaf}}))
38 | 			})
39 | 
40 | 			By("mentioning Blocky", func() {
41 | 				Expect(cert.Leaf.Subject.Organization).Should(Equal([]string{"Blocky"}))
42 | 			})
43 | 		})
44 | 	})
45 | })
46 | 


--------------------------------------------------------------------------------
/util/util_suite_test.go:
--------------------------------------------------------------------------------
 1 | package util
 2 | 
 3 | import (
 4 | 	"testing"
 5 | 
 6 | 	"github.com/0xERR0R/blocky/log"
 7 | 	. "github.com/onsi/ginkgo/v2"
 8 | 	. "github.com/onsi/gomega"
 9 | )
10 | 
11 | func init() {
12 | 	log.Silence()
13 | }
14 | 
15 | func TestLists(t *testing.T) {
16 | 	RegisterFailHandler(Fail)
17 | 	RunSpecs(t, "Util Suite")
18 | }
19 | 


--------------------------------------------------------------------------------
/web/index.go:
--------------------------------------------------------------------------------
 1 | package web
 2 | 
 3 | import (
 4 | 	"embed"
 5 | 	"io/fs"
 6 | )
 7 | 
 8 | // IndexTmpl html template for the start page
 9 | //
10 | //go:embed index.html
11 | var IndexTmpl string
12 | 
13 | //go:embed all:static
14 | var static embed.FS
15 | 
16 | func Assets() (fs.FS, error) {
17 | 	return fs.Sub(static, "static")
18 | }
19 | 


--------------------------------------------------------------------------------
/web/index.html:
--------------------------------------------------------------------------------
 1 | <!DOCTYPE html>
 2 | <html>
 3 | <head>
 4 |     <title>blocky</title>
 5 | </head>
 6 | <body>
 7 |     <h1>blocky</h1>
 8 |     <ul>
 9 |     {{range .Links}}
10 |         <li><a href="{{.URL}}">{{.Title}}</a></li>
11 |     {{end}}
12 |     </ul>
13 | 
14 |     <p><span class="small">Version {{.Version}}   Build time {{.BuildTime}}</span></p> 
15 |     </body>
16 | </html>


--------------------------------------------------------------------------------
/web/static/rapidoc.html:
--------------------------------------------------------------------------------
 1 | <!doctype html>
 2 | <html>
 3 | <head>
 4 |   <meta charset="utf-8">
 5 |   <script type="module" src="/static/rapidoc-min.js"></script>
 6 | </head>
 7 | <body>
 8 |   <rapi-doc
 9 |     spec-url="/docs/openapi.yaml"
10 |     theme = "light"
11 | 	  allow-authentication = "false"
12 |     show-header = "false"
13 |     bg-color = "#fdf8ed"
14 |     nav-bg-color = "#3f4d67"
15 |     nav-text-color = "#a9b7d0"
16 |     nav-hover-bg-color = "#333f54"
17 |     nav-hover-text-color = "#fff"
18 |     nav-accent-color = "#f87070"
19 |     primary-color = "#5c7096"
20 |   > </rapi-doc>
21 | </body>
22 | </html>


--------------------------------------------------------------------------------