├── .devcontainer ├── devcontainer.json └── scripts │ ├── postStart.sh │ └── runItOnGo.sh ├── .dockerignore ├── .gitattributes ├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ ├── build-bin.yml │ ├── close_stale.yml │ ├── codeql-analysis.yml │ ├── development-docker.yml │ ├── docs.yml │ ├── fork-sync.yml │ ├── goreleaser-test.yml │ ├── makefile.yml │ ├── mirror-repo.yml │ └── release.yml ├── .gitignore ├── .golangci.yml ├── .goreleaser.yml ├── .vscode ├── launch.json ├── settings.json └── tasks.json ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── api ├── api_client.gen.go ├── api_interface_impl.go ├── api_interface_impl_test.go ├── api_server.gen.go ├── api_suite_test.go ├── api_types.gen.go ├── client.cfg.yaml ├── server.cfg.yaml └── types.cfg.yaml ├── cache ├── expirationcache │ ├── cache_interface.go │ ├── expiration_cache.go │ ├── expiration_cache_suite_test.go │ ├── expiration_cache_test.go │ ├── prefetching_cache.go │ └── prefetching_cache_test.go └── stringcache │ ├── chained_grouped_cache.go │ ├── chained_grouped_cache_test.go │ ├── grouped_cache_interface.go │ ├── in_memory_grouped_cache.go │ ├── in_memory_grouped_cache_test.go │ ├── string_cache_suite_test.go │ ├── string_caches.go │ ├── string_caches_benchmark_test.go │ └── string_caches_test.go ├── cmd ├── blocking.go ├── blocking_test.go ├── cache.go ├── cache_test.go ├── cmd_suite_test.go ├── healthcheck.go ├── healthcheck_test.go ├── lists.go ├── lists_test.go ├── query.go ├── query_test.go ├── root.go ├── root_test.go ├── serve.go ├── serve_test.go ├── validate.go ├── validate_test.go ├── version.go └── version_test.go ├── codecov.yml ├── config ├── blocking.go ├── blocking_test.go ├── bytes_source.go ├── bytes_source_enum.go ├── caching.go ├── caching_test.go ├── client_lookup.go ├── client_lookup_test.go ├── conditional_upstream.go ├── conditional_upstream_test.go ├── config.go ├── config_enum.go ├── config_suite_test.go ├── config_test.go ├── custom_dns.go ├── custom_dns_test.go ├── duration.go ├── duration_test.go ├── ecs.go ├── ecs_test.go ├── filtering.go ├── filtering_test.go ├── hosts_file.go ├── hosts_file_test.go ├── metrics.go ├── metrics_test.go ├── migration │ └── migration.go ├── qtype_set.go ├── qtype_set_test.go ├── query_log.go ├── query_log_test.go ├── redis.go ├── redis_test.go ├── rewriter.go ├── rewriter_test.go ├── sudn.go ├── sudn_test.go ├── upstream.go ├── upstreams.go └── upstreams_test.go ├── docs ├── additional_information.md ├── api │ └── openapi.yaml ├── blocky-grafana.json ├── blocky-query-grafana-postgres.json ├── blocky-query-grafana.json ├── blocky.svg ├── config.yml ├── configuration.md ├── embed.go ├── fb_dns_config.png ├── grafana-dashboard.png ├── grafana-query-dashboard.png ├── includes │ └── abbreviations.md ├── index.md ├── installation.md ├── interfaces.md ├── network_configuration.md ├── prometheus_grafana.md └── rapidoc.html ├── e2e ├── basic_test.go ├── blocking_test.go ├── containers.go ├── custom_dns_test.go ├── e2e_suite_test.go ├── helper.go ├── metrics_test.go ├── querylog_test.go ├── redis_test.go └── upstream_test.go ├── evt └── events.go ├── go.mod ├── go.sum ├── helpertest ├── data │ ├── oisd-big-plain.txt │ └── oisd-big-wildcard.txt ├── helper.go ├── http.go ├── mock_call_sequence.go └── tmpdata.go ├── lists ├── downloader.go ├── downloader_test.go ├── list_cache.go ├── list_cache_benchmark_test.go ├── list_cache_enum.go ├── list_cache_test.go ├── list_suite_test.go ├── parsers │ ├── adapt.go │ ├── filtererrors.go │ ├── filtererrors_test.go │ ├── hosts.go │ ├── hosts_test.go │ ├── lines.go │ ├── lines_test.go │ ├── parser.go │ ├── parser_test.go │ └── parsers_suite_test.go └── sourcereader.go ├── log ├── context.go ├── logger.go ├── logger_enum.go └── mock_entry.go ├── main.go ├── main_static.go ├── metrics ├── metrics.go └── metrics_event_publisher.go ├── mkdocs.yml ├── model ├── models.go └── models_enum.go ├── querylog ├── database_writer.go ├── database_writer_test.go ├── file_writer.go ├── file_writer_test.go ├── logger_writer.go ├── logger_writer_test.go ├── none_writer.go ├── none_writer_test.go ├── querylog_suite_test.go └── writer.go ├── redis ├── redis.go ├── redis_suite_test.go └── redis_test.go ├── resolver ├── blocking_resolver.go ├── blocking_resolver_test.go ├── bootstrap.go ├── bootstrap_test.go ├── caching_resolver.go ├── caching_resolver_test.go ├── client_names_resolver.go ├── client_names_resolver_test.go ├── conditional_upstream_resolver.go ├── conditional_upstream_resolver_test.go ├── custom_dns_resolver.go ├── custom_dns_resolver_test.go ├── ecs_resolver.go ├── ecs_resolver_test.go ├── ede_resolver.go ├── ede_resolver_test.go ├── filtering_resolver.go ├── filtering_resolver_test.go ├── fqdn_only_resolver.go ├── fqdn_only_resolver_test.go ├── hosts_file_resolver.go ├── hosts_file_resolver_test.go ├── metrics_resolver.go ├── metrics_resolver_test.go ├── mock_udp_upstream_server.go ├── mocks_test.go ├── noop_resolver.go ├── noop_resolver_test.go ├── parallel_best_resolver.go ├── parallel_best_resolver_test.go ├── query_logging_resolver.go ├── query_logging_resolver_test.go ├── resolver.go ├── resolver_suite_test.go ├── resolver_test.go ├── rewriter_resolver.go ├── rewriter_resolver_test.go ├── strict_resolver.go ├── strict_resolver_test.go ├── sudn_resolver.go ├── sudn_resolver_test.go ├── upstream_resolver.go ├── upstream_resolver_test.go ├── upstream_tree_resolver.go └── upstream_tree_resolver_test.go ├── server ├── http.go ├── server.go ├── server_config_trigger.go ├── server_config_trigger_windows.go ├── server_endpoints.go ├── server_suite_test.go └── server_test.go ├── trie ├── split.go ├── split_test.go ├── trie.go ├── trie_suite_test.go └── trie_test.go ├── util ├── arpa.go ├── arpa_test.go ├── buildinfo.go ├── common.go ├── common_test.go ├── context.go ├── context_test.go ├── edns0.go ├── edns0_test.go ├── http.go ├── http_test.go ├── tls.go ├── tls_test.go └── util_suite_test.go └── web ├── index.go ├── index.html └── static ├── rapidoc-min.js └── rapidoc.html /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "blocky development", 3 | "image": "mcr.microsoft.com/devcontainers/base:ubuntu-22.04", 4 | "features": { 5 | "ghcr.io/devcontainers/features/go:1": {}, 6 | "ghcr.io/jungaretti/features/make:1": {}, 7 | "ghcr.io/devcontainers/features/docker-in-docker:2": { 8 | "dockerDashComposeVersion": "v2" 9 | }, 10 | "ghcr.io/devcontainers/features/python:1": {}, 11 | "ghcr.io/devcontainers/features/github-cli:1": {}, 12 | "ghcr.io/rocker-org/devcontainer-features/apt-packages:1": { 13 | "packages": "dnsutils " 14 | } 15 | }, 16 | "remoteEnv": { 17 | "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}", 18 | "WORKSPACE_FOLDER": "${containerWorkspaceFolder}", 19 | "GENERATE_LCOV": "true" 20 | }, 21 | "customizations": { 22 | "vscode": { 23 | "extensions": [ 24 | "golang.go", 25 | "esbenp.prettier-vscode", 26 | "yzhang.markdown-all-in-one", 27 | "joselitofilho.ginkgotestexplorer", 28 | "fsevenm.run-it-on", 29 | "markis.code-coverage", 30 | "tooltitudeteam.tooltitude", 31 | "GitHub.vscode-github-actions" 32 | ], 33 | "settings": { 34 | "go.lintFlags": [ 35 | "--config=${containerWorkspaceFolder}/.golangci.yml", 36 | "--fast" 37 | ], 38 | "go.alternateTools": { 39 | "go-langserver": "gopls" 40 | }, 41 | "markiscodecoverage.searchCriteria": "**/*.lcov", 42 | "runItOn": { 43 | "commands": [ 44 | { 45 | "match": "\\.goquot;, 46 | "cmd": "${workspaceRoot}/.devcontainer/scripts/runItOnGo.sh ${fileDirname} ${workspaceRoot}" 47 | } 48 | ] 49 | }, 50 | "[go]": { 51 | "editor.defaultFormatter": "golang.go" 52 | }, 53 | "[yaml][json][jsonc][github-actions-workflow]": { 54 | "editor.defaultFormatter": "esbenp.prettier-vscode" 55 | }, 56 | "[markdown]": { 57 | "editor.defaultFormatter": "yzhang.markdown-all-in-one" 58 | } 59 | } 60 | } 61 | }, 62 | "mounts": [ 63 | "type=bind,readonly,source=/etc/localtime,target=/usr/share/host/localtime", 64 | "type=bind,readonly,source=/etc/timezone,target=/usr/share/host/timezone", 65 | "type=volume,source=blocky-pkg_cache,target=/go/pkg" 66 | ], 67 | "postCreateCommand": "sudo chmod +x .devcontainer/scripts/*.sh", 68 | "postStartCommand": "sh .devcontainer/scripts/postStart.sh" 69 | } 70 | -------------------------------------------------------------------------------- /.devcontainer/scripts/postStart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | echo "Setting up go environment..." 4 | # Use the host's timezone and time 5 | sudo ln -sf /usr/share/host/localtime /etc/localtime 6 | sudo ln -sf /usr/share/host/timezone /etc/timezone 7 | # Change permission on pkg volume 8 | sudo chown -R vscode:golang /go/pkg 9 | echo "" 10 | 11 | echo "Downloading Go modules..." 12 | go mod download -x 13 | echo "" 14 | 15 | echo "Tidying Go modules..." 16 | go mod tidy -x 17 | echo "" 18 | 19 | echo "Installing Go tools..." 20 | echo " - ginkgo" 21 | go install github.com/onsi/ginkgo/v2/ginkgo@latest 22 | echo " - golangci-lint" 23 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 24 | echo " - gofumpt" 25 | go install mvdan.cc/gofumpt@latest 26 | echo " - gcov2lcov" 27 | go install github.com/jandelgado/gcov2lcov@latest -------------------------------------------------------------------------------- /.devcontainer/scripts/runItOnGo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | FOLDER_PATH=$1 4 | if [ -z "${FOLDER_PATH}" ]; then 5 | FOLDER_PATH=$PWD 6 | fi 7 | 8 | BASE_PATH=$2 9 | if [ -z "${BASE_PATH}" ]; then 10 | BASE_PATH=$WORKSPACE_FOLDER 11 | fi 12 | 13 | if [ "$FOLDER_PATH" = "$BASE_PATH" ]; then 14 | echo "Skipping lcov creation for base path" 15 | exit 1 16 | fi 17 | 18 | FOLDER_NAME=${FOLDER_PATH#"$BASE_PATH/"} 19 | WORK_NAME="$(echo "$FOLDER_NAME" | sed 's/\//-/g')" 20 | WORK_FILE_NAME="$WORK_NAME.ginkgo" 21 | WORK_FILE_PATH="/tmp/$WORK_FILE_NAME" 22 | OUTPUT_FOLDER="$BASE_PATH/coverage" 23 | OUTPUT_FILE_PATH="$OUTPUT_FOLDER/$WORK_NAME.lcov" 24 | 25 | 26 | mkdir -p "$OUTPUT_FOLDER" 27 | 28 | echo "-- Start $FOLDER_NAME ($(date '+%T')) --" 29 | 30 | TIMEFORMAT=' - Ginkgo tests finished in: %R seconds' 31 | time ginkgo --label-filter="!e2e" --keep-going --timeout=5m --output-dir=/tmp --coverprofile="$WORK_FILE_NAME" --covermode=atomic --cover -r -p "$FOLDER_PATH" || true 32 | 33 | TIMEFORMAT=' - lcov convert finished in: %R seconds' 34 | time gcov2lcov -infile="$WORK_FILE_PATH" -outfile="$OUTPUT_FILE_PATH" || true 35 | 36 | TIMEFORMAT=' - cleanup finished in: %R seconds' 37 | time rm "$WORK_FILE_PATH" || true 38 | 39 | echo "-- Finished $FOLDER_NAME ($(date '+%T')) --" -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | bin 2 | dist 3 | site 4 | node_modules 5 | .git 6 | .idea 7 | .github 8 | .vscode 9 | .gitignore 10 | *.md 11 | LICENSE 12 | vendor 13 | e2e/ 14 | .devcontainer/ 15 | coverage.txt 16 | coverage/ 17 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: 0xERR0R 2 | ko_fi: 0xerr0r 3 | custom: ["paypal.me/spx01"] 4 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gomod 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | open-pull-requests-limit: 10 8 | assignees: 9 | - 0xERR0R 10 | 11 | - package-ecosystem: github-actions 12 | directory: "/" 13 | schedule: 14 | interval: daily 15 | -------------------------------------------------------------------------------- /.github/workflows/close_stale.yml: -------------------------------------------------------------------------------- 1 | name: Close stale issues and PRs 2 | 3 | on: 4 | schedule: 5 | - cron: "0 4 * * *" 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.ref }} 9 | 10 | jobs: 11 | stale: 12 | runs-on: ubuntu-latest 13 | if: github.repository_owner == '0xERR0R' 14 | permissions: 15 | issues: write 16 | pull-requests: write 17 | steps: 18 | - uses: actions/stale@v9 19 | with: 20 | stale-issue-message: "This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days." 21 | stale-pr-message: "This PR is stale because it has been open 45 days with no activity. Remove stale label or comment or this will be closed in 10 days." 22 | close-issue-message: "This issue was closed because it has been stalled for 5 days with no activity." 23 | close-pr-message: "This PR was closed because it has been stalled for 10 days with no activity." 24 | days-before-issue-stale: 90 25 | days-before-pr-stale: 45 26 | days-before-issue-close: 5 27 | days-before-pr-close: 10 28 | exempt-all-milestones: true 29 | operations-per-run: 60 30 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: CodeQL 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | schedule: 11 | - cron: "33 15 * * 1" 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | 16 | jobs: 17 | analyze: 18 | name: Analyze 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v4 23 | with: 24 | fetch-depth: 0 25 | 26 | - name: Setup Golang 27 | uses: actions/setup-go@v5 28 | with: 29 | go-version-file: go.mod 30 | 31 | - name: Initialize CodeQL 32 | uses: github/codeql-action/init@v3 33 | with: 34 | languages: go 35 | 36 | - name: Build with Makefile 37 | run: make build 38 | 39 | - name: Perform CodeQL Analysis 40 | uses: github/codeql-action/analyze@v3 41 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | branches: 9 | - '**' 10 | paths: 11 | - .github/workflows/** 12 | - mkdocs.yml 13 | - docs/** 14 | 15 | concurrency: 16 | group: ${{ github.workflow }} 17 | 18 | jobs: 19 | deploy: 20 | runs-on: ubuntu-latest 21 | if: ${{ github.event.repository.has_pages && (github.repository_owner != '0xERR0R' || github.ref_type == 'tag' || github.ref_name == 'main') }} 22 | steps: 23 | - uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: 3.x 30 | 31 | - name: install tools 32 | run: pip install mkdocs-material mike 33 | 34 | - name: Setup doc deploy 35 | run: | 36 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 37 | git config --local user.name "github-actions[bot]" 38 | 39 | - name: Deploy version 40 | run: | 41 | VERSION="$(sed 's:/:-:g' <<< "$GITHUB_REF_NAME")" 42 | if [[ ${{github.ref}} =~ ^refs/tags/ ]]; then 43 | EXTRA_ALIAS=latest 44 | fi 45 | mike deploy --push --update-aliases "$VERSION" $EXTRA_ALIAS 46 | tr '[:upper:]' '[:lower:]' <<< "https://${{github.repository_owner}}.github.io/${{github.event.repository.name}}/$VERSION/" >> "$GITHUB_STEP_SUMMARY" 47 | -------------------------------------------------------------------------------- /.github/workflows/fork-sync.yml: -------------------------------------------------------------------------------- 1 | name: Sync Fork 2 | 3 | on: 4 | schedule: 5 | - cron: "*/30 * * * *" 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | sync: 14 | name: Sync with Upstream 15 | runs-on: ubuntu-latest 16 | if: github.repository_owner != '0xERR0R' 17 | steps: 18 | - name: Enabled Check 19 | id: check 20 | shell: bash 21 | run: | 22 | if [[ "${{ secrets.FORK_SYNC_TOKEN }}" != "" ]]; then 23 | echo "enabled=1" >> $GITHUB_OUTPUT 24 | 25 | echo "Workflow is enabled" 26 | else 27 | echo "enabled=0" >> $GITHUB_OUTPUT 28 | 29 | ( 30 | echo 'Workflow is disabled (create `FORK_SYNC_TOKEN` secret with repo write permission to enable it)' 31 | echo 32 | echo 'Alternatively, you can disable it for your repo from the web UI:' 33 | echo 'https://docs.github.com/en/actions/using-workflows/disabling-and-enabling-a-workflow' 34 | ) | tee "$GITHUB_STEP_SUMMARY" 35 | fi 36 | 37 | - name: Sync 38 | if: ${{ steps.check.outputs.enabled == 1 }} 39 | env: 40 | GH_TOKEN: ${{ secrets.FORK_SYNC_TOKEN }} 41 | shell: bash 42 | run: | 43 | gh repo sync ${{ github.repository }} -b main 44 | -------------------------------------------------------------------------------- /.github/workflows/makefile.yml: -------------------------------------------------------------------------------- 1 | name: Makefile 2 | 3 | on: 4 | push: 5 | paths: 6 | - .github/workflows/makefile.yml 7 | - Dockerfile 8 | - Makefile 9 | - "**.go" 10 | - "go.*" 11 | - "helpertest/data/**" 12 | pull_request: 13 | 14 | permissions: 15 | security-events: write 16 | actions: read 17 | contents: read 18 | 19 | env: 20 | GINKGO_PROCS: --procs=1 21 | 22 | jobs: 23 | make: 24 | name: make 25 | runs-on: ubuntu-latest 26 | strategy: 27 | matrix: 28 | include: 29 | - make: build 30 | go: true 31 | docker: false 32 | - make: test 33 | go: true 34 | docker: false 35 | - make: race 36 | go: true 37 | docker: false 38 | - make: docker-build 39 | go: false 40 | docker: true 41 | - make: e2e-test 42 | go: true 43 | docker: true 44 | - make: goreleaser 45 | go: false 46 | docker: false 47 | - make: lint 48 | go: true 49 | docker: false 50 | 51 | steps: 52 | - name: Check out code into the Go module directory 53 | uses: actions/checkout@v4 54 | 55 | - name: Setup Golang 56 | uses: actions/setup-go@v5 57 | if: matrix.go == true 58 | with: 59 | go-version-file: go.mod 60 | 61 | - name: Download dependencies 62 | run: go mod download 63 | if: matrix.go == true 64 | 65 | - name: Set up Docker Buildx 66 | uses: docker/setup-buildx-action@v3 67 | if: matrix.docker == true 68 | 69 | - name: make ${{ matrix.make }} 70 | run: make ${{ matrix.make }} 71 | if: matrix.make != 'goreleaser' 72 | env: 73 | GO_SKIP_GENERATE: 1 74 | 75 | - name: Upload results to codecov 76 | uses: codecov/codecov-action@v5 77 | if: matrix.make == 'test' && github.repository_owner == '0xERR0R' 78 | 79 | - name: Check GoReleaser configuration 80 | uses: goreleaser/goreleaser-action@v5 81 | if: matrix.make == 'goreleaser' 82 | with: 83 | args: check 84 | -------------------------------------------------------------------------------- /.github/workflows/mirror-repo.yml: -------------------------------------------------------------------------------- 1 | name: mirror git repo 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | 11 | jobs: 12 | mirror: 13 | runs-on: ubuntu-latest 14 | if: github.repository_owner == '0xERR0R' 15 | steps: 16 | - uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | 20 | - uses: yesolutions/mirror-action@master 21 | with: 22 | REMOTE: "https://codeberg.org/0xERR0R/blocky.git" 23 | GIT_USERNAME: 0xERR0R 24 | GIT_PASSWORD: ${{ secrets.CODEBERG_TOKEN }} 25 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | if: github.repository_owner == '0xERR0R' 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 20 | 21 | - name: Set up Go 22 | uses: actions/setup-go@v5 23 | with: 24 | go-version-file: go.mod 25 | id: go 26 | 27 | - name: Build 28 | run: make build 29 | 30 | - name: Test 31 | run: make test 32 | 33 | - name: Docker meta 34 | id: docker_meta 35 | uses: crazy-max/ghaction-docker-meta@v5 36 | with: 37 | images: spx01/blocky,ghcr.io/0xerr0r/blocky 38 | 39 | - name: Set up QEMU 40 | uses: docker/setup-qemu-action@v3 41 | with: 42 | platforms: arm,arm64 43 | 44 | - name: Set up Docker Buildx 45 | uses: docker/setup-buildx-action@v3 46 | 47 | - name: Login to GitHub Container Registry 48 | uses: docker/login-action@v3 49 | with: 50 | registry: ghcr.io 51 | username: ${{ github.repository_owner }} 52 | password: ${{ secrets.CR_PAT }} 53 | 54 | - name: Login to DockerHub 55 | uses: docker/login-action@v3 56 | with: 57 | username: ${{ secrets.DOCKER_USERNAME }} 58 | password: ${{ secrets.DOCKER_TOKEN }} 59 | 60 | - name: Populate build variables 61 | id: get_vars 62 | shell: bash 63 | run: | 64 | VERSION=$(git describe --always --tags) 65 | echo "version=${VERSION}" >> $GITHUB_OUTPUT 66 | echo "VERSION: ${VERSION}" 67 | 68 | BUILD_TIME=$(date --iso-8601=seconds) 69 | echo "build_time=${BUILD_TIME}" >> $GITHUB_OUTPUT 70 | echo "BUILD_TIME: ${BUILD_TIME}" 71 | 72 | DOC_PATH=${VERSION%%-*} 73 | echo "doc_path=${DOC_PATH}" >> $GITHUB_OUTPUT 74 | echo "DOC_PATH: ${DOC_PATH}" 75 | 76 | - name: Build and push 77 | uses: docker/build-push-action@v6 78 | with: 79 | context: . 80 | platforms: linux/amd64,linux/arm/v6,linux/arm/v7,linux/arm64 81 | push: ${{ github.event_name != 'pull_request' }} 82 | tags: ${{ steps.docker_meta.outputs.tags }} 83 | labels: ${{ steps.docker_meta.outputs.labels }} 84 | build-args: | 85 | VERSION=${{ steps.get_vars.outputs.version }} 86 | BUILD_TIME=${{ steps.get_vars.outputs.build_time }} 87 | DOC_PATH=${{ steps.get_vars.outputs.doc_path }} 88 | 89 | 90 | - name: Run GoReleaser 91 | uses: goreleaser/goreleaser-action@v5 92 | with: 93 | version: latest 94 | args: release --clean 95 | env: 96 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.iml 3 | *.test 4 | /*.pem 5 | bin/ 6 | dist/ 7 | docs/docs.go 8 | site/ 9 | config.yml 10 | todo.txt 11 | !docs/config.yml 12 | node_modules 13 | package-lock.json 14 | vendor/ 15 | coverage.html 16 | coverage.txt 17 | coverage/ 18 | blocky 19 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - asciicheck 4 | - bidichk 5 | - bodyclose 6 | - dogsled 7 | - dupl 8 | - durationcheck 9 | - errcheck 10 | - errchkjson 11 | - errorlint 12 | - exhaustive 13 | - funlen 14 | - gochecknoglobals 15 | - gochecknoinits 16 | - gocognit 17 | - goconst 18 | - gocritic 19 | - gocyclo 20 | - godox 21 | - gofmt 22 | - goimports 23 | - mnd 24 | - gomodguard 25 | - gosimple 26 | - govet 27 | - grouper 28 | - importas 29 | - ineffassign 30 | - lll 31 | - makezero 32 | - misspell 33 | - nakedret 34 | - nestif 35 | - nilerr 36 | - nilnil 37 | - nlreturn 38 | - nolintlint 39 | - nosprintfhostport 40 | - prealloc 41 | - predeclared 42 | - revive 43 | - sqlclosecheck 44 | - staticcheck 45 | - stylecheck 46 | - typecheck 47 | - unconvert 48 | - unparam 49 | - unused 50 | - usestdlibvars 51 | - wastedassign 52 | - whitespace 53 | - ginkgolinter 54 | - noctx 55 | - containedctx 56 | - contextcheck 57 | disable: 58 | - forbidigo 59 | - gosmopolitan 60 | - gosec 61 | - recvcheck 62 | disable-all: false 63 | presets: 64 | - bugs 65 | - unused 66 | fast: false 67 | 68 | linters-settings: 69 | mnd: 70 | ignored-numbers: 71 | - "0666" 72 | - "2" 73 | - "5" 74 | ginkgolinter: 75 | forbid-focus-container: true 76 | stylecheck: 77 | # Whietlist dot imports for test packages. 78 | dot-import-whitelist: 79 | - "github.com/onsi/ginkgo/v2" 80 | - "github.com/onsi/gomega" 81 | - "github.com/0xERR0R/blocky/config/migration" 82 | - "github.com/0xERR0R/blocky/helpertest" 83 | revive: 84 | rules: 85 | - name: dot-imports 86 | disabled: true # prefer stylecheck since it's more configurable 87 | 88 | issues: 89 | exclude-rules: 90 | # Exclude some linters from running on tests files. 91 | - path: _test\.go 92 | linters: 93 | - dupl 94 | - funlen 95 | - gochecknoinits 96 | - gochecknoglobals 97 | - gosec 98 | - path: _test\.go 99 | linters: 100 | - staticcheck 101 | text: "SA1012:" 102 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | project_name: blocky 2 | 3 | before: 4 | hooks: 5 | - go mod tidy 6 | builds: 7 | - goos: 8 | - linux 9 | - windows 10 | - freebsd 11 | - netbsd 12 | - openbsd 13 | - darwin 14 | goarch: 15 | - amd64 16 | - arm 17 | - arm64 18 | goarm: 19 | - 6 20 | - 7 21 | ignore: 22 | - goos: windows 23 | goarch: arm 24 | - goos: windows 25 | goarch: arm64 26 | ldflags: 27 | - -w 28 | - -s 29 | - -X github.com/0xERR0R/blocky/util.Version=v{{.Version}} 30 | - -X github.com/0xERR0R/blocky/util.BuildTime={{time "20060102-150405"}} 31 | - -X github.com/0xERR0R/blocky/util.Architecture={{.Arch}}{{.Arm}} 32 | env: 33 | - CGO_ENABLED=0 34 | release: 35 | draft: true 36 | archives: 37 | - format_overrides: 38 | - goos: windows 39 | format: zip 40 | name_template: >- 41 | {{ .ProjectName }}_v 42 | {{- .Version }}_ 43 | {{- title .Os }}_ 44 | {{- if eq .Arch "amd64" }}x86_64 45 | {{- else if eq .Arch "386" }}i386 46 | {{- else }}{{ .Arch }}{{ end }} 47 | {{- if .Arm }}v{{ .Arm }}{{ end }} 48 | 49 | snapshot: 50 | name_template: "{{ .Version }}-{{.ShortCommit}}" 51 | checksum: 52 | name_template: "{{ .ProjectName }}_checksums.txt" 53 | changelog: 54 | use: github 55 | sort: asc 56 | filters: 57 | exclude: 58 | - '^docs:' 59 | - '^chore:' 60 | - '^test:' 61 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Launch Package", 6 | "type": "go", 7 | "request": "launch", 8 | "mode": "auto", 9 | "program": "${fileDirname}" 10 | } 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.tabSize": 2, 3 | "editor.insertSpaces": true, 4 | "editor.detectIndentation": false, 5 | "editor.formatOnSave": true, 6 | "editor.formatOnPaste": true, 7 | "editor.codeActionsOnSave": { 8 | "source.organizeImports": "explicit", 9 | "source.fixAll": "explicit" 10 | }, 11 | "editor.rulers": [120], 12 | "go.showWelcome": false, 13 | "go.survey.prompt": false, 14 | "go.useLanguageServer": true, 15 | "go.formatTool": "gofumpt", 16 | "go.lintTool": "golangci-lint", 17 | "go.lintOnSave": "workspace", 18 | "gopls": { 19 | "ui.semanticTokens": true, 20 | "formatting.gofumpt": true, 21 | "build.standaloneTags": ["ignore", "tools"] 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "label": "Test", 6 | "type": "shell", 7 | "command": "make test", 8 | "group": { 9 | "kind": "test", 10 | "isDefault": true 11 | }, 12 | "presentation": { 13 | "reveal": "always" 14 | } 15 | }, 16 | { 17 | "label": "Race", 18 | "type": "shell", 19 | "command": "make race", 20 | "group": { 21 | "kind": "test", 22 | "isDefault": false 23 | }, 24 | "presentation": { 25 | "reveal": "always" 26 | } 27 | }, 28 | { 29 | "label": "e2e - Test", 30 | "type": "shell", 31 | "command": "make e2e-test", 32 | "group": { 33 | "kind": "test", 34 | "isDefault": false 35 | }, 36 | "presentation": { 37 | "reveal": "always" 38 | } 39 | }, 40 | { 41 | "label": "Build", 42 | "type": "shell", 43 | "command": "make build", 44 | "group": { 45 | "group": { 46 | "kind": "build", 47 | "isDefault": true 48 | }, 49 | "isDefault": true 50 | }, 51 | "presentation": { 52 | "reveal": "always" 53 | } 54 | }, 55 | { 56 | "label": "FMT", 57 | "type": "shell", 58 | "command": "make fmt", 59 | "group": "none", 60 | "presentation": { 61 | "reveal": "always" 62 | } 63 | }, 64 | { 65 | "label": "Tidy", 66 | "type": "shell", 67 | "command": "go mod tidy", 68 | "group": "none", 69 | "presentation": { 70 | "reveal": "always" 71 | } 72 | } 73 | ] 74 | } 75 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by creating an issue. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | 3 | # ----------- stage: build 4 | FROM golang:alpine AS build 5 | RUN apk add --no-cache make coreutils libcap 6 | 7 | # required arguments 8 | ARG VERSION 9 | ARG BUILD_TIME 10 | 11 | COPY . . 12 | # setup go 13 | ENV GO_SKIP_GENERATE=1\ 14 | GO_BUILD_FLAGS="-tags static -v " \ 15 | BIN_USER=100\ 16 | BIN_AUTOCAB=1 \ 17 | BIN_OUT_DIR="/bin" 18 | 19 | RUN make build 20 | 21 | # ----------- stage: final 22 | FROM scratch 23 | 24 | ARG VERSION 25 | ARG BUILD_TIME 26 | ARG DOC_PATH 27 | 28 | LABEL org.opencontainers.image.title="blocky" \ 29 | org.opencontainers.image.vendor="0xERR0R" \ 30 | org.opencontainers.image.licenses="Apache-2.0" \ 31 | org.opencontainers.image.version="${VERSION}" \ 32 | org.opencontainers.image.created="${BUILD_TIME}" \ 33 | org.opencontainers.image.description="Fast and lightweight DNS proxy as ad-blocker for local network with many features" \ 34 | org.opencontainers.image.url="https://github.com/0xERR0R/blocky#readme" \ 35 | org.opencontainers.image.source="https://github.com/0xERR0R/blocky" \ 36 | org.opencontainers.image.documentation="https://0xerr0r.github.io/blocky/${DOC_PATH}/" 37 | 38 | 39 | 40 | USER 100 41 | WORKDIR /app 42 | 43 | COPY --from=build /bin/blocky /app/blocky 44 | 45 | ENV BLOCKY_CONFIG_FILE=/app/config.yml 46 | 47 | ENTRYPOINT ["/app/blocky"] 48 | 49 | HEALTHCHECK --start-period=1m --timeout=3s CMD ["/app/blocky", "healthcheck"] 50 | -------------------------------------------------------------------------------- /api/api_suite_test.go: -------------------------------------------------------------------------------- 1 | package api_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | func init() { 12 | log.Silence() 13 | } 14 | 15 | func TestResolver(t *testing.T) { 16 | RegisterFailHandler(Fail) 17 | RunSpecs(t, "API Suite") 18 | } 19 | -------------------------------------------------------------------------------- /api/api_types.gen.go: -------------------------------------------------------------------------------- 1 | // Package api provides primitives to interact with the openapi HTTP API. 2 | // 3 | // Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.4.1 DO NOT EDIT. 4 | package api 5 | 6 | // ApiBlockingStatus defines model for api.BlockingStatus. 7 | type ApiBlockingStatus struct { 8 | // AutoEnableInSec If blocking is temporary disabled: amount of seconds until blocking will be enabled 9 | AutoEnableInSec *int `json:"autoEnableInSec,omitempty"` 10 | 11 | // DisabledGroups Disabled group names 12 | DisabledGroups *[]string `json:"disabledGroups,omitempty"` 13 | 14 | // Enabled True if blocking is enabled 15 | Enabled bool `json:"enabled"` 16 | } 17 | 18 | // ApiQueryRequest defines model for api.QueryRequest. 19 | type ApiQueryRequest struct { 20 | // Query query for DNS request 21 | Query string `json:"query"` 22 | 23 | // Type request type (A, AAAA, ...) 24 | Type string `json:"type"` 25 | } 26 | 27 | // ApiQueryResult defines model for api.QueryResult. 28 | type ApiQueryResult struct { 29 | // Reason blocky reason for resolution 30 | Reason string `json:"reason"` 31 | 32 | // Response actual DNS response 33 | Response string `json:"response"` 34 | 35 | // ResponseType response type (CACHED, BLOCKED, ...) 36 | ResponseType string `json:"responseType"` 37 | 38 | // ReturnCode DNS return code (NOERROR, NXDOMAIN, ...) 39 | ReturnCode string `json:"returnCode"` 40 | } 41 | 42 | // DisableBlockingParams defines parameters for DisableBlocking. 43 | type DisableBlockingParams struct { 44 | // Duration duration of blocking (Example: 300s, 5m, 1h, 5m30s) 45 | Duration *string `form:"duration,omitempty" json:"duration,omitempty"` 46 | 47 | // Groups groups to disable (comma separated). If empty, disable all groups 48 | Groups *string `form:"groups,omitempty" json:"groups,omitempty"` 49 | } 50 | 51 | // QueryJSONRequestBody defines body for Query for application/json ContentType. 52 | type QueryJSONRequestBody = ApiQueryRequest 53 | -------------------------------------------------------------------------------- /api/client.cfg.yaml: -------------------------------------------------------------------------------- 1 | package: api 2 | generate: 3 | client: true 4 | output: api_client.gen.go -------------------------------------------------------------------------------- /api/server.cfg.yaml: -------------------------------------------------------------------------------- 1 | package: api 2 | generate: 3 | chi-server: true 4 | strict-server: true 5 | embedded-spec: false 6 | output: api_server.gen.go -------------------------------------------------------------------------------- /api/types.cfg.yaml: -------------------------------------------------------------------------------- 1 | package: api 2 | generate: 3 | models: true 4 | output: api_types.gen.go -------------------------------------------------------------------------------- /cache/expirationcache/cache_interface.go: -------------------------------------------------------------------------------- 1 | package expirationcache 2 | 3 | import "time" 4 | 5 | type ExpiringCache[T any] interface { 6 | // Put adds the value to the cache unter the passed key with expiration. If expiration <=0, entry will NOT be cached 7 | Put(key string, val *T, expiration time.Duration) 8 | 9 | // Get returns the value of cached entry with remained TTL. If entry is not cached, returns nil 10 | Get(key string) (val *T, expiration time.Duration) 11 | 12 | // TotalCount returns the total count of valid (not expired) elements 13 | TotalCount() int 14 | 15 | // Clear removes all cache entries 16 | Clear() 17 | } 18 | -------------------------------------------------------------------------------- /cache/expirationcache/expiration_cache_suite_test.go: -------------------------------------------------------------------------------- 1 | package expirationcache_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestCache(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "Expiration cache suite") 19 | } 20 | -------------------------------------------------------------------------------- /cache/stringcache/chained_grouped_cache.go: -------------------------------------------------------------------------------- 1 | package stringcache 2 | 3 | import ( 4 | "sort" 5 | 6 | "golang.org/x/exp/maps" 7 | ) 8 | 9 | type ChainedGroupedCache struct { 10 | caches []GroupedStringCache 11 | } 12 | 13 | func NewChainedGroupedCache(caches ...GroupedStringCache) *ChainedGroupedCache { 14 | return &ChainedGroupedCache{ 15 | caches: caches, 16 | } 17 | } 18 | 19 | func (c *ChainedGroupedCache) ElementCount(group string) int { 20 | sum := 0 21 | for _, cache := range c.caches { 22 | sum += cache.ElementCount(group) 23 | } 24 | 25 | return sum 26 | } 27 | 28 | func (c *ChainedGroupedCache) Contains(searchString string, groups []string) []string { 29 | groupMatchedMap := make(map[string]struct{}, len(groups)) 30 | 31 | for _, cache := range c.caches { 32 | for _, group := range cache.Contains(searchString, groups) { 33 | groupMatchedMap[group] = struct{}{} 34 | } 35 | } 36 | 37 | matchedGroups := maps.Keys(groupMatchedMap) 38 | 39 | sort.Strings(matchedGroups) 40 | 41 | return matchedGroups 42 | } 43 | 44 | func (c *ChainedGroupedCache) Refresh(group string) GroupFactory { 45 | cacheFactories := make([]GroupFactory, len(c.caches)) 46 | for i, cache := range c.caches { 47 | cacheFactories[i] = cache.Refresh(group) 48 | } 49 | 50 | return &chainedGroupFactory{ 51 | cacheFactories: cacheFactories, 52 | } 53 | } 54 | 55 | type chainedGroupFactory struct { 56 | cacheFactories []GroupFactory 57 | } 58 | 59 | func (c *chainedGroupFactory) AddEntry(entry string) bool { 60 | for _, factory := range c.cacheFactories { 61 | if factory.AddEntry(entry) { 62 | return true 63 | } 64 | } 65 | 66 | return false 67 | } 68 | 69 | func (c *chainedGroupFactory) Count() int { 70 | var cnt int 71 | for _, factory := range c.cacheFactories { 72 | cnt += factory.Count() 73 | } 74 | 75 | return cnt 76 | } 77 | 78 | func (c *chainedGroupFactory) Finish() { 79 | for _, factory := range c.cacheFactories { 80 | factory.Finish() 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /cache/stringcache/grouped_cache_interface.go: -------------------------------------------------------------------------------- 1 | package stringcache 2 | 3 | type GroupedStringCache interface { 4 | // Contains checks if one or more groups in the cache contains the search string. 5 | // Returns group(s) containing the string or empty slice if string was not found 6 | Contains(searchString string, groups []string) []string 7 | 8 | // Refresh creates new factory for the group to be refreshed. 9 | // Calling Finish on the factory will perform the group refresh. 10 | Refresh(group string) GroupFactory 11 | 12 | // ElementCount returns the amount of elements in the group 13 | ElementCount(group string) int 14 | } 15 | 16 | type GroupFactory interface { 17 | // AddEntry adds a new string to the factory to be added later to the cache groups. 18 | AddEntry(entry string) bool 19 | 20 | // Count returns amount of processed string in the factory 21 | Count() int 22 | 23 | // Finish replaces the group in cache with factory's content 24 | Finish() 25 | } 26 | -------------------------------------------------------------------------------- /cache/stringcache/in_memory_grouped_cache.go: -------------------------------------------------------------------------------- 1 | package stringcache 2 | 3 | import "sync" 4 | 5 | type stringCacheFactoryFn func() cacheFactory 6 | 7 | type InMemoryGroupedCache struct { 8 | caches map[string]stringCache 9 | lock sync.RWMutex 10 | factoryFn stringCacheFactoryFn 11 | } 12 | 13 | func NewInMemoryGroupedStringCache() *InMemoryGroupedCache { 14 | return &InMemoryGroupedCache{ 15 | caches: make(map[string]stringCache), 16 | factoryFn: newStringCacheFactory, 17 | } 18 | } 19 | 20 | func NewInMemoryGroupedRegexCache() *InMemoryGroupedCache { 21 | return &InMemoryGroupedCache{ 22 | caches: make(map[string]stringCache), 23 | factoryFn: newRegexCacheFactory, 24 | } 25 | } 26 | 27 | func NewInMemoryGroupedWildcardCache() *InMemoryGroupedCache { 28 | return &InMemoryGroupedCache{ 29 | caches: make(map[string]stringCache), 30 | factoryFn: newWildcardCacheFactory, 31 | } 32 | } 33 | 34 | func (c *InMemoryGroupedCache) ElementCount(group string) int { 35 | c.lock.RLock() 36 | cache, found := c.caches[group] 37 | c.lock.RUnlock() 38 | 39 | if !found { 40 | return 0 41 | } 42 | 43 | return cache.elementCount() 44 | } 45 | 46 | func (c *InMemoryGroupedCache) Contains(searchString string, groups []string) []string { 47 | var result []string 48 | 49 | for _, group := range groups { 50 | c.lock.RLock() 51 | cache, found := c.caches[group] 52 | c.lock.RUnlock() 53 | 54 | if found && cache.contains(searchString) { 55 | result = append(result, group) 56 | } 57 | } 58 | 59 | return result 60 | } 61 | 62 | func (c *InMemoryGroupedCache) Refresh(group string) GroupFactory { 63 | return &inMemoryGroupFactory{ 64 | factory: c.factoryFn(), 65 | finishFn: func(sc stringCache) { 66 | c.lock.Lock() 67 | defer c.lock.Unlock() 68 | 69 | if sc != nil { 70 | c.caches[group] = sc 71 | } else { 72 | delete(c.caches, group) 73 | } 74 | }, 75 | } 76 | } 77 | 78 | type inMemoryGroupFactory struct { 79 | factory cacheFactory 80 | finishFn func(stringCache) 81 | } 82 | 83 | func (c *inMemoryGroupFactory) AddEntry(entry string) bool { 84 | return c.factory.addEntry(entry) 85 | } 86 | 87 | func (c *inMemoryGroupFactory) Count() int { 88 | return c.factory.count() 89 | } 90 | 91 | func (c *inMemoryGroupFactory) Finish() { 92 | sc := c.factory.create() 93 | c.finishFn(sc) 94 | } 95 | -------------------------------------------------------------------------------- /cache/stringcache/string_cache_suite_test.go: -------------------------------------------------------------------------------- 1 | package stringcache_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestCache(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "String cache suite") 19 | } 20 | -------------------------------------------------------------------------------- /cmd/cache.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/0xERR0R/blocky/api" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | func newCacheCommand() *cobra.Command { 12 | c := &cobra.Command{ 13 | Use: "cache", 14 | Short: "Performs cache operations", 15 | PersistentPreRunE: initConfigPreRun, 16 | } 17 | c.AddCommand(&cobra.Command{ 18 | Use: "flush", 19 | Args: cobra.NoArgs, 20 | Aliases: []string{"clear"}, 21 | Short: "Flush cache", 22 | RunE: flushCache, 23 | }) 24 | 25 | return c 26 | } 27 | 28 | func flushCache(_ *cobra.Command, _ []string) error { 29 | client, err := api.NewClientWithResponses(apiURL()) 30 | if err != nil { 31 | return fmt.Errorf("can't create client: %w", err) 32 | } 33 | 34 | resp, err := client.CacheFlushWithResponse(context.Background()) 35 | if err != nil { 36 | return fmt.Errorf("can't execute %w", err) 37 | } 38 | 39 | return printOkOrError(resp, string(resp.Body)) 40 | } 41 | -------------------------------------------------------------------------------- /cmd/cache_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | "github.com/sirupsen/logrus/hooks/test" 8 | 9 | "github.com/0xERR0R/blocky/log" 10 | 11 | . "github.com/onsi/ginkgo/v2" 12 | . "github.com/onsi/gomega" 13 | ) 14 | 15 | var _ = Describe("Cache command", func() { 16 | var ( 17 | ts *httptest.Server 18 | mockFn func(w http.ResponseWriter, _ *http.Request) 19 | loggerHook *test.Hook 20 | ) 21 | JustBeforeEach(func() { 22 | ts = testHTTPAPIServer(mockFn) 23 | }) 24 | JustAfterEach(func() { 25 | ts.Close() 26 | }) 27 | BeforeEach(func() { 28 | mockFn = func(w http.ResponseWriter, _ *http.Request) {} 29 | loggerHook = test.NewGlobal() 30 | log.Log().AddHook(loggerHook) 31 | }) 32 | AfterEach(func() { 33 | loggerHook.Reset() 34 | }) 35 | Describe("flush cache", func() { 36 | When("flush cache is called via REST", func() { 37 | It("should flush caches", func() { 38 | Expect(flushCache(newCacheCommand(), []string{})).Should(Succeed()) 39 | Expect(loggerHook.LastEntry().Message).Should(Equal("OK")) 40 | }) 41 | }) 42 | When("Wrong url is used", func() { 43 | It("Should end with error", func() { 44 | apiPort = 0 45 | err := flushCache(newCacheCommand(), []string{}) 46 | Expect(err).Should(HaveOccurred()) 47 | Expect(err.Error()).Should(ContainSubstring("connection refused")) 48 | }) 49 | }) 50 | }) 51 | }) 52 | -------------------------------------------------------------------------------- /cmd/cmd_suite_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestCmd(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "Command Suite") 19 | } 20 | -------------------------------------------------------------------------------- /cmd/healthcheck.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | const ( 12 | defaultDNSPort = 53 13 | defaultIPAddress = "127.0.0.1" 14 | ) 15 | 16 | func NewHealthcheckCommand() *cobra.Command { 17 | c := &cobra.Command{ 18 | Use: "healthcheck", 19 | Short: "performs healthcheck", 20 | RunE: healthcheck, 21 | } 22 | 23 | c.Flags().Uint16P("port", "p", defaultDNSPort, "blocky port") 24 | c.Flags().StringP("bindip", "b", defaultIPAddress, "blocky host binding ip address") 25 | 26 | return c 27 | } 28 | 29 | func healthcheck(cmd *cobra.Command, args []string) error { 30 | _ = args 31 | port, _ := cmd.Flags().GetUint16("port") 32 | bindIP, _ := cmd.Flags().GetString("bindip") 33 | 34 | c := new(dns.Client) 35 | c.Net = "tcp" 36 | m := new(dns.Msg) 37 | m.SetQuestion("healthcheck.blocky.", dns.TypeA) 38 | 39 | _, _, err := c.Exchange(m, net.JoinHostPort(bindIP, fmt.Sprintf("%d", port))) 40 | 41 | if err == nil { 42 | fmt.Println("OK") 43 | } else { 44 | fmt.Println("NOT OK") 45 | } 46 | 47 | return err 48 | } 49 | -------------------------------------------------------------------------------- /cmd/healthcheck_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/0xERR0R/blocky/helpertest" 7 | "github.com/miekg/dns" 8 | 9 | . "github.com/onsi/ginkgo/v2" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | var _ = Describe("Healthcheck command", func() { 14 | Describe("Call healthcheck command", func() { 15 | It("should fail", func() { 16 | c := NewHealthcheckCommand() 17 | c.SetArgs([]string{"-p", "533"}) 18 | 19 | err := c.Execute() 20 | 21 | Expect(err).Should(HaveOccurred()) 22 | }) 23 | 24 | It("should fail", func() { 25 | c := NewHealthcheckCommand() 26 | c.SetArgs([]string{"-b", "127.0.2.9"}) 27 | 28 | err := c.Execute() 29 | 30 | Expect(err).Should(HaveOccurred()) 31 | }) 32 | 33 | It("should succeed", func() { 34 | ip := "127.0.0.1" 35 | hostPort := helpertest.GetHostPort(ip, 65100) 36 | port := helpertest.GetStringPort(65100) 37 | srv := createMockServer(hostPort) 38 | go func() { 39 | defer GinkgoRecover() 40 | err := srv.ListenAndServe() 41 | Expect(err).Should(Succeed()) 42 | }() 43 | 44 | Eventually(func() error { 45 | c := NewHealthcheckCommand() 46 | c.SetArgs([]string{"-p", port, "-b", ip}) 47 | 48 | return c.Execute() 49 | }, "1s").Should(Succeed()) 50 | }) 51 | }) 52 | }) 53 | 54 | func createMockServer(hostPort string) *dns.Server { 55 | res := &dns.Server{ 56 | Addr: hostPort, 57 | Net: "tcp", 58 | Handler: dns.NewServeMux(), 59 | NotifyStartedFunc: func() { 60 | fmt.Printf("Mock healthcheck server is up: %s\n", hostPort) 61 | }, 62 | } 63 | 64 | th := res.Handler.(*dns.ServeMux) 65 | th.HandleFunc("healthcheck.blocky", func(w dns.ResponseWriter, request *dns.Msg) { 66 | resp := new(dns.Msg) 67 | resp.SetReply(request) 68 | resp.Rcode = dns.RcodeSuccess 69 | 70 | err := w.WriteMsg(resp) 71 | Expect(err).Should(Succeed()) 72 | }) 73 | 74 | DeferCleanup(res.Shutdown) 75 | 76 | return res 77 | } 78 | -------------------------------------------------------------------------------- /cmd/lists.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/0xERR0R/blocky/api" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | // NewListsCommand creates new command instance 12 | func NewListsCommand() *cobra.Command { 13 | c := &cobra.Command{ 14 | Use: "lists", 15 | Short: "lists operations", 16 | PersistentPreRunE: initConfigPreRun, 17 | } 18 | 19 | c.AddCommand(newRefreshCommand()) 20 | 21 | return c 22 | } 23 | 24 | func newRefreshCommand() *cobra.Command { 25 | return &cobra.Command{ 26 | Use: "refresh", 27 | Short: "refreshes all lists", 28 | RunE: refreshList, 29 | } 30 | } 31 | 32 | func refreshList(_ *cobra.Command, _ []string) error { 33 | client, err := api.NewClientWithResponses(apiURL()) 34 | if err != nil { 35 | return fmt.Errorf("can't create client: %w", err) 36 | } 37 | 38 | resp, err := client.ListRefreshWithResponse(context.Background()) 39 | if err != nil { 40 | return fmt.Errorf("can't execute %w", err) 41 | } 42 | 43 | return printOkOrError(resp, string(resp.Body)) 44 | } 45 | -------------------------------------------------------------------------------- /cmd/lists_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | 7 | "github.com/0xERR0R/blocky/log" 8 | "github.com/sirupsen/logrus/hooks/test" 9 | "github.com/spf13/cobra" 10 | 11 | . "github.com/onsi/ginkgo/v2" 12 | . "github.com/onsi/gomega" 13 | ) 14 | 15 | var _ = Describe("Lists command", func() { 16 | var ( 17 | ts *httptest.Server 18 | mockFn func(w http.ResponseWriter, _ *http.Request) 19 | loggerHook *test.Hook 20 | c *cobra.Command 21 | err error 22 | ) 23 | JustBeforeEach(func() { 24 | ts = testHTTPAPIServer(mockFn) 25 | }) 26 | JustAfterEach(func() { 27 | ts.Close() 28 | }) 29 | BeforeEach(func() { 30 | mockFn = func(w http.ResponseWriter, _ *http.Request) {} 31 | loggerHook = test.NewGlobal() 32 | log.Log().AddHook(loggerHook) 33 | }) 34 | AfterEach(func() { 35 | loggerHook.Reset() 36 | }) 37 | Describe("Call list refresh command", func() { 38 | When("list refresh is executed", func() { 39 | BeforeEach(func() { 40 | c = NewListsCommand() 41 | c.SetArgs([]string{"refresh"}) 42 | }) 43 | It("should print result", func() { 44 | err = c.Execute() 45 | Expect(err).Should(Succeed()) 46 | 47 | Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("OK")) 48 | }) 49 | }) 50 | When("Server returns 500", func() { 51 | BeforeEach(func() { 52 | c = newRefreshCommand() 53 | c.SetArgs(make([]string, 0)) 54 | mockFn = func(w http.ResponseWriter, _ *http.Request) { 55 | w.WriteHeader(http.StatusInternalServerError) 56 | } 57 | }) 58 | It("should end with error", func() { 59 | err = c.Execute() 60 | Expect(err).Should(HaveOccurred()) 61 | Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error")) 62 | }) 63 | }) 64 | When("Url is wrong", func() { 65 | BeforeEach(func() { 66 | c = newRefreshCommand() 67 | c.SetArgs(make([]string, 0)) 68 | }) 69 | It("should end with error", func() { 70 | apiPort = 0 71 | err = c.Execute() 72 | Expect(err).Should(HaveOccurred()) 73 | Expect(err.Error()).Should(ContainSubstring("connection refused")) 74 | }) 75 | }) 76 | }) 77 | }) 78 | -------------------------------------------------------------------------------- /cmd/query.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/0xERR0R/blocky/api" 9 | "github.com/0xERR0R/blocky/log" 10 | "github.com/miekg/dns" 11 | "github.com/spf13/cobra" 12 | ) 13 | 14 | // NewQueryCommand creates new command instance 15 | func NewQueryCommand() *cobra.Command { 16 | c := &cobra.Command{ 17 | Use: "query <domain>", 18 | Args: cobra.ExactArgs(1), 19 | Short: "performs DNS query", 20 | RunE: query, 21 | PersistentPreRunE: initConfigPreRun, 22 | } 23 | 24 | c.Flags().StringP("type", "t", "A", "query type (A, AAAA, ...)") 25 | 26 | return c 27 | } 28 | 29 | func query(cmd *cobra.Command, args []string) error { 30 | typeFlag, _ := cmd.Flags().GetString("type") 31 | qType := dns.StringToType[typeFlag] 32 | 33 | if qType == dns.TypeNone { 34 | return fmt.Errorf("unknown query type '%s'", typeFlag) 35 | } 36 | 37 | client, err := api.NewClientWithResponses(apiURL()) 38 | if err != nil { 39 | return fmt.Errorf("can't create client: %w", err) 40 | } 41 | 42 | req := api.ApiQueryRequest{ 43 | Query: args[0], 44 | Type: typeFlag, 45 | } 46 | 47 | resp, err := client.QueryWithResponse(context.Background(), req) 48 | if err != nil { 49 | return fmt.Errorf("can't execute %w", err) 50 | } 51 | 52 | if resp.StatusCode() != http.StatusOK { 53 | return fmt.Errorf("response NOK, %s %s", resp.Status(), string(resp.Body)) 54 | } 55 | 56 | log.Log().Infof("Query result for '%s' (%s):", req.Query, req.Type) 57 | log.Log().Infof("\treason: %20s", resp.JSON200.Reason) 58 | log.Log().Infof("\tresponse type: %20s", resp.JSON200.ResponseType) 59 | log.Log().Infof("\tresponse: %20s", resp.JSON200.Response) 60 | log.Log().Infof("\treturn code: %20s", resp.JSON200.ReturnCode) 61 | 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /cmd/query_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | 8 | "github.com/0xERR0R/blocky/log" 9 | "github.com/sirupsen/logrus/hooks/test" 10 | 11 | "github.com/0xERR0R/blocky/api" 12 | 13 | . "github.com/onsi/ginkgo/v2" 14 | . "github.com/onsi/gomega" 15 | ) 16 | 17 | var _ = Describe("Blocking command", func() { 18 | var ( 19 | ts *httptest.Server 20 | mockFn func(w http.ResponseWriter, _ *http.Request) 21 | loggerHook *test.Hook 22 | ) 23 | JustBeforeEach(func() { 24 | ts = testHTTPAPIServer(mockFn) 25 | }) 26 | JustAfterEach(func() { 27 | ts.Close() 28 | }) 29 | BeforeEach(func() { 30 | mockFn = func(w http.ResponseWriter, _ *http.Request) {} 31 | loggerHook = test.NewGlobal() 32 | log.Log().AddHook(loggerHook) 33 | }) 34 | AfterEach(func() { 35 | loggerHook.Reset() 36 | }) 37 | Describe("Call query command", func() { 38 | BeforeEach(func() { 39 | mockFn = func(w http.ResponseWriter, _ *http.Request) { 40 | response, err := json.Marshal(api.ApiQueryResult{ 41 | Reason: "Reason", 42 | ResponseType: "Type", 43 | Response: "Response", 44 | ReturnCode: "NOERROR", 45 | }) 46 | Expect(err).Should(Succeed()) 47 | 48 | _, err = w.Write(response) 49 | Expect(err).Should(Succeed()) 50 | } 51 | }) 52 | When("query command is called via REST", func() { 53 | BeforeEach(func() { 54 | mockFn = func(w http.ResponseWriter, _ *http.Request) { 55 | w.Header().Add("Content-Type", "application/json") 56 | response, err := json.Marshal(api.ApiQueryResult{ 57 | Reason: "Reason", 58 | ResponseType: "Type", 59 | Response: "Response", 60 | ReturnCode: "NOERROR", 61 | }) 62 | Expect(err).Should(Succeed()) 63 | 64 | _, err = w.Write(response) 65 | Expect(err).Should(Succeed()) 66 | } 67 | }) 68 | It("should print result", func() { 69 | Expect(query(NewQueryCommand(), []string{"google.de"})).Should(Succeed()) 70 | Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("NOERROR")) 71 | }) 72 | }) 73 | When("Server returns 500", func() { 74 | BeforeEach(func() { 75 | mockFn = func(w http.ResponseWriter, _ *http.Request) { 76 | w.WriteHeader(http.StatusInternalServerError) 77 | } 78 | }) 79 | It("should end with error", func() { 80 | err := query(NewQueryCommand(), []string{"google.de"}) 81 | Expect(err).Should(HaveOccurred()) 82 | Expect(err.Error()).Should(ContainSubstring("500 Internal Server Error")) 83 | }) 84 | }) 85 | When("Type is wrong", func() { 86 | It("should end with error", func() { 87 | command := NewQueryCommand() 88 | command.SetArgs([]string{"--type", "X", "google.de"}) 89 | err := command.Execute() 90 | Expect(err).Should(HaveOccurred()) 91 | Expect(err.Error()).Should(ContainSubstring("unknown query type 'X'")) 92 | }) 93 | }) 94 | When("Url is wrong", func() { 95 | It("should end with error", func() { 96 | apiPort = 0 97 | err := query(NewQueryCommand(), []string{"google.de"}) 98 | Expect(err).Should(HaveOccurred()) 99 | Expect(err.Error()).Should(ContainSubstring("connection refused")) 100 | }) 101 | }) 102 | }) 103 | }) 104 | -------------------------------------------------------------------------------- /cmd/serve.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | 10 | "github.com/0xERR0R/blocky/config" 11 | "github.com/0xERR0R/blocky/evt" 12 | "github.com/0xERR0R/blocky/log" 13 | "github.com/0xERR0R/blocky/server" 14 | "github.com/0xERR0R/blocky/util" 15 | 16 | "github.com/spf13/cobra" 17 | ) 18 | 19 | //nolint:gochecknoglobals 20 | var ( 21 | done = make(chan bool, 1) 22 | isConfigMandatory = true 23 | signals = make(chan os.Signal, 1) 24 | ) 25 | 26 | func newServeCommand() *cobra.Command { 27 | return &cobra.Command{ 28 | Use: "serve", 29 | Args: cobra.NoArgs, 30 | Short: "start blocky DNS server (default command)", 31 | RunE: startServer, 32 | PersistentPreRunE: initConfigPreRun, 33 | SilenceUsage: true, 34 | } 35 | } 36 | 37 | func startServer(_ *cobra.Command, _ []string) error { 38 | printBanner() 39 | 40 | cfg, err := config.LoadConfig(configPath, isConfigMandatory) 41 | if err != nil { 42 | return fmt.Errorf("unable to load configuration: %w", err) 43 | } 44 | 45 | log.Configure(&cfg.Log) 46 | 47 | signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) 48 | 49 | ctx, cancelFn := context.WithCancel(context.Background()) 50 | defer cancelFn() 51 | 52 | srv, err := server.NewServer(ctx, cfg) 53 | if err != nil { 54 | return fmt.Errorf("can't start server: %w", err) 55 | } 56 | 57 | const errChanSize = 10 58 | errChan := make(chan error, errChanSize) 59 | 60 | srv.Start(ctx, errChan) 61 | 62 | var terminationErr error 63 | 64 | go func() { 65 | select { 66 | case <-signals: 67 | log.Log().Infof("Terminating...") 68 | util.LogOnError(ctx, "can't stop server: ", srv.Stop(ctx)) 69 | done <- true 70 | 71 | case err := <-errChan: 72 | log.Log().Error("server start failed: ", err) 73 | terminationErr = err 74 | done <- true 75 | } 76 | }() 77 | 78 | evt.Bus().Publish(evt.ApplicationStarted, util.Version, util.BuildTime) 79 | <-done 80 | 81 | return terminationErr 82 | } 83 | 84 | func printBanner() { 85 | log.Log().Info("_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/") 86 | log.Log().Info("_/ _/") 87 | log.Log().Info("_/ _/") 88 | log.Log().Info("_/ _/ _/ _/ _/") 89 | log.Log().Info("_/ _/_/_/ _/ _/_/ _/_/_/ _/ _/ _/ _/ _/") 90 | log.Log().Info("_/ _/ _/ _/ _/ _/ _/ _/_/ _/ _/ _/") 91 | log.Log().Info("_/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/") 92 | log.Log().Info("_/ _/_/_/ _/ _/_/ _/_/_/ _/ _/ _/_/_/ _/") 93 | log.Log().Info("_/ _/ _/") 94 | log.Log().Info("_/ _/_/ _/") 95 | log.Log().Info("_/ _/") 96 | log.Log().Info("_/ _/") 97 | log.Log().Infof("_/ Version: %-18s Build time: %-18s _/", util.Version, util.BuildTime) 98 | log.Log().Info("_/ _/") 99 | log.Log().Info("_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/") 100 | } 101 | -------------------------------------------------------------------------------- /cmd/validate.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | 7 | "github.com/0xERR0R/blocky/log" 8 | 9 | "github.com/spf13/cobra" 10 | ) 11 | 12 | // NewValidateCommand creates new command instance 13 | func NewValidateCommand() *cobra.Command { 14 | return &cobra.Command{ 15 | Use: "validate", 16 | Args: cobra.NoArgs, 17 | Short: "Validates the configuration", 18 | RunE: validateConfiguration, 19 | } 20 | } 21 | 22 | func validateConfiguration(_ *cobra.Command, _ []string) error { 23 | log.Log().Infof("Validating configuration file: %s", configPath) 24 | 25 | _, err := os.Stat(configPath) 26 | if err != nil && errors.Is(err, os.ErrNotExist) { 27 | return errors.New("configuration path does not exist") 28 | } 29 | 30 | err = initConfig() 31 | if err != nil { 32 | return err 33 | } 34 | 35 | log.Log().Info("Configuration is valid") 36 | 37 | return nil 38 | } 39 | -------------------------------------------------------------------------------- /cmd/validate_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/0xERR0R/blocky/helpertest" 5 | . "github.com/onsi/ginkgo/v2" 6 | . "github.com/onsi/gomega" 7 | ) 8 | 9 | var _ = Describe("Validate command", func() { 10 | var tmpDir *helpertest.TmpFolder 11 | BeforeEach(func() { 12 | tmpDir = helpertest.NewTmpFolder("config") 13 | }) 14 | When("Validate is called with not existing configuration file", func() { 15 | It("should terminate with error", func() { 16 | c := NewRootCommand() 17 | c.SetArgs([]string{"validate", "--config", "/notexisting/path.yaml"}) 18 | 19 | Expect(c.Execute()).Should(HaveOccurred()) 20 | }) 21 | }) 22 | 23 | When("Validate is called with existing valid configuration file", func() { 24 | It("should terminate without error", func() { 25 | cfgFile := tmpDir.CreateStringFile("config.yaml", 26 | "upstreams:", 27 | " groups:", 28 | " default:", 29 | " - 1.1.1.1") 30 | 31 | c := NewRootCommand() 32 | c.SetArgs([]string{"validate", "--config", cfgFile.Path}) 33 | 34 | Expect(c.Execute()).Should(Succeed()) 35 | }) 36 | }) 37 | 38 | When("Validate is called with existing invalid configuration file", func() { 39 | It("should terminate with error", func() { 40 | cfgFile := tmpDir.CreateStringFile("config.yaml", 41 | "upstreams:", 42 | " groups:", 43 | " default:", 44 | " - 1.broken file") 45 | 46 | c := NewRootCommand() 47 | c.SetArgs([]string{"validate", "--config", cfgFile.Path}) 48 | 49 | Expect(c.Execute()).Should(HaveOccurred()) 50 | }) 51 | }) 52 | }) 53 | -------------------------------------------------------------------------------- /cmd/version.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/0xERR0R/blocky/util" 7 | "github.com/spf13/cobra" 8 | ) 9 | 10 | // NewVersionCommand creates new command instance 11 | func NewVersionCommand() *cobra.Command { 12 | return &cobra.Command{ 13 | Use: "version", 14 | Args: cobra.NoArgs, 15 | Short: "Print the version number of blocky", 16 | Run: printVersion, 17 | } 18 | } 19 | 20 | func printVersion(_ *cobra.Command, _ []string) { 21 | fmt.Println("blocky") 22 | fmt.Printf("Version: %s\n", util.Version) 23 | fmt.Printf("Build time: %s\n", util.BuildTime) 24 | fmt.Printf("Architecture: %s\n", util.Architecture) 25 | } 26 | -------------------------------------------------------------------------------- /cmd/version_test.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo/v2" 5 | . "github.com/onsi/gomega" 6 | ) 7 | 8 | var _ = Describe("Version command", func() { 9 | When("Version command is called", func() { 10 | It("should execute without error", func() { 11 | c := NewVersionCommand() 12 | c.SetArgs(make([]string, 0)) 13 | err := c.Execute() 14 | Expect(err).Should(Succeed()) 15 | }) 16 | }) 17 | }) 18 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto 6 | threshold: 80% 7 | patch: 8 | default: 9 | # basic 10 | target: auto 11 | threshold: 0% 12 | base: auto 13 | only_pulls: true 14 | ignore: 15 | - "**/mock_*" 16 | - "**/*_enum.go" 17 | - "**/*.gen.go" 18 | - "e2e/*.go" 19 | -------------------------------------------------------------------------------- /config/blocking_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("BlockingConfig", func() { 12 | var cfg Blocking 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = Blocking{ 18 | BlockType: "ZEROIP", 19 | BlockTTL: Duration(time.Minute), 20 | Denylists: map[string][]BytesSource{ 21 | "gr1": NewBytesSources("/a/file/path"), 22 | }, 23 | ClientGroupsBlock: map[string][]string{ 24 | "default": {"gr1"}, 25 | }, 26 | } 27 | }) 28 | 29 | Describe("IsEnabled", func() { 30 | It("should be false by default", func() { 31 | cfg := Blocking{} 32 | Expect(defaults.Set(&cfg)).Should(Succeed()) 33 | 34 | Expect(cfg.IsEnabled()).Should(BeFalse()) 35 | }) 36 | 37 | When("enabled", func() { 38 | It("should be true", func() { 39 | Expect(cfg.IsEnabled()).Should(BeTrue()) 40 | }) 41 | }) 42 | 43 | When("disabled", func() { 44 | It("should be false", func() { 45 | cfg := Blocking{ 46 | BlockTTL: Duration(-1), 47 | } 48 | 49 | Expect(cfg.IsEnabled()).Should(BeFalse()) 50 | }) 51 | }) 52 | }) 53 | 54 | Describe("LogConfig", func() { 55 | It("should log configuration", func() { 56 | cfg.LogConfig(logger) 57 | 58 | Expect(hook.Calls).ShouldNot(BeEmpty()) 59 | Expect(hook.Messages[0]).Should(Equal("clientGroupsBlock:")) 60 | Expect(hook.Messages).Should(ContainElement(Equal("blockType = ZEROIP"))) 61 | }) 62 | }) 63 | 64 | Describe("migrate", func() { 65 | It("should copy values", func() { 66 | cfg, err := WithDefaults[Blocking]() 67 | Expect(err).Should(Succeed()) 68 | 69 | cfg.Deprecated.BlackLists = &map[string][]BytesSource{ 70 | "deny-group": NewBytesSources("/deny.txt"), 71 | } 72 | cfg.Deprecated.WhiteLists = &map[string][]BytesSource{ 73 | "allow-group": NewBytesSources("/allow.txt"), 74 | } 75 | 76 | migrated := cfg.migrate(logger) 77 | Expect(migrated).Should(BeTrue()) 78 | 79 | Expect(hook.Calls).ShouldNot(BeEmpty()) 80 | Expect(hook.Messages).Should(ContainElements( 81 | ContainSubstring("blocking.allowlists"), 82 | ContainSubstring("blocking.denylists"), 83 | )) 84 | 85 | Expect(cfg.Allowlists).Should(Equal(*cfg.Deprecated.WhiteLists)) 86 | Expect(cfg.Denylists).Should(Equal(*cfg.Deprecated.BlackLists)) 87 | }) 88 | }) 89 | }) 90 | -------------------------------------------------------------------------------- /config/bytes_source.go: -------------------------------------------------------------------------------- 1 | //go:generate go tool go-enum -f=$GOFILE --marshal --names --values 2 | package config 3 | 4 | import ( 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | const maxTextSourceDisplayLen = 12 10 | 11 | // var BytesSourceNone = BytesSource{} 12 | 13 | // BytesSourceType supported BytesSource types. ENUM( 14 | // text=1 // Inline YAML block. 15 | // http // HTTP(S). 16 | // file // Local file. 17 | // ) 18 | type BytesSourceType uint16 19 | 20 | type BytesSource struct { 21 | Type BytesSourceType 22 | From string 23 | } 24 | 25 | func (s BytesSource) String() string { 26 | switch s.Type { 27 | case BytesSourceTypeText: 28 | break 29 | 30 | case BytesSourceTypeHttp: 31 | return s.From 32 | 33 | case BytesSourceTypeFile: 34 | return fmt.Sprintf("file://%s", s.From) 35 | 36 | default: 37 | return fmt.Sprintf("unknown source (%s: %s)", s.Type, s.From) 38 | } 39 | 40 | text := s.From 41 | truncated := false 42 | 43 | if idx := strings.IndexRune(text, '\n'); idx != -1 { 44 | text = text[:idx] // first line only 45 | truncated = idx < len(text) // don't count removing last char 46 | } 47 | 48 | if len(text) > maxTextSourceDisplayLen { // truncate 49 | text = text[:maxTextSourceDisplayLen] 50 | truncated = true 51 | } 52 | 53 | if truncated { 54 | return fmt.Sprintf("%s...", text[:maxTextSourceDisplayLen]) 55 | } 56 | 57 | return text 58 | } 59 | 60 | // UnmarshalText implements `encoding.TextUnmarshaler`. 61 | func (s *BytesSource) UnmarshalText(data []byte) error { 62 | source := string(data) 63 | 64 | switch { 65 | // Inline definition in YAML (with literal style Block Scalar) 66 | case strings.ContainsAny(source, "\n"): 67 | *s = BytesSource{Type: BytesSourceTypeText, From: source} 68 | 69 | // HTTP(S) 70 | case strings.HasPrefix(source, "http"): 71 | *s = BytesSource{Type: BytesSourceTypeHttp, From: source} 72 | 73 | // Probably path to a local file 74 | default: 75 | *s = BytesSource{Type: BytesSourceTypeFile, From: strings.TrimPrefix(source, "file://")} 76 | } 77 | 78 | return nil 79 | } 80 | 81 | func newBytesSource(source string) BytesSource { 82 | var res BytesSource 83 | 84 | // UnmarshalText never returns an error 85 | _ = res.UnmarshalText([]byte(source)) 86 | 87 | return res 88 | } 89 | 90 | func NewBytesSources(sources ...string) []BytesSource { 91 | res := make([]BytesSource, 0, len(sources)) 92 | 93 | for _, source := range sources { 94 | res = append(res, newBytesSource(source)) 95 | } 96 | 97 | return res 98 | } 99 | 100 | func TextBytesSource(lines ...string) BytesSource { 101 | return BytesSource{Type: BytesSourceTypeText, From: inlineList(lines...)} 102 | } 103 | 104 | func inlineList(lines ...string) string { 105 | res := strings.Join(lines, "\n") 106 | 107 | // ensure at least one line ending so it's parsed as an inline block 108 | res += "\n" 109 | 110 | return res 111 | } 112 | -------------------------------------------------------------------------------- /config/bytes_source_enum.go: -------------------------------------------------------------------------------- 1 | // Code generated by go-enum DO NOT EDIT. 2 | // Version: 3 | // Revision: 4 | // Build Date: 5 | // Built By: 6 | 7 | package config 8 | 9 | import ( 10 | "fmt" 11 | "strings" 12 | ) 13 | 14 | const ( 15 | // BytesSourceTypeText is a BytesSourceType of type Text. 16 | // Inline YAML block. 17 | BytesSourceTypeText BytesSourceType = iota + 1 18 | // BytesSourceTypeHttp is a BytesSourceType of type Http. 19 | // HTTP(S). 20 | BytesSourceTypeHttp 21 | // BytesSourceTypeFile is a BytesSourceType of type File. 22 | // Local file. 23 | BytesSourceTypeFile 24 | ) 25 | 26 | var ErrInvalidBytesSourceType = fmt.Errorf("not a valid BytesSourceType, try [%s]", strings.Join(_BytesSourceTypeNames, ", ")) 27 | 28 | const _BytesSourceTypeName = "texthttpfile" 29 | 30 | var _BytesSourceTypeNames = []string{ 31 | _BytesSourceTypeName[0:4], 32 | _BytesSourceTypeName[4:8], 33 | _BytesSourceTypeName[8:12], 34 | } 35 | 36 | // BytesSourceTypeNames returns a list of possible string values of BytesSourceType. 37 | func BytesSourceTypeNames() []string { 38 | tmp := make([]string, len(_BytesSourceTypeNames)) 39 | copy(tmp, _BytesSourceTypeNames) 40 | return tmp 41 | } 42 | 43 | // BytesSourceTypeValues returns a list of the values for BytesSourceType 44 | func BytesSourceTypeValues() []BytesSourceType { 45 | return []BytesSourceType{ 46 | BytesSourceTypeText, 47 | BytesSourceTypeHttp, 48 | BytesSourceTypeFile, 49 | } 50 | } 51 | 52 | var _BytesSourceTypeMap = map[BytesSourceType]string{ 53 | BytesSourceTypeText: _BytesSourceTypeName[0:4], 54 | BytesSourceTypeHttp: _BytesSourceTypeName[4:8], 55 | BytesSourceTypeFile: _BytesSourceTypeName[8:12], 56 | } 57 | 58 | // String implements the Stringer interface. 59 | func (x BytesSourceType) String() string { 60 | if str, ok := _BytesSourceTypeMap[x]; ok { 61 | return str 62 | } 63 | return fmt.Sprintf("BytesSourceType(%d)", x) 64 | } 65 | 66 | // IsValid provides a quick way to determine if the typed value is 67 | // part of the allowed enumerated values 68 | func (x BytesSourceType) IsValid() bool { 69 | _, ok := _BytesSourceTypeMap[x] 70 | return ok 71 | } 72 | 73 | var _BytesSourceTypeValue = map[string]BytesSourceType{ 74 | _BytesSourceTypeName[0:4]: BytesSourceTypeText, 75 | _BytesSourceTypeName[4:8]: BytesSourceTypeHttp, 76 | _BytesSourceTypeName[8:12]: BytesSourceTypeFile, 77 | } 78 | 79 | // ParseBytesSourceType attempts to convert a string to a BytesSourceType. 80 | func ParseBytesSourceType(name string) (BytesSourceType, error) { 81 | if x, ok := _BytesSourceTypeValue[name]; ok { 82 | return x, nil 83 | } 84 | return BytesSourceType(0), fmt.Errorf("%s is %w", name, ErrInvalidBytesSourceType) 85 | } 86 | 87 | // MarshalText implements the text marshaller method. 88 | func (x BytesSourceType) MarshalText() ([]byte, error) { 89 | return []byte(x.String()), nil 90 | } 91 | 92 | // UnmarshalText implements the text unmarshaller method. 93 | func (x *BytesSourceType) UnmarshalText(text []byte) error { 94 | name := string(text) 95 | tmp, err := ParseBytesSourceType(name) 96 | if err != nil { 97 | return err 98 | } 99 | *x = tmp 100 | return nil 101 | } 102 | -------------------------------------------------------------------------------- /config/caching.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | // Caching configuration for domain caching 10 | type Caching struct { 11 | MinCachingTime Duration `yaml:"minTime"` 12 | MaxCachingTime Duration `yaml:"maxTime"` 13 | CacheTimeNegative Duration `yaml:"cacheTimeNegative" default:"30m"` 14 | MaxItemsCount int `yaml:"maxItemsCount"` 15 | Prefetching bool `yaml:"prefetching"` 16 | PrefetchExpires Duration `yaml:"prefetchExpires" default:"2h"` 17 | PrefetchThreshold int `yaml:"prefetchThreshold" default:"5"` 18 | PrefetchMaxItemsCount int `yaml:"prefetchMaxItemsCount"` 19 | Exclude []string `yaml:"exclude"` 20 | } 21 | 22 | // IsEnabled implements `config.Configurable`. 23 | func (c *Caching) IsEnabled() bool { 24 | return c.MaxCachingTime.IsAtLeastZero() 25 | } 26 | 27 | // LogConfig implements `config.Configurable`. 28 | func (c *Caching) LogConfig(logger *logrus.Entry) { 29 | logger.Infof("minTime = %s", c.MinCachingTime) 30 | logger.Infof("maxTime = %s", c.MaxCachingTime) 31 | logger.Infof("cacheTimeNegative = %s", c.CacheTimeNegative) 32 | logger.Infof("exclude:") 33 | for _, val := range c.Exclude { 34 | logger.Infof("- %v", val) 35 | } 36 | 37 | if c.Prefetching { 38 | logger.Infof("prefetching:") 39 | logger.Infof(" expires = %s", c.PrefetchExpires) 40 | logger.Infof(" threshold = %d", c.PrefetchThreshold) 41 | logger.Infof(" maxItems = %d", c.PrefetchMaxItemsCount) 42 | } else { 43 | logger.Debug("prefetching: disabled") 44 | } 45 | } 46 | 47 | func (c *Caching) EnablePrefetch() { 48 | const day = Duration(24 * time.Hour) 49 | 50 | if !c.IsEnabled() { 51 | // make sure resolver gets enabled 52 | c.MaxCachingTime = day 53 | } 54 | 55 | c.Prefetching = true 56 | c.PrefetchThreshold = 0 57 | } 58 | -------------------------------------------------------------------------------- /config/caching_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("CachingConfig", func() { 12 | var cfg Caching 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = Caching{ 18 | MaxCachingTime: Duration(time.Hour), 19 | } 20 | }) 21 | 22 | Describe("IsEnabled", func() { 23 | It("should be true by default", func() { 24 | cfg := Caching{} 25 | Expect(defaults.Set(&cfg)).Should(Succeed()) 26 | 27 | Expect(cfg.IsEnabled()).Should(BeTrue()) 28 | }) 29 | 30 | When("the config is disabled", func() { 31 | BeforeEach(func() { 32 | cfg = Caching{ 33 | MaxCachingTime: Duration(time.Hour * -1), 34 | } 35 | }) 36 | It("should be false", func() { 37 | Expect(cfg.IsEnabled()).Should(BeFalse()) 38 | }) 39 | }) 40 | 41 | When("the config is enabled", func() { 42 | It("should be true", func() { 43 | Expect(cfg.IsEnabled()).Should(BeTrue()) 44 | }) 45 | }) 46 | 47 | When("the config is disabled", func() { 48 | It("should be false", func() { 49 | cfg := Caching{ 50 | MaxCachingTime: Duration(-1), 51 | } 52 | 53 | Expect(cfg.IsEnabled()).Should(BeFalse()) 54 | }) 55 | }) 56 | }) 57 | 58 | Describe("LogConfig", func() { 59 | When("prefetching is enabled", func() { 60 | BeforeEach(func() { 61 | cfg = Caching{ 62 | Prefetching: true, 63 | } 64 | }) 65 | 66 | It("should return configuration", func() { 67 | cfg.LogConfig(logger) 68 | 69 | Expect(hook.Calls).ShouldNot(BeEmpty()) 70 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("prefetching:"))) 71 | }) 72 | }) 73 | When("has any settings", func() { 74 | BeforeEach(func() { 75 | cfg = Caching{} 76 | }) 77 | It("should return Exclude", func() { 78 | cfg.LogConfig(logger) 79 | 80 | Expect(hook.Calls).ShouldNot(BeEmpty()) 81 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("exclude:"))) 82 | }) 83 | }) 84 | When("Exclude any settings", func() { 85 | BeforeEach(func() { 86 | cfg = Caching{Exclude: []string{"local"}} 87 | }) 88 | It("should return Exclude and a list with values in it", func() { 89 | cfg.LogConfig(logger) 90 | 91 | Expect(hook.Calls).ShouldNot(BeEmpty()) 92 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("exclude:"))) 93 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("- local"))) 94 | }) 95 | }) 96 | }) 97 | 98 | Describe("EnablePrefetch", func() { 99 | When("prefetching is enabled", func() { 100 | BeforeEach(func() { 101 | cfg = Caching{} 102 | }) 103 | 104 | It("should return configuration", func() { 105 | cfg.EnablePrefetch() 106 | 107 | Expect(cfg.Prefetching).Should(BeTrue()) 108 | Expect(cfg.PrefetchThreshold).Should(Equal(0)) 109 | Expect(cfg.MaxCachingTime).Should(BeZero()) 110 | }) 111 | }) 112 | }) 113 | 114 | Describe("Exclude", func() { 115 | It("should be empty by default", func() { 116 | cfg := Caching{} 117 | Expect(defaults.Set(&cfg)).Should(Succeed()) 118 | 119 | Expect(cfg.Exclude).Should(BeEmpty()) 120 | }) 121 | }) 122 | }) 123 | -------------------------------------------------------------------------------- /config/client_lookup.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | // ClientLookup configuration for the client lookup 10 | type ClientLookup struct { 11 | ClientnameIPMapping map[string][]net.IP `yaml:"clients"` 12 | Upstream Upstream `yaml:"upstream"` 13 | SingleNameOrder []uint `yaml:"singleNameOrder"` 14 | } 15 | 16 | // IsEnabled implements `config.Configurable`. 17 | func (c *ClientLookup) IsEnabled() bool { 18 | return !c.Upstream.IsDefault() || len(c.ClientnameIPMapping) != 0 19 | } 20 | 21 | // LogConfig implements `config.Configurable`. 22 | func (c *ClientLookup) LogConfig(logger *logrus.Entry) { 23 | if !c.Upstream.IsDefault() { 24 | logger.Infof("upstream = %s", c.Upstream) 25 | } 26 | 27 | logger.Infof("singleNameOrder = %v", c.SingleNameOrder) 28 | 29 | if len(c.ClientnameIPMapping) > 0 { 30 | logger.Infof("client IP mapping:") 31 | 32 | for k, v := range c.ClientnameIPMapping { 33 | logger.Infof(" %s = %s", k, v) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /config/client_lookup_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("ClientLookupConfig", func() { 12 | var cfg ClientLookup 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = ClientLookup{ 18 | Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"}, 19 | SingleNameOrder: []uint{1, 2}, 20 | ClientnameIPMapping: map[string][]net.IP{ 21 | "client8": {net.ParseIP("1.2.3.5")}, 22 | }, 23 | } 24 | }) 25 | 26 | Describe("IsEnabled", func() { 27 | It("should be false by default", func() { 28 | cfg = ClientLookup{} 29 | Expect(defaults.Set(&cfg)).Should(Succeed()) 30 | 31 | Expect(cfg.IsEnabled()).Should(BeFalse()) 32 | }) 33 | 34 | When("enabled", func() { 35 | It("should be true", func() { 36 | By("upstream", func() { 37 | cfg := ClientLookup{ 38 | Upstream: Upstream{Net: NetProtocolTcpUdp, Host: "host"}, 39 | ClientnameIPMapping: nil, 40 | } 41 | 42 | Expect(cfg.IsEnabled()).Should(BeTrue()) 43 | }) 44 | 45 | By("mapping", func() { 46 | cfg := ClientLookup{ 47 | ClientnameIPMapping: map[string][]net.IP{ 48 | "client8": {net.ParseIP("1.2.3.5")}, 49 | }, 50 | } 51 | 52 | Expect(cfg.IsEnabled()).Should(BeTrue()) 53 | }) 54 | }) 55 | }) 56 | }) 57 | 58 | Describe("LogConfig", func() { 59 | It("should log configuration", func() { 60 | cfg.LogConfig(logger) 61 | 62 | Expect(hook.Calls).ShouldNot(BeEmpty()) 63 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("client IP mapping:"))) 64 | }) 65 | }) 66 | }) 67 | -------------------------------------------------------------------------------- /config/conditional_upstream.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/sirupsen/logrus" 8 | ) 9 | 10 | // ConditionalUpstream conditional upstream configuration 11 | type ConditionalUpstream struct { 12 | RewriterConfig `yaml:",inline"` 13 | Mapping ConditionalUpstreamMapping `yaml:"mapping"` 14 | } 15 | 16 | // ConditionalUpstreamMapping mapping for conditional configuration 17 | type ConditionalUpstreamMapping struct { 18 | Upstreams map[string][]Upstream 19 | } 20 | 21 | // IsEnabled implements `config.Configurable`. 22 | func (c *ConditionalUpstream) IsEnabled() bool { 23 | return len(c.Mapping.Upstreams) != 0 24 | } 25 | 26 | // LogConfig implements `config.Configurable`. 27 | func (c *ConditionalUpstream) LogConfig(logger *logrus.Entry) { 28 | for key, val := range c.Mapping.Upstreams { 29 | logger.Infof("%s = %v", key, val) 30 | } 31 | } 32 | 33 | // UnmarshalYAML implements `yaml.Unmarshaler`. 34 | func (c *ConditionalUpstreamMapping) UnmarshalYAML(unmarshal func(interface{}) error) error { 35 | var input map[string]string 36 | if err := unmarshal(&input); err != nil { 37 | return err 38 | } 39 | 40 | result := make(map[string][]Upstream, len(input)) 41 | 42 | for k, v := range input { 43 | var upstreams []Upstream 44 | 45 | for _, part := range strings.Split(v, ",") { 46 | upstream, err := ParseUpstream(strings.TrimSpace(part)) 47 | if err != nil { 48 | return fmt.Errorf("can't convert upstream '%s': %w", strings.TrimSpace(part), err) 49 | } 50 | 51 | upstreams = append(upstreams, upstream) 52 | } 53 | 54 | result[k] = upstreams 55 | } 56 | 57 | c.Upstreams = result 58 | 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /config/conditional_upstream_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("ConditionalUpstreamConfig", func() { 12 | var cfg ConditionalUpstream 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = ConditionalUpstream{ 18 | Mapping: ConditionalUpstreamMapping{ 19 | Upstreams: map[string][]Upstream{ 20 | "fritz.box": {Upstream{Net: NetProtocolTcpUdp, Host: "fbTest"}}, 21 | "other.box": {Upstream{Net: NetProtocolTcpUdp, Host: "otherTest"}}, 22 | ".": {Upstream{Net: NetProtocolTcpUdp, Host: "dotTest"}}, 23 | }, 24 | }, 25 | } 26 | }) 27 | 28 | Describe("IsEnabled", func() { 29 | It("should be false by default", func() { 30 | cfg := ConditionalUpstream{} 31 | Expect(defaults.Set(&cfg)).Should(Succeed()) 32 | 33 | Expect(cfg.IsEnabled()).Should(BeFalse()) 34 | }) 35 | 36 | When("enabled", func() { 37 | It("should be true", func() { 38 | Expect(cfg.IsEnabled()).Should(BeTrue()) 39 | }) 40 | }) 41 | 42 | When("disabled", func() { 43 | It("should be false", func() { 44 | cfg := ConditionalUpstream{ 45 | Mapping: ConditionalUpstreamMapping{Upstreams: map[string][]Upstream{}}, 46 | } 47 | 48 | Expect(cfg.IsEnabled()).Should(BeFalse()) 49 | }) 50 | }) 51 | }) 52 | 53 | Describe("LogConfig", func() { 54 | It("should log configuration", func() { 55 | cfg.LogConfig(logger) 56 | 57 | Expect(hook.Calls).ShouldNot(BeEmpty()) 58 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("fritz.box = "))) 59 | }) 60 | }) 61 | 62 | Describe("UnmarshalYAML", func() { 63 | It("Should parse config as map", func() { 64 | c := &ConditionalUpstreamMapping{} 65 | err := c.UnmarshalYAML(func(i interface{}) error { 66 | *i.(*map[string]string) = map[string]string{"key": "1.2.3.4"} 67 | 68 | return nil 69 | }) 70 | Expect(err).Should(Succeed()) 71 | Expect(c.Upstreams).Should(HaveLen(1)) 72 | Expect(c.Upstreams["key"]).Should(HaveLen(1)) 73 | Expect(c.Upstreams["key"][0]).Should(Equal(Upstream{ 74 | Net: NetProtocolTcpUdp, Host: "1.2.3.4", Port: 53, 75 | })) 76 | }) 77 | 78 | It("should fail if wrong YAML format", func() { 79 | c := &ConditionalUpstreamMapping{} 80 | err := c.UnmarshalYAML(func(i interface{}) error { 81 | return errors.New("some err") 82 | }) 83 | Expect(err).Should(HaveOccurred()) 84 | Expect(err).Should(MatchError("some err")) 85 | }) 86 | }) 87 | }) 88 | -------------------------------------------------------------------------------- /config/config_suite_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | "github.com/sirupsen/logrus" 10 | ) 11 | 12 | var ( 13 | logger *logrus.Entry 14 | hook *log.MockLoggerHook 15 | ) 16 | 17 | func init() { 18 | log.Silence() 19 | } 20 | 21 | func TestConfig(t *testing.T) { 22 | RegisterFailHandler(Fail) 23 | RunSpecs(t, "Config Suite") 24 | } 25 | 26 | func suiteBeforeEach() { 27 | BeforeEach(func() { 28 | logger, hook = log.NewMockEntry() 29 | }) 30 | } 31 | -------------------------------------------------------------------------------- /config/custom_dns.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "strings" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/sirupsen/logrus" 10 | ) 11 | 12 | // CustomDNS custom DNS configuration 13 | type CustomDNS struct { 14 | RewriterConfig `yaml:",inline"` 15 | CustomTTL Duration `yaml:"customTTL" default:"1h"` 16 | Mapping CustomDNSMapping `yaml:"mapping"` 17 | Zone ZoneFileDNS `yaml:"zone" default:""` 18 | FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"` 19 | } 20 | 21 | type ( 22 | CustomDNSMapping map[string]CustomDNSEntries 23 | CustomDNSEntries []dns.RR 24 | 25 | ZoneFileDNS struct { 26 | RRs CustomDNSMapping 27 | configPath string 28 | } 29 | ) 30 | 31 | func (z *ZoneFileDNS) UnmarshalYAML(unmarshal func(interface{}) error) error { 32 | var input string 33 | if err := unmarshal(&input); err != nil { 34 | return err 35 | } 36 | 37 | result := make(CustomDNSMapping) 38 | 39 | zoneParser := dns.NewZoneParser(strings.NewReader(input), "", z.configPath) 40 | zoneParser.SetIncludeAllowed(true) 41 | 42 | for { 43 | zoneRR, ok := zoneParser.Next() 44 | 45 | if !ok { 46 | if zoneParser.Err() != nil { 47 | return zoneParser.Err() 48 | } 49 | 50 | // Done 51 | break 52 | } 53 | 54 | domain := zoneRR.Header().Name 55 | 56 | if _, ok := result[domain]; !ok { 57 | result[domain] = make(CustomDNSEntries, 0, 1) 58 | } 59 | 60 | result[domain] = append(result[domain], zoneRR) 61 | } 62 | 63 | z.RRs = result 64 | 65 | return nil 66 | } 67 | 68 | func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error { 69 | var input string 70 | if err := unmarshal(&input); err != nil { 71 | return err 72 | } 73 | 74 | parts := strings.Split(input, ",") 75 | result := make(CustomDNSEntries, len(parts)) 76 | 77 | for i, part := range parts { 78 | rr, err := configToRR(strings.TrimSpace(part)) 79 | if err != nil { 80 | return err 81 | } 82 | 83 | result[i] = rr 84 | } 85 | 86 | *c = result 87 | 88 | return nil 89 | } 90 | 91 | // IsEnabled implements `config.Configurable`. 92 | func (c *CustomDNS) IsEnabled() bool { 93 | return len(c.Mapping) != 0 94 | } 95 | 96 | // LogConfig implements `config.Configurable`. 97 | func (c *CustomDNS) LogConfig(logger *logrus.Entry) { 98 | logger.Debugf("TTL = %s", c.CustomTTL) 99 | logger.Debugf("filterUnmappedTypes = %t", c.FilterUnmappedTypes) 100 | 101 | logger.Info("mapping:") 102 | 103 | for key, val := range c.Mapping { 104 | logger.Infof(" %s = %s", key, val) 105 | } 106 | } 107 | 108 | func configToRR(ipStr string) (dns.RR, error) { 109 | ip := net.ParseIP(ipStr) 110 | if ip == nil { 111 | return nil, fmt.Errorf("invalid IP address '%s'", ipStr) 112 | } 113 | 114 | if ip.To4() != nil { 115 | a := new(dns.A) 116 | a.A = ip 117 | 118 | return a, nil 119 | } 120 | 121 | aaaa := new(dns.AAAA) 122 | aaaa.AAAA = ip 123 | 124 | return aaaa, nil 125 | } 126 | -------------------------------------------------------------------------------- /config/duration.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | 7 | "github.com/0xERR0R/blocky/log" 8 | "github.com/hako/durafmt" 9 | ) 10 | 11 | // Duration is a wrapper for time.Duration to support yaml unmarshalling 12 | type Duration time.Duration 13 | 14 | // ToDuration converts Duration to time.Duration 15 | func (c Duration) ToDuration() time.Duration { 16 | return time.Duration(c) 17 | } 18 | 19 | // IsAboveZero returns true if duration is strictly greater than zero. 20 | func (c Duration) IsAboveZero() bool { 21 | return c.ToDuration() > 0 22 | } 23 | 24 | // IsAtLeastZero returns true if duration is greater or equal to zero. 25 | func (c Duration) IsAtLeastZero() bool { 26 | return c.ToDuration() >= 0 27 | } 28 | 29 | // Seconds returns duration in seconds 30 | func (c Duration) Seconds() float64 { 31 | return c.ToDuration().Seconds() 32 | } 33 | 34 | // SecondsU32 returns duration in seconds as uint32 35 | func (c Duration) SecondsU32() uint32 { 36 | return uint32(c.Seconds()) 37 | } 38 | 39 | // String implements `fmt.Stringer` 40 | func (c Duration) String() string { 41 | return durafmt.Parse(c.ToDuration()).String() 42 | } 43 | 44 | // UnmarshalText implements `encoding.TextUnmarshaler`. 45 | func (c *Duration) UnmarshalText(data []byte) error { 46 | input := string(data) 47 | 48 | if minutes, err := strconv.Atoi(input); err == nil { 49 | // number without unit: use minutes to ensure back compatibility 50 | *c = Duration(time.Duration(minutes) * time.Minute) 51 | 52 | log.Log().Warnf("Setting a duration without a unit is deprecated. Please use '%s min' instead.", input) 53 | 54 | return nil 55 | } 56 | 57 | duration, err := time.ParseDuration(input) 58 | if err == nil { 59 | *c = Duration(duration) 60 | 61 | return nil 62 | } 63 | 64 | return err 65 | } 66 | -------------------------------------------------------------------------------- /config/duration_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | . "github.com/onsi/ginkgo/v2" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | var _ = Describe("Duration", func() { 11 | var d Duration 12 | 13 | BeforeEach(func() { 14 | var zero Duration 15 | 16 | d = zero 17 | }) 18 | 19 | Describe("UnmarshalText", func() { 20 | It("should parse duration with unit", func() { 21 | err := d.UnmarshalText([]byte("1m20s")) 22 | Expect(err).Should(Succeed()) 23 | Expect(d).Should(Equal(Duration(80 * time.Second))) 24 | Expect(d.String()).Should(Equal("1 minute 20 seconds")) 25 | }) 26 | 27 | It("should fail if duration is in wrong format", func() { 28 | err := d.UnmarshalText([]byte("wrong")) 29 | Expect(err).Should(HaveOccurred()) 30 | Expect(err).Should(MatchError("time: invalid duration \"wrong\"")) 31 | }) 32 | }) 33 | 34 | Describe("IsAboveZero", func() { 35 | It("should be false for zero", func() { 36 | Expect(d.IsAboveZero()).Should(BeFalse()) 37 | Expect(Duration(0).IsAboveZero()).Should(BeFalse()) 38 | }) 39 | 40 | It("should be false for negative", func() { 41 | Expect(Duration(-1).IsAboveZero()).Should(BeFalse()) 42 | }) 43 | 44 | It("should be true for positive", func() { 45 | Expect(Duration(1).IsAboveZero()).Should(BeTrue()) 46 | }) 47 | }) 48 | 49 | Describe("SecondsU32", func() { 50 | It("should return the seconds", func() { 51 | Expect(Duration(time.Minute).SecondsU32()).Should(Equal(uint32(60))) 52 | Expect(Duration(time.Hour).SecondsU32()).Should(Equal(uint32(3600))) 53 | }) 54 | }) 55 | }) 56 | -------------------------------------------------------------------------------- /config/ecs.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/sirupsen/logrus" 8 | ) 9 | 10 | const ( 11 | ecsIpv4MaskMax = uint8(32) 12 | ecsIpv6MaskMax = uint8(128) 13 | ) 14 | 15 | // ECSv4Mask is the subnet mask to be added as EDNS0 option for IPv4 16 | type ECSv4Mask uint8 17 | 18 | // UnmarshalText implements the encoding.TextUnmarshaler interface 19 | func (x *ECSv4Mask) UnmarshalText(text []byte) error { 20 | res, err := unmarshalInternal(text, ecsIpv4MaskMax, "IPv4") 21 | if err != nil { 22 | return err 23 | } 24 | 25 | *x = ECSv4Mask(res) 26 | 27 | return nil 28 | } 29 | 30 | // ECSv6Mask is the subnet mask to be added as EDNS0 option for IPv6 31 | type ECSv6Mask uint8 32 | 33 | // UnmarshalText implements the encoding.TextUnmarshaler interface 34 | func (x *ECSv6Mask) UnmarshalText(text []byte) error { 35 | res, err := unmarshalInternal(text, ecsIpv6MaskMax, "IPv6") 36 | if err != nil { 37 | return err 38 | } 39 | 40 | *x = ECSv6Mask(res) 41 | 42 | return nil 43 | } 44 | 45 | // ECS is the configuration of the ECS resolver 46 | type ECS struct { 47 | UseAsClient bool `yaml:"useAsClient" default:"false"` 48 | Forward bool `yaml:"forward" default:"false"` 49 | IPv4Mask ECSv4Mask `yaml:"ipv4Mask" default:"0"` 50 | IPv6Mask ECSv6Mask `yaml:"ipv6Mask" default:"0"` 51 | } 52 | 53 | // IsEnabled returns true if the ECS resolver is enabled 54 | func (c *ECS) IsEnabled() bool { 55 | return c.UseAsClient || c.Forward || c.IPv4Mask > 0 || c.IPv6Mask > 0 56 | } 57 | 58 | // LogConfig logs the configuration 59 | func (c *ECS) LogConfig(logger *logrus.Entry) { 60 | logger.Infof("Use as client = %t", c.UseAsClient) 61 | logger.Infof("Forward = %t", c.Forward) 62 | logger.Infof("IPv4 netmask = %d", c.IPv4Mask) 63 | logger.Infof("IPv6 netmask = %d", c.IPv6Mask) 64 | } 65 | 66 | // unmarshalInternal unmarshals the subnet mask from the given text and checks if the value is valid 67 | // it is used by the UnmarshalText methods of ECSv4Mask and ECSv6Mask 68 | func unmarshalInternal(text []byte, maxvalue uint8, name string) (uint8, error) { 69 | strVal := string(text) 70 | 71 | uiVal, err := strconv.ParseUint(strVal, 10, 8) 72 | if err != nil { 73 | return 0, err 74 | } 75 | 76 | if uiVal > uint64(maxvalue) { 77 | return 0, fmt.Errorf("mask value (%s) is too large for %s(max: %d)", strVal, name, maxvalue) 78 | } 79 | 80 | return uint8(uiVal), nil 81 | } 82 | -------------------------------------------------------------------------------- /config/filtering.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | ) 6 | 7 | type Filtering struct { 8 | QueryTypes QTypeSet `yaml:"queryTypes"` 9 | } 10 | 11 | // IsEnabled implements `config.Configurable`. 12 | func (c *Filtering) IsEnabled() bool { 13 | return len(c.QueryTypes) != 0 14 | } 15 | 16 | // LogConfig implements `config.Configurable`. 17 | func (c *Filtering) LogConfig(logger *logrus.Entry) { 18 | logger.Info("query types:") 19 | 20 | for qType := range c.QueryTypes { 21 | logger.Infof(" - %s", qType) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /config/filtering_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | . "github.com/0xERR0R/blocky/helpertest" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("FilteringConfig", func() { 12 | var cfg Filtering 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = Filtering{ 18 | QueryTypes: NewQTypeSet(AAAA, MX), 19 | } 20 | }) 21 | 22 | Describe("IsEnabled", func() { 23 | It("should be false by default", func() { 24 | cfg := Filtering{} 25 | Expect(defaults.Set(&cfg)).Should(Succeed()) 26 | 27 | Expect(cfg.IsEnabled()).Should(BeFalse()) 28 | }) 29 | 30 | When("enabled", func() { 31 | It("should be true", func() { 32 | Expect(cfg.IsEnabled()).Should(BeTrue()) 33 | }) 34 | }) 35 | 36 | When("disabled", func() { 37 | It("should be false", func() { 38 | cfg := Filtering{} 39 | 40 | Expect(cfg.IsEnabled()).Should(BeFalse()) 41 | }) 42 | }) 43 | }) 44 | 45 | Describe("LogConfig", func() { 46 | It("should log configuration", func() { 47 | cfg.LogConfig(logger) 48 | 49 | Expect(hook.Calls).Should(HaveLen(3)) 50 | Expect(hook.Messages).Should(ContainElements( 51 | ContainSubstring("query types:"), 52 | ContainSubstring(" - AAAA"), 53 | ContainSubstring(" - MX"), 54 | )) 55 | }) 56 | }) 57 | }) 58 | -------------------------------------------------------------------------------- /config/hosts_file.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | . "github.com/0xERR0R/blocky/config/migration" 5 | "github.com/0xERR0R/blocky/log" 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | type HostsFile struct { 10 | Sources []BytesSource `yaml:"sources"` 11 | HostsTTL Duration `yaml:"hostsTTL" default:"1h"` 12 | FilterLoopback bool `yaml:"filterLoopback"` 13 | Loading SourceLoading `yaml:"loading"` 14 | 15 | // Deprecated options 16 | Deprecated struct { 17 | RefreshPeriod *Duration `yaml:"refreshPeriod"` 18 | Filepath *BytesSource `yaml:"filePath"` 19 | } `yaml:",inline"` 20 | } 21 | 22 | func (c *HostsFile) migrate(logger *logrus.Entry) bool { 23 | return Migrate(logger, "hostsFile", c.Deprecated, map[string]Migrator{ 24 | "refreshPeriod": Move(To("loading.refreshPeriod", &c.Loading)), 25 | "filePath": Apply(To("sources", c), func(value BytesSource) { 26 | c.Sources = append(c.Sources, value) 27 | }), 28 | }) 29 | } 30 | 31 | // IsEnabled implements `config.Configurable`. 32 | func (c *HostsFile) IsEnabled() bool { 33 | return len(c.Sources) != 0 34 | } 35 | 36 | // LogConfig implements `config.Configurable`. 37 | func (c *HostsFile) LogConfig(logger *logrus.Entry) { 38 | logger.Infof("TTL: %s", c.HostsTTL) 39 | logger.Infof("filter loopback addresses: %t", c.FilterLoopback) 40 | 41 | logger.Info("loading:") 42 | log.WithIndent(logger, " ", c.Loading.LogConfig) 43 | 44 | logger.Info("sources:") 45 | 46 | for _, source := range c.Sources { 47 | logger.Infof(" - %s", source) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /config/hosts_file_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("HostsFileConfig", func() { 12 | var cfg HostsFile 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = HostsFile{ 18 | Sources: append( 19 | NewBytesSources("/a/file/path"), 20 | TextBytesSource("127.0.0.1 localhost"), 21 | ), 22 | HostsTTL: Duration(29 * time.Minute), 23 | Loading: SourceLoading{RefreshPeriod: Duration(30 * time.Minute)}, 24 | FilterLoopback: true, 25 | } 26 | }) 27 | 28 | Describe("IsEnabled", func() { 29 | It("should be false by default", func() { 30 | cfg := HostsFile{} 31 | Expect(defaults.Set(&cfg)).Should(Succeed()) 32 | 33 | Expect(cfg.IsEnabled()).Should(BeFalse()) 34 | }) 35 | 36 | When("enabled", func() { 37 | It("should be true", func() { 38 | Expect(cfg.IsEnabled()).Should(BeTrue()) 39 | }) 40 | }) 41 | 42 | When("disabled", func() { 43 | It("should be false", func() { 44 | cfg := HostsFile{} 45 | 46 | Expect(cfg.IsEnabled()).Should(BeFalse()) 47 | }) 48 | }) 49 | }) 50 | 51 | Describe("LogConfig", func() { 52 | It("should log configuration", func() { 53 | cfg.LogConfig(logger) 54 | 55 | Expect(hook.Calls).ShouldNot(BeEmpty()) 56 | Expect(hook.Messages).Should(ContainElements( 57 | ContainSubstring("- file:///a/file/path"), 58 | ContainSubstring("- 127.0.0.1 lo..."), 59 | )) 60 | }) 61 | }) 62 | 63 | Describe("migrate", func() { 64 | It("should", func() { 65 | cfg, err := WithDefaults[HostsFile]() 66 | Expect(err).Should(Succeed()) 67 | 68 | cfg.Deprecated.Filepath = ptrOf(newBytesSource("/a/file/path")) 69 | cfg.Deprecated.RefreshPeriod = ptrOf(Duration(time.Hour)) 70 | 71 | migrated := cfg.migrate(logger) 72 | Expect(migrated).Should(BeTrue()) 73 | 74 | Expect(hook.Calls).ShouldNot(BeEmpty()) 75 | Expect(hook.Messages).Should(ContainElements( 76 | ContainSubstring("hostsFile.loading.refreshPeriod"), 77 | ContainSubstring("hostsFile.sources"), 78 | )) 79 | 80 | Expect(cfg.Sources).Should(Equal([]BytesSource{*cfg.Deprecated.Filepath})) 81 | Expect(cfg.Loading.RefreshPeriod).Should(Equal(*cfg.Deprecated.RefreshPeriod)) 82 | }) 83 | }) 84 | }) 85 | -------------------------------------------------------------------------------- /config/metrics.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "github.com/sirupsen/logrus" 4 | 5 | // Metrics contains the config values for prometheus 6 | type Metrics struct { 7 | Enable bool `yaml:"enable" default:"false"` 8 | Path string `yaml:"path" default:"/metrics"` 9 | } 10 | 11 | // IsEnabled implements `config.Configurable`. 12 | func (c *Metrics) IsEnabled() bool { 13 | return c.Enable 14 | } 15 | 16 | // LogConfig implements `config.Configurable`. 17 | func (c *Metrics) LogConfig(logger *logrus.Entry) { 18 | logger.Infof("url path: %s", c.Path) 19 | } 20 | -------------------------------------------------------------------------------- /config/metrics_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/creasty/defaults" 5 | . "github.com/onsi/ginkgo/v2" 6 | . "github.com/onsi/gomega" 7 | ) 8 | 9 | var _ = Describe("MetricsConfig", func() { 10 | var cfg Metrics 11 | 12 | suiteBeforeEach() 13 | 14 | BeforeEach(func() { 15 | cfg = Metrics{ 16 | Enable: true, 17 | Path: "/custom/path", 18 | } 19 | }) 20 | 21 | Describe("IsEnabled", func() { 22 | It("should be false by default", func() { 23 | cfg := Metrics{} 24 | Expect(defaults.Set(&cfg)).Should(Succeed()) 25 | 26 | Expect(cfg.IsEnabled()).Should(BeFalse()) 27 | }) 28 | 29 | When("enabled", func() { 30 | It("should be true", func() { 31 | Expect(cfg.IsEnabled()).Should(BeTrue()) 32 | }) 33 | }) 34 | 35 | When("disabled", func() { 36 | It("should be false", func() { 37 | cfg := Metrics{} 38 | 39 | Expect(cfg.IsEnabled()).Should(BeFalse()) 40 | }) 41 | }) 42 | }) 43 | 44 | Describe("LogConfig", func() { 45 | It("should log configuration", func() { 46 | cfg.LogConfig(logger) 47 | 48 | Expect(hook.Calls).Should(HaveLen(1)) 49 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("url path: /custom/path"))) 50 | }) 51 | }) 52 | }) 53 | -------------------------------------------------------------------------------- /config/qtype_set.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strings" 7 | 8 | "github.com/miekg/dns" 9 | "golang.org/x/exp/maps" 10 | ) 11 | 12 | type QTypeSet map[QType]struct{} 13 | 14 | func NewQTypeSet(qTypes ...dns.Type) QTypeSet { 15 | s := make(QTypeSet, len(qTypes)) 16 | 17 | for _, qType := range qTypes { 18 | s.Insert(qType) 19 | } 20 | 21 | return s 22 | } 23 | 24 | func (s QTypeSet) Contains(qType dns.Type) bool { 25 | _, found := s[QType(qType)] 26 | 27 | return found 28 | } 29 | 30 | func (s *QTypeSet) Insert(qType dns.Type) { 31 | if *s == nil { 32 | *s = make(QTypeSet, 1) 33 | } 34 | 35 | (*s)[QType(qType)] = struct{}{} 36 | } 37 | 38 | func (s *QTypeSet) UnmarshalYAML(unmarshal func(interface{}) error) error { 39 | var input []QType 40 | if err := unmarshal(&input); err != nil { 41 | return err 42 | } 43 | 44 | *s = make(QTypeSet, len(input)) 45 | 46 | for _, qType := range input { 47 | (*s)[qType] = struct{}{} 48 | } 49 | 50 | return nil 51 | } 52 | 53 | type QType dns.Type 54 | 55 | func (c QType) String() string { 56 | return dns.Type(c).String() 57 | } 58 | 59 | // UnmarshalText implements `encoding.TextUnmarshaler`. 60 | func (c *QType) UnmarshalText(data []byte) error { 61 | input := string(data) 62 | 63 | t, found := dns.StringToType[input] 64 | if !found { 65 | types := maps.Keys(dns.StringToType) 66 | 67 | sort.Strings(types) 68 | 69 | return fmt.Errorf("unknown DNS query type: '%s'. Please use following types '%s'", 70 | input, strings.Join(types, ", ")) 71 | } 72 | 73 | *c = QType(t) 74 | 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /config/qtype_set_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/miekg/dns" 5 | . "github.com/onsi/ginkgo/v2" 6 | . "github.com/onsi/gomega" 7 | ) 8 | 9 | var _ = Describe("QTypeSet", func() { 10 | Describe("NewQTypeSet", func() { 11 | It("should insert given qTypes", func() { 12 | set := NewQTypeSet(dns.Type(dns.TypeA)) 13 | Expect(set).Should(HaveKey(QType(dns.TypeA))) 14 | Expect(set.Contains(dns.Type(dns.TypeA))).Should(BeTrue()) 15 | 16 | Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA))) 17 | Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue()) 18 | }) 19 | }) 20 | 21 | Describe("Insert", func() { 22 | It("should insert given qTypes", func() { 23 | set := NewQTypeSet() 24 | 25 | Expect(set).ShouldNot(HaveKey(QType(dns.TypeAAAA))) 26 | Expect(set.Contains(dns.Type(dns.TypeAAAA))).ShouldNot(BeTrue()) 27 | 28 | set.Insert(dns.Type(dns.TypeAAAA)) 29 | 30 | Expect(set).Should(HaveKey(QType(dns.TypeAAAA))) 31 | Expect(set.Contains(dns.Type(dns.TypeAAAA))).Should(BeTrue()) 32 | }) 33 | }) 34 | }) 35 | 36 | var _ = Describe("QType", func() { 37 | Describe("UnmarshalText", func() { 38 | It("Should parse existing DNS type as string", func() { 39 | t := QType(0) 40 | err := t.UnmarshalText([]byte("AAAA")) 41 | Expect(err).Should(Succeed()) 42 | Expect(t).Should(Equal(QType(dns.TypeAAAA))) 43 | Expect(t.String()).Should(Equal("AAAA")) 44 | }) 45 | 46 | It("should fail if DNS type does not exist", func() { 47 | t := QType(0) 48 | err := t.UnmarshalText([]byte("WRONGTYPE")) 49 | Expect(err).Should(HaveOccurred()) 50 | Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'WRONGTYPE'")) 51 | }) 52 | 53 | It("should fail if wrong YAML format", func() { 54 | d := QType(0) 55 | err := d.UnmarshalText([]byte("some err")) 56 | Expect(err).Should(HaveOccurred()) 57 | Expect(err.Error()).Should(ContainSubstring("unknown DNS query type: 'some err'")) 58 | }) 59 | }) 60 | }) 61 | -------------------------------------------------------------------------------- /config/query_log.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net/url" 5 | "strings" 6 | 7 | "github.com/0xERR0R/blocky/log" 8 | "github.com/sirupsen/logrus" 9 | ) 10 | 11 | // QueryLog configuration for the query logging 12 | type QueryLog struct { 13 | Target string `yaml:"target"` 14 | Type QueryLogType `yaml:"type"` 15 | LogRetentionDays uint64 `yaml:"logRetentionDays"` 16 | CreationAttempts int `yaml:"creationAttempts" default:"3"` 17 | CreationCooldown Duration `yaml:"creationCooldown" default:"2s"` 18 | Fields []QueryLogField `yaml:"fields"` 19 | FlushInterval Duration `yaml:"flushInterval" default:"30s"` 20 | Ignore QueryLogIgnore `yaml:"ignore"` 21 | } 22 | 23 | type QueryLogIgnore struct { 24 | SUDN bool `yaml:"sudn" default:"false"` 25 | } 26 | 27 | // SetDefaults implements `defaults.Setter`. 28 | func (c *QueryLog) SetDefaults() { 29 | // Since the default depends on the enum values, set it dynamically 30 | // to avoid having to repeat the values in the annotation. 31 | c.Fields = QueryLogFieldValues() 32 | } 33 | 34 | // IsEnabled implements `config.Configurable`. 35 | func (c *QueryLog) IsEnabled() bool { 36 | return c.Type != QueryLogTypeNone 37 | } 38 | 39 | // LogConfig implements `config.Configurable`. 40 | func (c *QueryLog) LogConfig(logger *logrus.Entry) { 41 | logger.Infof("type: %s", c.Type) 42 | 43 | if c.Target != "" { 44 | logger.Infof("target: %s", c.censoredTarget()) 45 | } 46 | 47 | logger.Infof("logRetentionDays: %d", c.LogRetentionDays) 48 | logger.Debugf("creationAttempts: %d", c.CreationAttempts) 49 | logger.Debugf("creationCooldown: %s", c.CreationCooldown) 50 | logger.Infof("flushInterval: %s", c.FlushInterval) 51 | logger.Infof("fields: %s", c.Fields) 52 | 53 | logger.Infof("ignore:") 54 | log.WithIndent(logger, " ", func(e *logrus.Entry) { 55 | logger.Infof("sudn: %t", c.Ignore.SUDN) 56 | }) 57 | } 58 | 59 | func (c *QueryLog) censoredTarget() string { 60 | // Make sure there's a scheme, otherwise the user is parsed as the scheme 61 | targetStr := c.Target 62 | if !strings.Contains(targetStr, "://") { 63 | targetStr = c.Type.String() + "://" + targetStr 64 | } 65 | 66 | target, err := url.Parse(targetStr) 67 | if err != nil { 68 | return c.Target 69 | } 70 | 71 | pass, ok := target.User.Password() 72 | if !ok { 73 | return c.Target 74 | } 75 | 76 | return strings.ReplaceAll(c.Target, pass, secretObfuscator) 77 | } 78 | -------------------------------------------------------------------------------- /config/query_log_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("QueryLogConfig", func() { 12 | var cfg QueryLog 13 | 14 | suiteBeforeEach() 15 | 16 | BeforeEach(func() { 17 | cfg = QueryLog{ 18 | Target: "/dev/null", 19 | Type: QueryLogTypeCsvClient, 20 | LogRetentionDays: 0, 21 | CreationAttempts: 1, 22 | CreationCooldown: Duration(time.Millisecond), 23 | } 24 | }) 25 | 26 | Describe("IsEnabled", func() { 27 | It("should be true by default", func() { 28 | cfg := QueryLog{} 29 | Expect(defaults.Set(&cfg)).Should(Succeed()) 30 | 31 | Expect(cfg.IsEnabled()).Should(BeTrue()) 32 | }) 33 | 34 | When("enabled", func() { 35 | It("should be true", func() { 36 | Expect(cfg.IsEnabled()).Should(BeTrue()) 37 | }) 38 | }) 39 | 40 | When("disabled", func() { 41 | It("should be false", func() { 42 | cfg := QueryLog{ 43 | Type: QueryLogTypeNone, 44 | } 45 | 46 | Expect(cfg.IsEnabled()).Should(BeFalse()) 47 | }) 48 | }) 49 | }) 50 | 51 | Describe("LogConfig", func() { 52 | It("should log configuration", func() { 53 | cfg.LogConfig(logger) 54 | 55 | Expect(hook.Calls).ShouldNot(BeEmpty()) 56 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("logRetentionDays:"))) 57 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("sudn:"))) 58 | }) 59 | 60 | DescribeTable("secret censoring", func(target string) { 61 | cfg.Type = QueryLogTypeMysql 62 | cfg.Target = target 63 | 64 | cfg.LogConfig(logger) 65 | 66 | Expect(hook.Calls).ShouldNot(BeEmpty()) 67 | Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("password"))) 68 | }, 69 | Entry("without scheme", "user:password@localhost"), 70 | Entry("with scheme", "scheme://user:password@localhost"), 71 | Entry("no password", "localhost"), 72 | Entry("not a URL", "invalid!://"), 73 | ) 74 | }) 75 | 76 | Describe("SetDefaults", func() { 77 | It("should log configuration", func() { 78 | cfg := QueryLog{} 79 | Expect(cfg.Fields).Should(BeEmpty()) 80 | 81 | Expect(defaults.Set(&cfg)).Should(Succeed()) 82 | 83 | Expect(cfg.Fields).ShouldNot(BeEmpty()) 84 | }) 85 | }) 86 | }) 87 | -------------------------------------------------------------------------------- /config/redis.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | ) 6 | 7 | // Redis configuration for the redis connection 8 | type Redis struct { 9 | Address string `yaml:"address"` 10 | Username string `yaml:"username" default:""` 11 | Password string `yaml:"password" default:""` 12 | Database int `yaml:"database" default:"0"` 13 | Required bool `yaml:"required" default:"false"` 14 | ConnectionAttempts int `yaml:"connectionAttempts" default:"3"` 15 | ConnectionCooldown Duration `yaml:"connectionCooldown" default:"1s"` 16 | SentinelUsername string `yaml:"sentinelUsername" default:""` 17 | SentinelPassword string `yaml:"sentinelPassword" default:""` 18 | SentinelAddresses []string `yaml:"sentinelAddresses"` 19 | } 20 | 21 | // IsEnabled implements `config.Configurable` 22 | func (c *Redis) IsEnabled() bool { 23 | return c.Address != "" 24 | } 25 | 26 | // LogConfig implements `config.Configurable` 27 | func (c *Redis) LogConfig(logger *logrus.Entry) { 28 | if len(c.SentinelAddresses) == 0 { 29 | logger.Info("address: ", c.Address) 30 | } 31 | 32 | logger.Info("username: ", c.Username) 33 | logger.Info("password: ", secretObfuscator) 34 | logger.Info("database: ", c.Database) 35 | logger.Info("required: ", c.Required) 36 | logger.Info("connectionAttempts: ", c.ConnectionAttempts) 37 | logger.Info("connectionCooldown: ", c.ConnectionCooldown) 38 | 39 | if len(c.SentinelAddresses) > 0 { 40 | logger.Info("sentinel:") 41 | logger.Info(" master: ", c.Address) 42 | logger.Info(" username: ", c.SentinelUsername) 43 | logger.Info(" password: ", secretObfuscator) 44 | logger.Info(" addresses:") 45 | 46 | for _, addr := range c.SentinelAddresses { 47 | logger.Info(" - ", addr) 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /config/redis_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/0xERR0R/blocky/log" 5 | "github.com/creasty/defaults" 6 | . "github.com/onsi/ginkgo/v2" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | var _ = Describe("Redis", func() { 11 | var ( 12 | c Redis 13 | err error 14 | ) 15 | 16 | suiteBeforeEach() 17 | 18 | BeforeEach(func() { 19 | err = defaults.Set(&c) 20 | Expect(err).Should(Succeed()) 21 | }) 22 | 23 | Describe("IsEnabled", func() { 24 | When("all fields are default", func() { 25 | It("should be disabled", func() { 26 | Expect(c.IsEnabled()).Should(BeFalse()) 27 | }) 28 | }) 29 | 30 | When("Address is set", func() { 31 | BeforeEach(func() { 32 | c.Address = "localhost:6379" 33 | }) 34 | 35 | It("should be enabled", func() { 36 | Expect(c.IsEnabled()).Should(BeTrue()) 37 | }) 38 | }) 39 | }) 40 | 41 | Describe("LogConfig", func() { 42 | BeforeEach(func() { 43 | logger, hook = log.NewMockEntry() 44 | }) 45 | 46 | When("all fields are default", func() { 47 | It("should log default values", func() { 48 | c.LogConfig(logger) 49 | 50 | Expect(hook.Messages).Should( 51 | SatisfyAll(ContainElement(ContainSubstring("address: ")), 52 | ContainElement(ContainSubstring("username: ")), 53 | ContainElement(ContainSubstring("password: ")), 54 | ContainElement(ContainSubstring("database: ")), 55 | ContainElement(ContainSubstring("required: ")), 56 | ContainElement(ContainSubstring("connectionAttempts: ")), 57 | ContainElement(ContainSubstring("connectionCooldown: ")))) 58 | }) 59 | }) 60 | 61 | When("Address is set", func() { 62 | BeforeEach(func() { 63 | c.Address = "localhost:6379" 64 | }) 65 | 66 | It("should log address", func() { 67 | c.LogConfig(logger) 68 | 69 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("address: localhost:6379"))) 70 | }) 71 | }) 72 | 73 | When("SentinelAddresses is set", func() { 74 | BeforeEach(func() { 75 | c.SentinelAddresses = []string{"localhost:26379", "localhost:26380"} 76 | }) 77 | 78 | It("should log sentinel addresses", func() { 79 | c.LogConfig(logger) 80 | 81 | Expect(hook.Messages).Should( 82 | SatisfyAll( 83 | ContainElement(ContainSubstring("sentinel:")), 84 | ContainElement(ContainSubstring(" addresses:")), 85 | ContainElement(ContainSubstring(" - localhost:26379")), 86 | ContainElement(ContainSubstring(" - localhost:26380")))) 87 | }) 88 | }) 89 | 90 | const secretValue = "secret-value" 91 | 92 | It("should not log the password", func() { 93 | c.Password = secretValue 94 | c.LogConfig(logger) 95 | 96 | Expect(hook.Calls).ShouldNot(BeEmpty()) 97 | Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring(secretValue))) 98 | }) 99 | 100 | It("should not log the sentinel password", func() { 101 | c.SentinelPassword = secretValue 102 | c.LogConfig(logger) 103 | 104 | Expect(hook.Calls).ShouldNot(BeEmpty()) 105 | Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring(secretValue))) 106 | }) 107 | }) 108 | }) 109 | -------------------------------------------------------------------------------- /config/rewriter.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | ) 6 | 7 | // RewriterConfig custom DNS configuration 8 | type RewriterConfig struct { 9 | Rewrite map[string]string `yaml:"rewrite"` 10 | FallbackUpstream bool `yaml:"fallbackUpstream" default:"false"` 11 | } 12 | 13 | // IsEnabled implements `config.Configurable`. 14 | func (c *RewriterConfig) IsEnabled() bool { 15 | return len(c.Rewrite) != 0 16 | } 17 | 18 | // LogConfig implements `config.Configurable`. 19 | func (c *RewriterConfig) LogConfig(logger *logrus.Entry) { 20 | logger.Infof("fallbackUpstream = %t", c.FallbackUpstream) 21 | 22 | logger.Info("rules:") 23 | 24 | for key, val := range c.Rewrite { 25 | logger.Infof(" %s = %s", key, val) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /config/rewriter_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/creasty/defaults" 5 | . "github.com/onsi/ginkgo/v2" 6 | . "github.com/onsi/gomega" 7 | ) 8 | 9 | var _ = Describe("RewriterConfig", func() { 10 | var cfg RewriterConfig 11 | 12 | suiteBeforeEach() 13 | 14 | BeforeEach(func() { 15 | cfg = RewriterConfig{ 16 | Rewrite: map[string]string{ 17 | "original1": "rewritten1", 18 | "original2": "rewritten2", 19 | }, 20 | } 21 | }) 22 | 23 | Describe("IsEnabled", func() { 24 | It("should be false by default", func() { 25 | cfg := RewriterConfig{} 26 | Expect(defaults.Set(&cfg)).Should(Succeed()) 27 | 28 | Expect(cfg.IsEnabled()).Should(BeFalse()) 29 | }) 30 | 31 | When("enabled", func() { 32 | It("should be true", func() { 33 | Expect(cfg.IsEnabled()).Should(BeTrue()) 34 | }) 35 | }) 36 | 37 | When("disabled", func() { 38 | It("should be false", func() { 39 | cfg := RewriterConfig{} 40 | 41 | Expect(cfg.IsEnabled()).Should(BeFalse()) 42 | }) 43 | }) 44 | }) 45 | 46 | Describe("LogConfig", func() { 47 | It("should log configuration", func() { 48 | cfg.LogConfig(logger) 49 | 50 | Expect(hook.Calls).ShouldNot(BeEmpty()) 51 | Expect(hook.Messages).Should(ContainElements( 52 | ContainSubstring("rules:"), 53 | ContainSubstring("original2 ="), 54 | )) 55 | }) 56 | }) 57 | }) 58 | -------------------------------------------------------------------------------- /config/sudn.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | ) 6 | 7 | // SUDN configuration for Special Use Domain Names 8 | type SUDN struct { 9 | // These are "recommended for private use" but not mandatory. 10 | // If a user wishes to use one, it will most likely be via conditional 11 | // upstream or custom DNS, which come before SUDN in the resolver chain. 12 | // Thus defaulting to `true` and returning NXDOMAIN here should not conflict. 13 | RFC6762AppendixG bool `yaml:"rfc6762-appendixG" default:"true"` 14 | Enable bool `yaml:"enable" default:"true"` 15 | } 16 | 17 | // IsEnabled implements `config.Configurable`. 18 | func (c *SUDN) IsEnabled() bool { 19 | return c.Enable 20 | } 21 | 22 | // LogConfig implements `config.Configurable`. 23 | func (c *SUDN) LogConfig(logger *logrus.Entry) { 24 | logger.Debugf("rfc6762-appendixG = %v", c.RFC6762AppendixG) 25 | } 26 | -------------------------------------------------------------------------------- /config/sudn_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo/v2" 5 | . "github.com/onsi/gomega" 6 | ) 7 | 8 | var _ = Describe("SUDNConfig", func() { 9 | var cfg SUDN 10 | 11 | suiteBeforeEach() 12 | 13 | BeforeEach(func() { 14 | var err error 15 | 16 | cfg, err = WithDefaults[SUDN]() 17 | Expect(err).Should(Succeed()) 18 | }) 19 | 20 | Describe("IsEnabled", func() { 21 | It("should be true by default", func() { 22 | Expect(cfg.IsEnabled()).Should(BeTrue()) 23 | }) 24 | 25 | When("enabled", func() { 26 | It("should be true", func() { 27 | cfg := SUDN{ 28 | Enable: true, 29 | } 30 | Expect(cfg.IsEnabled()).Should(BeTrue()) 31 | }) 32 | }) 33 | 34 | When("disabled", func() { 35 | It("should be false", func() { 36 | cfg := SUDN{ 37 | Enable: false, 38 | } 39 | Expect(cfg.IsEnabled()).Should(BeFalse()) 40 | }) 41 | }) 42 | }) 43 | 44 | Describe("LogConfig", func() { 45 | It("should log configuration", func() { 46 | cfg.LogConfig(logger) 47 | 48 | Expect(hook.Calls).ShouldNot(BeEmpty()) 49 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("rfc6762-appendixG = true"))) 50 | }) 51 | }) 52 | }) 53 | -------------------------------------------------------------------------------- /config/upstreams.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/0xERR0R/blocky/log" 5 | "github.com/sirupsen/logrus" 6 | ) 7 | 8 | const UpstreamDefaultCfgName = "default" 9 | 10 | // Upstreams upstream servers configuration 11 | type Upstreams struct { 12 | Init Init `yaml:"init"` 13 | Timeout Duration `yaml:"timeout" default:"2s"` // always > 0 14 | Groups UpstreamGroups `yaml:"groups"` 15 | Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` 16 | UserAgent string `yaml:"userAgent"` 17 | } 18 | 19 | type UpstreamGroups map[string][]Upstream 20 | 21 | func (c *Upstreams) validate(logger *logrus.Entry) { 22 | defaults := mustDefault[Upstreams]() 23 | 24 | if !c.Timeout.IsAboveZero() { 25 | logger.Warnf("upstreams.timeout <= 0, setting to %s", defaults.Timeout) 26 | c.Timeout = defaults.Timeout 27 | } 28 | } 29 | 30 | // IsEnabled implements `config.Configurable`. 31 | func (c *Upstreams) IsEnabled() bool { 32 | return len(c.Groups) != 0 33 | } 34 | 35 | // LogConfig implements `config.Configurable`. 36 | func (c *Upstreams) LogConfig(logger *logrus.Entry) { 37 | logger.Info("init:") 38 | log.WithIndent(logger, " ", c.Init.LogConfig) 39 | 40 | logger.Info("timeout: ", c.Timeout) 41 | logger.Info("strategy: ", c.Strategy) 42 | logger.Info("groups:") 43 | 44 | for name, upstreams := range c.Groups { 45 | logger.Infof(" %s:", name) 46 | 47 | for _, upstream := range upstreams { 48 | logger.Infof(" - %s", upstream) 49 | } 50 | } 51 | } 52 | 53 | // UpstreamGroup represents the config for one group (upstream branch) 54 | type UpstreamGroup struct { 55 | Upstreams 56 | 57 | Name string // group name 58 | } 59 | 60 | // NewUpstreamGroup creates an UpstreamGroup with the given name and upstreams. 61 | // 62 | // The upstreams from `cfg.Groups` are ignored. 63 | func NewUpstreamGroup(name string, cfg Upstreams, upstreams []Upstream) UpstreamGroup { 64 | group := UpstreamGroup{ 65 | Name: name, 66 | Upstreams: cfg, 67 | } 68 | 69 | group.Groups = UpstreamGroups{name: upstreams} 70 | 71 | return group 72 | } 73 | 74 | func (c *UpstreamGroup) GroupUpstreams() []Upstream { 75 | return c.Groups[c.Name] 76 | } 77 | 78 | // IsEnabled implements `config.Configurable`. 79 | func (c *UpstreamGroup) IsEnabled() bool { 80 | return len(c.GroupUpstreams()) != 0 81 | } 82 | 83 | // LogConfig implements `config.Configurable`. 84 | func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) { 85 | logger.Info("group: ", c.Name) 86 | logger.Info("upstreams:") 87 | 88 | for _, upstream := range c.GroupUpstreams() { 89 | logger.Infof(" - %s", upstream) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /config/upstreams_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/creasty/defaults" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("ParallelBestConfig", func() { 12 | suiteBeforeEach() 13 | 14 | Context("Upstreams", func() { 15 | var cfg Upstreams 16 | 17 | BeforeEach(func() { 18 | cfg = Upstreams{ 19 | Timeout: Duration(5 * time.Second), 20 | Groups: UpstreamGroups{ 21 | UpstreamDefaultCfgName: { 22 | {Host: "host1"}, 23 | {Host: "host2"}, 24 | }, 25 | }, 26 | } 27 | }) 28 | 29 | Describe("IsEnabled", func() { 30 | It("should be false by default", func() { 31 | cfg := Upstreams{} 32 | Expect(defaults.Set(&cfg)).Should(Succeed()) 33 | 34 | Expect(cfg.IsEnabled()).Should(BeFalse()) 35 | }) 36 | 37 | When("enabled", func() { 38 | It("should be true", func() { 39 | Expect(cfg.IsEnabled()).Should(BeTrue()) 40 | }) 41 | }) 42 | 43 | When("disabled", func() { 44 | It("should be false", func() { 45 | cfg := Upstreams{} 46 | 47 | Expect(cfg.IsEnabled()).Should(BeFalse()) 48 | }) 49 | }) 50 | }) 51 | 52 | Describe("LogConfig", func() { 53 | It("should log configuration", func() { 54 | cfg.LogConfig(logger) 55 | 56 | Expect(hook.Calls).ShouldNot(BeEmpty()) 57 | Expect(hook.Messages).Should(ContainElements( 58 | ContainSubstring("timeout:"), 59 | ContainSubstring("groups:"), 60 | ContainSubstring(":host2:"), 61 | )) 62 | }) 63 | }) 64 | 65 | Describe("validate", func() { 66 | It("should compute defaults", func() { 67 | cfg.Timeout = -1 68 | 69 | cfg.validate(logger) 70 | 71 | Expect(cfg.Timeout).Should(BeNumerically(">", 0)) 72 | 73 | Expect(hook.Calls).ShouldNot(BeEmpty()) 74 | Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout"))) 75 | }) 76 | 77 | It("should not override valid user values", func() { 78 | cfg.validate(logger) 79 | 80 | Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("timeout"))) 81 | }) 82 | }) 83 | }) 84 | 85 | Context("UpstreamGroupConfig", func() { 86 | var cfg UpstreamGroup 87 | 88 | BeforeEach(func() { 89 | upstreamsCfg, err := WithDefaults[Upstreams]() 90 | Expect(err).Should(Succeed()) 91 | 92 | cfg = NewUpstreamGroup("test", upstreamsCfg, []Upstream{ 93 | {Host: "host1"}, 94 | {Host: "host2"}, 95 | }) 96 | }) 97 | 98 | Describe("IsEnabled", func() { 99 | It("should be false by default", func() { 100 | cfg := UpstreamGroup{} 101 | Expect(defaults.Set(&cfg)).Should(Succeed()) 102 | 103 | Expect(cfg.IsEnabled()).Should(BeFalse()) 104 | }) 105 | 106 | When("enabled", func() { 107 | It("should be true", func() { 108 | Expect(cfg.IsEnabled()).Should(BeTrue()) 109 | }) 110 | }) 111 | 112 | When("disabled", func() { 113 | It("should be false", func() { 114 | cfg := UpstreamGroup{} 115 | 116 | Expect(cfg.IsEnabled()).Should(BeFalse()) 117 | }) 118 | }) 119 | }) 120 | 121 | Describe("LogConfig", func() { 122 | It("should log configuration", func() { 123 | cfg.LogConfig(logger) 124 | 125 | Expect(hook.Calls).ShouldNot(BeEmpty()) 126 | Expect(hook.Messages).Should(ContainElements( 127 | ContainSubstring("group: test"), 128 | ContainSubstring("upstreams:"), 129 | ContainSubstring(":host1:"), 130 | ContainSubstring(":host2:"), 131 | )) 132 | }) 133 | }) 134 | }) 135 | }) 136 | -------------------------------------------------------------------------------- /docs/embed.go: -------------------------------------------------------------------------------- 1 | package docs 2 | 3 | import _ "embed" 4 | 5 | //go:embed api/openapi.yaml 6 | var OpenAPI string 7 | -------------------------------------------------------------------------------- /docs/fb_dns_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0xERR0R/blocky/62610f657e052b9f5dba77bdfd6c2229d558ab6b/docs/fb_dns_config.png -------------------------------------------------------------------------------- /docs/grafana-dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0xERR0R/blocky/62610f657e052b9f5dba77bdfd6c2229d558ab6b/docs/grafana-dashboard.png -------------------------------------------------------------------------------- /docs/grafana-query-dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0xERR0R/blocky/62610f657e052b9f5dba77bdfd6c2229d558ab6b/docs/grafana-query-dashboard.png -------------------------------------------------------------------------------- /docs/includes/abbreviations.md: -------------------------------------------------------------------------------- 1 | *[DNS]: Domain Name System 2 | *[k8s]: Kubernetes 3 | *[UDP]: User Datagram Protocol 4 | *[TCP]: Transmission Control Protocol 5 | *[HTTP]: Hypertext Transfer Protocol 6 | *[HTTPS]: Hypertext Transfer Protocol Secure 7 | *[DoH]: DNS-over-HTTPS 8 | *[DoT]: DNS-over-TLS 9 | *[DNSSEC]: Domain Name System Security Extensions 10 | *[eDNS]: Extended DNS 11 | *[REST]: Representational State Transfer 12 | *[API]: Application Programming Interface 13 | *[CLI]: Command Line Interface 14 | *[YAML]: YAML Ain't Markup Language 15 | *[Helm]: package manager for Kubernetes 16 | *[CNAME]: Canonical Name 17 | *[CIDR]: Classless Inter-Domain Routing 18 | *[NXDOMAIN]: Non-Existence Domain 19 | *[TTL]: Time-To-Live 20 | *[rDNS]: Reverse DNS 21 | *[SSL]: Secure Sockets Layer 22 | *[CSV]: Comma-separated values 23 | *[SAMBA]: Server Message Block Protocol (Windows Network File System) 24 | *[DHCP]: Dynamic Host Configuration Protocol 25 | *[duration format]: Example: "300ms", "1.5h" or "2h45m". Valid time units are "ns", "us", "ms", "s", "m", "h". 26 | *[regex]: Regular expression -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Blocky 2 | 3 | <figure> 4 | <img src="https://raw.githubusercontent.com/0xERR0R/blocky/main/docs/blocky.svg" width="200" /> 5 | </figure> 6 | 7 | Blocky is a DNS proxy and ad-blocker for the local network written in Go with following features: 8 | 9 | ## Features 10 | 11 | - **Blocking** - :no_entry: Blocking of DNS queries with external lists (Ad-block, malware) and allowlisting 12 | 13 | * Definition of allow/denylists per client group (Kids, Smart home devices, etc.) 14 | * Periodical reload of external allow/denylists 15 | * Regex support 16 | * Blocking of request domain, response CNAME (deep CNAME inspection) and response IP addresses (against IP lists) 17 | 18 | - **Advanced DNS configuration** - :nerd: not just an ad-blocker 19 | 20 | * Custom DNS resolution for certain domain names 21 | * Conditional forwarding to external DNS server 22 | * Upstream resolvers can be defined per client group 23 | 24 | - **Performance** - :rocket: Improves speed and performance in your network 25 | 26 | * Customizable caching of DNS answers for queries -> improves DNS resolution speed and reduces amount of external DNS 27 | queries 28 | * Prefetching and caching of often used queries 29 | * Using multiple external resolver simultaneously 30 | * Low memory footprint 31 | 32 | - **Various Protocols** - :computer: Supports modern DNS protocols 33 | 34 | * DNS over UDP and TCP 35 | * DNS over HTTPS (aka DoH) 36 | * DNS over TLS (aka DoT) 37 | 38 | - **Security and Privacy** - :dark_sunglasses: Secure communication 39 | 40 | * Supports modern DNS extensions: DNSSEC, eDNS, ... 41 | * Free configurable blocking lists - no hidden filtering etc. 42 | * Provides DoH Endpoint 43 | * Uses random upstream resolvers from the configuration - increases your privacy through the distribution of your DNS 44 | traffic over multiple provider 45 | * Open source development 46 | * Blocky does **NOT** collect any user data, telemetry, statistics etc. 47 | 48 | - **Integration** - :notebook_with_decorative_cover: various integration 49 | 50 | * [Prometheus](https://prometheus.io/) metrics 51 | * Prepared [Grafana](https://grafana.com/) dashboards (Prometheus and database) 52 | * Logging of DNS queries per day / per client in CSV format or MySQL/MariaDB/PostgreSQL/Timescale database - easy to 53 | analyze 54 | * Various REST API endpoints 55 | * CLI tool 56 | 57 | - **Simple configuration** - :baby: single configuration file in YAML format 58 | 59 | * Simple to maintain 60 | * Simple to backup 61 | 62 | - **Simple installation/configuration** - :cloud: blocky was designed for simple installation 63 | 64 | * Stateless (no database, no temporary files) 65 | * Docker image with Multi-arch support 66 | * Single binary 67 | * Supports x86-64 and ARM architectures -> runs fine on Raspberry PI 68 | * Community supported Helm chart for k8s deployment 69 | 70 | 71 | ## Contribution 72 | 73 | Issues, feature suggestions and pull requests are welcome! Blocky lives on :material-github:[GitHub](https://github.com/0xERR0R/blocky). 74 | 75 | --8<-- "docs/includes/abbreviations.md" 76 | -------------------------------------------------------------------------------- /docs/interfaces.md: -------------------------------------------------------------------------------- 1 | # Interfaces 2 | 3 | ## REST API 4 | 5 | 6 | ??? abstract "OpenAPI specification" 7 | 8 | ```yaml 9 | --8<-- "docs/api/openapi.yaml" 10 | ``` 11 | 12 | If http listener is enabled, blocky provides REST API. You can download the [OpenAPI YAML](api/openapi.yaml) interface specification. 13 | 14 | You can also browse the interactive API documentation (RapiDoc) documentation [online](rapidoc.html). 15 | 16 | ## CLI 17 | 18 | Blocky provides a CLI interface to control. This interface uses internally the REST API. 19 | 20 | To run the CLI, please ensure, that blocky DNS server is running, then execute `blocky help` for help or 21 | 22 | - `./blocky blocking enable` to enable blocking 23 | - `./blocky blocking disable` to disable blocking 24 | - `./blocky blocking disable --duration [duration]` to disable blocking for a certain amount of time (30s, 5m, 10m30s, 25 | ...) 26 | - `./blocky blocking disable --groups ads,othergroup` to disable blocking only for special groups 27 | - `./blocky blocking status` to print current status of blocking 28 | - `./blocky query <domain>` execute DNS query (A) (simple replacement for dig, useful for debug purposes) 29 | - `./blocky query <domain> --type <queryType>` execute DNS query with passed query type (A, AAAA, MX, ...) 30 | - `./blocky lists refresh` reloads all allow/denylists 31 | - `./blocky validate [--config /path/to/config.yaml]` validates configuration file 32 | 33 | !!! tip 34 | 35 | To run this inside docker run `docker exec blocky ./blocky blocking status` 36 | 37 | --8<-- "docs/includes/abbreviations.md" 38 | -------------------------------------------------------------------------------- /docs/network_configuration.md: -------------------------------------------------------------------------------- 1 | # Network configuration 2 | 3 | In order, to benefit from all the advantages of blocky like ad-blocking, privacy and speed, it is necessary to use 4 | blocky as DNS server for your devices. You can configure DNS server on each device manually or use DHCP in your network 5 | router and push the right settings to your device. With this approach, you will configure blocky only once in your 6 | router and each device in your network will automatically use blocky as DNS server. 7 | 8 | ## Transparent configuration with DHCP 9 | 10 | Let us assume, blocky is installed on a Raspberry PI with fix IP address `192.168.178.2`. Each device which connects to 11 | the router will obtain an IP address and receive the network configuration. The IP address of the Raspberry PI should be 12 | pushed to the device as DNS server. 13 | 14 | ``` 15 | ┌──────────────┐ ┌─────────────────┐ 16 | │ │ │ Raspberry PI │ 17 | │ Router │ │ blocky │ 18 | │ │ │ 192.168.178.2 │ 19 | └─▲─────┬──────┘ └────▲────────────┘ 20 | │1 │ │ 3 21 | │ │ │ 22 | │ │ │ 23 | │ │ │ 24 | │ │ │ 25 | │ │ │ 26 | │ │ │ 27 | │ │ ┌─────────────┴──────┐ 28 | │ │ 2 │ │ 29 | │ └───────► Network device │ 30 | │ │ Android │ 31 | └─────────────┤ │ 32 | └────────────────────┘ 33 | ``` 34 | 35 | **1** - Network device asks the DHCP server (on Router) for the network configuration 36 | 37 | **2** - Router assigns a free IP address to the device and says "Use 192.168.178.2" as DNS server 38 | 39 | **3** - Clients makes DNS queries and is happy to use **blocky** :smile: 40 | 41 | !!! warning 42 | 43 | It is necessary to assign the server which runs blocky (e.g. Raspberry PI) a fix IP address. 44 | 45 | ### Example configuration with FritzBox 46 | 47 | To configure the DNS server in the FritzBox, please open in the FritzBox web interface: 48 | 49 | * in navigation menu on the left side: Home Network -> Network 50 | * Network Settings tab on the top 51 | * "IPv4 Configuration" Button at the bottom op the page 52 | * Enter the IP address of blocky under "Local DNS server", see screenshot 53 | 54 |  55 | 56 | --8<-- "docs/includes/abbreviations.md" -------------------------------------------------------------------------------- /docs/rapidoc.html: -------------------------------------------------------------------------------- 1 | <!doctype html> 2 | <html> 3 | <head> 4 | <meta charset="utf-8"> 5 | <script type="module" src="https://unpkg.com/rapidoc/dist/rapidoc-min.js"></script> 6 | </head> 7 | <body> 8 | <rapi-doc 9 | spec-url="api/openapi.yaml" 10 | theme = "light" 11 | allow-authentication = "false" 12 | show-header = "false" 13 | bg-color = "#fdf8ed" 14 | nav-bg-color = "#3f4d67" 15 | nav-text-color = "#a9b7d0" 16 | nav-hover-bg-color = "#333f54" 17 | nav-hover-text-color = "#fff" 18 | nav-accent-color = "#f87070" 19 | primary-color = "#5c7096" 20 | allow-try = "false" 21 | > </rapi-doc> 22 | </body> 23 | </html> -------------------------------------------------------------------------------- /e2e/e2e_suite_test.go: -------------------------------------------------------------------------------- 1 | package e2e 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/0xERR0R/blocky/log" 9 | . "github.com/onsi/ginkgo/v2" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | func init() { 14 | log.Silence() 15 | } 16 | 17 | func TestLists(t *testing.T) { 18 | RegisterFailHandler(Fail) 19 | RunSpecs(t, "e2e Suite", Label("e2e")) 20 | } 21 | 22 | var _ = BeforeSuite(func(ctx context.Context) { 23 | SetDefaultEventuallyTimeout(5 * time.Second) 24 | }) 25 | -------------------------------------------------------------------------------- /evt/events.go: -------------------------------------------------------------------------------- 1 | package evt 2 | 3 | import ( 4 | "github.com/asaskevich/EventBus" 5 | ) 6 | 7 | const ( 8 | // BlockingEnabledEvent fires if blocking status will be changed. Parameter: boolean (enabled = true) 9 | BlockingEnabledEvent = "blocking:enabled" 10 | 11 | // BlockingCacheGroupChanged fires, if a list group is changed. Parameter: list type, group name, element count 12 | BlockingCacheGroupChanged = "blocking:cachingGroupChanged" 13 | 14 | // CachingDomainPrefetched fires if a domain will be prefetched, Parameter: domain name 15 | CachingDomainPrefetched = "caching:prefetched" 16 | 17 | // CachingResultCacheChanged fires if a result cache was changed, Parameter: new cache size 18 | CachingResultCacheChanged = "caching:resultCacheChanged" 19 | 20 | // CachingPrefetchCacheHit fires if a query result was found in the prefetch cache, Parameter: domain name 21 | CachingPrefetchCacheHit = "caching:prefetchHit" 22 | 23 | // CachingDomainsToPrefetchCountChanged fires, if a number of domains being prefetched changed, Parameter: new count 24 | CachingDomainsToPrefetchCountChanged = "caching:domainsToPrefetchCountChanged" 25 | 26 | // CachingFailedDownloadChanged fires, if a download of a blocking list or hosts file fails 27 | CachingFailedDownloadChanged = "caching:failedDownload" 28 | 29 | // ApplicationStarted fires on start of the application. Parameter: version number, build time 30 | ApplicationStarted = "application:started" 31 | ) 32 | 33 | //nolint:gochecknoglobals 34 | var evtBus = EventBus.New() 35 | 36 | // Bus returns the global bus instance 37 | func Bus() EventBus.Bus { 38 | return evtBus 39 | } 40 | -------------------------------------------------------------------------------- /helpertest/http.go: -------------------------------------------------------------------------------- 1 | package helpertest 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "net/url" 8 | "sync/atomic" 9 | 10 | "github.com/onsi/ginkgo/v2" 11 | ) 12 | 13 | type HTTPProxy struct { 14 | Addr net.Addr 15 | requestTarget atomic.Value // string: HTTP Host of latest request 16 | } 17 | 18 | // TestHTTPProxy returns a new HTTPProxy server. 19 | // 20 | // All requests return http.StatusNotImplemented. 21 | func TestHTTPProxy() *HTTPProxy { 22 | proxyListener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) 23 | if err != nil { 24 | ginkgo.Fail(fmt.Sprintf("could not create HTTP proxy listener: %s", err)) 25 | } 26 | 27 | proxy := &HTTPProxy{ 28 | Addr: proxyListener.Addr(), 29 | } 30 | 31 | proxySrv := http.Server{ //nolint:gosec 32 | Addr: "127.0.0.1:0", 33 | Handler: proxy, 34 | } 35 | 36 | go func() { _ = proxySrv.Serve(proxyListener) }() 37 | ginkgo.DeferCleanup(proxySrv.Close) 38 | 39 | return proxy 40 | } 41 | 42 | // URL returns the proxy's URL for use by clients. 43 | func (p *HTTPProxy) URL() *url.URL { 44 | return &url.URL{ 45 | Scheme: "http", 46 | Host: p.Addr.String(), 47 | } 48 | } 49 | 50 | // Check ReqURL has the right type signature for http.Transport.Proxy 51 | var _ = http.Transport{Proxy: (*HTTPProxy)(nil).ReqURL} 52 | 53 | func (p *HTTPProxy) ReqURL(*http.Request) (*url.URL, error) { 54 | return p.URL(), nil 55 | } 56 | 57 | // RequestTarget returns the target of the last request. 58 | func (p *HTTPProxy) RequestTarget() string { 59 | val := p.requestTarget.Load() 60 | if val == nil { 61 | ginkgo.Fail(fmt.Sprintf("http proxy %s received no requests", p.Addr)) 62 | } 63 | 64 | return val.(string) 65 | } 66 | 67 | func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { 68 | p.requestTarget.Store(req.Host) 69 | 70 | w.WriteHeader(http.StatusNotImplemented) 71 | } 72 | -------------------------------------------------------------------------------- /helpertest/mock_call_sequence.go: -------------------------------------------------------------------------------- 1 | package helpertest 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | const mockCallTimeout = 2 * time.Second 11 | 12 | type MockCallSequence[T any] struct { 13 | driver func(chan<- T, chan<- error) 14 | res chan T 15 | err chan error 16 | callCount uint 17 | initOnce sync.Once 18 | closeOnce sync.Once 19 | } 20 | 21 | func NewMockCallSequence[T any](driver func(chan<- T, chan<- error)) MockCallSequence[T] { 22 | return MockCallSequence[T]{ 23 | driver: driver, 24 | } 25 | } 26 | 27 | func (m *MockCallSequence[T]) Call() (T, error) { 28 | m.callCount++ 29 | 30 | m.initOnce.Do(func() { 31 | m.res = make(chan T) 32 | m.err = make(chan error) 33 | 34 | // This goroutine never stops 35 | go func() { 36 | defer m.Close() 37 | 38 | m.driver(m.res, m.err) 39 | }() 40 | }) 41 | 42 | ctx, cancel := context.WithTimeout(context.Background(), mockCallTimeout) 43 | defer cancel() 44 | 45 | select { 46 | case t, ok := <-m.res: 47 | if !ok { 48 | break 49 | } 50 | 51 | return t, nil 52 | 53 | case err, ok := <-m.err: 54 | if !ok { 55 | break 56 | } 57 | 58 | var zero T 59 | 60 | return zero, err 61 | 62 | case <-ctx.Done(): 63 | panic(fmt.Sprintf("mock call sequence driver timed-out on call %d", m.CallCount())) 64 | } 65 | 66 | panic("mock call sequence called after driver returned (or sequence Close was called explicitly)") 67 | } 68 | 69 | func (m *MockCallSequence[T]) CallCount() uint { 70 | return m.callCount 71 | } 72 | 73 | func (m *MockCallSequence[T]) Close() { 74 | m.closeOnce.Do(func() { 75 | close(m.res) 76 | close(m.err) 77 | }) 78 | } 79 | -------------------------------------------------------------------------------- /helpertest/tmpdata.go: -------------------------------------------------------------------------------- 1 | package helpertest 2 | 3 | import ( 4 | "bufio" 5 | "io/fs" 6 | "os" 7 | "path/filepath" 8 | 9 | . "github.com/onsi/ginkgo/v2" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | type TmpFolder struct { 14 | Path string 15 | prefix string 16 | } 17 | 18 | type TmpFile struct { 19 | Path string 20 | Folder *TmpFolder 21 | } 22 | 23 | func NewTmpFolder(prefix string) *TmpFolder { 24 | ipref := prefix 25 | 26 | if len(ipref) == 0 { 27 | ipref = "blocky" 28 | } 29 | 30 | path, err := os.MkdirTemp("", ipref) 31 | Expect(err).Should(Succeed()) 32 | 33 | res := &TmpFolder{ 34 | Path: path, 35 | prefix: ipref, 36 | } 37 | 38 | DeferCleanup(res.Clean) 39 | 40 | return res 41 | } 42 | 43 | func (tf *TmpFolder) Clean() error { 44 | if len(tf.Path) > 0 { 45 | return os.RemoveAll(tf.Path) 46 | } 47 | 48 | return nil 49 | } 50 | 51 | func (tf *TmpFolder) CreateSubFolder(name string) *TmpFolder { 52 | var path string 53 | 54 | var err error 55 | 56 | if len(name) > 0 { 57 | path = filepath.Join(tf.Path, name) 58 | err = os.Mkdir(path, fs.ModePerm) 59 | } else { 60 | path, err = os.MkdirTemp(tf.Path, tf.prefix) 61 | } 62 | 63 | Expect(err).Should(Succeed()) 64 | 65 | res := &TmpFolder{ 66 | Path: path, 67 | prefix: tf.prefix, 68 | } 69 | 70 | return res 71 | } 72 | 73 | func (tf *TmpFolder) CreateEmptyFile(name string) *TmpFile { 74 | f, res := tf.createFile(name) 75 | defer f.Close() 76 | 77 | return res 78 | } 79 | 80 | func (tf *TmpFolder) CreateStringFile(name string, lines ...string) *TmpFile { 81 | f, res := tf.createFile(name) 82 | defer f.Close() 83 | 84 | first := true 85 | w := bufio.NewWriter(f) 86 | 87 | for _, l := range lines { 88 | if first { 89 | first = false 90 | } else { 91 | _, err := w.WriteString("\n") 92 | Expect(err).Should(Succeed()) 93 | } 94 | 95 | _, err := w.WriteString(l) 96 | Expect(err).Should(Succeed()) 97 | } 98 | 99 | w.Flush() 100 | 101 | return res 102 | } 103 | 104 | func (tf *TmpFolder) JoinPath(name string) string { 105 | return filepath.Join(tf.Path, name) 106 | } 107 | 108 | func (tf *TmpFolder) createFile(name string) (*os.File, *TmpFile) { 109 | var ( 110 | f *os.File 111 | err error 112 | ) 113 | 114 | if len(name) > 0 { 115 | f, err = os.Create(filepath.Join(tf.Path, name)) 116 | } else { 117 | f, err = os.CreateTemp(tf.Path, "temp") 118 | } 119 | 120 | Expect(err).Should(Succeed()) 121 | 122 | return f, &TmpFile{ 123 | Path: f.Name(), 124 | Folder: tf, 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /lists/downloader.go: -------------------------------------------------------------------------------- 1 | package lists 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "net/http" 10 | 11 | "github.com/0xERR0R/blocky/config" 12 | "github.com/0xERR0R/blocky/evt" 13 | "github.com/avast/retry-go/v4" 14 | ) 15 | 16 | // TransientError represents a temporary error like timeout, network errors... 17 | type TransientError struct { 18 | inner error 19 | } 20 | 21 | func (e *TransientError) Error() string { 22 | return fmt.Sprintf("temporary error occurred: %v", e.inner) 23 | } 24 | 25 | func (e *TransientError) Unwrap() error { 26 | return e.inner 27 | } 28 | 29 | // FileDownloader is able to download some text file 30 | type FileDownloader interface { 31 | DownloadFile(ctx context.Context, link string) (io.ReadCloser, error) 32 | } 33 | 34 | // httpDownloader downloads files via HTTP protocol 35 | type httpDownloader struct { 36 | cfg config.Downloader 37 | 38 | client http.Client 39 | } 40 | 41 | func NewDownloader(cfg config.Downloader, transport http.RoundTripper) FileDownloader { 42 | return newDownloader(cfg, transport) 43 | } 44 | 45 | func newDownloader(cfg config.Downloader, transport http.RoundTripper) *httpDownloader { 46 | return &httpDownloader{ 47 | cfg: cfg, 48 | 49 | client: http.Client{ 50 | Transport: transport, 51 | Timeout: cfg.Timeout.ToDuration(), 52 | }, 53 | } 54 | } 55 | 56 | func (d *httpDownloader) DownloadFile(ctx context.Context, link string) (io.ReadCloser, error) { 57 | var body io.ReadCloser 58 | 59 | err := retry.Do( 60 | func() error { 61 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | resp, httpErr := d.client.Do(req) 67 | if httpErr == nil { 68 | if resp.StatusCode == http.StatusOK { 69 | body = resp.Body 70 | 71 | return nil 72 | } 73 | 74 | _ = resp.Body.Close() 75 | 76 | return fmt.Errorf("got status code %d", resp.StatusCode) 77 | } 78 | 79 | var netErr net.Error 80 | if errors.As(httpErr, &netErr) && netErr.Timeout() { 81 | return &TransientError{inner: netErr} 82 | } 83 | 84 | return httpErr 85 | }, 86 | retry.Attempts(d.cfg.Attempts), 87 | retry.DelayType(retry.FixedDelay), 88 | retry.Delay(d.cfg.Cooldown.ToDuration()), 89 | retry.LastErrorOnly(true), 90 | retry.OnRetry(func(n uint, err error) { 91 | var transientErr *TransientError 92 | 93 | var dnsErr *net.DNSError 94 | 95 | logger := logger(). 96 | WithField("link", link). 97 | WithField("attempt", fmt.Sprintf("%d/%d", n+1, d.cfg.Attempts)) 98 | 99 | switch { 100 | case errors.As(err, &transientErr): 101 | logger.Warnf("Temporary network err / Timeout occurred: %s", transientErr) 102 | case errors.As(err, &dnsErr): 103 | logger.Warnf("Name resolution err: %s", dnsErr.Err) 104 | default: 105 | logger.Warnf("Can't download file: %s", err) 106 | } 107 | 108 | onDownloadError(link) 109 | })) 110 | 111 | return body, err 112 | } 113 | 114 | func onDownloadError(link string) { 115 | evt.Bus().Publish(evt.CachingFailedDownloadChanged, link) 116 | } 117 | -------------------------------------------------------------------------------- /lists/list_cache_benchmark_test.go: -------------------------------------------------------------------------------- 1 | package lists 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/0xERR0R/blocky/config" 8 | ) 9 | 10 | func BenchmarkRefresh(b *testing.B) { 11 | file1, _ := createTestListFile(b.TempDir(), 100000) 12 | file2, _ := createTestListFile(b.TempDir(), 150000) 13 | file3, _ := createTestListFile(b.TempDir(), 130000) 14 | lists := map[string][]config.BytesSource{ 15 | "gr1": config.NewBytesSources(file1, file2, file3), 16 | } 17 | 18 | cfg := config.SourceLoading{ 19 | Concurrency: 5, 20 | RefreshPeriod: config.Duration(-1), 21 | } 22 | downloader := NewDownloader(config.Downloader{}, nil) 23 | cache, _ := NewListCache(context.Background(), ListCacheTypeDenylist, cfg, lists, downloader) 24 | 25 | b.ReportAllocs() 26 | 27 | for n := 0; n < b.N; n++ { 28 | _ = cache.Refresh() 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /lists/list_cache_enum.go: -------------------------------------------------------------------------------- 1 | // Code generated by go-enum DO NOT EDIT. 2 | // Version: 3 | // Revision: 4 | // Build Date: 5 | // Built By: 6 | 7 | package lists 8 | 9 | import ( 10 | "fmt" 11 | "strings" 12 | ) 13 | 14 | const ( 15 | // ListCacheTypeDenylist is a ListCacheType of type Denylist. 16 | // is a list with blocked domains 17 | ListCacheTypeDenylist ListCacheType = iota 18 | // ListCacheTypeAllowlist is a ListCacheType of type Allowlist. 19 | // is a list with allowlisted domains / IPs 20 | ListCacheTypeAllowlist 21 | ) 22 | 23 | var ErrInvalidListCacheType = fmt.Errorf("not a valid ListCacheType, try [%s]", strings.Join(_ListCacheTypeNames, ", ")) 24 | 25 | const _ListCacheTypeName = "denylistallowlist" 26 | 27 | var _ListCacheTypeNames = []string{ 28 | _ListCacheTypeName[0:8], 29 | _ListCacheTypeName[8:17], 30 | } 31 | 32 | // ListCacheTypeNames returns a list of possible string values of ListCacheType. 33 | func ListCacheTypeNames() []string { 34 | tmp := make([]string, len(_ListCacheTypeNames)) 35 | copy(tmp, _ListCacheTypeNames) 36 | return tmp 37 | } 38 | 39 | var _ListCacheTypeMap = map[ListCacheType]string{ 40 | ListCacheTypeDenylist: _ListCacheTypeName[0:8], 41 | ListCacheTypeAllowlist: _ListCacheTypeName[8:17], 42 | } 43 | 44 | // String implements the Stringer interface. 45 | func (x ListCacheType) String() string { 46 | if str, ok := _ListCacheTypeMap[x]; ok { 47 | return str 48 | } 49 | return fmt.Sprintf("ListCacheType(%d)", x) 50 | } 51 | 52 | // IsValid provides a quick way to determine if the typed value is 53 | // part of the allowed enumerated values 54 | func (x ListCacheType) IsValid() bool { 55 | _, ok := _ListCacheTypeMap[x] 56 | return ok 57 | } 58 | 59 | var _ListCacheTypeValue = map[string]ListCacheType{ 60 | _ListCacheTypeName[0:8]: ListCacheTypeDenylist, 61 | _ListCacheTypeName[8:17]: ListCacheTypeAllowlist, 62 | } 63 | 64 | // ParseListCacheType attempts to convert a string to a ListCacheType. 65 | func ParseListCacheType(name string) (ListCacheType, error) { 66 | if x, ok := _ListCacheTypeValue[name]; ok { 67 | return x, nil 68 | } 69 | return ListCacheType(0), fmt.Errorf("%s is %w", name, ErrInvalidListCacheType) 70 | } 71 | 72 | // MarshalText implements the text marshaller method. 73 | func (x ListCacheType) MarshalText() ([]byte, error) { 74 | return []byte(x.String()), nil 75 | } 76 | 77 | // UnmarshalText implements the text unmarshaller method. 78 | func (x *ListCacheType) UnmarshalText(text []byte) error { 79 | name := string(text) 80 | tmp, err := ParseListCacheType(name) 81 | if err != nil { 82 | return err 83 | } 84 | *x = tmp 85 | return nil 86 | } 87 | -------------------------------------------------------------------------------- /lists/list_suite_test.go: -------------------------------------------------------------------------------- 1 | package lists 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestLists(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "Lists Suite") 19 | } 20 | -------------------------------------------------------------------------------- /lists/parsers/adapt.go: -------------------------------------------------------------------------------- 1 | package parsers 2 | 3 | import "context" 4 | 5 | // Adapt returns a parser that wraps `inner` converting each parsed value. 6 | func Adapt[From, To any](inner SeriesParser[From], adapt func(From) To) SeriesParser[To] { 7 | return TryAdapt(inner, func(from From) (To, error) { 8 | return adapt(from), nil 9 | }) 10 | } 11 | 12 | // TryAdapt returns a parser that wraps `inner` and tries to convert each parsed value. 13 | func TryAdapt[From, To any](inner SeriesParser[From], adapt func(From) (To, error)) SeriesParser[To] { 14 | return newAdapter(inner, adapt) 15 | } 16 | 17 | // TryAdaptMethod returns a parser that wraps `inner` and tries to convert each parsed value 18 | // using the given method with pointer receiver of `To`. 19 | func TryAdaptMethod[ToPtr *To, From, To any]( 20 | inner SeriesParser[From], method func(ToPtr, From) error, 21 | ) SeriesParser[*To] { 22 | return TryAdapt(inner, func(from From) (*To, error) { 23 | res := new(To) 24 | 25 | err := method(res, from) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | return res, nil 31 | }) 32 | } 33 | 34 | type adapter[From, To any] struct { 35 | inner SeriesParser[From] 36 | adapt func(From) (To, error) 37 | } 38 | 39 | func newAdapter[From, To any](inner SeriesParser[From], adapt func(From) (To, error)) SeriesParser[To] { 40 | return &adapter[From, To]{inner, adapt} 41 | } 42 | 43 | func (a *adapter[From, To]) Position() string { 44 | return a.inner.Position() 45 | } 46 | 47 | func (a *adapter[From, To]) Next(ctx context.Context) (To, error) { 48 | from, err := a.inner.Next(ctx) 49 | if err != nil { 50 | var zero To 51 | 52 | return zero, err 53 | } 54 | 55 | res, err := a.adapt(from) 56 | if err != nil { 57 | var zero To 58 | 59 | return zero, err 60 | } 61 | 62 | return res, nil 63 | } 64 | -------------------------------------------------------------------------------- /lists/parsers/filtererrors.go: -------------------------------------------------------------------------------- 1 | package parsers 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | // NoErrorLimit can be used to continue parsing until EOF. 9 | const NoErrorLimit = -1 10 | 11 | var ErrTooManyErrors = errors.New("too many parse errors") 12 | 13 | type FilteredSeriesParser[T any] interface { 14 | SeriesParser[T] 15 | 16 | // OnErr registers a callback invoked for each error encountered. 17 | OnErr(func(error)) 18 | } 19 | 20 | // AllowErrors returns a parser that wraps `inner` and tries to continue parsing. 21 | // 22 | // After `n` errors, it returns any error `inner` does. 23 | func FilterErrors[T any](inner SeriesParser[T], filter func(error) error) FilteredSeriesParser[T] { 24 | return newErrorFilter(inner, filter) 25 | } 26 | 27 | // AllowErrors returns a parser that wraps `inner` and tries to continue parsing. 28 | // 29 | // After `n` errors, it returns any error `inner` does. 30 | func AllowErrors[T any](inner SeriesParser[T], n int) FilteredSeriesParser[T] { 31 | if n == NoErrorLimit { 32 | return FilterErrors(inner, func(error) error { return nil }) 33 | } 34 | 35 | count := 0 36 | 37 | return FilterErrors(inner, func(err error) error { 38 | count++ 39 | 40 | if count > n { 41 | return ErrTooManyErrors 42 | } 43 | 44 | return nil 45 | }) 46 | } 47 | 48 | type errorFilter[T any] struct { 49 | inner SeriesParser[T] 50 | filter func(error) error 51 | } 52 | 53 | func newErrorFilter[T any](inner SeriesParser[T], filter func(error) error) FilteredSeriesParser[T] { 54 | return &errorFilter[T]{inner, filter} 55 | } 56 | 57 | func (f *errorFilter[T]) OnErr(callback func(error)) { 58 | filter := f.filter 59 | 60 | f.filter = func(err error) error { 61 | callback(ErrWithPosition(f.inner, err)) 62 | 63 | return filter(err) 64 | } 65 | } 66 | 67 | func (f *errorFilter[T]) Position() string { 68 | return f.inner.Position() 69 | } 70 | 71 | func (f *errorFilter[T]) Next(ctx context.Context) (T, error) { 72 | var zero T 73 | 74 | for { 75 | res, err := f.inner.Next(ctx) 76 | if err != nil { 77 | if IsNonResumableErr(err) { 78 | // bypass the filter, and just propagate the error 79 | return zero, err 80 | } 81 | 82 | err = f.filter(err) 83 | if err != nil { 84 | return zero, err 85 | } 86 | 87 | continue 88 | } 89 | 90 | return res, nil 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /lists/parsers/lines.go: -------------------------------------------------------------------------------- 1 | package parsers 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding" 7 | "fmt" 8 | "io" 9 | "strings" 10 | "unicode" 11 | ) 12 | 13 | // Lines splits `r` into a series of lines. 14 | // 15 | // Empty lines are skipped, and comments are stripped. 16 | func Lines(r io.Reader) SeriesParser[string] { 17 | return newLines(r) 18 | } 19 | 20 | // LinesAs returns a parser that parses each line of `r` as a `T`. 21 | func LinesAs[TPtr TextUnmarshaler[T], T any](r io.Reader) SeriesParser[*T] { 22 | return UnmarshalEach[TPtr](Lines(r)) 23 | } 24 | 25 | // UnmarshalEach returns a parser that unmarshals each string of `inner` as a `T`. 26 | func UnmarshalEach[TPtr TextUnmarshaler[T], T any](inner SeriesParser[string]) SeriesParser[*T] { 27 | stringToBytes := func(s string) []byte { 28 | return []byte(s) 29 | } 30 | 31 | return TryAdaptMethod(Adapt(inner, stringToBytes), TPtr.UnmarshalText) 32 | } 33 | 34 | type TextUnmarshaler[T any] interface { 35 | encoding.TextUnmarshaler 36 | *T 37 | } 38 | 39 | type lines struct { 40 | scanner *bufio.Scanner 41 | lineNo uint 42 | } 43 | 44 | func newLines(r io.Reader) SeriesParser[string] { 45 | scanner := bufio.NewScanner(r) 46 | scanner.Split(bufio.ScanLines) 47 | 48 | return &lines{scanner: scanner} 49 | } 50 | 51 | func (l *lines) Position() string { 52 | return fmt.Sprintf("line %d", l.lineNo) 53 | } 54 | 55 | func (l *lines) Next(ctx context.Context) (string, error) { 56 | for { 57 | l.lineNo++ 58 | 59 | if err := ctx.Err(); err != nil { 60 | return "", NewNonResumableError(err) 61 | } 62 | 63 | if !l.scanner.Scan() { 64 | break 65 | } 66 | 67 | text := strings.TrimSpace(l.scanner.Text()) 68 | 69 | if len(text) == 0 { 70 | continue // empty line 71 | } 72 | 73 | if idx := strings.IndexRune(text, '#'); idx != -1 { 74 | if idx == 0 { 75 | continue // commented line 76 | } 77 | 78 | // end of line comment 79 | text = text[:idx] 80 | text = strings.TrimRightFunc(text, unicode.IsSpace) 81 | } 82 | 83 | return text, nil 84 | } 85 | 86 | err := l.scanner.Err() 87 | if err != nil { 88 | // bufio.Scanner does not support continuing after an error 89 | return "", NewNonResumableError(err) 90 | } 91 | 92 | return "", NewNonResumableError(io.EOF) 93 | } 94 | -------------------------------------------------------------------------------- /lists/parsers/parser.go: -------------------------------------------------------------------------------- 1 | package parsers 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | ) 9 | 10 | // SeriesParser parses a series of `T`. 11 | type SeriesParser[T any] interface { 12 | // Next advances the cursor in the underlying data source, 13 | // and returns a `T`, or an error. 14 | // 15 | // Fatal parse errors, where no more calls to `Next` should 16 | // be made are of type `NonResumableError`. 17 | // Other errors apply to the item being parsed, and have no 18 | // impact on the rest of the series. 19 | Next(context.Context) (T, error) 20 | 21 | // Position returns a string that gives an user readable indication 22 | // as to where in the parser's underlying data source the cursor is. 23 | // 24 | // The string should be understandable easily by the user. 25 | Position() string 26 | } 27 | 28 | // ForEach is a helper for consuming a parser. 29 | // 30 | // It stops iteration at the first error encountered. 31 | // If that error is `io.EOF`, `nil` is returned instead. 32 | // Any other error is wrapped with the parser's position using `ErrWithPosition`. 33 | // 34 | // To continue iteration on resumable errors, use with `FilterErrors`. 35 | func ForEach[T any](ctx context.Context, parser SeriesParser[T], callback func(T) error) (rerr error) { 36 | defer func() { 37 | rerr = ErrWithPosition(parser, rerr) 38 | }() 39 | 40 | for { 41 | if err := ctx.Err(); err != nil { 42 | return err 43 | } 44 | 45 | res, err := parser.Next(ctx) 46 | if err != nil { 47 | if errors.Is(err, io.EOF) { 48 | return nil 49 | } 50 | 51 | return err 52 | } 53 | 54 | err = callback(res) 55 | if err != nil { 56 | return err 57 | } 58 | } 59 | } 60 | 61 | // ErrWithPosition adds the `parser`'s position to the given `err`. 62 | func ErrWithPosition[T any](parser SeriesParser[T], err error) error { 63 | if err == nil { 64 | return nil 65 | } 66 | 67 | return fmt.Errorf("%s: %w", parser.Position(), err) 68 | } 69 | 70 | // IsNonResumableErr is a helper to check if an error returned by a parser is resumable. 71 | func IsNonResumableErr(err error) bool { 72 | var nonResumableError *NonResumableError 73 | 74 | return errors.As(err, &nonResumableError) 75 | } 76 | 77 | // NonResumableError represents an error from which a parser cannot recover. 78 | type NonResumableError struct { 79 | inner error 80 | } 81 | 82 | // NewNonResumableError creates and returns a new `NonResumableError`. 83 | func NewNonResumableError(inner error) error { 84 | return &NonResumableError{inner} 85 | } 86 | 87 | func (e *NonResumableError) Error() string { 88 | return fmt.Sprintf("non resumable parse error: %s", e.inner.Error()) 89 | } 90 | 91 | func (e *NonResumableError) Unwrap() error { 92 | return e.inner 93 | } 94 | -------------------------------------------------------------------------------- /lists/parsers/parsers_suite_test.go: -------------------------------------------------------------------------------- 1 | package parsers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestLists(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "Parsers Suite") 19 | } 20 | -------------------------------------------------------------------------------- /lists/sourcereader.go: -------------------------------------------------------------------------------- 1 | package lists 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | 10 | "github.com/0xERR0R/blocky/config" 11 | ) 12 | 13 | type SourceOpener interface { 14 | fmt.Stringer 15 | 16 | Open(ctx context.Context) (io.ReadCloser, error) 17 | } 18 | 19 | func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) { 20 | switch source.Type { 21 | case config.BytesSourceTypeText: 22 | return &textOpener{source: source, locInfo: txtLocInfo}, nil 23 | 24 | case config.BytesSourceTypeHttp: 25 | return &httpOpener{source: source, downloader: downloader}, nil 26 | 27 | case config.BytesSourceTypeFile: 28 | return &fileOpener{source: source}, nil 29 | } 30 | 31 | return nil, fmt.Errorf("cannot open %s", source) 32 | } 33 | 34 | type textOpener struct { 35 | source config.BytesSource 36 | locInfo string 37 | } 38 | 39 | func (o *textOpener) Open(_ context.Context) (io.ReadCloser, error) { 40 | return io.NopCloser(strings.NewReader(o.source.From)), nil 41 | } 42 | 43 | func (o *textOpener) String() string { 44 | return fmt.Sprintf("%s: %s", o.locInfo, o.source) 45 | } 46 | 47 | type httpOpener struct { 48 | source config.BytesSource 49 | downloader FileDownloader 50 | } 51 | 52 | func (o *httpOpener) Open(ctx context.Context) (io.ReadCloser, error) { 53 | return o.downloader.DownloadFile(ctx, o.source.From) 54 | } 55 | 56 | func (o *httpOpener) String() string { 57 | return o.source.String() 58 | } 59 | 60 | type fileOpener struct { 61 | source config.BytesSource 62 | } 63 | 64 | func (o *fileOpener) Open(_ context.Context) (io.ReadCloser, error) { 65 | return os.Open(o.source.From) 66 | } 67 | 68 | func (o *fileOpener) String() string { 69 | return o.source.String() 70 | } 71 | -------------------------------------------------------------------------------- /log/context.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | type ctxKey struct{} 10 | 11 | func NewCtx(ctx context.Context, logger *logrus.Entry) (context.Context, *logrus.Entry) { 12 | ctx = context.WithValue(ctx, ctxKey{}, logger) 13 | 14 | return ctx, entryWithCtx(ctx, logger) 15 | } 16 | 17 | func FromCtx(ctx context.Context) *logrus.Entry { 18 | logger, ok := ctx.Value(ctxKey{}).(*logrus.Entry) 19 | if !ok { 20 | // Fallback to the global logger 21 | return logrus.NewEntry(Log()) 22 | } 23 | 24 | // Ensure `logger.Context == ctx`, not always the case since `ctx` could be a child of `logger.Context` 25 | return entryWithCtx(ctx, logger) 26 | } 27 | 28 | func entryWithCtx(ctx context.Context, logger *logrus.Entry) *logrus.Entry { 29 | loggerCopy := *logger 30 | loggerCopy.Context = ctx 31 | 32 | return &loggerCopy 33 | } 34 | 35 | func WrapCtx(ctx context.Context, wrap func(*logrus.Entry) *logrus.Entry) (context.Context, *logrus.Entry) { 36 | logger := FromCtx(ctx) 37 | logger = wrap(logger) 38 | 39 | return NewCtx(ctx, logger) 40 | } 41 | 42 | func CtxWithFields(ctx context.Context, fields logrus.Fields) (context.Context, *logrus.Entry) { 43 | return WrapCtx(ctx, func(e *logrus.Entry) *logrus.Entry { 44 | return e.WithFields(fields) 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /log/logger_enum.go: -------------------------------------------------------------------------------- 1 | // Code generated by go-enum DO NOT EDIT. 2 | // Version: 3 | // Revision: 4 | // Build Date: 5 | // Built By: 6 | 7 | package log 8 | 9 | import ( 10 | "fmt" 11 | "strings" 12 | ) 13 | 14 | const ( 15 | // FormatTypeText is a FormatType of type Text. 16 | // logging as text 17 | FormatTypeText FormatType = iota 18 | // FormatTypeJson is a FormatType of type Json. 19 | // JSON format 20 | FormatTypeJson 21 | ) 22 | 23 | var ErrInvalidFormatType = fmt.Errorf("not a valid FormatType, try [%s]", strings.Join(_FormatTypeNames, ", ")) 24 | 25 | const _FormatTypeName = "textjson" 26 | 27 | var _FormatTypeNames = []string{ 28 | _FormatTypeName[0:4], 29 | _FormatTypeName[4:8], 30 | } 31 | 32 | // FormatTypeNames returns a list of possible string values of FormatType. 33 | func FormatTypeNames() []string { 34 | tmp := make([]string, len(_FormatTypeNames)) 35 | copy(tmp, _FormatTypeNames) 36 | return tmp 37 | } 38 | 39 | var _FormatTypeMap = map[FormatType]string{ 40 | FormatTypeText: _FormatTypeName[0:4], 41 | FormatTypeJson: _FormatTypeName[4:8], 42 | } 43 | 44 | // String implements the Stringer interface. 45 | func (x FormatType) String() string { 46 | if str, ok := _FormatTypeMap[x]; ok { 47 | return str 48 | } 49 | return fmt.Sprintf("FormatType(%d)", x) 50 | } 51 | 52 | // IsValid provides a quick way to determine if the typed value is 53 | // part of the allowed enumerated values 54 | func (x FormatType) IsValid() bool { 55 | _, ok := _FormatTypeMap[x] 56 | return ok 57 | } 58 | 59 | var _FormatTypeValue = map[string]FormatType{ 60 | _FormatTypeName[0:4]: FormatTypeText, 61 | _FormatTypeName[4:8]: FormatTypeJson, 62 | } 63 | 64 | // ParseFormatType attempts to convert a string to a FormatType. 65 | func ParseFormatType(name string) (FormatType, error) { 66 | if x, ok := _FormatTypeValue[name]; ok { 67 | return x, nil 68 | } 69 | return FormatType(0), fmt.Errorf("%s is %w", name, ErrInvalidFormatType) 70 | } 71 | 72 | // MarshalText implements the text marshaller method. 73 | func (x FormatType) MarshalText() ([]byte, error) { 74 | return []byte(x.String()), nil 75 | } 76 | 77 | // UnmarshalText implements the text unmarshaller method. 78 | func (x *FormatType) UnmarshalText(text []byte) error { 79 | name := string(text) 80 | tmp, err := ParseFormatType(name) 81 | if err != nil { 82 | return err 83 | } 84 | *x = tmp 85 | return nil 86 | } 87 | -------------------------------------------------------------------------------- /log/mock_entry.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "github.com/sirupsen/logrus" 5 | "github.com/sirupsen/logrus/hooks/test" 6 | "github.com/stretchr/testify/mock" 7 | ) 8 | 9 | func NewMockEntry() (*logrus.Entry, *MockLoggerHook) { 10 | logger, _ := test.NewNullLogger() 11 | logger.Level = logrus.TraceLevel 12 | 13 | entry := logrus.Entry{Logger: logger} 14 | hook := MockLoggerHook{} 15 | 16 | entry.Logger.AddHook(&hook) 17 | 18 | hook.On("Fire", mock.Anything).Return(nil) 19 | 20 | return &entry, &hook 21 | } 22 | 23 | type MockLoggerHook struct { 24 | mock.Mock 25 | 26 | Messages []string 27 | } 28 | 29 | // Levels implements `logrus.Hook`. 30 | func (h *MockLoggerHook) Levels() []logrus.Level { 31 | return logrus.AllLevels 32 | } 33 | 34 | // Fire implements `logrus.Hook`. 35 | func (h *MockLoggerHook) Fire(entry *logrus.Entry) error { 36 | _ = h.Called() 37 | 38 | h.Messages = append(h.Messages, entry.Message) 39 | 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/0xERR0R/blocky/cmd" 7 | ) 8 | 9 | func main() { 10 | cmd.Execute() 11 | os.Exit(0) 12 | } 13 | -------------------------------------------------------------------------------- /main_static.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package main 5 | 6 | import ( 7 | "os" 8 | "time" 9 | _ "time/tzdata" 10 | 11 | _ "github.com/breml/rootcerts" 12 | 13 | reaper "github.com/ramr/go-reaper" 14 | ) 15 | 16 | //nolint:gochecknoinits 17 | func init() { 18 | go reaper.Start(reaper.Config{ 19 | DisablePid1Check: true, 20 | }) 21 | 22 | setLocaltime() 23 | } 24 | 25 | // set localtime to /etc/localtime if available 26 | // or modify the system time with the TZ environment variable if it is provided 27 | func setLocaltime() { 28 | // load /etc/localtime without modifying it 29 | if lt, err := os.ReadFile("/etc/localtime"); err == nil { 30 | if t, err := time.LoadLocationFromTZData("", lt); err == nil { 31 | time.Local = t 32 | 33 | return 34 | } 35 | } 36 | 37 | // use zoneinfo from time/tzdata and set location with the TZ environment variable 38 | if tz := os.Getenv("TZ"); tz != "" { 39 | if t, err := time.LoadLocation(tz); err == nil { 40 | time.Local = t 41 | 42 | return 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/0xERR0R/blocky/config" 5 | 6 | "github.com/go-chi/chi/v5" 7 | "github.com/prometheus/client_golang/prometheus" 8 | "github.com/prometheus/client_golang/prometheus/collectors" 9 | "github.com/prometheus/client_golang/prometheus/promhttp" 10 | ) 11 | 12 | //nolint:gochecknoglobals 13 | var Reg = prometheus.NewRegistry() 14 | 15 | // RegisterMetric registers prometheus collector 16 | func RegisterMetric(c prometheus.Collector) { 17 | _ = Reg.Register(c) 18 | } 19 | 20 | // Start starts prometheus endpoint 21 | func Start(router *chi.Mux, cfg config.Metrics) { 22 | if cfg.Enable { 23 | _ = Reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) 24 | _ = Reg.Register(collectors.NewGoCollector()) 25 | router.Handle(cfg.Path, promhttp.InstrumentMetricHandler(Reg, 26 | promhttp.HandlerFor(Reg, promhttp.HandlerOpts{}))) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: blocky 2 | site_description: blocky Documentation 3 | theme: 4 | name: material 5 | palette: 6 | primary: teal 7 | accent: teal 8 | extra: 9 | version: 10 | provider: mike 11 | social: 12 | - icon: fontawesome/brands/github 13 | link: https://github.com/0xERR0R/blocky 14 | - icon: simple/codeberg 15 | link: https://codeberg.org/0xERR0R/blocky 16 | - icon: fontawesome/brands/docker 17 | link: https://hub.docker.com/r/spx01/blocky 18 | repo_url: https://github.com/0xERR0R/blocky 19 | 20 | markdown_extensions: 21 | - abbr 22 | - pymdownx.snippets 23 | - pymdownx.emoji: 24 | emoji_index: !!python/name:material.extensions.emoji.twemoji 25 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 26 | - pymdownx.highlight 27 | - pymdownx.superfences 28 | - admonition 29 | - pymdownx.details 30 | 31 | nav: 32 | - 'Welcome': 'index.md' 33 | - 'Configuration': 'configuration.md' 34 | - 'Installation': 'installation.md' 35 | - 'Prometheus / Grafana': 'prometheus_grafana.md' 36 | - 'Interfaces': 'interfaces.md' 37 | - 'Network configuration': 'network_configuration.md' 38 | - 'Additional information': 'additional_information.md' 39 | -------------------------------------------------------------------------------- /model/models.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | //go:generate go tool go-enum -f=$GOFILE --marshal --names 4 | import ( 5 | "net" 6 | "time" 7 | 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // ResponseType represents the type of the response ENUM( 12 | // RESOLVED // the response was resolved by the external upstream resolver 13 | // CACHED // the response was resolved from cache 14 | // BLOCKED // the query was blocked 15 | // CONDITIONAL // the query was resolved by the conditional upstream resolver 16 | // CUSTOMDNS // the query was resolved by a custom rule 17 | // HOSTSFILE // the query was resolved by looking up the hosts file 18 | // FILTERED // the query was filtered by query type 19 | // NOTFQDN // the query was filtered as it is not fqdn conform 20 | // SPECIAL // the query was resolved by the special use domain name resolver 21 | // ) 22 | type ResponseType int 23 | 24 | func (t ResponseType) ToExtendedErrorCode() uint16 { 25 | switch t { 26 | case ResponseTypeRESOLVED: 27 | return dns.ExtendedErrorCodeOther 28 | case ResponseTypeCACHED: 29 | return dns.ExtendedErrorCodeCachedError 30 | case ResponseTypeCONDITIONAL: 31 | return dns.ExtendedErrorCodeForgedAnswer 32 | case ResponseTypeCUSTOMDNS: 33 | return dns.ExtendedErrorCodeForgedAnswer 34 | case ResponseTypeHOSTSFILE: 35 | return dns.ExtendedErrorCodeForgedAnswer 36 | case ResponseTypeNOTFQDN: 37 | return dns.ExtendedErrorCodeBlocked 38 | case ResponseTypeBLOCKED: 39 | return dns.ExtendedErrorCodeBlocked 40 | case ResponseTypeFILTERED: 41 | return dns.ExtendedErrorCodeFiltered 42 | case ResponseTypeSPECIAL: 43 | return dns.ExtendedErrorCodeFiltered 44 | default: 45 | return dns.ExtendedErrorCodeOther 46 | } 47 | } 48 | 49 | // Response represents the response of a DNS query 50 | type Response struct { 51 | Res *dns.Msg 52 | Reason string 53 | RType ResponseType 54 | } 55 | 56 | // RequestProtocol represents the server protocol ENUM( 57 | // TCP // is the TCP protocol 58 | // UDP // is the UDP protocol 59 | // ) 60 | type RequestProtocol uint8 61 | 62 | // Request represents client's DNS request 63 | type Request struct { 64 | ClientIP net.IP 65 | RequestClientID string 66 | Protocol RequestProtocol 67 | ClientNames []string 68 | Req *dns.Msg 69 | RequestTS time.Time 70 | } 71 | -------------------------------------------------------------------------------- /querylog/logger_writer.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | 7 | "github.com/0xERR0R/blocky/log" 8 | "github.com/sirupsen/logrus" 9 | ) 10 | 11 | const loggerPrefixLoggerWriter = "queryLog" 12 | 13 | type LoggerWriter struct { 14 | logger *logrus.Entry 15 | } 16 | 17 | func NewLoggerWriter() *LoggerWriter { 18 | return &LoggerWriter{logger: log.PrefixedLog(loggerPrefixLoggerWriter)} 19 | } 20 | 21 | func (d *LoggerWriter) Write(entry *LogEntry) { 22 | fields := LogEntryFields(entry) 23 | 24 | d.logger.WithFields(fields).Infof("query resolved") 25 | } 26 | 27 | func (d *LoggerWriter) CleanUp() { 28 | // Nothing to do 29 | } 30 | 31 | func LogEntryFields(entry *LogEntry) logrus.Fields { 32 | return withoutZeroes(logrus.Fields{ 33 | "client_ip": entry.ClientIP, 34 | "client_names": strings.Join(entry.ClientNames, "; "), 35 | "response_reason": entry.ResponseReason, 36 | "response_type": entry.ResponseType, 37 | "response_code": entry.ResponseCode, 38 | "question_name": entry.QuestionName, 39 | "question_type": entry.QuestionType, 40 | "answer": entry.Answer, 41 | "duration_ms": entry.DurationMs, 42 | "instance": entry.BlockyInstance, 43 | }) 44 | } 45 | 46 | func withoutZeroes(fields logrus.Fields) logrus.Fields { 47 | for k, v := range fields { 48 | if reflect.ValueOf(v).IsZero() { 49 | delete(fields, k) 50 | } 51 | } 52 | 53 | return fields 54 | } 55 | -------------------------------------------------------------------------------- /querylog/logger_writer_test.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/sirupsen/logrus" 7 | "github.com/sirupsen/logrus/hooks/test" 8 | 9 | . "github.com/onsi/gomega" 10 | 11 | . "github.com/onsi/ginkgo/v2" 12 | ) 13 | 14 | var _ = Describe("LoggerWriter", func() { 15 | Describe("logger query log", func() { 16 | When("New log entry was created", func() { 17 | It("should be logged", func() { 18 | writer := NewLoggerWriter() 19 | logger, hook := test.NewNullLogger() 20 | writer.logger = logger.WithField("k", "v") 21 | 22 | writer.Write(&LogEntry{ 23 | Start: time.Now(), 24 | DurationMs: 20, 25 | }) 26 | 27 | Expect(hook.Entries).Should(HaveLen(1)) 28 | Expect(hook.LastEntry().Message).Should(Equal("query resolved")) 29 | }) 30 | }) 31 | When("Cleanup is called", func() { 32 | It("should do nothing", func() { 33 | writer := NewLoggerWriter() 34 | writer.CleanUp() 35 | }) 36 | }) 37 | }) 38 | 39 | Describe("LogEntryFields", func() { 40 | It("should return log fields", func() { 41 | entry := LogEntry{ 42 | ClientIP: "ip", 43 | DurationMs: 100, 44 | QuestionType: "qtype", 45 | ResponseCode: "rcode", 46 | } 47 | 48 | fields := LogEntryFields(&entry) 49 | 50 | Expect(fields).Should(HaveKeyWithValue("client_ip", entry.ClientIP)) 51 | Expect(fields).Should(HaveKeyWithValue("duration_ms", entry.DurationMs)) 52 | Expect(fields).Should(HaveKeyWithValue("question_type", entry.QuestionType)) 53 | Expect(fields).Should(HaveKeyWithValue("response_code", entry.ResponseCode)) 54 | 55 | Expect(fields).ShouldNot(HaveKey("client_names")) 56 | Expect(fields).ShouldNot(HaveKey("question_name")) 57 | }) 58 | }) 59 | 60 | DescribeTable("withoutZeroes", 61 | func(value any, isZero bool) { 62 | fields := withoutZeroes(logrus.Fields{"a": value}) 63 | 64 | if isZero { 65 | Expect(fields).Should(BeEmpty()) 66 | } else { 67 | Expect(fields).ShouldNot(BeEmpty()) 68 | } 69 | }, 70 | Entry("empty string", 71 | "", 72 | true), 73 | Entry("non-empty string", 74 | "something", 75 | false), 76 | Entry("zero int", 77 | 0, 78 | true), 79 | Entry("non-zero int", 80 | 1, 81 | false), 82 | ) 83 | }) 84 | -------------------------------------------------------------------------------- /querylog/none_writer.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | type NoneWriter struct{} 4 | 5 | func NewNoneWriter() *NoneWriter { 6 | return &NoneWriter{} 7 | } 8 | 9 | func (d *NoneWriter) Write(*LogEntry) { 10 | // Nothing to do 11 | } 12 | 13 | func (d *NoneWriter) CleanUp() { 14 | // Nothing to do 15 | } 16 | -------------------------------------------------------------------------------- /querylog/none_writer_test.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo/v2" 5 | ) 6 | 7 | var _ = Describe("NoneWriter", func() { 8 | Describe("NoneWriter", func() { 9 | When("write is called", func() { 10 | It("should do nothing", func() { 11 | NewNoneWriter().Write(nil) 12 | }) 13 | }) 14 | When("cleanUp is called", func() { 15 | It("should do nothing", func() { 16 | NewNoneWriter().CleanUp() 17 | }) 18 | }) 19 | }) 20 | }) 21 | -------------------------------------------------------------------------------- /querylog/querylog_suite_test.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestResolver(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "Querylog Suite") 19 | } 20 | -------------------------------------------------------------------------------- /querylog/writer.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type LogEntry struct { 8 | Start time.Time 9 | ClientIP string 10 | ClientNames []string 11 | DurationMs int64 12 | ResponseReason string 13 | ResponseType string 14 | ResponseCode string 15 | QuestionType string 16 | QuestionName string 17 | Answer string 18 | BlockyInstance string 19 | } 20 | 21 | type Writer interface { 22 | Write(entry *LogEntry) 23 | CleanUp() 24 | } 25 | -------------------------------------------------------------------------------- /redis/redis_suite_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/0xERR0R/blocky/log" 8 | "github.com/go-redis/redis/v8" 9 | . "github.com/onsi/ginkgo/v2" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | func init() { 14 | log.Silence() 15 | redis.SetLogger(NoLogs{}) 16 | } 17 | 18 | func TestRedisClient(t *testing.T) { 19 | RegisterFailHandler(Fail) 20 | RunSpecs(t, "Redis Suite") 21 | } 22 | 23 | type NoLogs struct{} 24 | 25 | func (l NoLogs) Printf(context.Context, string, ...interface{}) {} 26 | -------------------------------------------------------------------------------- /resolver/ede_resolver.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/0xERR0R/blocky/config" 7 | "github.com/0xERR0R/blocky/model" 8 | "github.com/0xERR0R/blocky/util" 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | // A EDEResolver is responsible for adding the reason for the response as EDNS0 option 13 | type EDEResolver struct { 14 | configurable[*config.EDE] 15 | NextResolver 16 | typed 17 | } 18 | 19 | // NewEDEResolver creates new resolver instance which adds the reason for 20 | // the response as EDNS0 option to the response if it is enabled in the configuration 21 | func NewEDEResolver(cfg config.EDE) *EDEResolver { 22 | return &EDEResolver{ 23 | configurable: withConfig(&cfg), 24 | typed: withType("extended_error_code"), 25 | } 26 | } 27 | 28 | // Resolve adds the reason as EDNS0 option to the response of the next resolver 29 | // if it is enabled in the configuration 30 | func (r *EDEResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { 31 | if !r.cfg.Enable { 32 | return r.next.Resolve(ctx, request) 33 | } 34 | 35 | resp, err := r.next.Resolve(ctx, request) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | r.addExtraReasoning(resp) 41 | 42 | return resp, nil 43 | } 44 | 45 | // addExtraReasoning adds the reason for the response as EDNS0 option 46 | func (r *EDEResolver) addExtraReasoning(res *model.Response) { 47 | infocode := res.RType.ToExtendedErrorCode() 48 | 49 | if infocode == dns.ExtendedErrorCodeOther { 50 | // dns.ExtendedErrorCodeOther seams broken in some clients 51 | return 52 | } 53 | 54 | edeOption := new(dns.EDNS0_EDE) 55 | edeOption.InfoCode = infocode 56 | edeOption.ExtraText = res.Reason 57 | 58 | util.SetEdns0Option(res.Res, edeOption) 59 | } 60 | -------------------------------------------------------------------------------- /resolver/filtering_resolver.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/0xERR0R/blocky/config" 7 | "github.com/0xERR0R/blocky/model" 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // FilteringResolver filters DNS queries (for example can drop all AAAA query) 12 | // returns empty ANSWER with NOERROR 13 | type FilteringResolver struct { 14 | configurable[*config.Filtering] 15 | NextResolver 16 | typed 17 | } 18 | 19 | func NewFilteringResolver(cfg config.Filtering) *FilteringResolver { 20 | return &FilteringResolver{ 21 | configurable: withConfig(&cfg), 22 | typed: withType("filtering"), 23 | } 24 | } 25 | 26 | func (r *FilteringResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { 27 | qType := request.Req.Question[0].Qtype 28 | if r.cfg.QueryTypes.Contains(dns.Type(qType)) { 29 | response := new(dns.Msg) 30 | response.SetRcode(request.Req, dns.RcodeSuccess) 31 | 32 | return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil 33 | } 34 | 35 | return r.next.Resolve(ctx, request) 36 | } 37 | -------------------------------------------------------------------------------- /resolver/filtering_resolver_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/0xERR0R/blocky/config" 7 | . "github.com/0xERR0R/blocky/helpertest" 8 | "github.com/0xERR0R/blocky/log" 9 | . "github.com/0xERR0R/blocky/model" 10 | 11 | "github.com/miekg/dns" 12 | . "github.com/onsi/ginkgo/v2" 13 | . "github.com/onsi/gomega" 14 | "github.com/stretchr/testify/mock" 15 | ) 16 | 17 | var _ = Describe("FilteringResolver", func() { 18 | var ( 19 | sut *FilteringResolver 20 | sutConfig config.Filtering 21 | m *mockResolver 22 | mockAnswer *dns.Msg 23 | 24 | ctx context.Context 25 | cancelFn context.CancelFunc 26 | ) 27 | 28 | Describe("Type", func() { 29 | It("follows conventions", func() { 30 | expectValidResolverType(sut) 31 | }) 32 | }) 33 | 34 | BeforeEach(func() { 35 | ctx, cancelFn = context.WithCancel(context.Background()) 36 | DeferCleanup(cancelFn) 37 | 38 | mockAnswer = new(dns.Msg) 39 | }) 40 | 41 | JustBeforeEach(func() { 42 | sut = NewFilteringResolver(sutConfig) 43 | m = &mockResolver{} 44 | m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil) 45 | sut.Next(m) 46 | }) 47 | 48 | Describe("IsEnabled", func() { 49 | It("is false", func() { 50 | Expect(sut.IsEnabled()).Should(BeFalse()) 51 | }) 52 | }) 53 | 54 | Describe("LogConfig", func() { 55 | It("should log something", func() { 56 | logger, hook := log.NewMockEntry() 57 | 58 | sut.LogConfig(logger) 59 | 60 | Expect(hook.Calls).ShouldNot(BeEmpty()) 61 | }) 62 | }) 63 | 64 | When("Filtering query types are defined", func() { 65 | BeforeEach(func() { 66 | sutConfig = config.Filtering{ 67 | QueryTypes: config.NewQTypeSet(AAAA, MX), 68 | } 69 | }) 70 | It("Should delegate to next resolver if request query has other type", func() { 71 | Expect(sut.Resolve(ctx, newRequest("example.com.", A))). 72 | Should( 73 | SatisfyAll( 74 | HaveNoAnswer(), 75 | HaveResponseType(ResponseTypeRESOLVED), 76 | HaveReturnCode(dns.RcodeSuccess), 77 | )) 78 | 79 | // delegated to next resolver 80 | Expect(m.Calls).Should(HaveLen(1)) 81 | }) 82 | It("Should return empty answer for defined query type", func() { 83 | Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). 84 | Should( 85 | SatisfyAll( 86 | HaveNoAnswer(), 87 | HaveResponseType(ResponseTypeFILTERED), 88 | HaveReturnCode(dns.RcodeSuccess), 89 | )) 90 | 91 | // no call of next resolver 92 | Expect(m.Calls).Should(BeZero()) 93 | }) 94 | }) 95 | 96 | When("No filtering query types are defined", func() { 97 | BeforeEach(func() { 98 | sutConfig = config.Filtering{} 99 | }) 100 | It("Should return empty answer without error", func() { 101 | Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). 102 | Should( 103 | SatisfyAll( 104 | HaveNoAnswer(), 105 | HaveResponseType(ResponseTypeRESOLVED), 106 | HaveReturnCode(dns.RcodeSuccess), 107 | )) 108 | 109 | // delegated to next resolver 110 | Expect(m.Calls).Should(HaveLen(1)) 111 | }) 112 | }) 113 | }) 114 | -------------------------------------------------------------------------------- /resolver/fqdn_only_resolver.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "github.com/0xERR0R/blocky/config" 8 | "github.com/0xERR0R/blocky/model" 9 | "github.com/0xERR0R/blocky/util" 10 | "github.com/miekg/dns" 11 | ) 12 | 13 | type FQDNOnlyResolver struct { 14 | configurable[*config.FQDNOnly] 15 | NextResolver 16 | typed 17 | } 18 | 19 | func NewFQDNOnlyResolver(cfg config.FQDNOnly) *FQDNOnlyResolver { 20 | return &FQDNOnlyResolver{ 21 | configurable: withConfig(&cfg), 22 | typed: withType("fqdn_only"), 23 | } 24 | } 25 | 26 | func (r *FQDNOnlyResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { 27 | if r.IsEnabled() { 28 | domainFromQuestion := util.ExtractDomain(request.Req.Question[0]) 29 | if !strings.Contains(domainFromQuestion, ".") { 30 | response := new(dns.Msg) 31 | response.Rcode = dns.RcodeNameError 32 | 33 | return &model.Response{Res: response, RType: model.ResponseTypeNOTFQDN, Reason: "NOTFQDN"}, nil 34 | } 35 | } 36 | 37 | return r.next.Resolve(ctx, request) 38 | } 39 | -------------------------------------------------------------------------------- /resolver/metrics_resolver_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/0xERR0R/blocky/config" 8 | "github.com/0xERR0R/blocky/log" 9 | 10 | . "github.com/0xERR0R/blocky/helpertest" 11 | . "github.com/0xERR0R/blocky/model" 12 | 13 | "github.com/miekg/dns" 14 | . "github.com/onsi/ginkgo/v2" 15 | . "github.com/onsi/gomega" 16 | "github.com/prometheus/client_golang/prometheus" 17 | "github.com/prometheus/client_golang/prometheus/testutil" 18 | "github.com/stretchr/testify/mock" 19 | ) 20 | 21 | var _ = Describe("MetricResolver", func() { 22 | var ( 23 | sut *MetricsResolver 24 | m *mockResolver 25 | 26 | ctx context.Context 27 | cancelFn context.CancelFunc 28 | ) 29 | 30 | Describe("Type", func() { 31 | It("follows conventions", func() { 32 | expectValidResolverType(sut) 33 | }) 34 | }) 35 | 36 | BeforeEach(func() { 37 | ctx, cancelFn = context.WithCancel(context.Background()) 38 | DeferCleanup(cancelFn) 39 | 40 | sut = NewMetricsResolver(config.Metrics{Enable: true}) 41 | m = &mockResolver{} 42 | m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) 43 | sut.Next(m) 44 | }) 45 | 46 | Describe("IsEnabled", func() { 47 | It("is true", func() { 48 | Expect(sut.IsEnabled()).Should(BeTrue()) 49 | }) 50 | }) 51 | 52 | Describe("LogConfig", func() { 53 | It("should log something", func() { 54 | logger, hook := log.NewMockEntry() 55 | 56 | sut.LogConfig(logger) 57 | 58 | Expect(hook.Calls).ShouldNot(BeEmpty()) 59 | }) 60 | }) 61 | 62 | Describe("Recording prometheus metrics", func() { 63 | Context("Recording request metrics", func() { 64 | When("Request will be performed", func() { 65 | It("Should record metrics", func() { 66 | Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))). 67 | Should( 68 | SatisfyAll( 69 | HaveResponseType(ResponseTypeRESOLVED), 70 | HaveReturnCode(dns.RcodeSuccess), 71 | )) 72 | 73 | cnt, err := sut.totalQueries.GetMetricWith(prometheus.Labels{"client": "client", "type": "A"}) 74 | Expect(err).Should(Succeed()) 75 | 76 | Expect(testutil.ToFloat64(cnt)).Should(BeNumerically("==", 1)) 77 | m.AssertExpectations(GinkgoT()) 78 | }) 79 | }) 80 | When("Error occurs while request processing", func() { 81 | BeforeEach(func() { 82 | m = &mockResolver{} 83 | m.On("Resolve", mock.Anything).Return(nil, errors.New("error")) 84 | sut.Next(m) 85 | }) 86 | It("Error should be recorded", func() { 87 | _, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client")) 88 | 89 | Expect(err).Should(HaveOccurred()) 90 | 91 | Expect(testutil.ToFloat64(sut.totalErrors)).Should(BeNumerically("==", 1)) 92 | }) 93 | }) 94 | }) 95 | }) 96 | }) 97 | -------------------------------------------------------------------------------- /resolver/noop_resolver.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/0xERR0R/blocky/model" 7 | "github.com/sirupsen/logrus" 8 | ) 9 | 10 | var NoResponse = &model.Response{} //nolint:gochecknoglobals 11 | 12 | // NoOpResolver is used to finish a resolver branch as created in RewriterResolver 13 | type NoOpResolver struct{} 14 | 15 | func NewNoOpResolver() *NoOpResolver { 16 | return &NoOpResolver{} 17 | } 18 | 19 | // Type implements `Resolver`. 20 | func (NoOpResolver) Type() string { 21 | return "noop" 22 | } 23 | 24 | // String implements `fmt.Stringer`. 25 | func (r NoOpResolver) String() string { 26 | return r.Type() 27 | } 28 | 29 | // IsEnabled implements `config.Configurable`. 30 | func (NoOpResolver) IsEnabled() bool { 31 | return true 32 | } 33 | 34 | // LogConfig implements `config.Configurable`. 35 | func (NoOpResolver) LogConfig(*logrus.Entry) { 36 | } 37 | 38 | func (NoOpResolver) Resolve(context.Context, *model.Request) (*model.Response, error) { 39 | return NoResponse, nil 40 | } 41 | -------------------------------------------------------------------------------- /resolver/noop_resolver_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | 6 | . "github.com/0xERR0R/blocky/helpertest" 7 | "github.com/0xERR0R/blocky/log" 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | var _ = Describe("NoOpResolver", func() { 13 | var ( 14 | sut *NoOpResolver 15 | 16 | ctx context.Context 17 | cancelFn context.CancelFunc 18 | ) 19 | 20 | Describe("Type", func() { 21 | It("follows conventions", func() { 22 | expectValidResolverType(sut) 23 | }) 24 | }) 25 | 26 | BeforeEach(func() { 27 | ctx, cancelFn = context.WithCancel(context.Background()) 28 | DeferCleanup(cancelFn) 29 | 30 | sut = NewNoOpResolver() 31 | }) 32 | 33 | Describe("Resolving", func() { 34 | It("returns no response", func() { 35 | resp, err := sut.Resolve(ctx, newRequest("test.tld", A)) 36 | Expect(err).Should(Succeed()) 37 | Expect(resp).Should(Equal(NoResponse)) 38 | }) 39 | }) 40 | 41 | Describe("IsEnabled", func() { 42 | It("is true", func() { 43 | Expect(sut.IsEnabled()).Should(BeTrue()) 44 | }) 45 | }) 46 | 47 | Describe("LogConfig", func() { 48 | It("should not log anything", func() { 49 | logger, hook := log.NewMockEntry() 50 | 51 | sut.LogConfig(logger) 52 | 53 | Expect(hook.Calls).Should(BeEmpty()) 54 | }) 55 | }) 56 | }) 57 | -------------------------------------------------------------------------------- /resolver/resolver_suite_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/0xERR0R/blocky/config" 9 | "github.com/0xERR0R/blocky/log" 10 | "github.com/go-redis/redis/v8" 11 | 12 | . "github.com/onsi/ginkgo/v2" 13 | . "github.com/onsi/gomega" 14 | ) 15 | 16 | const ( 17 | timeout = 50 * time.Millisecond 18 | ) 19 | 20 | var defaultUpstreamsConfig config.Upstreams 21 | 22 | func init() { 23 | log.Silence() 24 | redis.SetLogger(NoLogs{}) 25 | 26 | var err error 27 | 28 | defaultUpstreamsConfig, err = config.WithDefaults[config.Upstreams]() 29 | if err != nil { 30 | panic(err) 31 | } 32 | 33 | // Shorter timeout for tests 34 | defaultUpstreamsConfig.Timeout = config.Duration(timeout) 35 | } 36 | 37 | func TestResolver(t *testing.T) { 38 | RegisterFailHandler(Fail) 39 | RunSpecs(t, "Resolver Suite") 40 | } 41 | 42 | type NoLogs struct{} 43 | 44 | func (l NoLogs) Printf(_ context.Context, _ string, _ ...interface{}) {} 45 | -------------------------------------------------------------------------------- /resolver/strict_resolver.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | "sync/atomic" 9 | 10 | "github.com/0xERR0R/blocky/config" 11 | "github.com/0xERR0R/blocky/model" 12 | "github.com/0xERR0R/blocky/util" 13 | 14 | "github.com/sirupsen/logrus" 15 | ) 16 | 17 | const ( 18 | strictResolverType = "strict" 19 | ) 20 | 21 | // StrictResolver delegates the DNS message strictly to the first configured upstream resolver 22 | // if it can't provide the answer in time the next resolver is used 23 | type StrictResolver struct { 24 | configurable[*config.UpstreamGroup] 25 | typed 26 | 27 | resolvers atomic.Pointer[[]*upstreamResolverStatus] 28 | } 29 | 30 | // NewStrictResolver creates a new strict resolver instance 31 | func NewStrictResolver( 32 | ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, 33 | ) (*StrictResolver, error) { 34 | r := newStrictResolver( 35 | cfg, 36 | []Resolver{bootstrap}, // if init strategy is fast, use bootstrap until init finishes 37 | ) 38 | 39 | return initGroupResolvers(ctx, r, cfg, bootstrap) 40 | } 41 | 42 | func newStrictResolver( 43 | cfg config.UpstreamGroup, resolvers []Resolver, 44 | ) *StrictResolver { 45 | r := StrictResolver{ 46 | configurable: withConfig(&cfg), 47 | typed: withType(strictResolverType), 48 | } 49 | 50 | r.setResolvers(newUpstreamResolverStatuses(resolvers)) 51 | 52 | return &r 53 | } 54 | 55 | func (r *StrictResolver) setResolvers(resolvers []*upstreamResolverStatus) { 56 | r.resolvers.Store(&resolvers) 57 | } 58 | 59 | func (r *StrictResolver) Name() string { 60 | return r.String() 61 | } 62 | 63 | func (r *StrictResolver) String() string { 64 | resolvers := *r.resolvers.Load() 65 | 66 | upstreams := make([]string, len(resolvers)) 67 | for i, s := range resolvers { 68 | upstreams[i] = s.resolver.String() 69 | } 70 | 71 | return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(upstreams, ",")) 72 | } 73 | 74 | // Resolve sends the query request in a strict order to the upstream resolvers 75 | func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) { 76 | ctx, logger := r.log(ctx) 77 | 78 | // start with first resolver 79 | for _, resolver := range *r.resolvers.Load() { 80 | logger.Debugf("using %s as resolver", resolver.resolver) 81 | 82 | resp, err := resolver.resolve(ctx, request) 83 | if err != nil { 84 | // log error and try next upstream 85 | logger.WithField("resolver", resolver.resolver).Debug("resolution failed from resolver, cause: ", err) 86 | 87 | continue 88 | } 89 | 90 | logger.WithFields(logrus.Fields{ 91 | "resolver": *resolver, 92 | "answer": util.AnswerToString(resp.Res.Answer), 93 | }).Debug("using response from resolver") 94 | 95 | return resp, nil 96 | } 97 | 98 | return nil, errors.New("resolution was not successful, no resolver returned an answer in time") 99 | } 100 | -------------------------------------------------------------------------------- /server/http.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/0xERR0R/blocky/config" 10 | "github.com/go-chi/chi/v5" 11 | "github.com/go-chi/cors" 12 | ) 13 | 14 | type httpServer struct { 15 | inner http.Server 16 | 17 | name string 18 | } 19 | 20 | func newHTTPServer(name string, handler http.Handler, cfg *config.Config) *httpServer { 21 | var ( 22 | writeTimeout = cfg.Blocking.Loading.Downloads.WriteTimeout 23 | readTimeout = cfg.Blocking.Loading.Downloads.ReadTimeout 24 | readHeaderTimeout = cfg.Blocking.Loading.Downloads.ReadHeaderTimeout 25 | ) 26 | 27 | return &httpServer{ 28 | inner: http.Server{ 29 | ReadTimeout: time.Duration(readTimeout), 30 | ReadHeaderTimeout: time.Duration(readHeaderTimeout), 31 | WriteTimeout: time.Duration(writeTimeout), 32 | Handler: withCommonMiddleware(handler), 33 | }, 34 | 35 | name: name, 36 | } 37 | } 38 | 39 | func (s *httpServer) String() string { 40 | return s.name 41 | } 42 | 43 | func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { 44 | go func() { 45 | <-ctx.Done() 46 | 47 | s.inner.Close() 48 | }() 49 | 50 | return s.inner.Serve(l) 51 | } 52 | 53 | func withCommonMiddleware(inner http.Handler) *chi.Mux { 54 | // Middleware must be defined before routes, so 55 | // create a new router and mount the inner handler 56 | mux := chi.NewMux() 57 | 58 | mux.Use( 59 | secureHeadersMiddleware, 60 | newCORSMiddleware(), 61 | ) 62 | 63 | mux.Mount("/", inner) 64 | 65 | return mux 66 | } 67 | 68 | type httpMiddleware = func(http.Handler) http.Handler 69 | 70 | func secureHeadersMiddleware(next http.Handler) http.Handler { 71 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 72 | if r.TLS != nil { 73 | w.Header().Set("strict-transport-security", "max-age=63072000") 74 | w.Header().Set("x-frame-options", "DENY") 75 | w.Header().Set("x-content-type-options", "nosniff") 76 | w.Header().Set("x-xss-protection", "1; mode=block") 77 | } 78 | 79 | next.ServeHTTP(w, r) 80 | }) 81 | } 82 | 83 | func newCORSMiddleware() httpMiddleware { 84 | const corsMaxAge = 5 * time.Minute 85 | 86 | options := cors.Options{ 87 | AllowCredentials: true, 88 | AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, 89 | AllowedMethods: []string{"GET", "POST"}, 90 | AllowedOrigins: []string{"*"}, 91 | ExposedHeaders: []string{"Link"}, 92 | MaxAge: int(corsMaxAge.Seconds()), 93 | } 94 | 95 | return cors.New(options).Handler 96 | } 97 | -------------------------------------------------------------------------------- /server/server_config_trigger.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package server 5 | 6 | import ( 7 | "context" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | ) 12 | 13 | func registerPrintConfigurationTrigger(ctx context.Context, s *Server) { 14 | signals := make(chan os.Signal, 1) 15 | signal.Notify(signals, syscall.SIGUSR1) 16 | 17 | go func() { 18 | for { 19 | select { 20 | case <-signals: 21 | s.printConfiguration() 22 | 23 | case <-ctx.Done(): 24 | return 25 | } 26 | } 27 | }() 28 | } 29 | -------------------------------------------------------------------------------- /server/server_config_trigger_windows.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import "context" 4 | 5 | func registerPrintConfigurationTrigger(ctx context.Context, s *Server) { 6 | } 7 | -------------------------------------------------------------------------------- /server/server_suite_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | func init() { 12 | log.Silence() 13 | } 14 | 15 | func TestDNSServer(t *testing.T) { 16 | RegisterFailHandler(Fail) 17 | RunSpecs(t, "Server Suite") 18 | } 19 | -------------------------------------------------------------------------------- /trie/split.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import "strings" 4 | 5 | type SplitFunc func(string) (label, rest string) 6 | 7 | // www.example.com -> ("com", "www.example") 8 | func SplitTLD(domain string) (label, rest string) { 9 | domain = strings.TrimRight(domain, ".") 10 | 11 | idx := strings.LastIndexByte(domain, '.') 12 | if idx == -1 { 13 | return domain, "" 14 | } 15 | 16 | label = domain[idx+1:] 17 | rest = domain[:idx] 18 | 19 | return label, rest 20 | } 21 | -------------------------------------------------------------------------------- /trie/split_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo/v2" 5 | . "github.com/onsi/gomega" 6 | ) 7 | 8 | var _ = Describe("SpltTLD", func() { 9 | It("should split a tld", func() { 10 | key, rest := SplitTLD("www.example.com") 11 | Expect(key).Should(Equal("com")) 12 | Expect(rest).Should(Equal("www.example")) 13 | }) 14 | 15 | It("should not split a plain string", func() { 16 | key, rest := SplitTLD("example") 17 | Expect(key).Should(Equal("example")) 18 | Expect(rest).Should(Equal("")) 19 | }) 20 | 21 | It("should not crash with an empty string", func() { 22 | key, rest := SplitTLD("") 23 | Expect(key).Should(Equal("")) 24 | Expect(rest).Should(Equal("")) 25 | }) 26 | 27 | It("should ignore trailing dots", func() { 28 | key, rest := SplitTLD("www.example.com.") 29 | Expect(key).Should(Equal("com")) 30 | Expect(rest).Should(Equal("www.example")) 31 | 32 | key, rest = SplitTLD(rest) 33 | Expect(key).Should(Equal("example")) 34 | Expect(rest).Should(Equal("www")) 35 | }) 36 | 37 | It("should skip empty parts", func() { 38 | key, rest := SplitTLD("www.example..com") 39 | Expect(key).Should(Equal("com")) 40 | Expect(rest).Should(Equal("www.example.")) 41 | 42 | key, rest = SplitTLD(rest) 43 | Expect(key).Should(Equal("example")) 44 | Expect(rest).Should(Equal("www")) 45 | }) 46 | }) 47 | -------------------------------------------------------------------------------- /trie/trie_suite_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | 8 | . "github.com/onsi/ginkgo/v2" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func init() { 13 | log.Silence() 14 | } 15 | 16 | func TestTrie(t *testing.T) { 17 | RegisterFailHandler(Fail) 18 | RunSpecs(t, "Trie Suite") 19 | } 20 | -------------------------------------------------------------------------------- /trie/trie_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo/v2" 5 | . "github.com/onsi/gomega" 6 | ) 7 | 8 | var _ = Describe("Trie", func() { 9 | var sut *Trie 10 | 11 | BeforeEach(func() { 12 | sut = NewTrie(SplitTLD) 13 | }) 14 | 15 | Describe("Basic operations", func() { 16 | When("Trie is created", func() { 17 | It("should be empty", func() { 18 | Expect(sut.IsEmpty()).Should(BeTrue()) 19 | }) 20 | 21 | It("should not find domains", func() { 22 | Expect(sut.HasParentOf("example.com")).Should(BeFalse()) 23 | }) 24 | 25 | It("should not insert the empty string", func() { 26 | sut.Insert("") 27 | Expect(sut.HasParentOf("")).Should(BeFalse()) 28 | }) 29 | }) 30 | 31 | When("Adding a domain", func() { 32 | var ( 33 | domainOkTLD = "com" 34 | domainOk = "example." + domainOkTLD 35 | 36 | domainKo = "example.org" 37 | ) 38 | 39 | BeforeEach(func() { 40 | Expect(sut.HasParentOf(domainOk)).Should(BeFalse()) 41 | sut.Insert(domainOk) 42 | Expect(sut.HasParentOf(domainOk)).Should(BeTrue()) 43 | }) 44 | 45 | AfterEach(func() { 46 | Expect(sut.HasParentOf(domainOk)).Should(BeTrue()) 47 | }) 48 | 49 | It("should be found", func() {}) 50 | 51 | It("should contain subdomains", func() { 52 | subdomain := "www." + domainOk 53 | 54 | Expect(sut.HasParentOf(subdomain)).Should(BeTrue()) 55 | }) 56 | 57 | It("should support inserting subdomains", func() { 58 | subdomain := "www." + domainOk 59 | 60 | Expect(sut.HasParentOf(subdomain)).Should(BeTrue()) 61 | sut.Insert(subdomain) 62 | Expect(sut.HasParentOf(subdomain)).Should(BeTrue()) 63 | }) 64 | 65 | It("should not find unrelated", func() { 66 | Expect(sut.HasParentOf(domainKo)).Should(BeFalse()) 67 | }) 68 | 69 | It("should not find uninserted parent", func() { 70 | Expect(sut.HasParentOf(domainOkTLD)).Should(BeFalse()) 71 | }) 72 | 73 | It("should not find deep uninserted parent", func() { 74 | sut.Insert("sub.sub.sub.test") 75 | 76 | Expect(sut.HasParentOf("sub.sub.test")).Should(BeFalse()) 77 | }) 78 | 79 | It("should find inserted parent", func() { 80 | sut.Insert(domainOkTLD) 81 | Expect(sut.HasParentOf(domainOkTLD)).Should(BeTrue()) 82 | }) 83 | 84 | It("should insert sibling", func() { 85 | sibling := "other." + domainOkTLD 86 | 87 | sut.Insert(sibling) 88 | Expect(sut.HasParentOf(sibling)).Should(BeTrue()) 89 | }) 90 | 91 | It("should insert grand-children siblings", func() { 92 | base := "other.com" 93 | abcSub := "abc." + base 94 | xyzSub := "xyz." + base 95 | 96 | sut.Insert(abcSub) 97 | Expect(sut.HasParentOf(abcSub)).Should(BeTrue()) 98 | Expect(sut.HasParentOf(xyzSub)).Should(BeFalse()) 99 | Expect(sut.HasParentOf(base)).Should(BeFalse()) 100 | 101 | sut.Insert(xyzSub) 102 | Expect(sut.HasParentOf(xyzSub)).Should(BeTrue()) 103 | Expect(sut.HasParentOf(abcSub)).Should(BeTrue()) 104 | Expect(sut.HasParentOf(base)).Should(BeFalse()) 105 | }) 106 | }) 107 | }) 108 | }) 109 | -------------------------------------------------------------------------------- /util/arpa.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | const ( 12 | IPv4PtrSuffix = ".in-addr.arpa." 13 | IPv6PtrSuffix = ".ip6.arpa." 14 | 15 | byteBits = 8 16 | ) 17 | 18 | var ErrInvalidArpaAddrLen = errors.New("arpa hostname is not of expected length") 19 | 20 | func ParseIPFromArpaAddr(arpa string) (net.IP, error) { 21 | if strings.HasSuffix(arpa, IPv4PtrSuffix) { 22 | return parseIPv4FromArpaAddr(arpa) 23 | } 24 | 25 | if strings.HasSuffix(arpa, IPv6PtrSuffix) { 26 | return parseIPv6FromArpaAddr(arpa) 27 | } 28 | 29 | return nil, fmt.Errorf("invalid arpa hostname: %s", arpa) 30 | } 31 | 32 | func parseIPv4FromArpaAddr(arpa string) (net.IP, error) { 33 | const base10 = 10 34 | 35 | revAddr := strings.TrimSuffix(arpa, IPv4PtrSuffix) 36 | 37 | parts := strings.Split(revAddr, ".") 38 | if len(parts) != net.IPv4len { 39 | return nil, ErrInvalidArpaAddrLen 40 | } 41 | 42 | buf := make([]byte, 0, net.IPv4len) 43 | 44 | // Parse and add each byte, in reverse, to the buffer 45 | for i := len(parts) - 1; i >= 0; i-- { 46 | part, err := strconv.ParseUint(parts[i], base10, byteBits) 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | buf = append(buf, byte(part)) 52 | } 53 | 54 | return net.IPv4(buf[0], buf[1], buf[2], buf[3]), nil 55 | } 56 | 57 | func parseIPv6FromArpaAddr(arpa string) (net.IP, error) { 58 | const ( 59 | base16 = 16 60 | ipv6Bytes = 2 * net.IPv6len 61 | nibbleBits = byteBits / 2 62 | ) 63 | 64 | revAddr := strings.TrimSuffix(arpa, IPv6PtrSuffix) 65 | 66 | parts := strings.Split(revAddr, ".") 67 | if len(parts) != ipv6Bytes { 68 | return nil, ErrInvalidArpaAddrLen 69 | } 70 | 71 | buf := make([]byte, 0, net.IPv6len) 72 | 73 | // Parse and add each byte, in reverse, to the buffer 74 | for i := len(parts) - 1; i >= 0; i -= 2 { 75 | msNibble, err := strconv.ParseUint(parts[i], base16, byteBits) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | lsNibble, err := strconv.ParseUint(parts[i-1], base16, byteBits) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | part := msNibble<<nibbleBits | lsNibble 86 | 87 | buf = append(buf, byte(part)) 88 | } 89 | 90 | return net.IP(buf), nil 91 | } 92 | -------------------------------------------------------------------------------- /util/buildinfo.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | //nolint:gochecknoglobals 4 | var ( 5 | // Version current version number 6 | Version = "undefined" 7 | // BuildTime build time of the binary 8 | BuildTime = "undefined" 9 | // Architecture current CPU architecture 10 | Architecture = "undefined" 11 | ) 12 | -------------------------------------------------------------------------------- /util/context.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "context" 4 | 5 | // CtxSend sends a value to a channel while the context isn't done. 6 | // If the message is sent, it returns true. 7 | // If the context is done or the channel is closed, it returns false. 8 | func CtxSend[T any](ctx context.Context, ch chan T, val T) (ok bool) { 9 | if ctx == nil || ch == nil || ctx.Err() != nil { 10 | ok = false 11 | 12 | return 13 | } 14 | 15 | defer func() { 16 | if err := recover(); err != nil { 17 | ok = false 18 | } 19 | }() 20 | 21 | select { 22 | case <-ctx.Done(): 23 | ok = false 24 | case ch <- val: 25 | ok = true 26 | } 27 | 28 | return 29 | } 30 | -------------------------------------------------------------------------------- /util/edns0.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // EDNS0Option is an interface for all EDNS0 options as type constraint for generics. 11 | type EDNS0Option interface { 12 | *dns.EDNS0_SUBNET | *dns.EDNS0_EDE | *dns.EDNS0_LOCAL | *dns.EDNS0_NSID | *dns.EDNS0_COOKIE | *dns.EDNS0_UL 13 | Option() uint16 14 | } 15 | 16 | // RemoveEdns0Record removes the OPT record from the Extra section of the given message. 17 | // If the OPT record is removed, true will be returned. 18 | func RemoveEdns0Record(msg *dns.Msg) bool { 19 | if msg == nil || msg.IsEdns0() == nil { 20 | return false 21 | } 22 | 23 | for i, rr := range msg.Extra { 24 | if rr.Header().Rrtype == dns.TypeOPT { 25 | msg.Extra = slices.Delete(msg.Extra, i, i+1) 26 | 27 | return true 28 | } 29 | } 30 | 31 | return false 32 | } 33 | 34 | // GetEdns0Option returns the option with the given code from the OPT record in the 35 | // Extra section of the given message. 36 | // If the option is not found, nil will be returned. 37 | func GetEdns0Option[T EDNS0Option](msg *dns.Msg) T { 38 | if msg == nil { 39 | return nil 40 | } 41 | 42 | opt := msg.IsEdns0() 43 | if opt == nil { 44 | return nil 45 | } 46 | 47 | var t T 48 | 49 | for _, o := range opt.Option { 50 | if o.Option() == t.Option() { 51 | t, ok := o.(T) 52 | if !ok { 53 | panic(fmt.Errorf("dns option with code %d is not of type %T", t.Option(), t)) 54 | } 55 | 56 | return t 57 | } 58 | } 59 | 60 | return nil 61 | } 62 | 63 | // RemoveEdns0Option removes the option according to the given type from the OPT record 64 | // in the Extra section of the given message. 65 | // If there are no more options in the OPT record, the OPT record will be removed. 66 | // If the option is successfully removed, true will be returned. 67 | func RemoveEdns0Option[T EDNS0Option](msg *dns.Msg) bool { 68 | if msg == nil { 69 | return false 70 | } 71 | 72 | opt := msg.IsEdns0() 73 | if opt == nil { 74 | return false 75 | } 76 | 77 | res := false 78 | 79 | var t T 80 | 81 | for i, o := range opt.Option { 82 | if o.Option() == t.Option() { 83 | opt.Option = slices.Delete(opt.Option, i, i+1) 84 | 85 | res = true 86 | 87 | break 88 | } 89 | } 90 | 91 | if len(opt.Option) == 0 { 92 | RemoveEdns0Record(msg) 93 | } 94 | 95 | return res 96 | } 97 | 98 | // SetEdns0Option adds the given option to the OPT record in the Extra section of the 99 | // given message. 100 | // If the option already exists, it will be replaced. 101 | // If the option is successfully set, true will be returned. 102 | func SetEdns0Option(msg *dns.Msg, opt dns.EDNS0) bool { 103 | if msg == nil || opt == nil { 104 | return false 105 | } 106 | 107 | optRecord := msg.IsEdns0() 108 | 109 | if optRecord == nil { 110 | optRecord = new(dns.OPT) 111 | optRecord.Hdr.Name = "." 112 | optRecord.Hdr.Rrtype = dns.TypeOPT 113 | msg.Extra = append(msg.Extra, optRecord) 114 | } 115 | 116 | newOpts := make([]dns.EDNS0, 0, len(optRecord.Option)+1) 117 | 118 | for _, o := range optRecord.Option { 119 | if o.Option() != opt.Option() { 120 | newOpts = append(newOpts, o) 121 | } 122 | } 123 | 124 | newOpts = append(newOpts, opt) 125 | optRecord.Option = newOpts 126 | 127 | return true 128 | } 129 | -------------------------------------------------------------------------------- /util/http.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | ) 8 | 9 | //nolint:gochecknoglobals 10 | var baseTransport *http.Transport 11 | 12 | //nolint:gochecknoinits 13 | func init() { 14 | base, ok := http.DefaultTransport.(*http.Transport) 15 | if !ok { 16 | panic(fmt.Errorf( 17 | "unsupported Go version: http.DefaultTransport is not of type *http.Transport: it is a %T", 18 | http.DefaultTransport, 19 | )) 20 | } 21 | 22 | baseTransport = base 23 | } 24 | 25 | // DefaultHTTPTransport returns a new Transport with the same defaults as net/http. 26 | func DefaultHTTPTransport() *http.Transport { 27 | return &http.Transport{ 28 | DialContext: baseTransport.DialContext, 29 | DialTLSContext: baseTransport.DialTLSContext, 30 | DisableCompression: baseTransport.DisableCompression, 31 | DisableKeepAlives: baseTransport.DisableKeepAlives, 32 | ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, 33 | ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, 34 | GetProxyConnectHeader: baseTransport.GetProxyConnectHeader, 35 | IdleConnTimeout: baseTransport.IdleConnTimeout, 36 | MaxConnsPerHost: baseTransport.MaxConnsPerHost, 37 | MaxIdleConns: baseTransport.MaxIdleConns, 38 | MaxIdleConnsPerHost: baseTransport.MaxConnsPerHost, 39 | MaxResponseHeaderBytes: baseTransport.MaxResponseHeaderBytes, 40 | OnProxyConnectResponse: baseTransport.OnProxyConnectResponse, 41 | Proxy: baseTransport.Proxy, 42 | ProxyConnectHeader: baseTransport.ProxyConnectHeader, 43 | ReadBufferSize: baseTransport.ReadBufferSize, 44 | ResponseHeaderTimeout: baseTransport.ResponseHeaderTimeout, 45 | TLSClientConfig: baseTransport.TLSClientConfig, 46 | TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, 47 | TLSNextProto: baseTransport.TLSNextProto, 48 | WriteBufferSize: baseTransport.WriteBufferSize, 49 | } 50 | } 51 | 52 | func HTTPClientIP(r *http.Request) net.IP { 53 | addr := r.Header.Get("X-FORWARDED-FOR") 54 | if addr == "" { 55 | addr = r.RemoteAddr 56 | } 57 | 58 | ip, _, err := net.SplitHostPort(addr) 59 | if err != nil { 60 | return net.ParseIP(addr) 61 | } 62 | 63 | return net.ParseIP(ip) 64 | } 65 | -------------------------------------------------------------------------------- /util/http_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http" 7 | "net/url" 8 | "reflect" 9 | 10 | "github.com/google/go-cmp/cmp" 11 | "github.com/google/go-cmp/cmp/cmpopts" 12 | . "github.com/onsi/ginkgo/v2" 13 | . "github.com/onsi/gomega" 14 | ) 15 | 16 | var _ = Describe("HTTP Util", func() { 17 | Describe("DefaultHTTPTransport", func() { 18 | It("returns a new transport", func() { 19 | a := DefaultHTTPTransport() 20 | Expect(a).Should(BeIdenticalTo(a)) 21 | 22 | b := DefaultHTTPTransport() 23 | Expect(a).ShouldNot(BeIdenticalTo(b)) 24 | }) 25 | 26 | It("returns a copy of http.DefaultTransport", func() { 27 | Expect(cmp.Diff( 28 | DefaultHTTPTransport(), http.DefaultTransport, 29 | cmpopts.IgnoreUnexported(http.Transport{}), 30 | // Non nil func field comparers 31 | cmp.Comparer(cmpAsPtrs[func(context.Context, string, string) (net.Conn, error)]), 32 | cmp.Comparer(cmpAsPtrs[func(*http.Request) (*url.URL, error)]), 33 | )).Should(BeEmpty()) 34 | }) 35 | }) 36 | 37 | Describe("HTTPClientIP", func() { 38 | It("extracts the IP from RemoteAddr", func() { 39 | r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) 40 | Expect(err).Should(Succeed()) 41 | 42 | ip := net.IPv4allrouter 43 | r.RemoteAddr = net.JoinHostPort(ip.String(), "78954") 44 | 45 | Expect(HTTPClientIP(r)).Should(Equal(ip)) 46 | }) 47 | 48 | It("extracts the IP from RemoteAddr without a port", func() { 49 | r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) 50 | Expect(err).Should(Succeed()) 51 | 52 | ip := net.IPv4allrouter 53 | r.RemoteAddr = ip.String() 54 | 55 | Expect(HTTPClientIP(r)).Should(Equal(ip)) 56 | }) 57 | 58 | It("extracts the IP from the X-Forwarded-For header", func() { 59 | r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) 60 | Expect(err).Should(Succeed()) 61 | 62 | ip := net.IPv4bcast 63 | r.RemoteAddr = ip.String() 64 | 65 | r.Header.Set("X-Forwarded-For", ip.String()) 66 | 67 | Expect(HTTPClientIP(r)).Should(Equal(ip)) 68 | }) 69 | }) 70 | }) 71 | 72 | // Go and cmp don't define func comparisons, besides with nil. 73 | // In practice we can just compare them as pointers. 74 | // See https://github.com/google/go-cmp/issues/162 75 | func cmpAsPtrs[T any](x, y T) bool { 76 | return reflect.ValueOf(x).Pointer() == reflect.ValueOf(y).Pointer() 77 | } 78 | -------------------------------------------------------------------------------- /util/tls.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "crypto/rand" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "fmt" 11 | "math/big" 12 | "time" 13 | ) 14 | 15 | const ( 16 | certSerialMaxBits = 128 17 | certExpiryYears = 5 18 | ) 19 | 20 | // TLSGenerateSelfSignedCert returns a new self-signed cert for the given domains. 21 | // 22 | // Being self-signed, no client will trust this certificate. 23 | func TLSGenerateSelfSignedCert(domains []string) (tls.Certificate, error) { 24 | serialMax := new(big.Int).Lsh(big.NewInt(1), certSerialMaxBits) 25 | serial, err := rand.Int(rand.Reader, serialMax) 26 | if err != nil { 27 | return tls.Certificate{}, err 28 | } 29 | 30 | template := &x509.Certificate{ 31 | SerialNumber: serial, 32 | 33 | Subject: pkix.Name{Organization: []string{"Blocky"}}, 34 | DNSNames: domains, 35 | 36 | NotBefore: time.Now(), 37 | NotAfter: time.Now().AddDate(certExpiryYears, 0, 0), 38 | 39 | KeyUsage: x509.KeyUsageDigitalSignature, 40 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 41 | } 42 | 43 | privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 44 | if err != nil { 45 | return tls.Certificate{}, fmt.Errorf("unable to generate private key: %w", err) 46 | } 47 | 48 | der, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey) 49 | if err != nil { 50 | return tls.Certificate{}, fmt.Errorf("cert creation from template failed: %w", err) 51 | } 52 | 53 | // Parse the generated DER back into a useable cert 54 | // This avoids needing to do it for each TLS handshake (see tls.Certificate.Leaf comment) 55 | cert, err := x509.ParseCertificate(der) 56 | if err != nil { 57 | return tls.Certificate{}, fmt.Errorf("generated cert DER could not be parsed: %w", err) 58 | } 59 | 60 | tlsCert := tls.Certificate{ 61 | Certificate: [][]byte{der}, 62 | PrivateKey: privKey, 63 | Leaf: cert, 64 | } 65 | 66 | return tlsCert, nil 67 | } 68 | -------------------------------------------------------------------------------- /util/tls_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "crypto/x509" 5 | 6 | . "github.com/onsi/ginkgo/v2" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | var _ = Describe("TLS Util", func() { 11 | Describe("TLSGenerateSelfSignedCert", func() { 12 | It("returns a good value", func() { 13 | const domain = "whatever.test.blocky.invalid" 14 | 15 | cert, err := TLSGenerateSelfSignedCert([]string{domain}) 16 | Expect(err).Should(Succeed()) 17 | 18 | Expect(cert.Certificate).ShouldNot(BeEmpty()) 19 | 20 | By("having the right Leaf", func() { 21 | fromDER, err := x509.ParseCertificate(cert.Certificate[0]) 22 | Expect(err).Should(Succeed()) 23 | 24 | Expect(cert.Leaf).Should(Equal(fromDER)) 25 | }) 26 | 27 | By("being valid as self-signed for server TLS on the given domain", func() { 28 | pool := x509.NewCertPool() 29 | pool.AddCert(cert.Leaf) 30 | 31 | chain, err := cert.Leaf.Verify(x509.VerifyOptions{ 32 | DNSName: domain, 33 | Roots: pool, 34 | KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 35 | }) 36 | Expect(err).Should(Succeed()) 37 | Expect(chain).Should(Equal([][]*x509.Certificate{{cert.Leaf}})) 38 | }) 39 | 40 | By("mentioning Blocky", func() { 41 | Expect(cert.Leaf.Subject.Organization).Should(Equal([]string{"Blocky"})) 42 | }) 43 | }) 44 | }) 45 | }) 46 | -------------------------------------------------------------------------------- /util/util_suite_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/0xERR0R/blocky/log" 7 | . "github.com/onsi/ginkgo/v2" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | func init() { 12 | log.Silence() 13 | } 14 | 15 | func TestLists(t *testing.T) { 16 | RegisterFailHandler(Fail) 17 | RunSpecs(t, "Util Suite") 18 | } 19 | -------------------------------------------------------------------------------- /web/index.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "embed" 5 | "io/fs" 6 | ) 7 | 8 | // IndexTmpl html template for the start page 9 | // 10 | //go:embed index.html 11 | var IndexTmpl string 12 | 13 | //go:embed all:static 14 | var static embed.FS 15 | 16 | func Assets() (fs.FS, error) { 17 | return fs.Sub(static, "static") 18 | } 19 | -------------------------------------------------------------------------------- /web/index.html: -------------------------------------------------------------------------------- 1 | <!DOCTYPE html> 2 | <html> 3 | <head> 4 | <title>blocky</title> 5 | </head> 6 | <body> 7 | <h1>blocky</h1> 8 | <ul> 9 | {{range .Links}} 10 | <li><a href="{{.URL}}">{{.Title}}</a></li> 11 | {{end}} 12 | </ul> 13 | 14 | <p><span class="small">Version {{.Version}} Build time {{.BuildTime}}</span></p> 15 | </body> 16 | </html> -------------------------------------------------------------------------------- /web/static/rapidoc.html: -------------------------------------------------------------------------------- 1 | <!doctype html> 2 | <html> 3 | <head> 4 | <meta charset="utf-8"> 5 | <script type="module" src="/static/rapidoc-min.js"></script> 6 | </head> 7 | <body> 8 | <rapi-doc 9 | spec-url="/docs/openapi.yaml" 10 | theme = "light" 11 | allow-authentication = "false" 12 | show-header = "false" 13 | bg-color = "#fdf8ed" 14 | nav-bg-color = "#3f4d67" 15 | nav-text-color = "#a9b7d0" 16 | nav-hover-bg-color = "#333f54" 17 | nav-hover-text-color = "#fff" 18 | nav-accent-color = "#f87070" 19 | primary-color = "#5c7096" 20 | > </rapi-doc> 21 | </body> 22 | </html> --------------------------------------------------------------------------------