├── .codecov.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── config.yml ├── dependabot.yml └── workflows │ ├── codecov.yml │ ├── codeql-action.yml │ ├── docker-hub.yml │ ├── docker-publish.yml │ ├── go.yml │ ├── golangci-lint.yml │ └── goreleaser.yml ├── .gitignore ├── .goreleaser.yml ├── CNAME ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── _config.yml ├── api ├── README.md ├── api.go ├── api_test.go ├── context.go ├── group.go ├── router.go └── tree.go ├── authcache ├── authserver.go ├── authserver_test.go ├── ns_cache.go └── ns_cache_test.go ├── cache ├── cache.go ├── cache_test.go ├── hash.go ├── hash_test.go ├── shard.go └── shard_test.go ├── config ├── config.go └── config_test.go ├── contrib └── linux │ ├── adduser.sh │ ├── sdns.conf │ └── sdns.service ├── dnsutil ├── dnsutil.go ├── dnsutil_test.go ├── ttl.go └── ttl_test.go ├── doc.go ├── docker-compose.yml ├── gen.go ├── go.mod ├── go.sum ├── logo.png ├── middleware ├── accesslist │ ├── accesslist.go │ └── accesslist_test.go ├── accesslog │ ├── accesslog.go │ └── accesslog_test.go ├── as112 │ ├── as112.go │ └── as112_test.go ├── blocklist │ ├── blocklist.go │ ├── blocklist_test.go │ ├── updater.go │ └── updater_test.go ├── cache │ ├── cache.go │ ├── cache_test.go │ └── item.go ├── chain.go ├── chain_test.go ├── chaos │ ├── chaos.go │ └── chaos_test.go ├── edns │ ├── edns.go │ └── edns_test.go ├── failover │ ├── failover.go │ └── failover_test.go ├── forwarder │ ├── forwarder.go │ └── forwarder_test.go ├── hostsfile │ ├── hostsfile.go │ └── hostsfile_test.go ├── loop │ ├── loop.go │ └── loop_test.go ├── metrics │ ├── metrics.go │ └── metrics_test.go ├── middleware.go ├── middleware_test.go ├── ratelimit │ ├── ratelimit.go │ └── ratelimit_test.go ├── recovery │ ├── recovery.go │ └── recovery_test.go ├── resolver │ ├── auto_trust_anchor.go │ ├── auto_trust_anchor_test.go │ ├── client.go │ ├── client_test.go │ ├── handler.go │ ├── handler_test.go │ ├── nsec3.go │ ├── nsec3_test.go │ ├── resolver.go │ ├── resolver_test.go │ ├── singleinflight.go │ ├── utils.go │ └── utils_test.go └── response_writer.go ├── mock ├── writer.go └── writer_test.go ├── response ├── typify.go └── typify_test.go ├── sdns.go ├── server ├── doh │ ├── doh.go │ ├── doh_test.go │ ├── msg.go │ ├── msg_test.go │ ├── qtype.go │ └── qtype_test.go ├── doq │ ├── doq.go │ ├── doq_test.go │ └── response_writer.go ├── server.go └── server_test.go ├── snap └── snapcraft.yaml ├── waitgroup ├── waitgroup.go └── waitgroup_test.go └── zregister.go /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 40% 6 | threshold: null 7 | patch: false 8 | changes: false 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "Bug Report" 2 | description: Create a new issue for a bug. 3 | title: "[BUG] - " 4 | labels: [ 5 | "bug" 6 | ] 7 | assignees: 8 | - semihalev 9 | body: 10 | - type: markdown 11 | attributes: 12 | value: | 13 | Thank you for dedicating time to report this issue. We appreciate your effort. Kindly complete the form below. 14 | - type: textarea 15 | id: description 16 | attributes: 17 | label: "Description" 18 | description: Please enter an explicit description of your issue 19 | placeholder: Short and explicit description of your incident... 20 | validations: 21 | required: true 22 | - type: textarea 23 | id: reprod 24 | attributes: 25 | label: "Reproduction steps" 26 | description: Please enter an explicit description of your issue 27 | render: bash 28 | validations: 29 | required: true 30 | - type: textarea 31 | id: sdns-version 32 | attributes: 33 | label: sdns version 34 | description: "`sdns --version` show the version information" 35 | render: bash 36 | validations: 37 | required: true 38 | - type: textarea 39 | id: logs 40 | attributes: 41 | label: "Logs" 42 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 43 | render: bash 44 | validations: 45 | required: false 46 | - type: dropdown 47 | id: os 48 | attributes: 49 | label: "OS" 50 | description: What is the impacted environment ? 51 | multiple: true 52 | options: 53 | - Windows 54 | - Linux 55 | - MacOS 56 | - FreeBSD 57 | - Other 58 | - type: textarea 59 | id: ctx 60 | attributes: 61 | label: Additional context 62 | description: Anything else you would like to add 63 | validations: 64 | required: false 65 | validations: 66 | required: false 67 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | contact_links: 2 | - name: Feature Request 3 | url: https://github.com/semihalev/sdns/discussions/new?category=ideas 4 | about: Discuss a new feture for SDNS 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | time: "15:30" 8 | labels: 9 | - "dependencies" 10 | - package-ecosystem: "github-actions" 11 | directory: "/" 12 | schedule: 13 | interval: "daily" 14 | time: "15:30" 15 | labels: 16 | - "dependencies" 17 | -------------------------------------------------------------------------------- /.github/workflows/codecov.yml: -------------------------------------------------------------------------------- 1 | name: Codecov 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | run: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@master 13 | 14 | - name: Set up Go 1.x 15 | uses: actions/setup-go@v5 16 | with: 17 | go-version: ^1.22 18 | 19 | - name: Generate coverage report 20 | run: make test 21 | 22 | - name: Upload coverage to Codecov 23 | uses: codecov/codecov-action@v5.3.1 24 | with: 25 | fail_ci_if_error: true 26 | token: ${{ secrets.CODECOV_TOKEN }} 27 | file: ./coverage.out 28 | -------------------------------------------------------------------------------- /.github/workflows/codeql-action.yml: -------------------------------------------------------------------------------- 1 | name: "Code Scanning - Action" 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: '0 0 * * 0' 8 | 9 | jobs: 10 | CodeQL-Build: 11 | # CodeQL runs on ubuntu-latest, windows-latest, and macos-latest 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | with: 18 | # Must fetch at least the immediate parents so that if this is 19 | # a pull request then we can checkout the head of the pull request. 20 | # Only include this option if you are running this workflow on pull requests. 21 | fetch-depth: 2 22 | 23 | # If this run was triggered by a pull request event then checkout 24 | # the head of the pull request instead of the merge commit. 25 | # Only include this step if you are running this workflow on pull requests. 26 | - run: git checkout HEAD^2 27 | if: ${{ github.event_name == 'pull_request' }} 28 | 29 | # Initializes the CodeQL tools for scanning. 30 | - name: Initialize CodeQL 31 | uses: github/codeql-action/init@v3 32 | # Override language selection by uncommenting this and choosing your languages 33 | # with: 34 | # languages: go, javascript, csharp, python, cpp, java 35 | 36 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 37 | # If this step fails, then you should remove it and run the build manually (see below). 38 | - name: Autobuild 39 | uses: github/codeql-action/autobuild@v3 40 | 41 | # ℹ️ Command-line programs to run using the OS shell. 42 | # 📚 https://git.io/JvXDl 43 | 44 | # ✏️ If the Autobuild fails above, remove it and uncomment the following 45 | # three lines and modify them (or add more) to build your code if your 46 | # project uses a compiled language 47 | 48 | #- run: | 49 | # make bootstrap 50 | # make release 51 | 52 | - name: Perform CodeQL Analysis 53 | uses: github/codeql-action/analyze@v3 54 | -------------------------------------------------------------------------------- /.github/workflows/docker-hub.yml: -------------------------------------------------------------------------------- 1 | name: "Docker Hub" 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - "master" 8 | tags: 9 | - v* 10 | jobs: 11 | docker: 12 | runs-on: ubuntu-latest 13 | if: github.event_name == 'push' 14 | 15 | steps: 16 | - 17 | name: Checkout 18 | uses: actions/checkout@v4 19 | - 20 | name: Docker meta 21 | id: meta 22 | uses: docker/metadata-action@v5 23 | with: 24 | images: | 25 | c1982/sdns 26 | tags: | 27 | type=semver,pattern={{version}} 28 | type=raw,value=latest,enable={{is_default_branch}} 29 | - 30 | name: Set up QEMU 31 | uses: docker/setup-qemu-action@v3 32 | - 33 | name: Set up Docker Buildx 34 | uses: docker/setup-buildx-action@v3 35 | - 36 | name: Login to Docker Hub 37 | uses: docker/login-action@v3 38 | with: 39 | username: ${{ secrets.DOCKERHUB_USERNAME }} 40 | password: ${{ secrets.DOCKERHUB_TOKEN }} 41 | - 42 | name: Build and push 43 | uses: docker/build-push-action@v6 44 | with: 45 | context: . 46 | platforms: linux/amd64,linux/arm64 47 | push: true 48 | tags: ${{ steps.meta.outputs.tags }} 49 | labels: ${{ steps.meta.outputs.labels }} 50 | -------------------------------------------------------------------------------- /.github/workflows/docker-publish.yml: -------------------------------------------------------------------------------- 1 | name: "Docker Package" 2 | 3 | on: 4 | push: 5 | # Publish `master` as Docker `latest` image. 6 | branches: 7 | - master 8 | 9 | # Publish `v1.2.3` tags as releases. 10 | tags: 11 | - v* 12 | 13 | env: 14 | # TODO: Change variable to your image's name. 15 | IMAGE_NAME: sdns 16 | 17 | jobs: 18 | # Push image to GitHub Packages. 19 | # See also https://docs.docker.com/docker-hub/builds/ 20 | push: 21 | runs-on: ubuntu-latest 22 | if: github.event_name == 'push' 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Version 28 | id: get_version 29 | run: | 30 | VERSION=$(echo "${{ github.ref }}" | sed -e 's,.*/\(.*\),\1,') 31 | [ "$VERSION" == "master" ] && VERSION=latest 32 | echo ::set-output name=VERSION::$VERSION 33 | 34 | - name: Build image 35 | run: docker build -t docker.pkg.github.com/${{ github.repository }}/sdns:${{ steps.get_version.outputs.VERSION }} . 36 | 37 | - name: Log into registry 38 | run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login docker.pkg.github.com -u ${{ github.actor }} --password-stdin 39 | 40 | - name: Push image 41 | run: docker push docker.pkg.github.com/${{ github.repository }}/sdns:${{ steps.get_version.outputs.VERSION }} 42 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: [ubuntu-latest, macos-latest, windows-latest] 17 | 18 | steps: 19 | - name: Set up Go 1.x 20 | uses: actions/setup-go@v5 21 | with: 22 | go-version: ^1.22 23 | id: go 24 | 25 | - name: Check out code into the Go module directory 26 | uses: actions/checkout@v4 27 | 28 | - name: Get dependencies 29 | run: | 30 | go get -v -t -d ./... 31 | 32 | - name: Build 33 | run: go build -v . 34 | 35 | - name: Test 36 | run: make test 37 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | branches: 7 | - master 8 | pull_request: 9 | jobs: 10 | golangci: 11 | name: lint 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-go@v5 16 | with: 17 | cache: false 18 | go-version: ^1.22 19 | - name: golangci-lint 20 | uses: golangci/golangci-lint-action@v6.4.1 21 | with: 22 | # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. 23 | version: v1.62.2 24 | 25 | # Optional: working directory, useful for monorepos 26 | # working-directory: somedir 27 | 28 | # Optional: golangci-lint command line arguments. 29 | # args: --disable typecheck 30 | 31 | # Optional: show only new issues if it's a pull request. The default value is `false`. 32 | # only-new-issues: true 33 | -------------------------------------------------------------------------------- /.github/workflows/goreleaser.yml: -------------------------------------------------------------------------------- 1 | name: Releaser 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v[0-9]+.[0-9]+.[0-9]+*' 7 | 8 | jobs: 9 | goreleaser: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - 13 | name: Checkout 14 | uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | - 18 | name: Set up Go 1.x 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: ^1.22 22 | - 23 | name: Run GoReleaser 24 | uses: goreleaser/goreleaser-action@v6.2.1 25 | with: 26 | version: latest 27 | args: release --clean 28 | env: 29 | GITHUB_TOKEN: ${{ secrets.GH_API_TOKEN }} 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pem 2 | *.crt 3 | *.key 4 | sdns 5 | sdns.conf 6 | coverage.out 7 | db 8 | access.log 9 | vendor 10 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | env_files: 4 | github_token: ~/.GH_TOKEN.txt 5 | 6 | env: 7 | - CGO_ENABLED=0 8 | 9 | before: 10 | hooks: 11 | - go mod download 12 | - go generate ./... 13 | 14 | builds: 15 | - id: nonf 16 | targets: 17 | - darwin_amd64 18 | - darwin_arm64 19 | - windows_amd64 20 | - freebsd_amd64 21 | - openbsd_amd64 22 | - netbsd_amd64 23 | flags: 24 | - -trimpath 25 | ldflags: 26 | - -s -w 27 | - id: linux 28 | goos: 29 | - linux 30 | goarch: 31 | - arm 32 | - arm64 33 | - mips 34 | - mipsle 35 | - mips64 36 | - mips64le 37 | goarm: 38 | - 5 39 | - 6 40 | - 7 41 | gomips: 42 | - softfloat 43 | flags: 44 | - -trimpath 45 | ldflags: 46 | - -s -w 47 | - id: nf 48 | goos: 49 | - linux 50 | goarch: 51 | - amd64 52 | flags: 53 | - -trimpath 54 | ldflags: 55 | - -s -w 56 | 57 | release: 58 | github: 59 | owner: semihalev 60 | name: sdns 61 | prerelease: false 62 | draft: true 63 | 64 | archives: 65 | - 66 | name_template: "{{ .ProjectName }}-{{ .Version }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}" 67 | format_overrides: 68 | - goos: windows 69 | format: zip 70 | wrap_in_directory: true 71 | files: 72 | - README.md 73 | - LICENSE 74 | 75 | checksum: 76 | name_template: '{{ .ProjectName }}-{{ .Version }}_sha256sums.txt' 77 | algorithm: sha256 78 | 79 | changelog: 80 | disable: true 81 | 82 | nfpms: 83 | - file_name_template: '{{ .ProjectName }}_{{ .Version }}_{{- if eq .Arch "amd64" }}x86_64{{- else }}{{ .Arch }}{{ end }}' 84 | builds: 85 | - nf 86 | homepage: https://sdns.dev 87 | description: A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy 88 | maintainer: Yasar Alev <semihalev@gmail.com> 89 | license: MIT 90 | bindir: /usr/bin 91 | contents: 92 | - src: "./contrib/linux/sdns.service" 93 | dst: "/lib/systemd/system/sdns.service" 94 | - src: "./contrib/linux/sdns.conf" 95 | dst: "/etc/sdns.conf" 96 | type: config 97 | - dst: "/var/lib/sdns" 98 | type: dir 99 | scripts: 100 | postinstall: "contrib/linux/adduser.sh" 101 | release: 1 102 | formats: 103 | - deb 104 | - rpm 105 | overrides: 106 | deb: 107 | dependencies: 108 | - systemd-sysv 109 | rpm: 110 | dependencies: 111 | - systemd 112 | 113 | brews: 114 | - 115 | repository: 116 | owner: semihalev 117 | name: homebrew-tap 118 | directory: Formula 119 | homepage: https://sdns.dev 120 | description: A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy 121 | dependencies: 122 | - name: go 123 | type: build 124 | commit_author: 125 | name: semihalev 126 | email: semihalev@gmail.com 127 | service: | 128 | run [opt_bin/"sdns", "-config", etc/"sdns.conf"] 129 | keep_alive true 130 | require_root true 131 | error_log_path var/"log/sdns.log" 132 | log_path var/"log/sdns.log" 133 | working_dir opt_prefix 134 | test: | 135 | fork do 136 | exec bin/"sdns", "-config", testpath/"sdns.conf" 137 | end 138 | sleep(2) 139 | assert_predicate testpath/"sdns.conf", :exist? 140 | -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | sdns.dev -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | murat3ok@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code\_of\_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SDNS 2 | 3 | First and foremost, thank you for considering contributing to SDNS! It's people like you that make SDNS such a great tool. 4 | 5 | ## Getting Started 6 | 7 | * Make sure you have a [GitHub account](https://github.com/signup/free). 8 | * Fork the repository on GitHub. 9 | * Decide if you want to work on an existing issue or if you want to propose a new feature or bug fix. 10 | 11 | ## Making Changes 12 | 13 | 1. Create a new branch in your fork from the main branch. Name your branch something descriptive. 14 | 2. Make the changes in your fork. 15 | 3. If you're adding a feature or fixing a bug, please add or modify existing tests if applicable. 16 | 4. Run all tests to ensure your changes don't negatively impact existing code. 17 | 5. Commit your changes to your branch. Keep commit messages clear and concise, stating what you did and why. 18 | 19 | ## Submitting Changes 20 | 21 | 1. Push your changes to your fork on GitHub. 22 | 2. Open a pull request against the main branch of the original repository. 23 | 3. Please ensure your pull request description clearly describes the problem and solution and relates to any issues it addresses. 24 | 25 | ## Additional Resources 26 | 27 | * [Issue tracker](https://github.com/semihalev/sdns/issues) 28 | * [General GitHub documentation](https://docs.github.com/) 29 | * [GitHub pull request documentation](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) 30 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:alpine AS builder 2 | 3 | COPY . /go/src/github.com/semihalev/sdns/ 4 | 5 | WORKDIR /go/src/github.com/semihalev/sdns 6 | RUN apk --no-cache add \ 7 | ca-certificates \ 8 | gcc \ 9 | binutils-gold \ 10 | git \ 11 | musl-dev 12 | 13 | RUN go build -ldflags "-linkmode external -extldflags -static -s -w" -o /tmp/sdns \ 14 | && strip --strip-all /tmp/sdns 15 | 16 | FROM scratch 17 | 18 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ 19 | COPY --from=builder /tmp/sdns /sdns 20 | 21 | EXPOSE 53/tcp 22 | EXPOSE 53/udp 23 | EXPOSE 853 24 | EXPOSE 8053 25 | EXPOSE 8080 26 | 27 | ENTRYPOINT ["/sdns"] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 semihalev, https://github.com/semihalev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GO ?= go 2 | TESTFOLDER := $(shell $(GO) list ./...) 3 | BIN = sdns 4 | 5 | all: generate tidy test build 6 | 7 | .PHONY: test 8 | test: 9 | echo "mode: atomic" > coverage.out 10 | for d in $(TESTFOLDER); do \ 11 | $(GO) test -v -covermode=atomic -race -coverprofile=profile.out $$d > profiles.out; \ 12 | cat profiles.out; \ 13 | if grep -q "^--- FAIL" profiles.out; then \ 14 | rm -rf profiles.out; \ 15 | rm -rf profile.out; \ 16 | exit 1; \ 17 | fi; \ 18 | if [ -f profile.out ]; then \ 19 | cat profile.out | grep -v "mode:" >> coverage.out; \ 20 | rm -rf profile.out; \ 21 | fi; \ 22 | rm -rf profiles.out; \ 23 | done 24 | 25 | .PHONY: generate 26 | generate: 27 | $(GO) generate 28 | 29 | .PHONY: tidy 30 | tidy: 31 | $(GO) mod tidy 32 | 33 | .PHONY: build 34 | build: 35 | $(GO) build 36 | 37 | .PHONY: clean 38 | clean: 39 | rm -rf $(BIN) 40 | rm -rf zregister.go 41 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | Only the following version is currently being supported with security updates: 6 | 7 | | Version | Supported | 8 | | ------- | ------------------ | 9 | | 1.3.x | :white\_check\_mark: | 10 | | < 1.3 | :x: | 11 | 12 | ## Reporting a Vulnerability 13 | 14 | We take security issues seriously. If you discover a security vulnerability in this project, please follow these steps: 15 | 16 | 1. **Open an Issue**: Once you've made sure you're on the latest version and the vulnerability still exists, open an issue on our GitHub repository. Describe the vulnerability in detail, including the steps to reproduce if possible. 17 | 2. **Discussion**: After you report the vulnerability, we'll engage in a discussion with you on the issue to understand it better and evaluate its impact. 18 | 3. **Resolution**: We will address the security issue and release a new version with the necessary patches as soon as possible. 19 | 20 | Your efforts to responsibly disclose your findings are sincerely appreciated and will be acknowledged. 21 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-dinky 2 | title: SDNS 3 | show_downloads: true 4 | plugins: 5 | - jemoji 6 | -------------------------------------------------------------------------------- /api/README.md: -------------------------------------------------------------------------------- 1 | # HTTP API 2 | 3 | You can manage all blocks with basic HTTP API functions. 4 | 5 | ## Authentication 6 | 7 | API bearer token can be set on sdns config. If the token set, Authorization header should be send on API requests. 8 | ### Example Header 9 | `Authorization: Bearer my_very_long_token` 10 | 11 | ## Actions 12 | 13 | ### GET /api/v1/block/set/:key 14 | 15 | It is used to create a new block. 16 | 17 | __request__ 18 | 19 | > curl http://localhost:8080/api/v1/block/set/domain.com 20 | 21 | __response__ 22 | 23 | ```json 24 | {"success":true} 25 | ``` 26 | 27 | ### GET /api/v1/block/get/:key 28 | 29 | Used to request an existing block 30 | 31 | __request__ 32 | 33 | > curl http://localhost:8080/api/v1/block/get/domain.com 34 | 35 | __response__ 36 | 37 | ```json 38 | {"success":true} 39 | ``` 40 | or 41 | 42 | ```json 43 | {"error":"domain.com not found"} 44 | ``` 45 | 46 | ### GET /api/v1/block/exists/:key 47 | 48 | It queries whether it has a block. 49 | 50 | __request__ 51 | 52 | > curl http://localhost:8080/api/v1/block/exists/domain.com 53 | 54 | __response__ 55 | 56 | ```json 57 | {"success":true} 58 | ``` 59 | 60 | ### GET /api/v1/block/remove/:key 61 | 62 | Deletes the block. 63 | 64 | __request__ 65 | 66 | > curl http://localhost:8080/api/v1/block/remove/domain.com 67 | 68 | __response__ 69 | 70 | ```json 71 | {"success":true} 72 | ``` 73 | 74 | ### GET /api/v1/purge/domain/type 75 | 76 | Purge a cached query. 77 | 78 | __request__ 79 | 80 | > curl http://localhost:8080/api/v1/purge/example.com/MX 81 | 82 | __response__ 83 | 84 | ```json 85 | {"success":true} 86 | ``` 87 | 88 | ### GET /metrics 89 | 90 | Export the prometheus metrics. 91 | -------------------------------------------------------------------------------- /api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "net/http" 7 | "net/http/pprof" 8 | "os" 9 | "strings" 10 | "time" 11 | 12 | "github.com/miekg/dns" 13 | "github.com/prometheus/client_golang/prometheus/promhttp" 14 | "github.com/semihalev/log" 15 | "github.com/semihalev/sdns/config" 16 | "github.com/semihalev/sdns/dnsutil" 17 | "github.com/semihalev/sdns/middleware" 18 | "github.com/semihalev/sdns/middleware/blocklist" 19 | ) 20 | 21 | // API type 22 | type API struct { 23 | addr string 24 | bearerToken string 25 | router *Router 26 | blocklist *blocklist.BlockList 27 | } 28 | 29 | var debugpprof bool 30 | 31 | func init() { 32 | _, debugpprof = os.LookupEnv("SDNS_PPROF") 33 | } 34 | 35 | // New return new api 36 | func New(cfg *config.Config) *API { 37 | var bl *blocklist.BlockList 38 | 39 | b := middleware.Get("blocklist") 40 | if b != nil { 41 | bl = b.(*blocklist.BlockList) 42 | } 43 | 44 | a := &API{ 45 | addr: cfg.API, 46 | blocklist: bl, 47 | router: NewRouter(), 48 | bearerToken: cfg.BearerToken, 49 | } 50 | 51 | return a 52 | } 53 | 54 | func (a *API) checkToken(ctx *Context) bool { 55 | if a.bearerToken == "" { 56 | return true 57 | } 58 | 59 | authHeader := ctx.Request.Header.Get("Authorization") 60 | if authHeader == "" { 61 | ctx.JSON(http.StatusUnauthorized, Json{"error": "unauthorized"}) 62 | return false 63 | } 64 | 65 | tokenSplit := strings.Split(authHeader, " ") 66 | if len(tokenSplit) != 2 { 67 | ctx.JSON(http.StatusUnauthorized, Json{"error": "unauthorized"}) 68 | return false 69 | } 70 | 71 | if tokenSplit[0] == "Bearer" && a.bearerToken == tokenSplit[1] { 72 | return true 73 | } 74 | 75 | ctx.JSON(http.StatusUnauthorized, Json{"error": "unauthorized"}) 76 | return false 77 | } 78 | 79 | func (a *API) existsBlock(ctx *Context) { 80 | if !a.checkToken(ctx) { 81 | return 82 | } 83 | 84 | ctx.JSON(http.StatusOK, Json{"exists": a.blocklist.Exists(ctx.Param("key"))}) 85 | } 86 | 87 | func (a *API) getBlock(ctx *Context) { 88 | if !a.checkToken(ctx) { 89 | return 90 | } 91 | 92 | if ok, _ := a.blocklist.Get(ctx.Param("key")); !ok { 93 | ctx.JSON(http.StatusNotFound, Json{"error": ctx.Param("key") + " not found"}) 94 | } else { 95 | ctx.JSON(http.StatusOK, Json{"success": ok}) 96 | } 97 | } 98 | 99 | func (a *API) removeBlock(ctx *Context) { 100 | if !a.checkToken(ctx) { 101 | return 102 | } 103 | 104 | ctx.JSON(http.StatusOK, Json{"success": a.blocklist.Remove(ctx.Param("key"))}) 105 | } 106 | 107 | func (a *API) setBlock(ctx *Context) { 108 | if !a.checkToken(ctx) { 109 | return 110 | } 111 | 112 | ctx.JSON(http.StatusOK, Json{"success": a.blocklist.Set(ctx.Param("key"))}) 113 | } 114 | 115 | func (a *API) metrics(ctx *Context) { 116 | if !a.checkToken(ctx) { 117 | return 118 | } 119 | 120 | promhttp.Handler().ServeHTTP(ctx.Writer, ctx.Request) 121 | } 122 | 123 | func (a *API) purge(ctx *Context) { 124 | if !a.checkToken(ctx) { 125 | return 126 | } 127 | 128 | qtype := strings.ToUpper(ctx.Param("qtype")) 129 | qname := dns.Fqdn(ctx.Param("qname")) 130 | 131 | bqname := base64.StdEncoding.EncodeToString([]byte(qtype + ":" + qname)) 132 | 133 | req := new(dns.Msg) 134 | req.SetQuestion(dns.Fqdn(bqname), dns.TypeNULL) 135 | req.Question[0].Qclass = dns.ClassCHAOS 136 | 137 | _, _ = dnsutil.ExchangeInternal(context.Background(), req) 138 | 139 | ctx.JSON(http.StatusOK, Json{"success": true}) 140 | } 141 | 142 | // Run API server 143 | func (a *API) Run(ctx context.Context) { 144 | if a.addr == "" { 145 | return 146 | } 147 | 148 | if debugpprof { 149 | profiler := a.router.Group("/debug") 150 | { 151 | profiler.GET("/", func(ctx *Context) { 152 | http.Redirect(ctx.Writer, ctx.Request, profiler.path+"/pprof/", http.StatusMovedPermanently) 153 | }) 154 | profiler.GET("/pprof/", func(ctx *Context) { pprof.Index(ctx.Writer, ctx.Request) }) 155 | profiler.GET("/pprof/*", func(ctx *Context) { pprof.Index(ctx.Writer, ctx.Request) }) 156 | profiler.GET("/pprof/cmdline", func(ctx *Context) { pprof.Cmdline(ctx.Writer, ctx.Request) }) 157 | profiler.GET("/pprof/profile", func(ctx *Context) { pprof.Profile(ctx.Writer, ctx.Request) }) 158 | profiler.GET("/pprof/symbol", func(ctx *Context) { pprof.Symbol(ctx.Writer, ctx.Request) }) 159 | profiler.GET("/pprof/trace", func(ctx *Context) { pprof.Trace(ctx.Writer, ctx.Request) }) 160 | } 161 | } 162 | 163 | if a.blocklist != nil { 164 | block := a.router.Group("/api/v1/block") 165 | { 166 | block.GET("/exists/:key", a.existsBlock) 167 | block.GET("/get/:key", a.getBlock) 168 | block.GET("/remove/:key", a.removeBlock) 169 | block.GET("/set/:key", a.setBlock) 170 | } 171 | } 172 | 173 | a.router.GET("/api/v1/purge/:qname/:qtype", a.purge) 174 | 175 | a.router.GET("/metrics", a.metrics) 176 | 177 | srv := &http.Server{ 178 | Addr: a.addr, 179 | Handler: a.router, 180 | } 181 | 182 | go func() { 183 | if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { 184 | log.Error("Start API server failed", "error", err.Error()) 185 | } 186 | }() 187 | 188 | log.Info("API server listening...", "addr", a.addr) 189 | if a.bearerToken != "" { 190 | log.Info("API authorization bearer token", "token", a.bearerToken) 191 | } 192 | 193 | go func() { 194 | <-ctx.Done() 195 | 196 | log.Info("API server stopping...", "addr", a.addr) 197 | 198 | apiCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 199 | defer cancel() 200 | 201 | if err := srv.Shutdown(apiCtx); err != nil { 202 | log.Error("Shutdown API server failed:", "error", err.Error()) 203 | } 204 | }() 205 | } 206 | -------------------------------------------------------------------------------- /api/context.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | ) 7 | 8 | type ( 9 | Context struct { 10 | Request *http.Request 11 | Writer http.ResponseWriter 12 | Handler Handler 13 | Params *Params 14 | } 15 | 16 | Handler func(ctx *Context) 17 | 18 | Param struct { 19 | Key string 20 | Value string 21 | } 22 | 23 | Params []Param 24 | 25 | Json map[string]any 26 | ) 27 | 28 | func (ctx *Context) JSON(code int, data any) { 29 | buf, err := json.Marshal(data) 30 | if err != nil { 31 | ctx.Writer.WriteHeader(http.StatusInternalServerError) 32 | return 33 | } 34 | 35 | ctx.Writer.WriteHeader(code) 36 | ctx.Writer.Header().Set("Content-Type", "application/json") 37 | 38 | _, _ = ctx.Writer.Write(buf) 39 | } 40 | 41 | func (ctx *Context) Param(key string) string { 42 | params := *ctx.Params 43 | for _, p := range params { 44 | if p.Key == key { 45 | return p.Value 46 | } 47 | } 48 | 49 | return "" 50 | } 51 | 52 | func (ctx *Context) addParameter(key, value string) { 53 | i := len(*ctx.Params) 54 | *ctx.Params = (*ctx.Params)[:i+1] 55 | (*ctx.Params)[i] = Param{ 56 | Key: key, 57 | Value: value, 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /api/group.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | type Group struct { 4 | parent *Router 5 | path string 6 | } 7 | 8 | func (g *Group) Handle(method, path string, handle Handler) { 9 | g.parent.Handle(method, g.path+path, handle) 10 | } 11 | 12 | func (g *Group) GET(path string, handle Handler) { 13 | g.parent.GET(g.path+path, handle) 14 | } 15 | 16 | func (g *Group) POST(path string, handle Handler) { 17 | g.parent.POST(g.path+path, handle) 18 | } 19 | -------------------------------------------------------------------------------- /api/router.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | "runtime/debug" 8 | "sync" 9 | 10 | "github.com/semihalev/log" 11 | ) 12 | 13 | type Router struct { 14 | get Tree 15 | post Tree 16 | delete Tree 17 | put Tree 18 | patch Tree 19 | head Tree 20 | connect Tree 21 | trace Tree 22 | options Tree 23 | 24 | ctxPool sync.Pool 25 | } 26 | 27 | var extraHeaders = map[string]string{ 28 | "Server": "sdns", 29 | "Access-Control-Allow-Origin": "*", 30 | "Access-Control-Allow-Methods": "GET,POST", 31 | "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", 32 | "Pragma": "no-cache", 33 | } 34 | 35 | func NewRouter() *Router { 36 | r := &Router{} 37 | 38 | r.ctxPool.New = func() any { 39 | params := make(Params, 0, 20) 40 | return &Context{Params: ¶ms} 41 | } 42 | 43 | return r 44 | } 45 | 46 | func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { 47 | defer func() { 48 | if r := recover(); r != nil { 49 | http.Error(w, "Internal Server Error", http.StatusInternalServerError) 50 | log.Error("Recovered in API", "recover", r) 51 | 52 | _, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n\n", r)) 53 | debug.PrintStack() 54 | } 55 | }() 56 | 57 | for k, v := range extraHeaders { 58 | w.Header().Set(k, v) 59 | } 60 | 61 | if len(r.URL.Path) > 2048 { 62 | http.Error(w, "Bad Request", http.StatusBadRequest) 63 | return 64 | } 65 | 66 | ctx := rt.getContext(w, r) 67 | 68 | if r.Method[0] == 'G' { 69 | rt.get.Lookup(ctx) 70 | } else { 71 | tree := rt.selectTree(r.Method) 72 | if tree != nil { 73 | tree.Lookup(ctx) 74 | } 75 | } 76 | 77 | if ctx.Handler == nil { 78 | http.NotFound(w, r) 79 | rt.putContext(ctx) 80 | return 81 | } 82 | 83 | ctx.Handler(ctx) 84 | 85 | rt.putContext(ctx) 86 | } 87 | 88 | func (rt *Router) Handle(method, path string, handle Handler) { 89 | tree := rt.selectTree(method) 90 | tree.Add(path, handle) 91 | } 92 | 93 | func (rt *Router) GET(path string, handle Handler) { 94 | rt.get.Add(path, handle) 95 | } 96 | 97 | func (rt *Router) POST(path string, handle Handler) { 98 | rt.post.Add(path, handle) 99 | } 100 | 101 | func (rt *Router) Group(rp string) *Group { 102 | return &Group{parent: rt, path: rp} 103 | } 104 | 105 | func (rt *Router) getContext(w http.ResponseWriter, r *http.Request) *Context { 106 | ctx := rt.ctxPool.Get().(*Context) 107 | 108 | ctx.Request = r 109 | ctx.Writer = w 110 | ctx.Handler = nil 111 | (*ctx.Params) = (*ctx.Params)[:0] 112 | 113 | return ctx 114 | } 115 | 116 | func (rt *Router) putContext(ctx *Context) { 117 | rt.ctxPool.Put(ctx) 118 | } 119 | 120 | func (rt *Router) selectTree(method string) *Tree { 121 | switch method { 122 | case http.MethodGet: 123 | return &rt.get 124 | case http.MethodPost: 125 | return &rt.post 126 | case http.MethodDelete: 127 | return &rt.delete 128 | case http.MethodPut: 129 | return &rt.put 130 | case http.MethodPatch: 131 | return &rt.patch 132 | case http.MethodHead: 133 | return &rt.head 134 | case http.MethodConnect: 135 | return &rt.connect 136 | case http.MethodTrace: 137 | return &rt.trace 138 | case http.MethodOptions: 139 | return &rt.options 140 | default: 141 | return nil 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /authcache/authserver.go: -------------------------------------------------------------------------------- 1 | package authcache 2 | 3 | import ( 4 | "sort" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | ) 9 | 10 | // AuthServer type 11 | type AuthServer struct { 12 | // place atomic members at the start to fix alignment for ARM32 13 | Rtt int64 14 | Count int64 15 | Addr string 16 | Version Version 17 | } 18 | 19 | // Version type 20 | type Version byte 21 | 22 | const ( 23 | // IPv4 mode 24 | IPv4 Version = 0x1 25 | 26 | // IPv6 mode 27 | IPv6 Version = 0x2 28 | ) 29 | 30 | // NewAuthServer return a new server 31 | func NewAuthServer(addr string, version Version) *AuthServer { 32 | return &AuthServer{ 33 | Addr: addr, 34 | Version: version, 35 | } 36 | } 37 | 38 | func (v Version) String() string { 39 | switch v { 40 | case IPv4: 41 | return "IPv4" 42 | case IPv6: 43 | return "IPv6" 44 | default: 45 | return "Unknown" 46 | } 47 | } 48 | 49 | func (a *AuthServer) String() string { 50 | count := atomic.LoadInt64(&a.Count) 51 | rn := atomic.LoadInt64(&a.Rtt) 52 | 53 | if count == 0 { 54 | count = 1 55 | } 56 | 57 | var health string 58 | if rn >= int64(time.Second) { 59 | health = "POOR" 60 | } else if rn > 0 { 61 | health = "GOOD" 62 | } else { 63 | health = "UNKNOWN" 64 | } 65 | 66 | rtt := (time.Duration(rn) / time.Duration(count)).Round(time.Millisecond) 67 | 68 | return a.Version.String() + ":" + a.Addr + " rtt:" + rtt.String() + " health:[" + health + "]" 69 | } 70 | 71 | // AuthServers type 72 | type AuthServers struct { 73 | sync.RWMutex 74 | // place atomic members at the start to fix alignment for ARM32 75 | Called uint64 76 | ErrorCount uint32 77 | 78 | Zone string 79 | 80 | List []*AuthServer 81 | Nss []string 82 | 83 | CheckingDisable bool 84 | Checked bool 85 | } 86 | 87 | // Sort sort servers by rtt 88 | func Sort(serversList []*AuthServer, called uint64) { 89 | for _, s := range serversList { 90 | //clear stats and re-start again 91 | if called%1e3 == 0 { 92 | atomic.StoreInt64(&s.Rtt, 0) 93 | atomic.StoreInt64(&s.Count, 0) 94 | 95 | continue 96 | } 97 | 98 | rtt := atomic.LoadInt64(&s.Rtt) 99 | count := atomic.LoadInt64(&s.Count) 100 | 101 | if count > 0 { 102 | // average rtt 103 | atomic.StoreInt64(&s.Rtt, rtt/count) 104 | atomic.StoreInt64(&s.Count, 1) 105 | } 106 | } 107 | sort.Slice(serversList, func(i, j int) bool { 108 | return atomic.LoadInt64(&serversList[i].Rtt) < atomic.LoadInt64(&serversList[j].Rtt) 109 | }) 110 | } 111 | -------------------------------------------------------------------------------- /authcache/authserver_test.go: -------------------------------------------------------------------------------- 1 | package authcache 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_TrySort(t *testing.T) { 13 | s := &AuthServers{ 14 | List: []*AuthServer{}, 15 | } 16 | 17 | for i := 0; i < 10; i++ { 18 | s.List = append(s.List, NewAuthServer(fmt.Sprintf("0.0.0.%d:53", i), IPv4)) 19 | s.List = append(s.List, NewAuthServer(fmt.Sprintf("[::%d]:53", i), IPv6)) 20 | } 21 | 22 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 23 | for i := 0; i < 2000; i++ { 24 | for j := range s.List { 25 | s.List[j].Count++ 26 | s.List[j].Rtt += (time.Duration(r.Intn(2000-0)+0) * time.Millisecond).Nanoseconds() 27 | Sort(s.List, uint64(i)) 28 | } 29 | } 30 | 31 | assert.Equal(t, int64(1), s.List[0].Count) 32 | } 33 | -------------------------------------------------------------------------------- /authcache/ns_cache.go: -------------------------------------------------------------------------------- 1 | package authcache 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/semihalev/sdns/cache" 8 | ) 9 | 10 | // NS represents a cache entry 11 | type NS struct { 12 | Servers *AuthServers 13 | DSRR []dns.RR 14 | TTL time.Duration 15 | 16 | ut time.Time 17 | } 18 | 19 | // NSCache type 20 | type NSCache struct { 21 | cache *cache.Cache 22 | 23 | now func() time.Time 24 | } 25 | 26 | // NewNSCache return new cache 27 | func NewNSCache() *NSCache { 28 | n := &NSCache{ 29 | cache: cache.New(defaultCap), 30 | now: time.Now, 31 | } 32 | 33 | return n 34 | } 35 | 36 | // Get returns the entry for a key or an error 37 | func (n *NSCache) Get(key uint64) (*NS, error) { 38 | el, ok := n.cache.Get(key) 39 | 40 | if !ok { 41 | return nil, cache.ErrCacheNotFound 42 | } 43 | 44 | ns := el.(*NS) 45 | 46 | elapsed := n.now().UTC().Sub(ns.ut) 47 | 48 | if elapsed >= ns.TTL { 49 | return nil, cache.ErrCacheExpired 50 | } 51 | 52 | return ns, nil 53 | } 54 | 55 | // Set sets a keys value to a NS 56 | func (n *NSCache) Set(key uint64, dsRR []dns.RR, servers *AuthServers, ttl time.Duration) { 57 | if ttl > maximumTTL { 58 | ttl = maximumTTL 59 | } else if ttl < minimumTTL { 60 | ttl = minimumTTL 61 | } 62 | 63 | n.cache.Add(key, &NS{ 64 | Servers: servers, 65 | DSRR: dsRR, 66 | TTL: ttl, 67 | ut: n.now().UTC().Round(time.Second), 68 | }) 69 | } 70 | 71 | // Remove remove a cache 72 | func (n *NSCache) Remove(key uint64) { 73 | n.cache.Remove(key) 74 | } 75 | 76 | const ( 77 | maximumTTL = 12 * time.Hour 78 | minimumTTL = 1 * time.Hour 79 | defaultCap = 1024 * 256 80 | ) 81 | -------------------------------------------------------------------------------- /authcache/ns_cache_test.go: -------------------------------------------------------------------------------- 1 | package authcache 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/cache" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_NSCache(t *testing.T) { 13 | nscache := NewNSCache() 14 | 15 | m := new(dns.Msg) 16 | m.SetQuestion(dns.Fqdn("example.com."), dns.TypeA) 17 | key := cache.Hash(m.Question[0]) 18 | 19 | a := NewAuthServer("0.0.0.0:53", IPv4) 20 | _ = a.String() 21 | 22 | servers := &AuthServers{List: []*AuthServer{a}} 23 | 24 | _, err := nscache.Get(key) 25 | assert.Error(t, err) 26 | assert.Equal(t, err.Error(), "cache not found") 27 | 28 | nscache.Set(key, nil, servers, time.Hour) 29 | 30 | _, err = nscache.Get(key) 31 | assert.NoError(t, err) 32 | 33 | nscache.now = func() time.Time { 34 | return time.Now().Add(30 * time.Minute) 35 | } 36 | _, err = nscache.Get(key) 37 | assert.NoError(t, err) 38 | 39 | nscache.now = func() time.Time { 40 | return time.Now().Add(2 * time.Hour) 41 | } 42 | _, err = nscache.Get(key) 43 | assert.Error(t, err) 44 | assert.Equal(t, err.Error(), "cache expired") 45 | 46 | _, err = nscache.Get(key) 47 | assert.Error(t, err) 48 | 49 | nscache.Remove(key) 50 | } 51 | -------------------------------------------------------------------------------- /cache/cache.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import ( 7 | "errors" 8 | ) 9 | 10 | var ( 11 | // ErrCacheNotFound error 12 | ErrCacheNotFound = errors.New("cache not found") 13 | // ErrCacheExpired error 14 | ErrCacheExpired = errors.New("cache expired") 15 | ) 16 | 17 | // Cache is cache. 18 | type Cache struct { 19 | shards [shardSize]*shard 20 | } 21 | 22 | // New returns a new cache. 23 | func New(size int) *Cache { 24 | ssize := size / shardSize 25 | if ssize < 4 { 26 | ssize = 4 27 | } 28 | 29 | c := &Cache{} 30 | 31 | // Initialize all the shards 32 | for i := 0; i < shardSize; i++ { 33 | c.shards[i] = newShard(ssize) 34 | } 35 | return c 36 | } 37 | 38 | // Get looks up element index under key. 39 | func (c *Cache) Get(key uint64) (interface{}, bool) { 40 | shard := key & (shardSize - 1) 41 | return c.shards[shard].Get(key) 42 | } 43 | 44 | // Add adds a new element to the cache. If the element already exists it is overwritten. 45 | func (c *Cache) Add(key uint64, el interface{}) { 46 | shard := key & (shardSize - 1) 47 | c.shards[shard].Add(key, el) 48 | } 49 | 50 | // Remove removes the element indexed with key. 51 | func (c *Cache) Remove(key uint64) { 52 | shard := key & (shardSize - 1) 53 | c.shards[shard].Remove(key) 54 | } 55 | 56 | // Len returns the number of elements in the cache. 57 | func (c *Cache) Len() int { 58 | l := 0 59 | for _, s := range c.shards { 60 | l += s.Len() 61 | } 62 | return l 63 | } 64 | -------------------------------------------------------------------------------- /cache/cache_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | ) 10 | 11 | func TestCacheAddAndGet(t *testing.T) { 12 | c := New(4) 13 | c.Add(1, 1) 14 | 15 | if _, found := c.Get(1); !found { 16 | t.Fatal("Failed to find inserted record") 17 | } 18 | } 19 | 20 | func TestCacheLen(t *testing.T) { 21 | c := New(4) 22 | 23 | c.Add(1, 1) 24 | if l := c.Len(); l != 1 { 25 | t.Fatalf("Cache size should %d, got %d", 1, l) 26 | } 27 | 28 | c.Add(1, 1) 29 | if l := c.Len(); l != 1 { 30 | t.Fatalf("Cache size should %d, got %d", 1, l) 31 | } 32 | 33 | c.Add(2, 2) 34 | if l := c.Len(); l != 2 { 35 | t.Fatalf("Cache size should %d, got %d", 2, l) 36 | } 37 | } 38 | 39 | func TestCacheRemove(t *testing.T) { 40 | c := New(4) 41 | 42 | c.Add(1, 1) 43 | if l := c.Len(); l != 1 { 44 | t.Fatalf("Cache size should %d, got %d", 1, l) 45 | } 46 | 47 | c.Remove(1) 48 | if l := c.Len(); l != 0 { 49 | t.Fatalf("Cache size should %d, got %d", 1, l) 50 | } 51 | } 52 | 53 | func BenchmarkCacheGet(b *testing.B) { 54 | const items = 1 << 16 55 | c := New(12 * items) 56 | v := []byte("xyza") 57 | for i := 0; i < items; i++ { 58 | c.Add(uint64(i), v) 59 | } 60 | 61 | b.ReportAllocs() 62 | b.SetBytes(items) 63 | b.RunParallel(func(pb *testing.PB) { 64 | for pb.Next() { 65 | for i := 0; i < items; i++ { 66 | b, _ := c.Get(uint64(i)) 67 | if string(b.([]byte)) != string(v) { 68 | panic(fmt.Errorf("BUG: invalid value obtained; got %q; want %q", b, v)) 69 | } 70 | } 71 | } 72 | }) 73 | } 74 | 75 | func BenchmarkCacheSet(b *testing.B) { 76 | const items = 1 << 16 77 | c := New(12 * items) 78 | b.ReportAllocs() 79 | b.SetBytes(items) 80 | b.RunParallel(func(pb *testing.PB) { 81 | v := []byte("xyza") 82 | for pb.Next() { 83 | for i := 0; i < items; i++ { 84 | c.Add(uint64(i), v) 85 | } 86 | } 87 | }) 88 | } 89 | 90 | func BenchmarkCacheSetGet(b *testing.B) { 91 | const items = 1 << 16 92 | c := New(12 * items) 93 | b.ReportAllocs() 94 | b.SetBytes(2 * items) 95 | b.RunParallel(func(pb *testing.PB) { 96 | v := []byte("xyza") 97 | for pb.Next() { 98 | for i := 0; i < items; i++ { 99 | c.Add(uint64(i), v) 100 | } 101 | for i := 0; i < items; i++ { 102 | b, _ := c.Get(uint64(i)) 103 | if string(b.([]byte)) != string(v) { 104 | panic(fmt.Errorf("BUG: invalid value obtained; got %q; want %q", b, v)) 105 | } 106 | } 107 | } 108 | }) 109 | } 110 | -------------------------------------------------------------------------------- /cache/hash.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import ( 7 | "bytes" 8 | "hash" 9 | "sync" 10 | 11 | "github.com/cespare/xxhash/v2" 12 | "github.com/miekg/dns" 13 | ) 14 | 15 | // Hash returns a hash for cache 16 | func Hash(q dns.Question, cd ...bool) uint64 { 17 | h := AcquireHash() 18 | defer ReleaseHash(h) 19 | 20 | buf := AcquireBuf() 21 | defer ReleaseBuf(buf) 22 | 23 | buf.Write([]byte{uint8(q.Qclass >> 8), uint8(q.Qclass & 0xff)}) 24 | buf.Write([]byte{uint8(q.Qtype >> 8), uint8(q.Qtype & 0xff)}) 25 | 26 | if len(cd) > 0 && cd[0] { 27 | buf.WriteByte(1) 28 | } 29 | 30 | for i := range q.Name { 31 | c := q.Name[i] 32 | if c >= 'A' && c <= 'Z' { 33 | c += 'a' - 'A' 34 | } 35 | buf.WriteByte(c) 36 | } 37 | 38 | _, _ = h.Write(buf.Bytes()) 39 | 40 | return h.Sum64() 41 | } 42 | 43 | var bufferPool sync.Pool 44 | var hashPool sync.Pool 45 | 46 | // AcquireHash returns a hash from pool 47 | func AcquireHash() hash.Hash64 { 48 | v := hashPool.Get() 49 | if v == nil { 50 | return xxhash.New() 51 | } 52 | return v.(hash.Hash64) 53 | } 54 | 55 | // ReleaseHash returns hash to pool 56 | func ReleaseHash(h hash.Hash64) { 57 | h.Reset() 58 | hashPool.Put(h) 59 | } 60 | 61 | // AcquireBuf returns a buf from pool 62 | func AcquireBuf() *bytes.Buffer { 63 | v := bufferPool.Get() 64 | if v == nil { 65 | return &bytes.Buffer{} 66 | } 67 | return v.(*bytes.Buffer) 68 | } 69 | 70 | // ReleaseBuf returns buf to pool 71 | func ReleaseBuf(buf *bytes.Buffer) { 72 | buf.Reset() 73 | bufferPool.Put(buf) 74 | } 75 | -------------------------------------------------------------------------------- /cache/hash_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/miekg/dns" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_Hash(t *testing.T) { 14 | 15 | q := dns.Question{Name: "goOgle.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} 16 | 17 | asset := Hash(q) 18 | 19 | assert.Equal(t, uint64(13726664550454464700), asset) 20 | 21 | asset = Hash(q, true) 22 | 23 | assert.Equal(t, uint64(8882204296994448420), asset) 24 | } 25 | 26 | func Benchmark_Hash(b *testing.B) { 27 | q := dns.Question{Name: "goOgle.com.", Qtype: dns.TypeA, Qclass: dns.ClassANY} 28 | 29 | b.ResetTimer() 30 | 31 | for i := 0; i < b.N; i++ { 32 | Hash(q) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /cache/shard.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import "sync" 7 | 8 | // shard is a cache with random eviction. 9 | type shard struct { 10 | items map[uint64]interface{} 11 | size int 12 | 13 | sync.RWMutex 14 | } 15 | 16 | // newShard returns a new shard with size. 17 | func newShard(size int) *shard { return &shard{items: make(map[uint64]interface{}), size: size} } 18 | 19 | // Add adds element indexed by key into the cache. Any existing element is overwritten 20 | func (s *shard) Add(key uint64, el interface{}) { 21 | l := s.Len() 22 | if l+1 > s.size { 23 | s.Evict() 24 | } 25 | 26 | s.Lock() 27 | s.items[key] = el 28 | s.Unlock() 29 | } 30 | 31 | // Remove removes the element indexed by key from the cache. 32 | func (s *shard) Remove(key uint64) { 33 | s.Lock() 34 | delete(s.items, key) 35 | s.Unlock() 36 | } 37 | 38 | // Evict removes a random element from the cache. 39 | func (s *shard) Evict() { 40 | hasKey := false 41 | var key uint64 42 | 43 | s.RLock() 44 | for k := range s.items { 45 | key = k 46 | hasKey = true 47 | break 48 | } 49 | s.RUnlock() 50 | 51 | if !hasKey { 52 | // empty cache 53 | return 54 | } 55 | 56 | // If this item is gone between the RUnlock and Lock race we don't care. 57 | s.Remove(key) 58 | } 59 | 60 | // Get looks up the element indexed under key. 61 | func (s *shard) Get(key uint64) (interface{}, bool) { 62 | s.RLock() 63 | el, found := s.items[key] 64 | s.RUnlock() 65 | return el, found 66 | } 67 | 68 | // Len returns the current length of the cache. 69 | func (s *shard) Len() int { 70 | s.RLock() 71 | l := len(s.items) 72 | s.RUnlock() 73 | return l 74 | } 75 | 76 | const shardSize = 256 77 | -------------------------------------------------------------------------------- /cache/shard_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import "testing" 7 | 8 | func TestShardAddAndGet(t *testing.T) { 9 | s := newShard(4) 10 | s.Add(1, 1) 11 | 12 | if _, found := s.Get(1); !found { 13 | t.Fatal("Failed to find inserted record") 14 | } 15 | } 16 | 17 | func TestShardLen(t *testing.T) { 18 | s := newShard(4) 19 | 20 | s.Add(1, 1) 21 | if l := s.Len(); l != 1 { 22 | t.Fatalf("Shard size should %d, got %d", 1, l) 23 | } 24 | 25 | s.Add(1, 1) 26 | if l := s.Len(); l != 1 { 27 | t.Fatalf("Shard size should %d, got %d", 1, l) 28 | } 29 | 30 | s.Add(2, 2) 31 | if l := s.Len(); l != 2 { 32 | t.Fatalf("Shard size should %d, got %d", 2, l) 33 | } 34 | } 35 | 36 | func TestShardEvict(t *testing.T) { 37 | s := newShard(1) 38 | s.Evict() // empty cache 39 | 40 | s.Add(1, 1) 41 | s.Add(2, 2) 42 | // 1 should be gone 43 | 44 | if _, found := s.Get(1); found { 45 | t.Fatal("Found item that should have been evicted") 46 | } 47 | } 48 | 49 | func TestShardLenEvict(t *testing.T) { 50 | s := newShard(4) 51 | s.Add(1, 1) 52 | s.Add(2, 1) 53 | s.Add(3, 1) 54 | s.Add(4, 1) 55 | 56 | if l := s.Len(); l != 4 { 57 | t.Fatalf("Shard size should %d, got %d", 4, l) 58 | } 59 | 60 | // This should evict one element 61 | s.Add(5, 1) 62 | if l := s.Len(); l != 4 { 63 | t.Fatalf("Shard size should %d, got %d", 4, l) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/semihalev/log" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_config(t *testing.T) { 12 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 13 | 14 | const configFile = "example.conf" 15 | 16 | err := generateConfig(configFile) 17 | assert.NoError(t, err) 18 | 19 | _, err = Load(configFile, "0.0.0") 20 | assert.NoError(t, err) 21 | 22 | os.Remove(configFile) 23 | os.Remove("db") 24 | } 25 | 26 | func Test_configError(t *testing.T) { 27 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 28 | 29 | const configFile = "" 30 | 31 | _, err := Load(configFile, "0.0.0") 32 | assert.Error(t, err) 33 | 34 | os.Remove("db") 35 | } 36 | -------------------------------------------------------------------------------- /contrib/linux/adduser.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | groupadd --system sdns 4 | useradd --system -d /var/lib/sdns -s /usr/sbin/nologin -g sdns sdns 5 | mkdir -p /var/lib/sdns 6 | chown sdns:sdns /var/lib/sdns 7 | -------------------------------------------------------------------------------- /contrib/linux/sdns.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=SDNS - Fast DNS Resolver 3 | ConditionPathExists=/var/lib/sdns 4 | Wants=network.target 5 | After=network.target 6 | 7 | [Service] 8 | Type=simple 9 | User=sdns 10 | Group=sdns 11 | LimitNOFILE=131072 12 | Restart=on-failure 13 | RestartSec=10 14 | Environment="SDNS_DEBUGNS=false" 15 | Environment="SDNS_PPROF=false" 16 | WorkingDirectory=/var/lib/sdns 17 | ExecStart=/usr/bin/sdns --config=/etc/sdns.conf 18 | PermissionsStartOnly=true 19 | StandardOutput=syslog 20 | StandardError=journal 21 | SyslogIdentifier=sdns 22 | AmbientCapabilities=CAP_NET_BIND_SERVICE 23 | 24 | [Install] 25 | WantedBy=multi-user.target 26 | -------------------------------------------------------------------------------- /dnsutil/ttl.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package dnsutil 5 | 6 | import ( 7 | "time" 8 | 9 | "github.com/miekg/dns" 10 | "github.com/semihalev/sdns/response" 11 | ) 12 | 13 | // MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message. 14 | func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration { 15 | if mt != response.NoError && mt != response.NameError && mt != response.NoData { 16 | return MinimalDefaultTTL 17 | } 18 | 19 | // No records or OPT is the only record, return a short ttl as a fail safe. 20 | if len(m.Answer)+len(m.Ns) == 0 && 21 | (len(m.Extra) == 0 || (len(m.Extra) == 1 && m.Extra[0].Header().Rrtype == dns.TypeOPT)) { 22 | return MinimalDefaultTTL 23 | } 24 | 25 | minTTL := MaximumDefaulTTL 26 | for _, r := range m.Answer { 27 | if r.Header().Ttl < uint32(minTTL.Seconds()) { 28 | minTTL = time.Duration(r.Header().Ttl) * time.Second 29 | } 30 | } 31 | for _, r := range m.Ns { 32 | if r.Header().Ttl < uint32(minTTL.Seconds()) { 33 | minTTL = time.Duration(r.Header().Ttl) * time.Second 34 | } 35 | } 36 | 37 | for _, r := range m.Extra { 38 | if r.Header().Rrtype == dns.TypeOPT { 39 | // OPT records use TTL field for extended rcode and flags 40 | continue 41 | } 42 | if r.Header().Ttl < uint32(minTTL.Seconds()) { 43 | minTTL = time.Duration(r.Header().Ttl) * time.Second 44 | } 45 | } 46 | return minTTL 47 | } 48 | 49 | const ( 50 | // MinimalDefaultTTL is the absolute lowest TTL. 51 | MinimalDefaultTTL = 5 * time.Second 52 | // MaximumDefaulTTL is the maximum TTL. 53 | MaximumDefaulTTL = 24 * time.Hour 54 | ) 55 | -------------------------------------------------------------------------------- /dnsutil/ttl_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package dnsutil 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/miekg/dns" 11 | "github.com/semihalev/sdns/response" 12 | ) 13 | 14 | // See https://github.com/kubernetes/dns/issues/121, add some specific tests for those use cases. 15 | 16 | func makeRR(data string) dns.RR { 17 | r, _ := dns.NewRR(data) 18 | 19 | return r 20 | } 21 | 22 | func TestMinimalTTL(t *testing.T) { 23 | utc := time.Now().UTC() 24 | 25 | mt, _ := response.Typify(nil, utc) 26 | if mt != response.OtherError { 27 | t.Fatalf("Expected type to be response.NoData, got %s", mt) 28 | } 29 | 30 | dur := MinimalTTL(nil, mt) // minTTL on msg is 3600 (neg. ttl on SOA) 31 | if dur != time.Duration(MinimalDefaultTTL) { 32 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) 33 | } 34 | 35 | m := new(dns.Msg) 36 | m.SetQuestion("z.alm.im.", dns.TypeA) 37 | m.SetEdns0(dns.DefaultMsgSize, true) 38 | 39 | mt, _ = response.Typify(m, utc) 40 | if mt != response.NoError { 41 | t.Fatalf("Expected type to be response.NoData, got %s", mt) 42 | } 43 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) 44 | if dur != time.Duration(MinimalDefaultTTL) { 45 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) 46 | } 47 | 48 | m.Ns = []dns.RR{ 49 | makeRR("alm.im. 1800 IN SOA ivan.ns.cloudflare.com. dns.cloudflare.com. 2025042470 10000 2400 604800 3600"), 50 | } 51 | 52 | mt, _ = response.Typify(m, utc) 53 | if mt != response.NoData { 54 | t.Fatalf("Expected type to be response.NoData, got %s", mt) 55 | } 56 | 57 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) 58 | if dur != time.Duration(1800*time.Second) { 59 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) 60 | } 61 | 62 | m.Extra = []dns.RR{ 63 | makeRR("alm.im. 1200 IN A 127.0.0.1"), 64 | } 65 | 66 | m.Rcode = dns.RcodeNameError 67 | mt, _ = response.Typify(m, utc) 68 | if mt != response.NameError { 69 | t.Fatalf("Expected type to be response.NameError, got %s", mt) 70 | } 71 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) 72 | if dur != time.Duration(1200*time.Second) { 73 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) 74 | } 75 | 76 | m.Answer = []dns.RR{ 77 | makeRR("z.alm.im. 600 IN A 127.0.0.1"), 78 | } 79 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) 80 | if dur != time.Duration(600*time.Second) { 81 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) 82 | } 83 | } 84 | 85 | func BenchmarkMinimalTTL(b *testing.B) { 86 | m := new(dns.Msg) 87 | m.SetQuestion("example.org.", dns.TypeA) 88 | m.Ns = []dns.RR{ 89 | makeRR("a.example.org. 1800 IN A 127.0.0.53"), 90 | makeRR("b.example.org. 1900 IN A 127.0.0.53"), 91 | makeRR("c.example.org. 1600 IN A 127.0.0.53"), 92 | makeRR("d.example.org. 1100 IN A 127.0.0.53"), 93 | makeRR("e.example.org. 1000 IN A 127.0.0.53"), 94 | } 95 | m.Extra = []dns.RR{ 96 | makeRR("a.example.org. 1800 IN A 127.0.0.53"), 97 | makeRR("b.example.org. 1600 IN A 127.0.0.53"), 98 | makeRR("c.example.org. 1400 IN A 127.0.0.53"), 99 | makeRR("d.example.org. 1200 IN A 127.0.0.53"), 100 | makeRR("e.example.org. 1100 IN A 127.0.0.53"), 101 | } 102 | 103 | utc := time.Now().UTC() 104 | mt, _ := response.Typify(m, utc) 105 | 106 | b.ResetTimer() 107 | for i := 0; i < b.N; i++ { 108 | dur := MinimalTTL(m, mt) 109 | if dur != 1000*time.Second { 110 | b.Fatalf("Wrong MinimalTTL %d, expected %d", dur, 1000*time.Second) 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy. 3 | https://sdns.dev for more information. 4 | */ 5 | package main // import "github.com/semihalev/sdns" 6 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | sdns: 5 | image: c1982/sdns 6 | container_name: sdns 7 | restart: unless-stopped 8 | ports: 9 | - 127.0.0.1:53:53 10 | - 127.0.0.1:53:53/udp 11 | read_only: false 12 | -------------------------------------------------------------------------------- /gen.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package main 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "sort" 10 | ) 11 | 12 | // middleware list order very important, handlers call via this order. 13 | var middlewareList = []string{ 14 | "recovery", 15 | "loop", 16 | "metrics", 17 | "accesslist", 18 | "ratelimit", 19 | "edns", 20 | "accesslog", 21 | "chaos", 22 | "hostsfile", 23 | "blocklist", 24 | "as112", 25 | "cache", 26 | "failover", 27 | "resolver", 28 | "forwarder", 29 | } 30 | 31 | func main() { 32 | var pathlist []string 33 | for _, name := range middlewareList { 34 | stat, err := os.Stat(filepath.Join(middlewareDir, name)) 35 | if err != nil { 36 | fmt.Println(err) 37 | os.Exit(1) 38 | } 39 | if !stat.IsDir() { 40 | fmt.Println("path is not directory") 41 | os.Exit(1) 42 | } 43 | pathlist = append(pathlist, filepath.Join(prefixDir, middlewareDir, name)) 44 | } 45 | 46 | file, err := os.Create(filename) 47 | if err != nil { 48 | fmt.Println(err) 49 | os.Exit(1) 50 | } 51 | 52 | defer file.Close() 53 | 54 | file.WriteString("// Code generated by gen.go DO NOT EDIT.\n") 55 | 56 | file.WriteString("\npackage main\n\nimport (\n") 57 | file.WriteString("\t\"github.com/semihalev/sdns/config\"\n") 58 | file.WriteString("\t\"github.com/semihalev/sdns/middleware\"\n") 59 | 60 | sort.StringSlice(pathlist).Sort() 61 | 62 | for _, path := range pathlist { 63 | file.WriteString("\t\"" + path + "\"\n") 64 | } 65 | 66 | file.WriteString(")\n\n") 67 | 68 | file.WriteString("func init() {\n") 69 | for _, name := range middlewareList { 70 | file.WriteString("\tmiddleware.Register(\"" + name + "\", func(cfg *config.Config) middleware.Handler { return " + name + ".New(cfg) })\n") 71 | } 72 | 73 | file.WriteString("}\n") 74 | } 75 | 76 | const ( 77 | filename = "zregister.go" 78 | prefixDir = "github.com/semihalev/sdns" 79 | middlewareDir = "middleware" 80 | ) 81 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/semihalev/sdns 2 | 3 | require ( 4 | github.com/BurntSushi/toml v1.4.0 5 | github.com/cespare/xxhash/v2 v2.3.0 6 | github.com/miekg/dns v1.1.63 7 | github.com/prometheus/client_golang v1.20.5 8 | github.com/quic-go/quic-go v0.49.0 9 | github.com/semihalev/log v0.1.1 10 | github.com/stretchr/testify v1.10.0 11 | github.com/yl2chen/cidranger v1.0.2 12 | golang.org/x/time v0.10.0 13 | ) 14 | 15 | require ( 16 | github.com/beorn7/perks v1.0.1 // indirect 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/go-stack/stack v1.8.0 // indirect 19 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect 20 | github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect 21 | github.com/klauspost/compress v1.17.9 // indirect 22 | github.com/kr/text v0.2.0 // indirect 23 | github.com/mattn/go-colorable v0.1.7 // indirect 24 | github.com/mattn/go-isatty v0.0.19 // indirect 25 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 26 | github.com/onsi/ginkgo/v2 v2.9.5 // indirect 27 | github.com/pmezard/go-difflib v1.0.0 // indirect 28 | github.com/prometheus/client_model v0.6.1 // indirect 29 | github.com/prometheus/common v0.55.0 // indirect 30 | github.com/prometheus/procfs v0.15.1 // indirect 31 | github.com/quic-go/qpack v0.5.1 // indirect 32 | go.uber.org/mock v0.5.0 // indirect 33 | golang.org/x/crypto v0.31.0 // indirect 34 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect 35 | golang.org/x/mod v0.18.0 // indirect 36 | golang.org/x/net v0.33.0 // indirect 37 | golang.org/x/sync v0.10.0 // indirect 38 | golang.org/x/sys v0.28.0 // indirect 39 | golang.org/x/text v0.21.0 // indirect 40 | golang.org/x/tools v0.22.0 // indirect 41 | google.golang.org/protobuf v1.34.2 // indirect 42 | gopkg.in/yaml.v3 v3.0.1 // indirect 43 | ) 44 | 45 | go 1.22 46 | 47 | toolchain go1.22.5 48 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/semihalev/sdns/86830209f5550c03071fde5adf33dea0ff6e086b/logo.png -------------------------------------------------------------------------------- /middleware/accesslist/accesslist.go: -------------------------------------------------------------------------------- 1 | package accesslist 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "github.com/semihalev/log" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | "github.com/yl2chen/cidranger" 11 | ) 12 | 13 | // AccessList type 14 | type AccessList struct { 15 | ranger cidranger.Ranger 16 | } 17 | 18 | // New return accesslist 19 | func New(cfg *config.Config) *AccessList { 20 | if len(cfg.AccessList) == 0 { 21 | cfg.AccessList = append(cfg.AccessList, "0.0.0.0/0") 22 | cfg.AccessList = append(cfg.AccessList, "::0/0") 23 | } 24 | 25 | a := new(AccessList) 26 | a.ranger = cidranger.NewPCTrieRanger() 27 | for _, cidr := range cfg.AccessList { 28 | _, ipnet, err := net.ParseCIDR(cidr) 29 | if err != nil { 30 | log.Error("Access list parse cidr failed", "error", err.Error()) 31 | continue 32 | } 33 | 34 | _ = a.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet)) 35 | 36 | } 37 | 38 | return a 39 | } 40 | 41 | // Name return middleware name 42 | func (a *AccessList) Name() string { return name } 43 | 44 | // ServeDNS implements the Handle interface. 45 | func (a *AccessList) ServeDNS(ctx context.Context, ch *middleware.Chain) { 46 | if ch.Writer.Internal() { 47 | ch.Next(ctx) 48 | return 49 | } 50 | 51 | allowed, _ := a.ranger.Contains(ch.Writer.RemoteIP()) 52 | 53 | if !allowed { 54 | //no reply to client 55 | ch.Cancel() 56 | return 57 | } 58 | 59 | ch.Next(ctx) 60 | } 61 | 62 | const name = "accesslist" 63 | -------------------------------------------------------------------------------- /middleware/accesslist/accesslist_test.go: -------------------------------------------------------------------------------- 1 | package accesslist 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/semihalev/log" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | "github.com/semihalev/sdns/mock" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func Test_AccesslistDefaults(t *testing.T) { 15 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 16 | 17 | cfg := new(config.Config) 18 | cfg.AccessList = []string{} 19 | 20 | a := New(cfg) 21 | 22 | ch := middleware.NewChain([]middleware.Handler{a}) 23 | 24 | mw := mock.NewWriter("udp", "8.8.8.8:0") 25 | ch.Writer = mw 26 | a.ServeDNS(context.Background(), ch) 27 | } 28 | 29 | func Test_Accesslist(t *testing.T) { 30 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 31 | 32 | cfg := new(config.Config) 33 | cfg.AccessList = []string{"127.0.0.1/32", "1"} 34 | 35 | middleware.Register("accesslist", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 36 | middleware.Setup(cfg) 37 | 38 | a := middleware.Get("accesslist").(*AccessList) 39 | assert.Equal(t, "accesslist", a.Name()) 40 | 41 | ch := middleware.NewChain([]middleware.Handler{}) 42 | 43 | mw := mock.NewWriter("udp", "127.0.0.255:0") 44 | ch.Writer = mw 45 | a.ServeDNS(context.Background(), ch) 46 | 47 | mw = mock.NewWriter("udp", "0.0.0.0:0") 48 | ch.Writer = mw 49 | a.ServeDNS(context.Background(), ch) 50 | } 51 | -------------------------------------------------------------------------------- /middleware/accesslog/accesslog.go: -------------------------------------------------------------------------------- 1 | package accesslog 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | "github.com/miekg/dns" 11 | "github.com/semihalev/log" 12 | "github.com/semihalev/sdns/config" 13 | "github.com/semihalev/sdns/middleware" 14 | ) 15 | 16 | // AccessLog type 17 | type AccessLog struct { 18 | cfg *config.Config 19 | logFile *os.File 20 | } 21 | 22 | // New returns a new AccessLog 23 | func New(cfg *config.Config) *AccessLog { 24 | var logFile *os.File 25 | var err error 26 | 27 | if cfg.AccessLog != "" { 28 | logFile, err = os.OpenFile(cfg.AccessLog, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600) 29 | if err != nil { 30 | log.Error("Access log file open failed", "error", strings.Trim(err.Error(), "\n")) 31 | } 32 | } 33 | 34 | return &AccessLog{ 35 | cfg: cfg, 36 | logFile: logFile, 37 | } 38 | } 39 | 40 | // Name return middleware name 41 | func (a *AccessLog) Name() string { return name } 42 | 43 | // ServeDNS implements the Handle interface. 44 | func (a *AccessLog) ServeDNS(ctx context.Context, ch *middleware.Chain) { 45 | ch.Next(ctx) 46 | 47 | w := ch.Writer 48 | 49 | if a.logFile != nil && w.Written() && !w.Internal() { 50 | resp := w.Msg() 51 | 52 | cd := "-cd" 53 | if resp.CheckingDisabled { 54 | cd = "+cd" 55 | } 56 | 57 | record := []string{ 58 | w.RemoteIP().String() + " -", 59 | "[" + time.Now().Format("02/Jan/2006:15:04:05 -0700") + "]", 60 | formatQuestion(resp.Question[0]), 61 | w.Proto(), 62 | cd, 63 | dns.RcodeToString[resp.Rcode], 64 | strconv.Itoa(resp.Len()), 65 | } 66 | 67 | _, err := a.logFile.WriteString(strings.Join(record, " ") + "\n") 68 | if err != nil { 69 | log.Error("Access log write failed", "error", strings.Trim(err.Error(), "\n")) 70 | } 71 | } 72 | } 73 | 74 | func formatQuestion(q dns.Question) string { 75 | return "\"" + strings.ToLower(q.Name) + " " + dns.ClassToString[q.Qclass] + " " + dns.TypeToString[q.Qtype] + "\"" 76 | } 77 | 78 | const name = "accesslog" 79 | -------------------------------------------------------------------------------- /middleware/accesslog/accesslog_test.go: -------------------------------------------------------------------------------- 1 | package accesslog 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/log" 10 | "github.com/semihalev/sdns/config" 11 | "github.com/semihalev/sdns/middleware" 12 | "github.com/semihalev/sdns/mock" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func Test_accesslog(t *testing.T) { 17 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 18 | 19 | cfg := &config.Config{ 20 | AccessLog: "access_test.log", 21 | } 22 | 23 | middleware.Register("accesslog", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 24 | middleware.Setup(cfg) 25 | a := middleware.Get("accesslog").(*AccessLog) 26 | 27 | assert.Equal(t, "accesslog", a.Name()) 28 | assert.NotNil(t, a.logFile) 29 | 30 | ch := middleware.NewChain([]middleware.Handler{a}) 31 | 32 | mw := mock.NewWriter("udp", "127.0.0.1:0") 33 | req := new(dns.Msg) 34 | req.SetQuestion("test.com.", dns.TypeA) 35 | 36 | ch.Reset(mw, req) 37 | 38 | resp := new(dns.Msg) 39 | resp.SetRcode(req, dns.RcodeServerFailure) 40 | resp.Question = req.Copy().Question 41 | 42 | _ = ch.Writer.WriteMsg(resp) 43 | 44 | a.ServeDNS(context.Background(), ch) 45 | 46 | assert.Equal(t, dns.RcodeServerFailure, mw.Msg().Rcode) 47 | 48 | resp.CheckingDisabled = true 49 | a.ServeDNS(context.Background(), ch) 50 | 51 | assert.True(t, resp.CheckingDisabled) 52 | 53 | assert.NoError(t, a.logFile.Close()) 54 | 55 | a.ServeDNS(context.Background(), ch) 56 | 57 | assert.NoError(t, os.Remove(cfg.AccessLog)) 58 | } 59 | -------------------------------------------------------------------------------- /middleware/as112/as112_test.go: -------------------------------------------------------------------------------- 1 | package as112 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/log" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/semihalev/sdns/middleware" 11 | "github.com/semihalev/sdns/mock" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_AS112(t *testing.T) { 16 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 17 | 18 | cfg := new(config.Config) 19 | cfg.EmptyZones = []string{ 20 | "10.in-addr.arpa.", 21 | "example.arpa", 22 | } 23 | 24 | middleware.Register("as112", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 25 | middleware.Setup(cfg) 26 | 27 | a := middleware.Get("as112").(*AS112) 28 | 29 | assert.Equal(t, "as112", a.Name()) 30 | 31 | ch := middleware.NewChain([]middleware.Handler{}) 32 | 33 | req := new(dns.Msg) 34 | req.SetQuestion("10.in-addr.arpa.", dns.TypeSOA) 35 | ch.Request = req 36 | 37 | mw := mock.NewWriter("udp", "127.0.0.1:0") 38 | ch.Writer = mw 39 | 40 | a.ServeDNS(context.Background(), ch) 41 | assert.Equal(t, true, len(mw.Msg().Answer) > 0) 42 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 43 | 44 | req.SetQuestion("10.in-addr.arpa.", dns.TypeNS) 45 | 46 | mw = mock.NewWriter("udp", "127.0.0.1:0") 47 | ch.Writer = mw 48 | a.ServeDNS(context.Background(), ch) 49 | assert.Equal(t, true, len(mw.Msg().Answer) > 0) 50 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 51 | 52 | req.SetQuestion("10.in-addr.arpa.", dns.TypeSOA) 53 | 54 | mw = mock.NewWriter("udp", "127.0.0.1:0") 55 | ch.Writer = mw 56 | a.ServeDNS(context.Background(), ch) 57 | assert.Equal(t, true, len(mw.Msg().Answer) > 0) 58 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 59 | 60 | req.SetQuestion("10.in-addr.arpa.", dns.TypeDS) 61 | 62 | mw = mock.NewWriter("udp", "127.0.0.1:0") 63 | ch.Writer = mw 64 | a.ServeDNS(context.Background(), ch) 65 | assert.False(t, mw.Written()) 66 | 67 | req.SetQuestion("20.in-addr.arpa.", dns.TypeNS) 68 | 69 | mw = mock.NewWriter("udp", "127.0.0.1:0") 70 | ch.Writer = mw 71 | a.ServeDNS(context.Background(), ch) 72 | assert.False(t, mw.Written()) 73 | 74 | req.SetQuestion("example.com.", dns.TypeNS) 75 | 76 | mw = mock.NewWriter("udp", "127.0.0.1:0") 77 | ch.Writer = mw 78 | a.ServeDNS(context.Background(), ch) 79 | assert.False(t, mw.Written()) 80 | 81 | req.SetQuestion("10.10.in-addr.arpa.", dns.TypeSOA) 82 | 83 | mw = mock.NewWriter("udp", "127.0.0.1:0") 84 | ch.Writer = mw 85 | a.ServeDNS(context.Background(), ch) 86 | assert.Equal(t, true, len(mw.Msg().Ns) > 0) 87 | assert.Equal(t, dns.RcodeNameError, mw.Rcode()) 88 | 89 | req.SetQuestion("10.10.in-addr.arpa.", dns.TypeA) 90 | 91 | mw = mock.NewWriter("udp", "127.0.0.1:0") 92 | ch.Writer = mw 93 | a.ServeDNS(context.Background(), ch) 94 | assert.Equal(t, true, len(mw.Msg().Ns) > 0) 95 | assert.Equal(t, dns.RcodeNameError, mw.Rcode()) 96 | 97 | req.SetQuestion("10.10.in-addr.arpa.", dns.TypeNS) 98 | 99 | mw = mock.NewWriter("udp", "127.0.0.1:0") 100 | ch.Writer = mw 101 | a.ServeDNS(context.Background(), ch) 102 | assert.Equal(t, true, len(mw.Msg().Ns) > 0) 103 | assert.Equal(t, dns.RcodeNameError, mw.Rcode()) 104 | } 105 | -------------------------------------------------------------------------------- /middleware/blocklist/blocklist.go: -------------------------------------------------------------------------------- 1 | package blocklist 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "path/filepath" 9 | "sync" 10 | 11 | "github.com/miekg/dns" 12 | "github.com/semihalev/sdns/config" 13 | "github.com/semihalev/sdns/middleware" 14 | ) 15 | 16 | // BlockList type 17 | type BlockList struct { 18 | mu sync.RWMutex 19 | 20 | nullroute net.IP 21 | null6route net.IP 22 | 23 | m map[string]bool 24 | w map[string]bool 25 | 26 | cfg *config.Config 27 | } 28 | 29 | // New returns a new BlockList 30 | func New(cfg *config.Config) *BlockList { 31 | b := &BlockList{ 32 | nullroute: net.ParseIP(cfg.Nullroute), 33 | null6route: net.ParseIP(cfg.Nullroutev6), 34 | 35 | m: make(map[string]bool), 36 | w: make(map[string]bool), 37 | 38 | cfg: cfg, 39 | } 40 | 41 | go b.fetchBlocklists() 42 | 43 | return b 44 | } 45 | 46 | // Name return middleware name 47 | func (b *BlockList) Name() string { return name } 48 | 49 | // ServeDNS implements the Handle interface. 50 | func (b *BlockList) ServeDNS(ctx context.Context, ch *middleware.Chain) { 51 | w, req := ch.Writer, ch.Request 52 | 53 | q := req.Question[0] 54 | 55 | if !b.Exists(q.Name) { 56 | ch.Next(ctx) 57 | return 58 | } 59 | 60 | msg := new(dns.Msg) 61 | msg.SetReply(req) 62 | msg.Authoritative, msg.RecursionAvailable = true, true 63 | 64 | switch q.Qtype { 65 | case dns.TypeA: 66 | rrHeader := dns.RR_Header{ 67 | Name: q.Name, 68 | Rrtype: dns.TypeA, 69 | Class: dns.ClassINET, 70 | Ttl: 3600, 71 | } 72 | a := &dns.A{Hdr: rrHeader, A: b.nullroute} 73 | msg.Answer = append(msg.Answer, a) 74 | case dns.TypeAAAA: 75 | rrHeader := dns.RR_Header{ 76 | Name: q.Name, 77 | Rrtype: dns.TypeAAAA, 78 | Class: dns.ClassINET, 79 | Ttl: 3600, 80 | } 81 | a := &dns.AAAA{Hdr: rrHeader, AAAA: b.null6route} 82 | msg.Answer = append(msg.Answer, a) 83 | default: 84 | rrHeader := dns.RR_Header{ 85 | Name: q.Name, 86 | Rrtype: dns.TypeSOA, 87 | Class: dns.ClassINET, 88 | Ttl: 86400, 89 | } 90 | soa := &dns.SOA{ 91 | Hdr: rrHeader, 92 | Ns: q.Name, 93 | Mbox: ".", 94 | Serial: 0, 95 | Refresh: 28800, 96 | Retry: 7200, 97 | Expire: 604800, 98 | Minttl: 86400, 99 | } 100 | msg.Extra = append(msg.Answer, soa) 101 | } 102 | 103 | _ = w.WriteMsg(msg) 104 | 105 | ch.Cancel() 106 | } 107 | 108 | // Get returns the entry for a key or an error 109 | func (b *BlockList) Get(key string) (bool, error) { 110 | b.mu.RLock() 111 | defer b.mu.RUnlock() 112 | 113 | key = dns.CanonicalName(key) 114 | val, ok := b.m[key] 115 | 116 | if !ok { 117 | return false, errors.New("block not found") 118 | } 119 | 120 | return val, nil 121 | } 122 | 123 | // Remove removes an entry from the cache 124 | func (b *BlockList) Remove(key string) bool { 125 | if !b.Exists(key) { 126 | return false 127 | } 128 | 129 | b.mu.Lock() 130 | defer b.mu.Unlock() 131 | 132 | key = dns.CanonicalName(key) 133 | delete(b.m, key) 134 | b.save() 135 | 136 | return true 137 | } 138 | 139 | // Set sets a value in the BlockList 140 | func (b *BlockList) Set(key string) bool { 141 | b.mu.Lock() 142 | defer b.mu.Unlock() 143 | 144 | key = dns.CanonicalName(key) 145 | 146 | if b.w[key] { 147 | return false 148 | } 149 | 150 | b.m[key] = true 151 | b.save() 152 | 153 | return true 154 | } 155 | 156 | func (b *BlockList) set(key string) bool { 157 | b.mu.Lock() 158 | defer b.mu.Unlock() 159 | 160 | key = dns.CanonicalName(key) 161 | 162 | if b.w[key] { 163 | return false 164 | } 165 | 166 | b.m[key] = true 167 | 168 | return true 169 | } 170 | 171 | // Exists returns whether or not a key exists in the cache 172 | func (b *BlockList) Exists(key string) bool { 173 | b.mu.RLock() 174 | defer b.mu.RUnlock() 175 | 176 | key = dns.CanonicalName(key) 177 | _, ok := b.m[key] 178 | 179 | return ok 180 | } 181 | 182 | // Length returns the caches length 183 | func (b *BlockList) Length() int { 184 | b.mu.RLock() 185 | defer b.mu.RUnlock() 186 | 187 | return len(b.m) 188 | } 189 | 190 | func (b *BlockList) save() { 191 | path := filepath.Join(b.cfg.BlockListDir, "local") 192 | 193 | file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 194 | if err != nil { 195 | return 196 | } 197 | 198 | _, _ = file.WriteString("# The file generated by auto. DO NOT EDIT\n") 199 | for d := range b.m { 200 | _, _ = file.WriteString(d + "\n") 201 | } 202 | 203 | _ = file.Close() 204 | } 205 | 206 | const name = "blocklist" 207 | -------------------------------------------------------------------------------- /middleware/blocklist/blocklist_test.go: -------------------------------------------------------------------------------- 1 | package blocklist 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/miekg/dns" 12 | "github.com/semihalev/sdns/config" 13 | "github.com/semihalev/sdns/middleware" 14 | "github.com/semihalev/sdns/mock" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func Test_BlockList(t *testing.T) { 19 | testDomain := "test.com." 20 | 21 | cfg := new(config.Config) 22 | cfg.Nullroute = "0.0.0.0" 23 | cfg.Nullroutev6 = "::0" 24 | cfg.BlockListDir = filepath.Join(os.TempDir(), "sdns_temp") 25 | 26 | middleware.Register("blocklist", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 27 | middleware.Setup(cfg) 28 | 29 | blocklist := middleware.Get("blocklist").(*BlockList) 30 | 31 | assert.Equal(t, "blocklist", blocklist.Name()) 32 | blocklist.Set(testDomain) 33 | 34 | ch := middleware.NewChain([]middleware.Handler{}) 35 | 36 | req := new(dns.Msg) 37 | req.SetQuestion("test.com.", dns.TypeA) 38 | ch.Request = req 39 | 40 | mw := mock.NewWriter("udp", "127.0.0.1:0") 41 | ch.Writer = mw 42 | 43 | blocklist.ServeDNS(context.Background(), ch) 44 | assert.Equal(t, true, len(mw.Msg().Answer) > 0) 45 | 46 | req.SetQuestion("test.com.", dns.TypeAAAA) 47 | ch.Request = req 48 | 49 | blocklist.ServeDNS(context.Background(), ch) 50 | assert.Equal(t, true, len(mw.Msg().Answer) > 0) 51 | 52 | req.SetQuestion("test.com.", dns.TypeNS) 53 | ch.Request = req 54 | 55 | blocklist.ServeDNS(context.Background(), ch) 56 | assert.Equal(t, true, len(mw.Msg().Extra) > 0) 57 | 58 | mw = mock.NewWriter("udp", "127.0.0.1:0") 59 | ch.Writer = mw 60 | req.SetQuestion("test2.com.", dns.TypeA) 61 | blocklist.ServeDNS(context.Background(), ch) 62 | assert.Nil(t, mw.Msg()) 63 | 64 | assert.Equal(t, blocklist.Exists(testDomain), true) 65 | assert.Equal(t, blocklist.Exists(strings.ToUpper(testDomain)), true) 66 | 67 | _, err := blocklist.Get(testDomain) 68 | assert.NoError(t, err) 69 | 70 | assert.Equal(t, blocklist.Length(), 1) 71 | 72 | if exists := blocklist.Exists(fmt.Sprintf("%sfuzz", testDomain)); exists { 73 | t.Error("fuzz existed in block blocklist") 74 | } 75 | 76 | if blocklistLen := blocklist.Length(); blocklistLen != 1 { 77 | t.Error("invalid length: ", blocklistLen) 78 | } 79 | 80 | blocklist.Remove(testDomain) 81 | assert.Equal(t, blocklist.Exists(testDomain), false) 82 | 83 | _, err = blocklist.Get(testDomain) 84 | assert.Error(t, err) 85 | 86 | blocklist.Set(testDomain) 87 | } 88 | -------------------------------------------------------------------------------- /middleware/blocklist/updater.go: -------------------------------------------------------------------------------- 1 | package blocklist 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/url" 9 | "os" 10 | "path/filepath" 11 | "strings" 12 | "sync" 13 | "time" 14 | 15 | "github.com/miekg/dns" 16 | "github.com/semihalev/log" 17 | ) 18 | 19 | var timesSeen = make(map[string]int) 20 | 21 | func (b *BlockList) fetchBlocklists() { 22 | if b.cfg.BlockListDir == "" { 23 | b.cfg.BlockListDir = filepath.Join(b.cfg.Directory, "blacklists") 24 | } 25 | 26 | <-time.After(time.Second) 27 | 28 | if err := b.updateBlocklists(); err != nil { 29 | log.Error("Update blocklists failed", "error", err.Error()) 30 | } 31 | 32 | if err := b.readBlocklists(); err != nil { 33 | log.Error("Read blocklists failed", "dir", b.cfg.BlockListDir, "error", err.Error()) 34 | } 35 | } 36 | 37 | func (b *BlockList) updateBlocklists() error { 38 | if _, err := os.Stat(b.cfg.BlockListDir); os.IsNotExist(err) { 39 | if err := os.Mkdir(b.cfg.BlockListDir, 0750); err != nil { 40 | return fmt.Errorf("error creating blacklist directory: %s", err) 41 | } 42 | } 43 | 44 | b.mu.Lock() 45 | for _, entry := range b.cfg.Whitelist { 46 | b.w[dns.CanonicalName(entry)] = true 47 | } 48 | b.mu.Unlock() 49 | 50 | for _, entry := range b.cfg.Blocklist { 51 | b.set(entry) 52 | } 53 | 54 | b.fetchBlocklist() 55 | 56 | return nil 57 | } 58 | 59 | func (b *BlockList) downloadBlocklist(uri, name string) error { 60 | filePath := filepath.FromSlash(fmt.Sprintf("%s/%s", b.cfg.BlockListDir, name)) 61 | 62 | output, err := os.Create(filePath) 63 | if err != nil { 64 | return fmt.Errorf("error creating file: %s", err) 65 | } 66 | 67 | defer func() { 68 | err := output.Close() 69 | if err != nil { 70 | log.Warn("Blocklist file close failed", "name", name, "error", err.Error()) 71 | } 72 | }() 73 | 74 | response, err := http.Get(uri) 75 | if err != nil { 76 | return fmt.Errorf("error downloading source: %s", err) 77 | } 78 | defer response.Body.Close() 79 | 80 | if _, err := io.Copy(output, response.Body); err != nil { 81 | return fmt.Errorf("error copying output: %s", err) 82 | } 83 | 84 | return nil 85 | } 86 | 87 | func (b *BlockList) fetchBlocklist() { 88 | var wg sync.WaitGroup 89 | 90 | for _, uri := range b.cfg.BlockLists { 91 | wg.Add(1) 92 | 93 | u, _ := url.Parse(uri) 94 | host := u.Host 95 | timesSeen[host] = timesSeen[host] + 1 96 | fileName := fmt.Sprintf("%s.%d.tmp", host, timesSeen[host]) 97 | 98 | go func(uri string, name string) { 99 | log.Info("Fetching blacklist", "uri", uri) 100 | if err := b.downloadBlocklist(uri, name); err != nil { 101 | log.Error("Fetching blacklist", "uri", uri, "error", err.Error()) 102 | } 103 | 104 | wg.Done() 105 | }(uri, fileName) 106 | } 107 | 108 | wg.Wait() 109 | } 110 | 111 | func (b *BlockList) readBlocklists() error { 112 | log.Info("Loading blocked domains...", "path", b.cfg.BlockListDir) 113 | 114 | if _, err := os.Stat(b.cfg.BlockListDir); os.IsNotExist(err) { 115 | log.Warn("Path not found, skipping...", "path", b.cfg.BlockListDir) 116 | return nil 117 | } 118 | 119 | err := filepath.Walk(b.cfg.BlockListDir, func(path string, f os.FileInfo, _ error) error { 120 | if !f.IsDir() { 121 | file, err := os.Open(filepath.FromSlash(path)) 122 | if err != nil { 123 | return fmt.Errorf("error opening file: %s", err) 124 | } 125 | 126 | if err = b.parseHostFile(file); err != nil { 127 | _ = file.Close() 128 | return fmt.Errorf("error parsing hostfile %s", err) 129 | } 130 | 131 | _ = file.Close() 132 | 133 | if filepath.Ext(path) == ".tmp" { 134 | _ = os.Remove(filepath.FromSlash(path)) 135 | } 136 | } 137 | 138 | return nil 139 | }) 140 | 141 | if err != nil { 142 | return fmt.Errorf("error walking location %s", err) 143 | } 144 | 145 | log.Info("Blocked domains loaded", "total", b.Length()) 146 | 147 | return nil 148 | } 149 | 150 | func (b *BlockList) parseHostFile(file *os.File) error { 151 | scanner := bufio.NewScanner(file) 152 | for scanner.Scan() { 153 | line := scanner.Text() 154 | line = strings.TrimSpace(line) 155 | isComment := strings.HasPrefix(line, "#") 156 | 157 | if !isComment && line != "" { 158 | fields := strings.Fields(line) 159 | 160 | if len(fields) > 1 && !strings.HasPrefix(fields[1], "#") { 161 | line = fields[1] 162 | } else { 163 | line = fields[0] 164 | } 165 | 166 | line = dns.CanonicalName(line) 167 | 168 | if !b.Exists(line) { 169 | b.set(line) 170 | } 171 | } 172 | } 173 | 174 | if err := scanner.Err(); err != nil { 175 | return fmt.Errorf("error scanning hostfile: %s", err) 176 | } 177 | 178 | return nil 179 | } 180 | -------------------------------------------------------------------------------- /middleware/blocklist/updater_test.go: -------------------------------------------------------------------------------- 1 | package blocklist 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/semihalev/log" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | const ( 14 | testDomain = "www.google.com" 15 | ) 16 | 17 | func Test_UpdateBlocklists(t *testing.T) { 18 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 19 | 20 | tempDir := filepath.Join(os.TempDir(), "sdns_temp") 21 | 22 | cfg := new(config.Config) 23 | cfg.BlockListDir = tempDir 24 | cfg.Whitelist = append(cfg.Whitelist, testDomain) 25 | cfg.Blocklist = append(cfg.Blocklist, testDomain) 26 | 27 | cfg.BlockLists = []string{} 28 | cfg.BlockLists = append(cfg.BlockLists, "https://raw.githubusercontent.com/quidsup/notrack/master/trackers.txt") 29 | cfg.BlockLists = append(cfg.BlockLists, "https://test.dev/hosts") 30 | 31 | b := New(cfg) 32 | 33 | err := b.updateBlocklists() 34 | assert.NoError(t, err) 35 | 36 | err = b.readBlocklists() 37 | assert.NoError(t, err) 38 | } 39 | -------------------------------------------------------------------------------- /middleware/cache/item.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package cache 5 | 6 | import ( 7 | "time" 8 | 9 | "github.com/miekg/dns" 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | type item struct { 14 | Rcode int 15 | Authoritative bool 16 | AuthenticatedData bool 17 | RecursionAvailable bool 18 | Answer []dns.RR 19 | Ns []dns.RR 20 | Extra []dns.RR 21 | 22 | Limiter *rate.Limiter 23 | 24 | origTTL uint32 25 | stored time.Time 26 | 27 | prefetching bool 28 | } 29 | 30 | func newItem(m *dns.Msg, now time.Time, d time.Duration, queryRate int) *item { 31 | i := new(item) 32 | i.Rcode = m.Rcode 33 | i.Authoritative = m.Authoritative 34 | i.AuthenticatedData = m.AuthenticatedData 35 | i.RecursionAvailable = m.RecursionAvailable 36 | i.Answer = m.Answer 37 | i.Ns = m.Ns 38 | i.Extra = make([]dns.RR, len(m.Extra)) 39 | // Don't copy OPT records as these are hop-by-hop. 40 | j := 0 41 | for _, e := range m.Extra { 42 | if e.Header().Rrtype == dns.TypeOPT { 43 | continue 44 | } 45 | i.Extra[j] = e 46 | j++ 47 | } 48 | i.Extra = i.Extra[:j] 49 | 50 | i.origTTL = uint32(d.Seconds()) 51 | i.stored = now.UTC() 52 | 53 | limit := rate.Limit(0) 54 | if queryRate > 0 { 55 | limit = rate.Every(time.Second / time.Duration(queryRate)) 56 | } 57 | 58 | i.Limiter = rate.NewLimiter(limit, queryRate) 59 | 60 | return i 61 | } 62 | 63 | // toMsg turns i into a message, it tailors the reply to m. 64 | // The Authoritative bit is always set to 0, because the answer is from the cache. 65 | func (i *item) toMsg(m *dns.Msg, now time.Time) *dns.Msg { 66 | m1 := new(dns.Msg) 67 | m1.SetReply(m) 68 | 69 | m1.Authoritative = false 70 | m1.AuthenticatedData = i.AuthenticatedData 71 | m1.RecursionAvailable = i.RecursionAvailable 72 | m1.Rcode = i.Rcode 73 | 74 | m1.Answer = i.Answer 75 | m1.Ns = i.Ns 76 | m1.Extra = i.Extra 77 | 78 | m1.Answer = make([]dns.RR, len(i.Answer)) 79 | m1.Ns = make([]dns.RR, len(i.Ns)) 80 | m1.Extra = make([]dns.RR, len(i.Extra)) 81 | 82 | ttl := uint32(i.ttl(now)) 83 | for j, r := range i.Answer { 84 | m1.Answer[j] = dns.Copy(r) 85 | m1.Answer[j].Header().Ttl = ttl 86 | } 87 | for j, r := range i.Ns { 88 | m1.Ns[j] = dns.Copy(r) 89 | m1.Ns[j].Header().Ttl = ttl 90 | } 91 | // newItem skips OPT records, so we can just use i.Extra as is. 92 | for j, r := range i.Extra { 93 | m1.Extra[j] = dns.Copy(r) 94 | m1.Extra[j].Header().Ttl = ttl 95 | } 96 | return m1 97 | } 98 | 99 | func (i *item) ttl(now time.Time) int { 100 | ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) 101 | return ttl 102 | } 103 | -------------------------------------------------------------------------------- /middleware/chain.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // Chain type 10 | type Chain struct { 11 | Writer ResponseWriter 12 | Request *dns.Msg 13 | 14 | handlers []Handler 15 | 16 | head int 17 | tail int 18 | count int 19 | } 20 | 21 | // NewChain return new fresh chain 22 | func NewChain(handlers []Handler) *Chain { 23 | return &Chain{ 24 | Writer: &responseWriter{}, 25 | handlers: handlers, 26 | count: len(handlers), 27 | } 28 | } 29 | 30 | // Next call next dns handler in the chain 31 | func (ch *Chain) Next(ctx context.Context) { 32 | if ch.count == 0 { 33 | return 34 | } 35 | 36 | handler := ch.handlers[ch.head] 37 | ch.head = (ch.head + 1) % len(ch.handlers) 38 | ch.count-- 39 | 40 | handler.ServeDNS(ctx, ch) 41 | } 42 | 43 | // Cancel next calls 44 | func (ch *Chain) Cancel() { 45 | ch.count = 0 46 | } 47 | 48 | // CancelWithRcode next calls with rcode 49 | func (ch *Chain) CancelWithRcode(rcode int, do bool) { 50 | m := new(dns.Msg) 51 | m.Extra = ch.Request.Extra 52 | m.SetRcode(ch.Request, rcode) 53 | 54 | m.RecursionAvailable = true 55 | m.RecursionDesired = true 56 | 57 | if opt := m.IsEdns0(); opt != nil { 58 | opt.SetDo(do) 59 | } 60 | 61 | _ = ch.Writer.WriteMsg(m) 62 | 63 | ch.count = 0 64 | } 65 | 66 | // Reset the chain variables 67 | func (ch *Chain) Reset(w dns.ResponseWriter, r *dns.Msg) { 68 | ch.Writer.Reset(w) 69 | ch.Request = r 70 | ch.count = len(ch.handlers) 71 | ch.head, ch.tail = 0, 0 72 | } 73 | -------------------------------------------------------------------------------- /middleware/chain_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/mock" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_Chain(t *testing.T) { 13 | w := mock.NewWriter("tcp", "127.0.0.1:0") 14 | ch := NewChain([]Handler{&dummy{}}) 15 | req := new(dns.Msg) 16 | req.SetQuestion("test.com.", dns.TypeA) 17 | req.SetEdns0(512, true) 18 | ch.Reset(w, req) 19 | 20 | ch.Next(context.Background()) 21 | 22 | req.Rcode = dns.RcodeSuccess 23 | err := ch.Writer.WriteMsg(req) 24 | assert.NoError(t, err) 25 | 26 | data, err := req.Pack() 27 | assert.NoError(t, err) 28 | 29 | assert.Equal(t, true, ch.Writer.Written()) 30 | assert.Equal(t, dns.RcodeSuccess, ch.Writer.Rcode()) 31 | 32 | _, err = ch.Writer.Write(data) 33 | assert.Equal(t, errAlreadyWritten, err) 34 | 35 | ch.Reset(mock.NewWriter("tcp", "127.0.0.1:0"), req) 36 | size, err := ch.Writer.Write(data) 37 | assert.NoError(t, err) 38 | assert.Equal(t, len(data), size) 39 | assert.NotNil(t, ch.Writer.Msg()) 40 | 41 | err = ch.Writer.WriteMsg(req) 42 | assert.Equal(t, errAlreadyWritten, err) 43 | 44 | ch.Reset(mock.NewWriter("tcp", "127.0.0.1:0"), req) 45 | _, err = ch.Writer.Write([]byte{}) 46 | assert.Error(t, err) 47 | 48 | assert.Equal(t, "tcp", ch.Writer.Proto()) 49 | assert.Equal(t, "127.0.0.1", ch.Writer.RemoteIP().String()) 50 | 51 | ch.Cancel() 52 | assert.Equal(t, 0, ch.count) 53 | 54 | ch.Reset(mock.NewWriter("tcp", "127.0.0.1:0"), req) 55 | 56 | ch.CancelWithRcode(dns.RcodeServerFailure, true) 57 | assert.True(t, ch.Writer.Written()) 58 | assert.Equal(t, dns.RcodeServerFailure, ch.Writer.Rcode()) 59 | assert.Equal(t, 0, ch.count) 60 | } 61 | -------------------------------------------------------------------------------- /middleware/chaos/chaos.go: -------------------------------------------------------------------------------- 1 | package chaos 2 | 3 | import ( 4 | "context" 5 | "os" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | ) 11 | 12 | // Chaos type 13 | type Chaos struct { 14 | chaos bool 15 | version string 16 | } 17 | 18 | // New return accesslist 19 | func New(cfg *config.Config) *Chaos { 20 | return &Chaos{ 21 | version: "SDNS v" + cfg.ServerVersion() + " (github.com/semihalev/sdns)", 22 | chaos: cfg.Chaos, 23 | } 24 | } 25 | 26 | // Name return middleware name 27 | func (c *Chaos) Name() string { return name } 28 | 29 | // ServeDNS implements the Handle interface. 30 | func (c *Chaos) ServeDNS(ctx context.Context, ch *middleware.Chain) { 31 | w, req := ch.Writer, ch.Request 32 | 33 | q := req.Question[0] 34 | 35 | if q.Qclass != dns.ClassCHAOS || q.Qtype != dns.TypeTXT || !c.chaos { 36 | ch.Next(ctx) 37 | return 38 | } 39 | 40 | resp := new(dns.Msg) 41 | resp.SetReply(req) 42 | 43 | switch q.Name { 44 | case "version.bind.", "version.server.": 45 | resp.Answer = []dns.RR{ 46 | &dns.TXT{ 47 | Hdr: dns.RR_Header{ 48 | Name: q.Name, 49 | Rrtype: dns.TypeTXT, 50 | Class: q.Qclass, 51 | }, 52 | Txt: []string{c.version}, 53 | }} 54 | case "hostname.bind.", "id.server.": 55 | hostname, err := os.Hostname() 56 | if err != nil { 57 | hostname = "unknown" 58 | } 59 | 60 | resp.Answer = []dns.RR{ 61 | &dns.TXT{ 62 | Hdr: dns.RR_Header{ 63 | Name: q.Name, 64 | Rrtype: dns.TypeTXT, 65 | Class: q.Qclass, 66 | }, 67 | Txt: []string{limitTXTLength(hostname)}, 68 | }} 69 | default: 70 | ch.Next(ctx) 71 | return 72 | } 73 | 74 | _ = w.WriteMsg(resp) 75 | ch.Cancel() 76 | } 77 | 78 | func limitTXTLength(s string) string { 79 | if len(s) < 256 { 80 | return s 81 | } 82 | return s[:255] 83 | } 84 | 85 | const name = "chaos" 86 | -------------------------------------------------------------------------------- /middleware/chaos/chaos_test.go: -------------------------------------------------------------------------------- 1 | package chaos 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | "github.com/semihalev/sdns/mock" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func Test_Chaos(t *testing.T) { 15 | cfg := new(config.Config) 16 | cfg.Chaos = true 17 | 18 | middleware.Register("chaos", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 19 | middleware.Setup(cfg) 20 | 21 | c := middleware.Get("chaos").(*Chaos) 22 | assert.Equal(t, "chaos", c.Name()) 23 | 24 | ch := middleware.NewChain([]middleware.Handler{}) 25 | 26 | mw := mock.NewWriter("udp", "127.0.0.1:0") 27 | req := new(dns.Msg) 28 | req.SetQuestion("version.bind.", dns.TypeTXT) 29 | ch.Reset(mw, req) 30 | c.ServeDNS(context.Background(), ch) 31 | 32 | assert.False(t, mw.Written()) 33 | 34 | mw = mock.NewWriter("udp", "127.0.0.1:0") 35 | req.Question[0].Qclass = dns.ClassCHAOS 36 | ch.Reset(mw, req) 37 | c.ServeDNS(context.Background(), ch) 38 | 39 | assert.True(t, mw.Written()) 40 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 41 | 42 | mw = mock.NewWriter("udp", "127.0.0.1:0") 43 | req.Question[0].Name = "hostname.bind." 44 | ch.Reset(mw, req) 45 | c.ServeDNS(context.Background(), ch) 46 | 47 | assert.True(t, mw.Written()) 48 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 49 | 50 | mw = mock.NewWriter("udp", "127.0.0.1:0") 51 | req.Question[0].Name = "unknown.bind." 52 | ch.Reset(mw, req) 53 | c.ServeDNS(context.Background(), ch) 54 | 55 | assert.False(t, mw.Written()) 56 | } 57 | -------------------------------------------------------------------------------- /middleware/edns/edns.go: -------------------------------------------------------------------------------- 1 | package edns 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/dnsutil" 10 | "github.com/semihalev/sdns/middleware" 11 | ) 12 | 13 | // EDNS type 14 | type EDNS struct { 15 | cookiesecret string 16 | nsidstr string 17 | } 18 | 19 | // New return edns 20 | func New(cfg *config.Config) *EDNS { 21 | return &EDNS{cookiesecret: cfg.CookieSecret, nsidstr: cfg.NSID} 22 | } 23 | 24 | // Name return middleware name 25 | func (e *EDNS) Name() string { return name } 26 | 27 | // ResponseWriter implement of ctx.ResponseWriter 28 | type ResponseWriter struct { 29 | middleware.ResponseWriter 30 | *EDNS 31 | 32 | opt *dns.OPT 33 | size int 34 | do bool 35 | cookie string 36 | nsid bool 37 | noedns bool 38 | noad bool 39 | } 40 | 41 | // ServeDNS implements the Handle interface. 42 | func (e *EDNS) ServeDNS(ctx context.Context, ch *middleware.Chain) { 43 | w, req := ch.Writer, ch.Request 44 | 45 | if req.Opcode > 0 { 46 | _ = dnsutil.NotSupported(w, req) 47 | 48 | ch.Cancel() 49 | return 50 | } 51 | 52 | noedns := req.IsEdns0() == nil 53 | 54 | opt, size, cookie, nsid, do := dnsutil.SetEdns0(req) 55 | if opt.Version() != 0 { 56 | opt.SetVersion(0) 57 | 58 | ch.CancelWithRcode(dns.RcodeBadVers, do) 59 | 60 | return 61 | } 62 | 63 | switch w.Proto() { 64 | case "tcp", "doq", "doh": 65 | size = dns.MaxMsgSize 66 | } 67 | 68 | if noedns { 69 | size = dns.MinMsgSize 70 | } 71 | 72 | ch.Writer = &ResponseWriter{ 73 | ResponseWriter: w, 74 | EDNS: e, 75 | 76 | opt: opt, 77 | size: size, 78 | do: do, 79 | cookie: cookie, 80 | noedns: noedns, 81 | nsid: nsid, 82 | noad: !req.AuthenticatedData && !do, 83 | } 84 | 85 | ch.Next(ctx) 86 | 87 | ch.Writer = w 88 | } 89 | 90 | // WriteMsg implements the ctx.ResponseWriter interface 91 | func (w *ResponseWriter) WriteMsg(m *dns.Msg) error { 92 | m.Compress = true 93 | 94 | if !w.do { 95 | m = dnsutil.ClearDNSSEC(m) 96 | } 97 | m = dnsutil.ClearOPT(m) 98 | 99 | if !w.noedns { 100 | w.opt.SetDo(w.do) 101 | w.setCookie() 102 | w.setNSID() 103 | m.Extra = append(m.Extra, w.opt) 104 | } 105 | 106 | if w.noad { 107 | m.AuthenticatedData = false 108 | } 109 | 110 | if w.Proto() == "udp" && m.Len() > w.size { 111 | m.Truncated = true 112 | m.Answer = []dns.RR{} 113 | m.Ns = []dns.RR{} 114 | m.AuthenticatedData = false 115 | } 116 | 117 | return w.ResponseWriter.WriteMsg(m) 118 | } 119 | 120 | func (w *ResponseWriter) setCookie() { 121 | if w.cookie == "" { 122 | return 123 | } 124 | 125 | w.opt.Option = append(w.opt.Option, &dns.EDNS0_COOKIE{ 126 | Code: dns.EDNS0COOKIE, 127 | Cookie: dnsutil.GenerateServerCookie(w.cookiesecret, w.RemoteIP().String(), w.cookie), 128 | }) 129 | } 130 | 131 | func (w *ResponseWriter) setNSID() { 132 | if w.nsidstr == "" || !w.nsid { 133 | return 134 | } 135 | 136 | w.opt.Option = append(w.opt.Option, &dns.EDNS0_NSID{ 137 | Code: dns.EDNS0NSID, 138 | Nsid: hex.EncodeToString([]byte(w.nsidstr)), 139 | }) 140 | } 141 | 142 | const name = "edns" 143 | -------------------------------------------------------------------------------- /middleware/edns/edns_test.go: -------------------------------------------------------------------------------- 1 | package edns 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "testing" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/semihalev/sdns/middleware" 11 | "github.com/semihalev/sdns/mock" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type dummy struct{} 16 | 17 | func (d *dummy) ServeDNS(ctx context.Context, ch *middleware.Chain) { 18 | w, req := ch.Writer, ch.Request 19 | 20 | m := new(dns.Msg) 21 | m.SetReply(req) 22 | 23 | rrHeader := dns.RR_Header{ 24 | Name: req.Question[0].Name, 25 | Rrtype: dns.TypeA, 26 | Class: dns.ClassINET, 27 | Ttl: 3600, 28 | } 29 | a := &dns.A{Hdr: rrHeader, A: net.ParseIP("127.0.0.1")} 30 | 31 | for i := 0; i < 100; i++ { 32 | m.Answer = append(m.Answer, a) 33 | } 34 | 35 | _ = w.WriteMsg(m) 36 | } 37 | 38 | func (d *dummy) Name() string { return "dummy" } 39 | 40 | func Test_EDNS(t *testing.T) { 41 | testDomain := "example.com." 42 | 43 | cfg := new(config.Config) 44 | 45 | middleware.Register("edns", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 46 | middleware.Setup(cfg) 47 | 48 | edns := middleware.Get("edns").(*EDNS) 49 | assert.Equal(t, "edns", edns.Name()) 50 | 51 | ch := middleware.NewChain([]middleware.Handler{edns, &dummy{}}) 52 | 53 | req := new(dns.Msg) 54 | req.SetQuestion(testDomain, dns.TypeA) 55 | 56 | mw := mock.NewWriter("tcp", "127.0.0.1:0") 57 | ch.Reset(mw, req) 58 | ch.Next(context.Background()) 59 | 60 | assert.True(t, ch.Writer.Written()) 61 | assert.Equal(t, dns.RcodeSuccess, ch.Writer.Rcode()) 62 | assert.Nil(t, ch.Writer.Msg().IsEdns0()) 63 | 64 | req.SetEdns0(4096, true) 65 | opt := req.IsEdns0() 66 | opt.SetVersion(100) 67 | 68 | mw = mock.NewWriter("udp", "127.0.0.1:0") 69 | ch.Reset(mw, req) 70 | ch.Next(context.Background()) 71 | 72 | assert.True(t, ch.Writer.Written()) 73 | assert.Equal(t, dns.RcodeBadVers, ch.Writer.Rcode()) 74 | 75 | opt = req.IsEdns0() 76 | opt.SetVersion(0) 77 | opt.SetUDPSize(512) 78 | 79 | mw = mock.NewWriter("tcp", "127.0.0.1:0") 80 | ch.Reset(mw, req) 81 | ch.Next(context.Background()) 82 | 83 | if assert.True(t, ch.Writer.Written()) { 84 | assert.False(t, ch.Writer.Msg().Truncated) 85 | } 86 | 87 | mw = mock.NewWriter("udp", "127.0.0.1:0") 88 | ch.Reset(mw, req) 89 | ch.Next(context.Background()) 90 | 91 | if assert.True(t, ch.Writer.Written()) { 92 | assert.True(t, ch.Writer.Msg().Truncated) 93 | } 94 | 95 | opt.Option = append(opt.Option, &dns.EDNS0_COOKIE{ 96 | Code: dns.EDNS0COOKIE, 97 | Cookie: "testtesttesttest", 98 | }) 99 | opt.SetUDPSize(4096) 100 | mw = mock.NewWriter("udp", "127.0.0.1:0") 101 | ch.Reset(mw, req) 102 | ch.Next(context.Background()) 103 | } 104 | -------------------------------------------------------------------------------- /middleware/failover/failover.go: -------------------------------------------------------------------------------- 1 | package failover 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "strings" 7 | "time" 8 | 9 | "github.com/miekg/dns" 10 | "github.com/semihalev/log" 11 | "github.com/semihalev/sdns/config" 12 | "github.com/semihalev/sdns/dnsutil" 13 | "github.com/semihalev/sdns/middleware" 14 | ) 15 | 16 | // Failover type 17 | type Failover struct { 18 | servers []string 19 | } 20 | 21 | // ResponseWriter implement of ctx.ResponseWriter 22 | type ResponseWriter struct { 23 | middleware.ResponseWriter 24 | 25 | f *Failover 26 | } 27 | 28 | // New return failover 29 | func New(cfg *config.Config) *Failover { 30 | fallbackservers := []string{} 31 | for _, s := range cfg.FallbackServers { 32 | host, _, _ := net.SplitHostPort(s) 33 | 34 | if ip := net.ParseIP(host); ip != nil && ip.To4() != nil { 35 | fallbackservers = append(fallbackservers, s) 36 | } else if ip != nil && ip.To16() != nil { 37 | fallbackservers = append(fallbackservers, s) 38 | } else { 39 | log.Error("Fallback server is not correct. Check your config.", "server", s) 40 | } 41 | } 42 | 43 | return &Failover{servers: fallbackservers} 44 | } 45 | 46 | // Name return middleware name 47 | func (f *Failover) Name() string { return name } 48 | 49 | // ServeDNS implements the Handle interface. 50 | func (f *Failover) ServeDNS(ctx context.Context, ch *middleware.Chain) { 51 | w := ch.Writer 52 | 53 | ch.Writer = &ResponseWriter{ResponseWriter: w, f: f} 54 | 55 | ch.Next(ctx) 56 | 57 | ch.Writer = w 58 | } 59 | 60 | // WriteMsg implements the ctx.ResponseWriter interface 61 | func (w *ResponseWriter) WriteMsg(m *dns.Msg) error { 62 | if len(m.Question) == 0 || len(w.f.servers) == 0 { 63 | return w.ResponseWriter.WriteMsg(m) 64 | } 65 | 66 | if m.Rcode != dns.RcodeServerFailure || !m.RecursionDesired { 67 | return w.ResponseWriter.WriteMsg(m) 68 | } 69 | 70 | req := new(dns.Msg) 71 | req.SetQuestion(m.Question[0].Name, m.Question[0].Qtype) 72 | req.Question[0].Qclass = m.Question[0].Qclass 73 | req.SetEdns0(dnsutil.DefaultMsgSize, true) 74 | req.CheckingDisabled = m.CheckingDisabled 75 | 76 | ctx := context.Background() 77 | 78 | for _, server := range w.f.servers { 79 | ctx, cancel := context.WithTimeout(ctx, 5*time.Second) 80 | defer cancel() 81 | resp, err := dnsutil.Exchange(ctx, req, server, "udp") 82 | if err != nil { 83 | log.Info("Failover query failed", "query", formatQuestion(req.Question[0]), "error", err.Error()) 84 | continue 85 | } 86 | 87 | resp.Id = m.Id 88 | 89 | return w.ResponseWriter.WriteMsg(resp) 90 | } 91 | 92 | return w.ResponseWriter.WriteMsg(m) 93 | } 94 | 95 | func formatQuestion(q dns.Question) string { 96 | return strings.ToLower(q.Name) + " " + dns.ClassToString[q.Qclass] + " " + dns.TypeToString[q.Qtype] 97 | } 98 | 99 | const name = "failover" 100 | -------------------------------------------------------------------------------- /middleware/failover/failover_test.go: -------------------------------------------------------------------------------- 1 | package failover 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/log" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/semihalev/sdns/middleware" 11 | "github.com/semihalev/sdns/mock" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type dummy struct{} 16 | 17 | func (d *dummy) ServeDNS(ctx context.Context, ch *middleware.Chain) { 18 | w, req := ch.Writer, ch.Request 19 | 20 | m := new(dns.Msg) 21 | m.SetRcode(req, dns.RcodeServerFailure) 22 | 23 | _ = w.WriteMsg(m) 24 | } 25 | 26 | func (d *dummy) Name() string { return "dummy" } 27 | 28 | func Test_Failover(t *testing.T) { 29 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 30 | 31 | cfg := new(config.Config) 32 | cfg.FallbackServers = []string{"[::255]:53", "8.8.8.8:53", "1"} 33 | 34 | middleware.Register("failover", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 35 | middleware.Setup(cfg) 36 | 37 | f := middleware.Get("failover").(*Failover) 38 | assert.Equal(t, "failover", f.Name()) 39 | 40 | ch := middleware.NewChain([]middleware.Handler{f, &dummy{}}) 41 | 42 | ctx := context.Background() 43 | 44 | req := new(dns.Msg) 45 | req.SetQuestion("example.com.", dns.TypeA) 46 | req.RecursionDesired = false 47 | 48 | mw := mock.NewWriter("udp", "127.0.0.1:0") 49 | ch.Writer = mw 50 | ch.Request = req 51 | 52 | ch.Reset(mw, req) 53 | ch.Next(ctx) 54 | 55 | assert.Equal(t, dns.RcodeServerFailure, mw.Rcode()) 56 | 57 | req.RecursionDesired = true 58 | 59 | ch.Reset(mw, req) 60 | ch.Next(ctx) 61 | 62 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess) 63 | 64 | f.servers = []string{} 65 | 66 | ch.Reset(mw, req) 67 | ch.Next(ctx) 68 | 69 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure) 70 | 71 | f.servers = []string{"[::255]:53"} 72 | 73 | ch.Reset(mw, req) 74 | ch.Next(ctx) 75 | 76 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure) 77 | } 78 | -------------------------------------------------------------------------------- /middleware/forwarder/forwarder.go: -------------------------------------------------------------------------------- 1 | package forwarder 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "strings" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/log" 10 | "github.com/semihalev/sdns/config" 11 | "github.com/semihalev/sdns/dnsutil" 12 | "github.com/semihalev/sdns/middleware" 13 | ) 14 | 15 | type server struct { 16 | Addr string 17 | Proto string 18 | } 19 | 20 | // Forwarder type 21 | type Forwarder struct { 22 | servers []*server 23 | dnssec bool 24 | } 25 | 26 | // New return forwarder 27 | func New(cfg *config.Config) *Forwarder { 28 | forwarderservers := []*server{} 29 | for _, s := range cfg.ForwarderServers { 30 | srv := &server{Proto: "udp"} 31 | 32 | if strings.HasPrefix(s, "tls://") { 33 | s = strings.TrimPrefix(s, "tls://") 34 | srv.Proto = "tcp-tls" 35 | } 36 | 37 | host, _, _ := net.SplitHostPort(s) 38 | 39 | if ip := net.ParseIP(host); ip != nil && ip.To4() != nil { 40 | srv.Addr = s 41 | forwarderservers = append(forwarderservers, srv) 42 | } else if ip != nil && ip.To16() != nil { 43 | srv.Addr = s 44 | forwarderservers = append(forwarderservers, srv) 45 | } else { 46 | log.Error("Forwarder server is not correct. Check your config.", "server", s) 47 | } 48 | } 49 | 50 | return &Forwarder{servers: forwarderservers, dnssec: cfg.DNSSEC == "on"} 51 | } 52 | 53 | // Name return middleware name 54 | func (f *Forwarder) Name() string { return name } 55 | 56 | // ServeDNS implements the Handle interface. 57 | func (f *Forwarder) ServeDNS(ctx context.Context, ch *middleware.Chain) { 58 | w, req := ch.Writer, ch.Request 59 | 60 | if len(req.Question) == 0 || len(f.servers) == 0 { 61 | ch.CancelWithRcode(dns.RcodeServerFailure, true) 62 | return 63 | } 64 | 65 | if !req.CheckingDisabled { 66 | req.CheckingDisabled = !f.dnssec 67 | } 68 | 69 | for _, server := range f.servers { 70 | resp, err := dnsutil.Exchange(ctx, req, server.Addr, server.Proto) 71 | if err != nil { 72 | log.Info("forwarder query failed", "query", formatQuestion(req.Question[0]), "error", err.Error()) 73 | continue 74 | } 75 | 76 | resp.Id = req.Id 77 | if !f.dnssec { 78 | resp.CheckingDisabled = false 79 | } 80 | 81 | _ = w.WriteMsg(resp) 82 | return 83 | } 84 | 85 | ch.CancelWithRcode(dns.RcodeServerFailure, true) 86 | } 87 | 88 | func formatQuestion(q dns.Question) string { 89 | return strings.ToLower(q.Name) + " " + dns.ClassToString[q.Qclass] + " " + dns.TypeToString[q.Qtype] 90 | } 91 | 92 | const name = "forwarder" 93 | -------------------------------------------------------------------------------- /middleware/forwarder/forwarder_test.go: -------------------------------------------------------------------------------- 1 | package forwarder 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/log" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/semihalev/sdns/middleware" 11 | "github.com/semihalev/sdns/mock" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_Forwarder(t *testing.T) { 16 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 17 | 18 | cfg := new(config.Config) 19 | cfg.ForwarderServers = []string{"[::255]:53", "8.8.8.8:53", "1", "tls://8.8.8.8:853"} 20 | 21 | middleware.Register("forwarder", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 22 | middleware.Setup(cfg) 23 | 24 | f := middleware.Get("forwarder").(*Forwarder) 25 | assert.Equal(t, "forwarder", f.Name()) 26 | 27 | ch := middleware.NewChain([]middleware.Handler{f}) 28 | 29 | ctx := context.Background() 30 | 31 | req := new(dns.Msg) 32 | req.SetQuestion("example.com.", dns.TypeA) 33 | req.RecursionDesired = false 34 | 35 | mw := mock.NewWriter("udp", "127.0.0.1:0") 36 | ch.Writer = mw 37 | ch.Request = req 38 | 39 | ch.Reset(mw, req) 40 | ch.Next(ctx) 41 | 42 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 43 | 44 | req.RecursionDesired = true 45 | 46 | ch.Reset(mw, req) 47 | ch.Next(ctx) 48 | 49 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess) 50 | 51 | f.servers = []*server{} 52 | 53 | ch.Reset(mw, req) 54 | ch.Next(ctx) 55 | 56 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure) 57 | 58 | srv := &server{Addr: "[::255]:53", Proto: "udp"} 59 | f.servers = []*server{srv} 60 | 61 | ch.Reset(mw, req) 62 | ch.Next(ctx) 63 | 64 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure) 65 | 66 | srv = &server{Addr: "8.8.8.8:853", Proto: "tcp-tls"} 67 | f.servers = []*server{srv} 68 | 69 | ch.Reset(mw, req) 70 | ch.Next(ctx) 71 | 72 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess) 73 | } 74 | -------------------------------------------------------------------------------- /middleware/loop/loop.go: -------------------------------------------------------------------------------- 1 | package loop 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/semihalev/log" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | ) 11 | 12 | // Loop dummy type 13 | type Loop struct{} 14 | 15 | type ctxKey string 16 | 17 | // New return loop 18 | func New(cfg *config.Config) *Loop { 19 | return &Loop{} 20 | } 21 | 22 | // Name return middleware name 23 | func (l *Loop) Name() string { return name } 24 | 25 | // ServeDNS implements the Handle interface. 26 | func (l *Loop) ServeDNS(ctx context.Context, ch *middleware.Chain) { 27 | req := ch.Request 28 | 29 | if len(req.Question) == 0 { 30 | ch.Cancel() 31 | return 32 | } 33 | 34 | qKey := req.Question[0].Name + ":" + dns.TypeToString[req.Question[0].Qtype] 35 | 36 | key := ctxKey("loopcheck:" + qKey) 37 | 38 | if v := ctx.Value(key); v != nil { 39 | count := v.(uint64) 40 | 41 | if count > 10 { 42 | log.Warn("Loop detected", "query", qKey) 43 | ch.CancelWithRcode(dns.RcodeServerFailure, false) 44 | return 45 | } 46 | 47 | count++ 48 | ctx = context.WithValue(ctx, key, count) 49 | } else { 50 | ctx = context.WithValue(ctx, key, uint64(1)) 51 | } 52 | 53 | ch.Next(ctx) 54 | } 55 | 56 | const name = "loop" 57 | -------------------------------------------------------------------------------- /middleware/loop/loop_test.go: -------------------------------------------------------------------------------- 1 | package loop 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/log" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/semihalev/sdns/middleware" 11 | "github.com/semihalev/sdns/mock" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_loop(t *testing.T) { 16 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 17 | 18 | middleware.Register("loop", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 19 | middleware.Setup(&config.Config{}) 20 | 21 | l := middleware.Get("loop").(*Loop) 22 | 23 | assert.Equal(t, "loop", l.Name()) 24 | 25 | ch := middleware.NewChain([]middleware.Handler{l, l, l, l, l, l, l, l, l, l, l}) 26 | 27 | ctx := context.Background() 28 | 29 | mw := mock.NewWriter("udp", "127.0.0.1:0") 30 | req := new(dns.Msg) 31 | req.SetQuestion("example.com.", dns.TypeA) 32 | ch.Reset(mw, req) 33 | l.ServeDNS(ctx, ch) 34 | assert.Equal(t, dns.RcodeServerFailure, mw.Msg().Rcode) 35 | 36 | mw = mock.NewWriter("udp", "127.0.0.1:0") 37 | req.Question = []dns.Question{} 38 | ch.Reset(mw, req) 39 | l.ServeDNS(ctx, ch) 40 | assert.Nil(t, mw.Msg()) 41 | } 42 | -------------------------------------------------------------------------------- /middleware/metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/prometheus/client_golang/prometheus" 9 | "github.com/semihalev/sdns/config" 10 | "github.com/semihalev/sdns/middleware" 11 | ) 12 | 13 | // Metrics type 14 | type Metrics struct { 15 | queries *prometheus.CounterVec 16 | } 17 | 18 | // New return new metrics 19 | func New(cfg *config.Config) *Metrics { 20 | m := &Metrics{ 21 | queries: prometheus.NewCounterVec( 22 | prometheus.CounterOpts{ 23 | Name: "dns_queries_total", 24 | Help: "How many DNS queries processed", 25 | }, 26 | []string{"qtype", "rcode"}, 27 | ), 28 | } 29 | _ = prometheus.Register(m.queries) 30 | 31 | return m 32 | } 33 | 34 | // Name return middleware name 35 | func (m *Metrics) Name() string { return name } 36 | 37 | // ServeDNS implements the Handle interface. 38 | func (m *Metrics) ServeDNS(ctx context.Context, ch *middleware.Chain) { 39 | ch.Next(ctx) 40 | 41 | if !ch.Writer.Written() { 42 | return 43 | } 44 | 45 | labels := AcquireLabels() 46 | defer ReleaseLabels(labels) 47 | 48 | labels["qtype"] = dns.TypeToString[ch.Request.Question[0].Qtype] 49 | labels["rcode"] = dns.RcodeToString[ch.Writer.Rcode()] 50 | 51 | m.queries.With(labels).Inc() 52 | } 53 | 54 | var labelsPool sync.Pool 55 | 56 | // AcquireLabels returns a label from pool 57 | func AcquireLabels() prometheus.Labels { 58 | x := labelsPool.Get() 59 | if x == nil { 60 | return prometheus.Labels{"qtype": "", "rcode": ""} 61 | } 62 | 63 | return x.(prometheus.Labels) 64 | } 65 | 66 | // ReleaseLabels returns labels to pool 67 | func ReleaseLabels(labels prometheus.Labels) { 68 | labelsPool.Put(labels) 69 | } 70 | 71 | const name = "metrics" 72 | -------------------------------------------------------------------------------- /middleware/metrics/metrics_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | "github.com/semihalev/sdns/mock" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func Test_Metrics(t *testing.T) { 15 | middleware.Register("metrics", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 16 | middleware.Setup(&config.Config{}) 17 | 18 | m := middleware.Get("metrics").(*Metrics) 19 | 20 | assert.Equal(t, "metrics", m.Name()) 21 | 22 | ch := middleware.NewChain([]middleware.Handler{}) 23 | 24 | mw := mock.NewWriter("udp", "127.0.0.1:0") 25 | req := new(dns.Msg) 26 | req.SetQuestion("test.com.", dns.TypeA) 27 | 28 | ch.Reset(mw, req) 29 | 30 | m.ServeDNS(context.Background(), ch) 31 | assert.Equal(t, dns.RcodeServerFailure, mw.Rcode()) 32 | 33 | _ = ch.Writer.WriteMsg(req) 34 | assert.Equal(t, true, ch.Writer.Written()) 35 | 36 | m.ServeDNS(context.Background(), ch) 37 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode()) 38 | } 39 | -------------------------------------------------------------------------------- /middleware/middleware.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "plugin" 7 | "sync" 8 | 9 | "github.com/semihalev/log" 10 | "github.com/semihalev/sdns/config" 11 | ) 12 | 13 | // Handler interface 14 | type Handler interface { 15 | Name() string 16 | ServeDNS(context.Context, *Chain) 17 | } 18 | 19 | type middleware struct { 20 | mu sync.RWMutex 21 | 22 | cfg *config.Config 23 | handlers []handler 24 | } 25 | 26 | type handler struct { 27 | name string 28 | new func(*config.Config) Handler 29 | } 30 | 31 | var ( 32 | chainHandlers []Handler 33 | setup bool 34 | m middleware 35 | ) 36 | 37 | // Register a middleware 38 | func Register(name string, new func(*config.Config) Handler) { 39 | RegisterAt(name, new, len(m.handlers)) 40 | } 41 | 42 | // RegisterAt a middleware at an index 43 | func RegisterAt(name string, new func(*config.Config) Handler, idx int) { 44 | log.Debug("Register middleware", "name", name, "index", idx) 45 | 46 | m.mu.Lock() 47 | defer m.mu.Unlock() 48 | 49 | m.handlers = append(m.handlers, handler{}) 50 | copy(m.handlers[idx+1:], m.handlers[idx:]) 51 | m.handlers[idx] = handler{name: name, new: new} 52 | } 53 | 54 | // RegisterBefore a middleware before another middleware 55 | func RegisterBefore(name string, new func(*config.Config) Handler, before string) { 56 | log.Debug("Register middleware", "name", name, "before", before) 57 | 58 | m.mu.Lock() 59 | defer m.mu.Unlock() 60 | 61 | for idx, v := range m.handlers { 62 | if v.name == before { 63 | m.handlers = append(m.handlers, handler{}) 64 | copy(m.handlers[idx+1:], m.handlers[idx:]) 65 | m.handlers[idx] = handler{name: name, new: new} 66 | return 67 | } 68 | } 69 | 70 | panic(fmt.Sprintf("Middleware %s not found", before)) 71 | } 72 | 73 | // Setup handlers 74 | func Setup(cfg *config.Config) { 75 | if setup { 76 | panic("middleware setup already done") 77 | } 78 | 79 | m.cfg = cfg 80 | 81 | LoadExternalPlugins() 82 | 83 | m.mu.Lock() 84 | defer m.mu.Unlock() 85 | 86 | for i, handler := range m.handlers { 87 | h := handler.new(m.cfg) 88 | chainHandlers = append(chainHandlers, h) 89 | 90 | log.Debug("Middleware registered", "name", h.Name(), "index", i) 91 | } 92 | 93 | setup = true 94 | } 95 | 96 | // LoadExternalPlugins load external plugins into chain 97 | func LoadExternalPlugins() { 98 | for name, pcfg := range m.cfg.Plugins { 99 | pl, err := plugin.Open(pcfg.Path) 100 | if err != nil { 101 | log.Error("Plugin open failed", "plugin", name, "error", err.Error()) 102 | continue 103 | } 104 | 105 | newFuncSym, err := pl.Lookup("New") 106 | if err != nil { 107 | log.Error("Plugin new function lookup failed", "plugin", name, "error", err.Error()) 108 | continue 109 | } 110 | 111 | newFn, ok := newFuncSym.(func(cfg *config.Config) Handler) 112 | 113 | if !ok { 114 | log.Error("Plugin new function assert failed", "plugin", name) 115 | continue 116 | } 117 | 118 | RegisterBefore(name, newFn, "cache") 119 | log.Info("Plugin successfully loaded", "plugin", name, "path", pcfg.Path) 120 | } 121 | } 122 | 123 | // Handlers return registered handlers 124 | func Handlers() []Handler { 125 | handlers := chainHandlers 126 | return handlers 127 | } 128 | 129 | // List return names of handlers 130 | func List() (list []string) { 131 | m.mu.RLock() 132 | defer m.mu.RUnlock() 133 | 134 | for _, handler := range m.handlers { 135 | list = append(list, handler.name) 136 | } 137 | 138 | return list 139 | } 140 | 141 | // Get return a handler by name 142 | func Get(name string) Handler { 143 | if !setup { 144 | return nil 145 | } 146 | 147 | m.mu.RLock() 148 | defer m.mu.RUnlock() 149 | 150 | for i, handler := range m.handlers { 151 | if handler.name == name { 152 | if len(chainHandlers) <= i { 153 | return nil 154 | } 155 | return chainHandlers[i] 156 | } 157 | } 158 | 159 | return nil 160 | } 161 | 162 | // Ready return true if middleware setup was done 163 | func Ready() bool { 164 | m.mu.RLock() 165 | defer m.mu.RUnlock() 166 | 167 | return setup 168 | } 169 | -------------------------------------------------------------------------------- /middleware/middleware_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/semihalev/sdns/config" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type dummy struct{} 12 | 13 | func (d *dummy) ServeDNS(ctx context.Context, ch *Chain) { ch.Next(ctx) } 14 | func (d *dummy) Name() string { return "dummy" } 15 | 16 | func Test_Middleware(t *testing.T) { 17 | Register("dummy", func(*config.Config) Handler { 18 | return &dummy{} 19 | }) 20 | 21 | cfg := &config.Config{} 22 | 23 | d := Get("dummy") 24 | assert.Nil(t, d) 25 | 26 | assert.NotPanics(t, func() { 27 | Setup(cfg) 28 | }) 29 | 30 | assert.True(t, Ready()) 31 | 32 | assert.Panics(t, func() { 33 | Setup(cfg) 34 | }) 35 | assert.True(t, len(List()) == 1) 36 | assert.True(t, len(Handlers()) == 1) 37 | 38 | d = Get("dummy") 39 | assert.NotNil(t, d) 40 | 41 | d = Get("none") 42 | assert.Nil(t, d) 43 | 44 | chainHandlers = []Handler{} 45 | d = Get("dummy") 46 | assert.Nil(t, d) 47 | } 48 | 49 | func Test_RegisterAt(t *testing.T) { 50 | m.handlers = []handler{} 51 | 52 | Register("dummy", func(*config.Config) Handler { 53 | return &dummy{} 54 | }) 55 | RegisterAt("dummy2", func(*config.Config) Handler { 56 | return &dummy{} 57 | }, 0) 58 | 59 | assert.True(t, len(m.handlers) == 2) 60 | assert.True(t, m.handlers[0].name == "dummy2") 61 | assert.True(t, m.handlers[1].name == "dummy") 62 | 63 | RegisterBefore("dummy3", func(*config.Config) Handler { 64 | return &dummy{} 65 | }, "dummy") 66 | assert.True(t, len(m.handlers) == 3) 67 | assert.True(t, m.handlers[0].name == "dummy2") 68 | assert.True(t, m.handlers[1].name == "dummy3") 69 | assert.True(t, m.handlers[2].name == "dummy") 70 | 71 | assert.Panics(t, func() { 72 | RegisterAt("dummy4", func(*config.Config) Handler { 73 | return &dummy{} 74 | }, 4) 75 | }) 76 | assert.Panics(t, func() { 77 | RegisterAt("dummy5", func(*config.Config) Handler { 78 | return &dummy{} 79 | }, -1) 80 | }) 81 | assert.Panics(t, func() { 82 | RegisterBefore("dummy6", func(*config.Config) Handler { 83 | return &dummy{} 84 | }, "noexist") 85 | }) 86 | 87 | m.handlers = []handler{} 88 | } 89 | -------------------------------------------------------------------------------- /middleware/ratelimit/ratelimit.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/cespare/xxhash/v2" 10 | "github.com/miekg/dns" 11 | "github.com/semihalev/sdns/cache" 12 | "github.com/semihalev/sdns/config" 13 | "github.com/semihalev/sdns/dnsutil" 14 | "github.com/semihalev/sdns/middleware" 15 | "golang.org/x/time/rate" 16 | ) 17 | 18 | type limiter struct { 19 | rl *rate.Limiter 20 | cookie atomic.Value 21 | } 22 | 23 | // RateLimit type 24 | type RateLimit struct { 25 | cookiesecret string 26 | 27 | cache *cache.Cache 28 | rate int 29 | } 30 | 31 | // New return accesslist 32 | func New(cfg *config.Config) *RateLimit { 33 | r := &RateLimit{ 34 | cache: cache.New(cacheSize), 35 | cookiesecret: cfg.CookieSecret, 36 | rate: cfg.ClientRateLimit, 37 | } 38 | 39 | return r 40 | } 41 | 42 | // Name return middleware name 43 | func (r *RateLimit) Name() string { return name } 44 | 45 | // ServeDNS implements the Handle interface. 46 | func (r *RateLimit) ServeDNS(ctx context.Context, ch *middleware.Chain) { 47 | w, req := ch.Writer, ch.Request 48 | 49 | if w.Internal() { 50 | ch.Next(ctx) 51 | return 52 | } 53 | 54 | if r.rate == 0 { 55 | ch.Next(ctx) 56 | return 57 | } 58 | 59 | if w.RemoteIP() == nil { 60 | ch.Next(ctx) 61 | return 62 | } else if w.RemoteIP().IsLoopback() { 63 | ch.Next(ctx) 64 | return 65 | } 66 | 67 | var cachedcookie, clientcookie, servercookie string 68 | 69 | l := r.getLimiter(w.RemoteIP()) 70 | cachedcookie = l.cookie.Load().(string) 71 | 72 | if opt := req.IsEdns0(); opt != nil { 73 | for _, option := range opt.Option { 74 | switch option.Option() { 75 | case dns.EDNS0COOKIE: 76 | if len(option.String()) >= cookieSize { 77 | clientcookie = option.String()[:cookieSize] 78 | servercookie = dnsutil.GenerateServerCookie(r.cookiesecret, w.RemoteIP().String(), clientcookie) 79 | 80 | if cachedcookie == "" || cachedcookie == option.String() { 81 | ch.Next(ctx) 82 | 83 | l.cookie.Store(servercookie) 84 | return 85 | } 86 | 87 | if w.Proto() == "udp" { 88 | if !l.rl.Allow() { 89 | ch.Cancel() 90 | return 91 | } 92 | 93 | l.cookie.Store(servercookie) 94 | option.(*dns.EDNS0_COOKIE).Cookie = servercookie 95 | 96 | ch.CancelWithRcode(dns.RcodeBadCookie, false) 97 | 98 | return 99 | } 100 | } 101 | } 102 | } 103 | } 104 | 105 | if !l.rl.Allow() { 106 | //no reply to client 107 | ch.Cancel() 108 | return 109 | } 110 | 111 | ch.Next(ctx) 112 | 113 | if servercookie != "" { 114 | l.cookie.Store(servercookie) 115 | } 116 | } 117 | 118 | func (r *RateLimit) getLimiter(remoteip net.IP) *limiter { 119 | xxhash := xxhash.New() 120 | _, _ = xxhash.Write(remoteip) 121 | key := xxhash.Sum64() 122 | 123 | if v, ok := r.cache.Get(key); ok { 124 | return v.(*limiter) 125 | } 126 | 127 | limit := rate.Limit(0) 128 | if r.rate > 0 { 129 | limit = rate.Every(time.Minute / time.Duration(r.rate)) 130 | } 131 | 132 | rl := rate.NewLimiter(limit, r.rate) 133 | 134 | l := &limiter{rl: rl} 135 | l.cookie.Store("") 136 | 137 | r.cache.Add(key, l) 138 | 139 | return l 140 | } 141 | 142 | const ( 143 | cacheSize = 256 * 100 144 | cookieSize = 16 145 | 146 | name = "ratelimit" 147 | ) 148 | -------------------------------------------------------------------------------- /middleware/ratelimit/ratelimit_test.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/config" 9 | "github.com/semihalev/sdns/middleware" 10 | "github.com/semihalev/sdns/mock" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func Test_RateLimit(t *testing.T) { 15 | cfg := new(config.Config) 16 | cfg.ClientRateLimit = 1 17 | 18 | middleware.Register("ratelimit", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 19 | middleware.Setup(cfg) 20 | 21 | r := middleware.Get("ratelimit").(*RateLimit) 22 | 23 | assert.Equal(t, "ratelimit", r.Name()) 24 | 25 | ch := middleware.NewChain([]middleware.Handler{}) 26 | 27 | req := new(dns.Msg) 28 | req.SetQuestion("example.com.", dns.TypeA) 29 | req.SetEdns0(4096, true) 30 | 31 | opt := req.IsEdns0() 32 | opt.Option = append(opt.Option, &dns.EDNS0_COOKIE{ 33 | Code: dns.EDNS0COOKIE, 34 | Cookie: "testtesttesttest", 35 | }) 36 | 37 | mw := mock.NewWriter("udp", "") 38 | ch.Reset(mw, req) 39 | r.ServeDNS(context.Background(), ch) 40 | 41 | mw = mock.NewWriter("udp", "10.0.0.1:0") 42 | ch.Reset(mw, req) 43 | r.ServeDNS(context.Background(), ch) 44 | r.ServeDNS(context.Background(), ch) 45 | if assert.True(t, mw.Written()) { 46 | assert.Equal(t, dns.RcodeBadCookie, mw.Rcode()) 47 | } 48 | 49 | opt.Option = nil 50 | opt.Option = append(opt.Option, &dns.EDNS0_COOKIE{ 51 | Code: dns.EDNS0COOKIE, 52 | Cookie: "testtesttesttest", 53 | }) 54 | 55 | mw = mock.NewWriter("udp", "10.0.0.1:0") 56 | ch.Reset(mw, req) 57 | r.ServeDNS(context.Background(), ch) 58 | assert.False(t, mw.Written()) 59 | 60 | mw = mock.NewWriter("tcp", "10.0.0.2:0") 61 | ch.Reset(mw, req) 62 | r.ServeDNS(context.Background(), ch) 63 | r.ServeDNS(context.Background(), ch) 64 | assert.False(t, mw.Written()) 65 | 66 | opt.Option = nil 67 | mw = mock.NewWriter("udp", "10.0.0.1:0") 68 | ch.Reset(mw, req) 69 | r.ServeDNS(context.Background(), ch) 70 | r.ServeDNS(context.Background(), ch) 71 | assert.False(t, mw.Written()) 72 | 73 | mw = mock.NewWriter("udp", "0.0.0.0:0") 74 | ch.Reset(mw, req) 75 | r.ServeDNS(context.Background(), ch) 76 | 77 | mw = mock.NewWriter("udp", "127.0.0.1:0") 78 | ch.Reset(mw, req) 79 | r.ServeDNS(context.Background(), ch) 80 | 81 | r.rate = 0 82 | 83 | r.ServeDNS(context.Background(), ch) 84 | } 85 | -------------------------------------------------------------------------------- /middleware/recovery/recovery.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "runtime/debug" 8 | 9 | "github.com/miekg/dns" 10 | "github.com/semihalev/log" 11 | "github.com/semihalev/sdns/config" 12 | "github.com/semihalev/sdns/middleware" 13 | ) 14 | 15 | // Recovery dummy type 16 | type Recovery struct{} 17 | 18 | // New return recovery 19 | func New(cfg *config.Config) *Recovery { 20 | return &Recovery{} 21 | } 22 | 23 | // Name return middleware name 24 | func (r *Recovery) Name() string { return name } 25 | 26 | // ServeDNS implements the Handle interface. 27 | func (r *Recovery) ServeDNS(ctx context.Context, ch *middleware.Chain) { 28 | defer func() { 29 | if r := recover(); r != nil { 30 | ch.CancelWithRcode(dns.RcodeServerFailure, false) 31 | 32 | log.Error("Recovered in ServeDNS", "recover", r) 33 | 34 | _, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n\n", r)) 35 | debug.PrintStack() 36 | } 37 | }() 38 | 39 | ch.Next(ctx) 40 | } 41 | 42 | const name = "recovery" 43 | -------------------------------------------------------------------------------- /middleware/recovery/recovery_test.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/log" 10 | "github.com/semihalev/sdns/config" 11 | "github.com/semihalev/sdns/middleware" 12 | "github.com/semihalev/sdns/mock" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func Test_Recovery(t *testing.T) { 17 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 18 | 19 | stderr := os.Stderr 20 | os.Stderr, _ = os.Open(os.DevNull) 21 | 22 | middleware.Register("recovery", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 23 | middleware.Setup(&config.Config{}) 24 | 25 | r := middleware.Get("recovery").(*Recovery) 26 | 27 | assert.Equal(t, "recovery", r.Name()) 28 | 29 | ch := middleware.NewChain([]middleware.Handler{r, nil}) 30 | 31 | mw := mock.NewWriter("udp", "127.0.0.1:0") 32 | req := new(dns.Msg) 33 | req.SetQuestion("test.com.", dns.TypeA) 34 | 35 | ch.Reset(mw, req) 36 | 37 | r.ServeDNS(context.Background(), ch) 38 | 39 | assert.Equal(t, dns.RcodeServerFailure, mw.Msg().Rcode) 40 | 41 | ch = middleware.NewChain([]middleware.Handler{r}) 42 | ch.Reset(mw, req) 43 | r.ServeDNS(context.Background(), ch) 44 | 45 | os.Stderr = stderr 46 | } 47 | -------------------------------------------------------------------------------- /middleware/resolver/auto_trust_anchor_test.go: -------------------------------------------------------------------------------- 1 | package resolver_test 2 | 3 | import ( 4 | "context" 5 | "crypto/x509" 6 | "encoding/base64" 7 | "net" 8 | "os" 9 | "path/filepath" 10 | "testing" 11 | "time" 12 | 13 | "github.com/miekg/dns" 14 | "github.com/semihalev/log" 15 | "github.com/semihalev/sdns/authcache" 16 | "github.com/semihalev/sdns/config" 17 | "github.com/semihalev/sdns/middleware/resolver" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | var ( 22 | privateKey = "MIICXAIBAAKBgQCWIKiOFx/LqVppaUSfW2a9hEnfUS+Qb752/fL3odiGQxCrxcmcEXvn+APSN3ipRetdLdHeB7FSZQ4eIhBtgKjBuFlqQj8pnZOWhV16w80HFjYg/ea9nhG8IziTzK/lsSIk2cDTe1k9kD5WUaLRijLJEEy7gLkOOFmt3Ho675dw2QIDAQABAoGANYEsMX/iUBZqZ5kh4N2Vb0O/hDyOBB8fNY9qUYE4BxnNzjpukRXWICVPT1N/yGxn5syWuFfrhZ8IegrP6gbpnZ1ViYRONOkrfoGOm/U71IL8mlr/NCrxAd/ifB4Db1HOEvlewwQ3G8+HE7HBAjYpup+w4Yw/Du2Cw6dtlJ9MmWUCQQDDHj0MCxWus38EHBwueVjmKq/gE1oZpuLCGmjVZIXlA7yw4IlQU27Y+XlEdVJIMRIUQ1K7Zdw/KFU+aKfBmYy3AkEAxPijYGdWPSZDZn/9tPMfdBtipz1wXHREHLBNOPOcgP4TjVoqBY8Yl6ZUwwjTA8C2JZ1ZU4oSUyLuecbH/N1+7wJAVb2w99zbH1UTSLwNikKa1TIG7UGzwzf5x3ARh0xQJk4ZGeThkmHHgSNHrdScXsrpdewLq/vb6AkSRIV6ynFuSwJAHLk5cfR/0fkDeS4O/FU77/2SXFsMSJ8304suJ7D20KS8iy9r01Wzu2GpGKvvwatXpJKWlSUcWP1OE3oWbdyLBwJBAMCcuKf9EIw9Wgkt9KKhJXKpSqUr1xN+3WZf4bmQl4nT1mITMPcmnQla/JYepnspYrt06L16Ed8vf4u8AbEW68I=" 23 | ) 24 | 25 | type dummyHandler struct { 26 | DNSKEY []dns.RR 27 | } 28 | 29 | func (d *dummyHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 30 | resp := new(dns.Msg) 31 | resp.Question = r.Question 32 | resp.Authoritative = true 33 | resp.SetRcode(r, dns.RcodeSuccess) 34 | 35 | if r.Question[0].Qtype == dns.TypeDNSKEY { 36 | resp.Answer = append(resp.Answer, d.DNSKEY...) 37 | } 38 | 39 | _ = w.WriteMsg(resp) 40 | } 41 | 42 | func makeRootKeysConfig() *config.Config { 43 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 44 | 45 | cfg := new(config.Config) 46 | cfg.RootServers = []string{"127.0.0.1:44302"} 47 | cfg.RootKeys = []string{ 48 | ". 86400 IN DNSKEY 257 3 8 AwEAAZYgqI4XH8upWmlpRJ9bZr2ESd9RL5Bvvnb98veh2IZDEKvFyZwRe+f4A9I3eKlF610t0d4HsVJlDh4iEG2AqMG4WWpCPymdk5aFXXrDzQcWNiD95r2eEbwjOJPMr+WxIiTZwNN7WT2QPlZRotGKMskQTLuAuQ44Wa3cejrvl3DZ", 49 | ". 86400 IN DNSKEY 257 3 8 AwEAAcjLi71oV55rThFidre9DgnEgJwOmPPg0XwWmFkz3uNoT3+SaT6hErHuJS2I8+vc4rZIoGaNdlOrsNBEqyfaikDniq6+PwdNFK8Adt8xBCh9YOZkexdb8i59MABbv1TtJ130O9L8OQ9MOfJfyLm9UknV4D5y8HDDOBcjkJ2U4DRx", 50 | ". 86400 IN DNSKEY 257 3 8 AwEAAc78m2ldR7iPdjdFZlGheNgdUclcSrPSx+E5s0XWiW6nBaDDawTICkwWI7m7Uzuva1myKkZKgidtwmmxS1P6/xjsRCn1xEjPXvim5Xzr0gjsp16KFQsR8IALGu6dxJYn7WHB+UdT3yiV0x6FwAVb/ilYsOMmn3S/oaaTx4Oh7OEL", 51 | ". 86400 IN DNSKEY 385 3 8 AwEAAcZMCRaHx5n6GWwAWFrwhnfheNafQfaoBcn1IfwmQ8RD/0V+WAeFU+CVH+zinmqamv/V+zF2FF03WZjuq5HpOtqFVAKsQmC3Wb6DttjwzgNs0Iywgy/Ae8QZp03WApVmzcr4hDvxXeP5ABMwf8vR7gF/JtArb2Mlnekh/7sWs/wr", 52 | } 53 | cfg.Maxdepth = 30 54 | cfg.Expire = 600 55 | cfg.CacheSize = 1024 56 | cfg.Timeout.Duration = 2 * time.Second 57 | cfg.Directory = filepath.Join(os.TempDir(), "sdns_temp_autota") 58 | cfg.IPv6Access = false 59 | 60 | _ = os.Mkdir(cfg.Directory, 0777) 61 | 62 | return cfg 63 | } 64 | 65 | func runTestServer() { 66 | buf, _ := base64.StdEncoding.DecodeString(privateKey) 67 | privKey, _ := x509.ParsePKCS1PrivateKey(buf) 68 | 69 | dnskey1 := &dns.DNSKEY{ 70 | Hdr: dns.RR_Header{ 71 | Name: ".", 72 | Rrtype: dns.TypeDNSKEY, 73 | Class: dns.ClassINET, 74 | Ttl: 86400, 75 | }, 76 | Algorithm: 8, 77 | Flags: 257, 78 | Protocol: 3, 79 | PublicKey: "AwEAAZYgqI4XH8upWmlpRJ9bZr2ESd9RL5Bvvnb98veh2IZDEKvFyZwRe+f4A9I3eKlF610t0d4HsVJlDh4iEG2AqMG4WWpCPymdk5aFXXrDzQcWNiD95r2eEbwjOJPMr+WxIiTZwNN7WT2QPlZRotGKMskQTLuAuQ44Wa3cejrvl3DZ", 80 | } 81 | 82 | dnskey2 := &dns.DNSKEY{ 83 | Hdr: dns.RR_Header{ 84 | Name: ".", 85 | Rrtype: dns.TypeDNSKEY, 86 | Class: dns.ClassINET, 87 | Ttl: 86400, 88 | }, 89 | Algorithm: 8, 90 | Flags: 385, 91 | Protocol: 3, 92 | PublicKey: "AwEAAcjLi71oV55rThFidre9DgnEgJwOmPPg0XwWmFkz3uNoT3+SaT6hErHuJS2I8+vc4rZIoGaNdlOrsNBEqyfaikDniq6+PwdNFK8Adt8xBCh9YOZkexdb8i59MABbv1TtJ130O9L8OQ9MOfJfyLm9UknV4D5y8HDDOBcjkJ2U4DRx", 93 | } 94 | 95 | rrsigdk := &dns.RRSIG{ 96 | Hdr: dns.RR_Header{ 97 | Name: ".", 98 | Rrtype: dns.TypeRRSIG, 99 | Class: dns.ClassINET, 100 | Ttl: 86400, 101 | }, 102 | TypeCovered: dns.TypeDNSKEY, 103 | Algorithm: 8, 104 | SignerName: ".", 105 | KeyTag: dnskey1.KeyTag(), 106 | Inception: uint32(time.Now().UTC().Unix()), 107 | Expiration: uint32(time.Now().UTC().Add(15 * 24 * time.Hour).Unix()), 108 | OrigTtl: 3600, 109 | } 110 | 111 | dnskey3 := &dns.DNSKEY{ 112 | Hdr: dns.RR_Header{ 113 | Name: ".", 114 | Rrtype: dns.TypeDNSKEY, 115 | Class: dns.ClassINET, 116 | Ttl: 86400, 117 | }, 118 | Algorithm: 8, 119 | Flags: 257, 120 | Protocol: 3, 121 | } 122 | 123 | _, _ = dnskey3.Generate(1024) 124 | 125 | _ = rrsigdk.Sign(privKey, []dns.RR{dnskey1, dnskey2, dnskey3}) 126 | 127 | /*if s, ok := privkey.(*rsa.PrivateKey); ok { 128 | buf := x509.MarshalPKCS1PrivateKey(s) 129 | fmt.Println(base64.StdEncoding.EncodeToString(buf)) 130 | rrsig.Sign(s, []dns.RR{dnskey}) 131 | }*/ 132 | 133 | go func() { 134 | _ = dns.ListenAndServe("127.0.0.1:44302", "udp", &dummyHandler{DNSKEY: []dns.RR{dnskey1, dnskey2, dnskey3, rrsigdk}}) 135 | }() 136 | } 137 | 138 | func Test_autota(t *testing.T) { 139 | runTestServer() 140 | 141 | cfg := makeRootKeysConfig() 142 | 143 | r := resolver.NewResolver(cfg) 144 | 145 | time.Sleep(time.Second) 146 | 147 | req := new(dns.Msg) 148 | req.SetQuestion(".", dns.TypeDNSKEY) 149 | req.SetEdns0(1400, true) 150 | 151 | rootservers := &authcache.AuthServers{} 152 | rootservers.Zone = "." 153 | 154 | for _, s := range cfg.RootServers { 155 | host, _, _ := net.SplitHostPort(s) 156 | if ip := net.ParseIP(host); ip != nil && ip.To4() != nil { 157 | rootservers.List = append(rootservers.List, authcache.NewAuthServer(s, authcache.IPv4)) 158 | } 159 | } 160 | 161 | resp, err := r.Resolve(context.Background(), req, rootservers, true, 30, 0, false, nil) 162 | 163 | assert.True(t, resp.AuthenticatedData) 164 | assert.NoError(t, err) 165 | assert.Len(t, resp.Answer, 4) 166 | 167 | os.Remove(filepath.Join(cfg.Directory, "trust-anchor.db")) 168 | } 169 | -------------------------------------------------------------------------------- /middleware/resolver/client.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | // Originally this Client from github.com/miekg/dns 4 | // Adapted for resolver package usage by Semih Alev. 5 | 6 | import ( 7 | "encoding/binary" 8 | "errors" 9 | "io" 10 | "net" 11 | "sync" 12 | "time" 13 | 14 | "github.com/miekg/dns" 15 | ) 16 | 17 | const ( 18 | headerSize = 12 19 | ) 20 | 21 | // A Conn represents a connection to a DNS server. 22 | type Conn struct { 23 | net.Conn // a net.Conn holding the connection 24 | UDPSize uint16 // minimum receive buffer for UDP messages 25 | } 26 | 27 | // Exchange performs a synchronous query 28 | func (co *Conn) Exchange(m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) { 29 | 30 | opt := m.IsEdns0() 31 | // If EDNS0 is used use that for size. 32 | if opt != nil && opt.UDPSize() >= dns.MinMsgSize { 33 | co.UDPSize = opt.UDPSize() 34 | } 35 | 36 | if opt == nil && co.UDPSize < dns.MinMsgSize { 37 | co.UDPSize = dns.MinMsgSize 38 | } 39 | 40 | t := time.Now() 41 | 42 | if err = co.WriteMsg(m); err != nil { 43 | return nil, 0, err 44 | } 45 | 46 | r, err = co.ReadMsg() 47 | if err == nil && r.Id != m.Id { 48 | err = dns.ErrId 49 | } 50 | 51 | rtt = time.Since(t) 52 | 53 | return r, rtt, err 54 | } 55 | 56 | // ReadMsg reads a message from the connection co. 57 | // If the received message contains a TSIG record the transaction signature 58 | // is verified. This method always tries to return the message, however if an 59 | // error is returned there are no guarantees that the returned message is a 60 | // valid representation of the packet read. 61 | func (co *Conn) ReadMsg() (*dns.Msg, error) { 62 | var ( 63 | p []byte 64 | n int 65 | err error 66 | ) 67 | 68 | if _, ok := co.Conn.(net.PacketConn); ok { 69 | p = AcquireBuf(co.UDPSize) 70 | n, err = co.Read(p) 71 | } else { 72 | var length uint16 73 | if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { 74 | return nil, err 75 | } 76 | 77 | p = AcquireBuf(length) 78 | n, err = io.ReadFull(co.Conn, p) 79 | } 80 | 81 | if err != nil { 82 | return nil, err 83 | } else if n < headerSize { 84 | return nil, dns.ErrShortRead 85 | } 86 | 87 | defer ReleaseBuf(p) 88 | 89 | m := new(dns.Msg) 90 | if err := m.Unpack(p); err != nil { 91 | // If an error was returned, we still want to allow the user to use 92 | // the message, but naively they can just check err if they don't want 93 | // to use an erroneous message 94 | return m, err 95 | } 96 | return m, err 97 | } 98 | 99 | // Read implements the net.Conn read method. 100 | func (co *Conn) Read(p []byte) (n int, err error) { 101 | if co.Conn == nil { 102 | return 0, dns.ErrConnEmpty 103 | } 104 | 105 | if _, ok := co.Conn.(net.PacketConn); ok { 106 | // UDP connection 107 | return co.Conn.Read(p) 108 | } 109 | 110 | var length uint16 111 | if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { 112 | return 0, err 113 | } 114 | if int(length) > len(p) { 115 | return 0, io.ErrShortBuffer 116 | } 117 | 118 | return io.ReadFull(co.Conn, p[:length]) 119 | } 120 | 121 | // WriteMsg sends a message through the connection co. 122 | // If the message m contains a TSIG record the transaction 123 | // signature is calculated. 124 | func (co *Conn) WriteMsg(m *dns.Msg) (err error) { 125 | size := uint16(m.Len()) + 1 126 | 127 | out := AcquireBuf(size) 128 | defer ReleaseBuf(out) 129 | 130 | out, err = m.PackBuffer(out) 131 | if err != nil { 132 | return err 133 | } 134 | _, err = co.Write(out) 135 | return err 136 | } 137 | 138 | // Write implements the net.Conn Write method. 139 | func (co *Conn) Write(p []byte) (int, error) { 140 | if len(p) > dns.MaxMsgSize { 141 | return 0, errors.New("message too large") 142 | } 143 | 144 | if _, ok := co.Conn.(net.PacketConn); ok { 145 | return co.Conn.Write(p) 146 | } 147 | 148 | l := make([]byte, 2) 149 | binary.BigEndian.PutUint16(l, uint16(len(p))) 150 | 151 | n, err := (&net.Buffers{l, p}).WriteTo(co.Conn) 152 | return int(n), err 153 | } 154 | 155 | var bufferPool sync.Pool 156 | 157 | // AcquireBuf returns an buf from pool 158 | func AcquireBuf(size uint16) []byte { 159 | x := bufferPool.Get() 160 | if x == nil { 161 | return make([]byte, size) 162 | } 163 | buf := *(x.(*[]byte)) 164 | if cap(buf) < int(size) { 165 | return make([]byte, size) 166 | } 167 | return buf[:size] 168 | } 169 | 170 | // ReleaseBuf returns buf to pool 171 | func ReleaseBuf(buf []byte) { 172 | bufferPool.Put(&buf) 173 | } 174 | -------------------------------------------------------------------------------- /middleware/resolver/client_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | "time" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/sdns/dnsutil" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_ClientTimeout(t *testing.T) { 14 | req := new(dns.Msg) 15 | req.SetQuestion(".", dns.TypeNS) 16 | req.SetEdns0(dnsutil.DefaultMsgSize, true) 17 | 18 | dialer := &net.Dialer{Deadline: time.Now().Add(2 * time.Second)} 19 | co := &Conn{} 20 | 21 | var err error 22 | co.Conn, err = dialer.Dial("udp4", "127.1.0.255:53") 23 | assert.NoError(t, err) 24 | 25 | err = co.SetDeadline(time.Now().Add(2 * time.Second)) 26 | assert.NoError(t, err) 27 | 28 | _, _, err = co.Exchange(req) 29 | assert.Error(t, err) 30 | assert.NoError(t, co.Close()) 31 | } 32 | 33 | func Test_Client(t *testing.T) { 34 | req := new(dns.Msg) 35 | req.SetQuestion(".", dns.TypeNS) 36 | req.SetEdns0(dnsutil.DefaultMsgSize, true) 37 | 38 | dialer := &net.Dialer{Deadline: time.Now().Add(2 * time.Second)} 39 | co := &Conn{} 40 | 41 | var err error 42 | co.Conn, err = dialer.Dial("udp4", "198.41.0.4:53") 43 | assert.NoError(t, err) 44 | 45 | err = co.SetDeadline(time.Now().Add(2 * time.Second)) 46 | assert.NoError(t, err) 47 | 48 | r, _, err := co.Exchange(req) 49 | assert.NoError(t, err) 50 | assert.NotNil(t, r) 51 | } 52 | -------------------------------------------------------------------------------- /middleware/resolver/handler.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "time" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/log" 10 | "github.com/semihalev/sdns/authcache" 11 | "github.com/semihalev/sdns/cache" 12 | "github.com/semihalev/sdns/config" 13 | "github.com/semihalev/sdns/dnsutil" 14 | "github.com/semihalev/sdns/middleware" 15 | ) 16 | 17 | // DNSHandler type 18 | type DNSHandler struct { 19 | resolver *Resolver 20 | cfg *config.Config 21 | } 22 | 23 | type ctxKey string 24 | 25 | var debugns bool 26 | 27 | func init() { 28 | _, debugns = os.LookupEnv("SDNS_DEBUGNS") 29 | } 30 | 31 | // New returns a new Handler 32 | func New(cfg *config.Config) *DNSHandler { 33 | if cfg.Maxdepth == 0 { 34 | cfg.Maxdepth = 30 35 | } 36 | 37 | if cfg.QueryTimeout.Duration == 0 { 38 | cfg.QueryTimeout.Duration = 10 * time.Second 39 | } 40 | 41 | return &DNSHandler{ 42 | resolver: NewResolver(cfg), 43 | cfg: cfg, 44 | } 45 | } 46 | 47 | // Name return middleware name 48 | func (h *DNSHandler) Name() string { return name } 49 | 50 | // ServeDNS implements the Handle interface. 51 | func (h *DNSHandler) ServeDNS(ctx context.Context, ch *middleware.Chain) { 52 | if len(h.cfg.ForwarderServers) > 0 { 53 | ch.Next(ctx) 54 | return 55 | } 56 | 57 | w, req := ch.Writer, ch.Request 58 | 59 | if v := ctx.Value(ctxKey("reqid")); v == nil { 60 | ctx = context.WithValue(ctx, ctxKey("reqid"), req.Id) 61 | } 62 | msg := h.handle(ctx, req) 63 | 64 | _ = w.WriteMsg(msg) 65 | } 66 | 67 | func (h *DNSHandler) handle(ctx context.Context, req *dns.Msg) *dns.Msg { 68 | q := req.Question[0] 69 | 70 | do := false 71 | opt := req.IsEdns0() 72 | if opt != nil { 73 | do = opt.Do() 74 | } 75 | 76 | if q.Qtype == dns.TypeANY { 77 | return dnsutil.SetRcode(req, dns.RcodeNotImplemented, do) 78 | } 79 | 80 | // debug ns stats 81 | if debugns && q.Qclass == dns.ClassCHAOS && q.Qtype == dns.TypeHINFO { 82 | return h.nsStats(req) 83 | } 84 | 85 | // check purge query 86 | if q.Qclass == dns.ClassCHAOS && q.Qtype == dns.TypeNULL { 87 | if qname, qtype, ok := dnsutil.ParsePurgeQuestion(req); ok { 88 | if qtype == dns.TypeNS { 89 | h.purge(qname) 90 | } 91 | 92 | resp := dnsutil.SetRcode(req, dns.RcodeSuccess, do) 93 | txt, _ := dns.NewRR(q.Name + ` 20 IN TXT "cache purged"`) 94 | 95 | resp.Extra = append(resp.Extra, txt) 96 | 97 | return resp 98 | } 99 | } 100 | 101 | if q.Name != rootzone && !req.RecursionDesired { 102 | return dnsutil.SetRcode(req, dns.RcodeServerFailure, do) 103 | } 104 | 105 | // we shouldn't send rd and ad flag to aa servers 106 | req.RecursionDesired = false 107 | req.AuthenticatedData = false 108 | 109 | if !req.CheckingDisabled { 110 | req.CheckingDisabled = !h.resolver.dnssec 111 | } 112 | 113 | ctx, cancel := context.WithDeadline(ctx, time.Now().Add(h.cfg.QueryTimeout.Duration)) 114 | defer cancel() 115 | 116 | depth := h.cfg.Maxdepth 117 | resp, err := h.resolver.Resolve(ctx, req, h.resolver.rootservers, true, depth, 0, false, nil, q.Name == rootzone) 118 | 119 | if !h.resolver.dnssec { 120 | req.CheckingDisabled = false 121 | if resp != nil { 122 | resp.CheckingDisabled = false 123 | } 124 | } 125 | 126 | if err != nil { 127 | log.Info("Resolve query failed", "query", formatQuestion(q), "error", err.Error()) 128 | 129 | return dnsutil.SetRcode(req, dns.RcodeServerFailure, do) 130 | } 131 | 132 | if resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeNotZone { 133 | return dnsutil.SetRcode(req, dns.RcodeServerFailure, do) 134 | } 135 | 136 | return resp 137 | } 138 | 139 | func (h *DNSHandler) nsStats(req *dns.Msg) *dns.Msg { 140 | q := req.Question[0] 141 | 142 | msg := new(dns.Msg) 143 | msg.SetReply(req) 144 | 145 | msg.Authoritative = false 146 | msg.RecursionAvailable = true 147 | 148 | servers := h.resolver.rootservers 149 | ttl := uint32(0) 150 | name := rootzone 151 | 152 | if q.Name != rootzone { 153 | nsKey := cache.Hash(dns.Question{Name: q.Name, Qtype: dns.TypeNS, Qclass: dns.ClassINET}, msg.CheckingDisabled) 154 | ns, err := h.resolver.ncache.Get(nsKey) 155 | if err != nil { 156 | nsKey = cache.Hash(dns.Question{Name: q.Name, Qtype: dns.TypeNS, Qclass: dns.ClassINET}, !msg.CheckingDisabled) 157 | ns, err := h.resolver.ncache.Get(nsKey) 158 | if err == nil { 159 | servers = ns.Servers 160 | name = q.Name 161 | } 162 | } else { 163 | servers = ns.Servers 164 | name = q.Name 165 | } 166 | } 167 | 168 | var serversList []*authcache.AuthServer 169 | 170 | servers.RLock() 171 | serversList = append(serversList, servers.List...) 172 | servers.RUnlock() 173 | 174 | authcache.Sort(serversList, 1) 175 | 176 | rrHeader := dns.RR_Header{ 177 | Name: name, 178 | Rrtype: dns.TypeHINFO, 179 | Class: dns.ClassCHAOS, 180 | Ttl: ttl, 181 | } 182 | 183 | for _, server := range serversList { 184 | hinfo := &dns.HINFO{Hdr: rrHeader, Cpu: "Host", Os: server.String()} 185 | msg.Ns = append(msg.Ns, hinfo) 186 | } 187 | 188 | return msg 189 | } 190 | 191 | func (h *DNSHandler) purge(qname string) { 192 | q := dns.Question{Name: qname, Qtype: dns.TypeNS, Qclass: dns.ClassINET} 193 | 194 | key := cache.Hash(q, false) 195 | h.resolver.ncache.Remove(key) 196 | 197 | key = cache.Hash(q, true) 198 | h.resolver.ncache.Remove(key) 199 | } 200 | 201 | const name = "resolver" 202 | -------------------------------------------------------------------------------- /middleware/resolver/handler_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | "time" 10 | 11 | "github.com/miekg/dns" 12 | "github.com/semihalev/log" 13 | "github.com/semihalev/sdns/config" 14 | "github.com/semihalev/sdns/dnsutil" 15 | "github.com/semihalev/sdns/middleware" 16 | "github.com/semihalev/sdns/middleware/edns" 17 | "github.com/semihalev/sdns/mock" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | func makeTestConfig() *config.Config { 22 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 23 | 24 | cfg := new(config.Config) 25 | cfg.RootServers = []string{"192.5.5.241:53"} 26 | cfg.Root6Servers = []string{"[2001:500:2f::f]:53"} 27 | cfg.RootKeys = []string{ 28 | ". 172800 IN DNSKEY 257 3 8 AwEAAaz/tAm8yTn4Mfeh5eyI96WSVexTBAvkMgJzkKTOiW1vkIbzxeF3+/4RgWOq7HrxRixHlFlExOLAJr5emLvN7SWXgnLh4+B5xQlNVz8Og8kvArMtNROxVQuCaSnIDdD5LKyWbRd2n9WGe2R8PzgCmr3EgVLrjyBxWezF0jLHwVN8efS3rCj/EWgvIWgb9tarpVUDK/b58Da+sqqls3eNbuv7pr+eoZG+SrDK6nWeL3c6H5Apxz7LjVc1uTIdsIXxuOLYA4/ilBmSVIzuDWfdRUfhHdY6+cn8HFRm+2hM8AnXGXws9555KrUB5qihylGa8subX2Nn6UwNR1AkUTV74bU=", 29 | } 30 | cfg.Maxdepth = 30 31 | cfg.Expire = 600 32 | cfg.CacheSize = 1024 33 | cfg.Timeout.Duration = 2 * time.Second 34 | cfg.Directory = filepath.Join(os.TempDir(), "sdns_temp") 35 | cfg.IPv6Access = true 36 | cfg.DNSSEC = "on" 37 | 38 | if !middleware.Ready() { 39 | middleware.Register("edns", func(cfg *config.Config) middleware.Handler { return edns.New(cfg) }) 40 | middleware.Register("resolver", func(cfg *config.Config) middleware.Handler { return New(cfg) }) 41 | middleware.Setup(cfg) 42 | } 43 | 44 | return cfg 45 | } 46 | 47 | func Test_handler(t *testing.T) { 48 | makeTestConfig() 49 | 50 | ctx := context.Background() 51 | 52 | handler := middleware.Get("resolver").(*DNSHandler) 53 | 54 | time.Sleep(2 * time.Second) 55 | 56 | assert.Equal(t, "resolver", handler.Name()) 57 | 58 | m := new(dns.Msg) 59 | m.SetQuestion("www.apple.com.", dns.TypeA) 60 | r := handler.handle(ctx, m) 61 | assert.Equal(t, len(r.Answer) > 0, true) 62 | 63 | m = new(dns.Msg) 64 | // test again for caches 65 | m.SetQuestion("www.apple.com.", dns.TypeA) 66 | r = handler.handle(ctx, m) 67 | assert.Equal(t, len(r.Answer) > 0, true) 68 | 69 | m = new(dns.Msg) 70 | m.SetEdns0(dnsutil.DefaultMsgSize, true) 71 | m.SetQuestion("dnssec-failed.org.", dns.TypeA) 72 | r = handler.handle(ctx, m) 73 | assert.Equal(t, len(r.Answer) == 0, true) 74 | 75 | m = new(dns.Msg) 76 | m.SetQuestion("example.com.", dns.TypeA) 77 | r = handler.handle(ctx, m) 78 | assert.Equal(t, len(r.Answer) > 0, true) 79 | 80 | m = new(dns.Msg) 81 | m.SetQuestion(".", dns.TypeANY) 82 | r = handler.handle(ctx, m) 83 | assert.Equal(t, r.Rcode, dns.RcodeNotImplemented) 84 | 85 | m = new(dns.Msg) 86 | m.SetQuestion(".", dns.TypeNS) 87 | m.RecursionDesired = false 88 | r = handler.handle(ctx, m) 89 | assert.NotEqual(t, r.Rcode, dns.RcodeServerFailure) 90 | } 91 | 92 | func Test_HandlerHINFO(t *testing.T) { 93 | ctx := context.Background() 94 | cfg := makeTestConfig() 95 | handler := New(cfg) 96 | 97 | m := new(dns.Msg) 98 | m.SetQuestion(".", dns.TypeHINFO) 99 | m.Question[0].Qclass = dns.ClassCHAOS 100 | 101 | debugns = true 102 | resp := handler.handle(ctx, m) 103 | 104 | assert.Equal(t, true, len(resp.Ns) > 0) 105 | } 106 | 107 | func Test_HandlerPurge(t *testing.T) { 108 | ctx := context.Background() 109 | cfg := makeTestConfig() 110 | handler := New(cfg) 111 | 112 | bqname := base64.StdEncoding.EncodeToString([]byte("NS:.")) 113 | 114 | req := new(dns.Msg) 115 | req.SetQuestion(dns.Fqdn(bqname), dns.TypeNULL) 116 | req.Question[0].Qclass = dns.ClassCHAOS 117 | 118 | resp := handler.handle(ctx, req) 119 | 120 | assert.Equal(t, true, len(resp.Extra) > 0) 121 | } 122 | 123 | func Test_HandlerServe(t *testing.T) { 124 | cfg := makeTestConfig() 125 | h := New(cfg) 126 | 127 | ch := middleware.NewChain([]middleware.Handler{}) 128 | mw := mock.NewWriter("tcp", "127.0.0.1:0") 129 | 130 | req := new(dns.Msg) 131 | req.SetQuestion(".", dns.TypeNS) 132 | 133 | ch.Reset(mw, req) 134 | 135 | h.ServeDNS(context.Background(), ch) 136 | assert.Equal(t, true, ch.Writer.Written()) 137 | } 138 | -------------------------------------------------------------------------------- /middleware/resolver/nsec3.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | var ( 10 | errNSECTypeExists = errors.New("NSEC3 record shows question type exists") 11 | errNSECMissingCoverage = errors.New("NSEC3 record missing for expected encloser") 12 | errNSECBadDelegation = errors.New("DS or SOA bit set in NSEC3 type map") 13 | errNSECNSMissing = errors.New("NS bit not set in NSEC3 type map") 14 | errNSECOptOut = errors.New("Opt-Out bit not set for NSEC3 record covering next closer") 15 | ) 16 | 17 | func typesSet(set []uint16, types ...uint16) bool { 18 | tm := make(map[uint16]struct{}, len(types)) 19 | for _, t := range types { 20 | tm[t] = struct{}{} 21 | } 22 | for _, t := range set { 23 | if _, ok := tm[t]; ok { 24 | return true 25 | } 26 | } 27 | return false 28 | } 29 | 30 | func findClosestEncloser(name string, nsec []dns.RR) (string, string) { 31 | labelIndices := dns.Split(name) 32 | nc := name 33 | for i := 0; i < len(labelIndices); i++ { 34 | z := name[labelIndices[i]:] 35 | _, err := findMatching(z, nsec) 36 | if err != nil { 37 | continue 38 | } 39 | if i != 0 { 40 | nc = name[labelIndices[i-1]:] 41 | } 42 | return z, nc 43 | } 44 | return "", "" 45 | } 46 | 47 | func findMatching(name string, nsec []dns.RR) ([]uint16, error) { 48 | for _, rr := range nsec { 49 | n := rr.(*dns.NSEC3) 50 | if n.Match(name) { 51 | return n.TypeBitMap, nil 52 | } 53 | } 54 | return nil, errNSECMissingCoverage 55 | } 56 | 57 | func findCoverer(name string, nsec []dns.RR) ([]uint16, bool, error) { 58 | for _, rr := range nsec { 59 | n := rr.(*dns.NSEC3) 60 | if n.Cover(name) { 61 | return n.TypeBitMap, (n.Flags & 1) == 1, nil 62 | } 63 | } 64 | return nil, false, errNSECMissingCoverage 65 | } 66 | 67 | func verifyNameError(msg *dns.Msg, nsec []dns.RR) error { 68 | q := msg.Question[0] 69 | qname := q.Name 70 | 71 | if dname := getDnameTarget(msg); dname != "" { 72 | qname = dname 73 | } 74 | 75 | ce, _ := findClosestEncloser(qname, nsec) 76 | if ce == "" { 77 | return errNSECMissingCoverage 78 | } 79 | _, _, err := findCoverer("*."+ce, nsec) 80 | if err != nil { 81 | return err 82 | } 83 | return nil 84 | } 85 | 86 | func verifyNODATA(msg *dns.Msg, nsec []dns.RR) error { 87 | q := msg.Question[0] 88 | qname := q.Name 89 | 90 | if dname := getDnameTarget(msg); dname != "" { 91 | qname = dname 92 | } 93 | 94 | types, err := findMatching(qname, nsec) 95 | if err != nil { 96 | if q.Qtype != dns.TypeDS { 97 | return err 98 | } 99 | 100 | ce, nc := findClosestEncloser(qname, nsec) 101 | if ce == "" { 102 | return errNSECMissingCoverage 103 | } 104 | _, _, err := findCoverer(nc, nsec) 105 | if err != nil { 106 | return err 107 | } 108 | return nil 109 | } 110 | 111 | if typesSet(types, q.Qtype, dns.TypeCNAME) { 112 | return errNSECTypeExists 113 | } 114 | 115 | return nil 116 | } 117 | 118 | func verifyDelegation(delegation string, nsec []dns.RR) error { 119 | types, err := findMatching(delegation, nsec) 120 | if err != nil { 121 | ce, nc := findClosestEncloser(delegation, nsec) 122 | if ce == "" { 123 | return errNSECMissingCoverage 124 | } 125 | _, optOut, err := findCoverer(nc, nsec) 126 | if err != nil { 127 | return err 128 | } 129 | if !optOut { 130 | return errNSECOptOut 131 | } 132 | return nil 133 | } 134 | if !typesSet(types, dns.TypeNS) { 135 | return errNSECNSMissing 136 | } 137 | if typesSet(types, dns.TypeDS, dns.TypeSOA) { 138 | return errNSECBadDelegation 139 | } 140 | return nil 141 | } 142 | -------------------------------------------------------------------------------- /middleware/resolver/singleinflight.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Adapted for resolver package usage by Semih Alev. 6 | 7 | package resolver 8 | 9 | import ( 10 | "sync" 11 | 12 | "github.com/miekg/dns" 13 | ) 14 | 15 | // call is an in-flight or completed singleflight.Do call 16 | type call struct { 17 | wg sync.WaitGroup 18 | val *dns.Msg 19 | err error 20 | dups int 21 | } 22 | 23 | // singleflight represents a class of work and forms a namespace in 24 | // which units of work can be executed with duplicate suppression. 25 | type singleflight struct { 26 | sync.RWMutex // protects m 27 | m map[uint64]*call // lazily initialized 28 | } 29 | 30 | // Do executes and returns the results of the given function, making 31 | // sure that only one execution is in-flight for a given key at a 32 | // time. If a duplicate comes in, the duplicate caller waits for the 33 | // original to complete and receives the same results. 34 | // The return value shared indicates whether v was given to multiple callers. 35 | func (g *singleflight) Do(key uint64, fn func() (*dns.Msg, error)) (v *dns.Msg, shared bool, err error) { 36 | g.Lock() 37 | if g.m == nil { 38 | g.m = make(map[uint64]*call) 39 | } 40 | if c, ok := g.m[key]; ok { 41 | c.dups++ 42 | g.Unlock() 43 | c.wg.Wait() 44 | return c.val, true, c.err 45 | } 46 | c := new(call) 47 | c.wg.Add(1) 48 | g.m[key] = c 49 | g.Unlock() 50 | 51 | c.val, c.err = fn() 52 | c.wg.Done() 53 | 54 | g.Lock() 55 | delete(g.m, key) 56 | g.Unlock() 57 | 58 | return c.val, c.dups > 0, c.err 59 | } 60 | -------------------------------------------------------------------------------- /middleware/resolver/utils_test.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "testing" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_shuffleStr(t *testing.T) { 13 | 14 | vals := make([]string, 1) 15 | 16 | rr := shuffleStr(vals) 17 | 18 | if len(rr) != 1 { 19 | t.Error("invalid array length") 20 | } 21 | } 22 | 23 | func Test_searchAddr(t *testing.T) { 24 | testDomain := "google.com." 25 | 26 | m := new(dns.Msg) 27 | m.SetQuestion(testDomain, dns.TypeA) 28 | 29 | m.SetEdns0(512, true) 30 | assert.Equal(t, isDO(m), true) 31 | 32 | m.Extra = []dns.RR{} 33 | assert.Equal(t, isDO(m), false) 34 | 35 | a1 := &dns.A{ 36 | Hdr: dns.RR_Header{ 37 | Name: testDomain, 38 | Rrtype: dns.TypeA, 39 | Class: dns.ClassINET, 40 | Ttl: 10, 41 | }, 42 | A: net.ParseIP("127.0.0.1")} 43 | 44 | m.Answer = append(m.Answer, a1) 45 | 46 | a2 := &dns.A{ 47 | Hdr: dns.RR_Header{ 48 | Name: testDomain, 49 | Rrtype: dns.TypeA, 50 | Class: dns.ClassINET, 51 | Ttl: 10, 52 | }, 53 | A: net.ParseIP("192.0.2.1")} 54 | 55 | m.Answer = append(m.Answer, a2) 56 | 57 | addrs, found := searchAddrs(m) 58 | assert.Equal(t, len(addrs), 1) 59 | assert.NotEqual(t, addrs[0], "127.0.0.1") 60 | assert.Equal(t, addrs[0], "192.0.2.1") 61 | assert.Equal(t, found, true) 62 | } 63 | 64 | func Test_extractRRSet(t *testing.T) { 65 | var rr []dns.RR 66 | for i := 0; i < 3; i++ { 67 | a, _ := dns.NewRR(fmt.Sprintf("test.com. 5 IN A 127.0.0.%d", i)) 68 | rr = append(rr, a) 69 | } 70 | 71 | rre := extractRRSet(rr, "test.com.", dns.TypeA) 72 | assert.Len(t, rre, 3) 73 | } 74 | -------------------------------------------------------------------------------- /middleware/response_writer.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/semihalev/sdns/mock" 9 | "github.com/semihalev/sdns/server/doq" 10 | ) 11 | 12 | // ResponseWriter implement of dns.ResponseWriter 13 | type ResponseWriter interface { 14 | dns.ResponseWriter 15 | Msg() *dns.Msg 16 | Rcode() int 17 | Written() bool 18 | Reset(dns.ResponseWriter) 19 | Proto() string 20 | RemoteIP() net.IP 21 | Internal() bool 22 | } 23 | 24 | type responseWriter struct { 25 | dns.ResponseWriter 26 | msg *dns.Msg 27 | size int 28 | rcode int 29 | proto string 30 | remoteip net.IP 31 | internal bool 32 | } 33 | 34 | var _ ResponseWriter = &responseWriter{} 35 | var errAlreadyWritten = errors.New("msg already written") 36 | 37 | func (w *responseWriter) Msg() *dns.Msg { 38 | return w.msg 39 | } 40 | 41 | func (w *responseWriter) Reset(rw dns.ResponseWriter) { 42 | w.ResponseWriter = rw 43 | w.size = -1 44 | w.msg = nil 45 | w.rcode = dns.RcodeSuccess 46 | 47 | switch rw.LocalAddr().(type) { 48 | case (*net.TCPAddr): 49 | w.proto = "tcp" 50 | w.remoteip = w.RemoteAddr().(*net.TCPAddr).IP 51 | case (*net.UDPAddr): 52 | w.proto = "udp" 53 | w.remoteip = w.RemoteAddr().(*net.UDPAddr).IP 54 | } 55 | 56 | switch writer := rw.(type) { 57 | case (*mock.Writer): 58 | w.proto = writer.Proto() 59 | case (*doq.ResponseWriter): 60 | w.proto = "doq" 61 | } 62 | 63 | w.internal = w.RemoteAddr().String() == "127.0.0.255:0" 64 | } 65 | 66 | func (w *responseWriter) RemoteIP() net.IP { 67 | return w.remoteip 68 | } 69 | 70 | func (w *responseWriter) Proto() string { 71 | return w.proto 72 | } 73 | 74 | func (w *responseWriter) Rcode() int { 75 | return w.rcode 76 | } 77 | 78 | func (w *responseWriter) Written() bool { 79 | return w.size != -1 80 | } 81 | 82 | func (w *responseWriter) Write(m []byte) (int, error) { 83 | if w.Written() { 84 | return 0, errAlreadyWritten 85 | } 86 | 87 | w.msg = new(dns.Msg) 88 | err := w.msg.Unpack(m) 89 | if err != nil { 90 | return 0, err 91 | } 92 | w.rcode = w.msg.Rcode 93 | 94 | n, err := w.ResponseWriter.Write(m) 95 | w.size = n 96 | return n, err 97 | } 98 | 99 | func (w *responseWriter) WriteMsg(m *dns.Msg) error { 100 | if w.Written() { 101 | return errAlreadyWritten 102 | } 103 | 104 | w.msg = m 105 | w.rcode = m.Rcode 106 | w.size = 0 107 | 108 | return w.ResponseWriter.WriteMsg(m) 109 | } 110 | 111 | // Internal func 112 | func (w *responseWriter) Internal() bool { return w.internal } 113 | -------------------------------------------------------------------------------- /mock/writer.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // Writer type 10 | type Writer struct { 11 | msg *dns.Msg 12 | 13 | proto string 14 | 15 | localAddr net.Addr 16 | remoteAddr net.Addr 17 | 18 | remoteip net.IP 19 | 20 | internal bool 21 | } 22 | 23 | // NewWriter return writer 24 | func NewWriter(proto, addr string) *Writer { 25 | w := &Writer{} 26 | 27 | switch proto { 28 | case "tcp", "doh": 29 | w.localAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 53} 30 | w.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) 31 | w.remoteip = w.remoteAddr.(*net.TCPAddr).IP 32 | 33 | case "udp": 34 | w.localAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 53} 35 | w.remoteAddr, _ = net.ResolveUDPAddr("udp", addr) 36 | w.remoteip = w.remoteAddr.(*net.UDPAddr).IP 37 | } 38 | 39 | w.internal = w.RemoteAddr().String() == "127.0.0.255:0" 40 | 41 | w.proto = proto 42 | 43 | return w 44 | } 45 | 46 | // Rcode return message response code 47 | func (w *Writer) Rcode() int { 48 | if w.msg == nil { 49 | return dns.RcodeServerFailure 50 | } 51 | 52 | return w.msg.Rcode 53 | } 54 | 55 | // Msg return current dns message 56 | func (w *Writer) Msg() *dns.Msg { 57 | return w.msg 58 | } 59 | 60 | // Write func 61 | func (w *Writer) Write(b []byte) (int, error) { 62 | w.msg = new(dns.Msg) 63 | err := w.msg.Unpack(b) 64 | if err != nil { 65 | return 0, err 66 | } 67 | return len(b), nil 68 | } 69 | 70 | // WriteMsg func 71 | func (w *Writer) WriteMsg(msg *dns.Msg) error { 72 | w.msg = msg 73 | return nil 74 | } 75 | 76 | // Written func 77 | func (w *Writer) Written() bool { 78 | return w.msg != nil 79 | } 80 | 81 | // RemoteIP func 82 | func (w *Writer) RemoteIP() net.IP { return w.remoteip } 83 | 84 | // Proto func 85 | func (w *Writer) Proto() string { return w.proto } 86 | 87 | // Reset func 88 | func (w *Writer) Reset(rw dns.ResponseWriter) {} 89 | 90 | // Close func 91 | func (w *Writer) Close() error { return nil } 92 | 93 | // Hijack func 94 | func (w *Writer) Hijack() {} 95 | 96 | // LocalAddr func 97 | func (w *Writer) LocalAddr() net.Addr { return w.localAddr } 98 | 99 | // RemoteAddr func 100 | func (w *Writer) RemoteAddr() net.Addr { return w.remoteAddr } 101 | 102 | // TsigStatus func 103 | func (w *Writer) TsigStatus() error { return nil } 104 | 105 | // TsigTimersOnly func 106 | func (w *Writer) TsigTimersOnly(ok bool) {} 107 | 108 | // Internal func 109 | func (w *Writer) Internal() bool { return w.internal } 110 | -------------------------------------------------------------------------------- /mock/writer_test.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_Writer(t *testing.T) { 11 | mw := NewWriter("udp", "127.0.0.1:0") 12 | 13 | m := new(dns.Msg) 14 | m.SetQuestion("example.com.", dns.TypeA) 15 | err := mw.WriteMsg(m) 16 | 17 | assert.NoError(t, err) 18 | assert.True(t, mw.Written()) 19 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess) 20 | assert.NotNil(t, mw.Msg()) 21 | assert.Equal(t, mw.LocalAddr().String(), "127.0.0.1:53") 22 | assert.Equal(t, mw.RemoteAddr().String(), "127.0.0.1:0") 23 | assert.Nil(t, mw.Close()) 24 | assert.Nil(t, mw.TsigStatus()) 25 | 26 | mw = NewWriter("tcp", "127.0.0.255:0") 27 | assert.False(t, mw.Written()) 28 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure) 29 | 30 | assert.Equal(t, "tcp", mw.Proto()) 31 | assert.Equal(t, "127.0.0.255", mw.RemoteIP().String()) 32 | 33 | _, err = mw.Write([]byte{}) 34 | assert.Error(t, err) 35 | 36 | data, err := m.Pack() 37 | assert.NoError(t, err) 38 | _, err = mw.Write(data) 39 | assert.NoError(t, err) 40 | assert.True(t, mw.Written()) 41 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess) 42 | assert.True(t, mw.Internal()) 43 | } 44 | -------------------------------------------------------------------------------- /response/typify.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package response 5 | 6 | import ( 7 | "fmt" 8 | "time" 9 | 10 | "github.com/miekg/dns" 11 | ) 12 | 13 | // Type is the type of the message. 14 | type Type int 15 | 16 | const ( 17 | // NoError indicates a positive reply 18 | NoError Type = iota 19 | // NameError is a NXDOMAIN in header, SOA in auth. 20 | NameError 21 | // NoData indicates name found, but not the type: NOERROR in header, SOA in auth. 22 | NoData 23 | // Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked). 24 | Delegation 25 | // Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR. 26 | Meta 27 | // Update is an dynamic update message. 28 | Update 29 | // OtherError indicates any other error. 30 | OtherError 31 | // Expired if sigs expired: don't cache these 32 | Expired 33 | // NoCache indicates a no cache reply 34 | NoCache 35 | ) 36 | 37 | var toString = map[Type]string{ 38 | NoError: "NOERROR", 39 | NameError: "NXDOMAIN", 40 | NoData: "NODATA", 41 | Delegation: "DELEGATION", 42 | Meta: "META", 43 | Update: "UPDATE", 44 | OtherError: "OTHERERROR", 45 | Expired: "EXPIRED", 46 | NoCache: "NOCACHE", 47 | } 48 | 49 | func (t Type) String() string { return toString[t] } 50 | 51 | // TypeFromString returns the type from the string s. If not type matches 52 | // the OtherError type and an error are returned. 53 | func TypeFromString(s string) (Type, error) { 54 | for t, str := range toString { 55 | if s == str { 56 | return t, nil 57 | } 58 | } 59 | return NoError, fmt.Errorf("invalid Type: %s", s) 60 | } 61 | 62 | // Typify classifies a message, it returns the Type. 63 | func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) { 64 | if m == nil { 65 | return OtherError, nil 66 | } 67 | opt := m.IsEdns0() 68 | do := false 69 | if opt != nil { 70 | do = opt.Do() 71 | } 72 | 73 | if m.Opcode == dns.OpcodeUpdate { 74 | return Update, opt 75 | } 76 | 77 | // Check transfer and update first 78 | if m.Opcode == dns.OpcodeNotify { 79 | return Meta, opt 80 | } 81 | 82 | if len(m.Question) > 0 { 83 | if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR { 84 | return Meta, opt 85 | } 86 | } 87 | 88 | // If our message contains any expired sigs and we care about that, we should return expired 89 | if do { 90 | if expired := typifyExpired(m, t); expired { 91 | return Expired, opt 92 | } 93 | } 94 | 95 | if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { 96 | return NoError, opt 97 | } 98 | 99 | soa := false 100 | ns := 0 101 | for _, r := range m.Ns { 102 | if r.Header().Rrtype == dns.TypeSOA { 103 | soa = true 104 | continue 105 | } 106 | if r.Header().Rrtype == dns.TypeNS { 107 | ns++ 108 | } 109 | } 110 | 111 | if !soa && len(m.Question) > 0 { 112 | if len(m.Answer) == 0 && m.Question[0].Qtype == dns.TypeDNSKEY { 113 | return NoCache, opt 114 | } 115 | } 116 | 117 | if soa && m.Rcode == dns.RcodeSuccess { 118 | return NoData, opt 119 | } 120 | if soa && m.Rcode == dns.RcodeNameError { 121 | return NameError, opt 122 | } 123 | 124 | if ns > 0 && m.Rcode == dns.RcodeSuccess { 125 | return Delegation, opt 126 | } 127 | 128 | if m.Rcode == dns.RcodeSuccess { 129 | return NoError, opt 130 | } 131 | 132 | return OtherError, opt 133 | } 134 | 135 | func typifyExpired(m *dns.Msg, t time.Time) bool { 136 | if expired := typifyExpiredRRSIG(m.Answer, t); expired { 137 | return true 138 | } 139 | if expired := typifyExpiredRRSIG(m.Ns, t); expired { 140 | return true 141 | } 142 | if expired := typifyExpiredRRSIG(m.Extra, t); expired { 143 | return true 144 | } 145 | return false 146 | } 147 | 148 | func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool { 149 | for _, r := range rrs { 150 | if r.Header().Rrtype != dns.TypeRRSIG { 151 | continue 152 | } 153 | ok := r.(*dns.RRSIG).ValidityPeriod(t) 154 | if !ok { 155 | return true 156 | } 157 | } 158 | return false 159 | } 160 | -------------------------------------------------------------------------------- /response/typify_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016-2020 The CoreDNS authors and contributors 2 | // Adapted for SDNS usage by Semih Alev. 3 | 4 | package response 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/miekg/dns" 11 | ) 12 | 13 | func makeRR(data string) dns.RR { 14 | r, _ := dns.NewRR(data) 15 | 16 | return r 17 | } 18 | 19 | func TestTypifyNilMsg(t *testing.T) { 20 | var m *dns.Msg 21 | 22 | ty, _ := Typify(m, time.Now().UTC()) 23 | if ty != OtherError { 24 | t.Errorf("Message wrongly typified, expected OtherError, got %s", ty) 25 | } 26 | 27 | ty, _ = TypeFromString("") 28 | if ty != NoError { 29 | t.Errorf("Message wrongly typified, expected NoError, got %s", ty) 30 | } 31 | 32 | ty, _ = TypeFromString("NOERROR") 33 | if ty != NoError { 34 | t.Errorf("Message wrongly typified, expected NoError, got %s", ty) 35 | } 36 | 37 | ts := ty.String() 38 | if ts != "NOERROR" { 39 | t.Errorf("Type to string wrong, expected NOERROR, got %s", ty) 40 | } 41 | } 42 | 43 | func TestTypify(t *testing.T) { 44 | m := new(dns.Msg) 45 | m.SetQuestion("miek.nl.", dns.TypeA) 46 | 47 | utc := time.Now().UTC() 48 | 49 | mt, _ := Typify(m, utc) 50 | if mt != NoError { 51 | t.Errorf("Message is wrongly typified, expected NoError, got %s", mt) 52 | } 53 | 54 | m.Opcode = dns.OpcodeUpdate 55 | mt, _ = Typify(m, utc) 56 | if mt != Update { 57 | t.Errorf("Message is wrongly typified, expected Update, got %s", mt) 58 | } 59 | 60 | m.Opcode = dns.OpcodeNotify 61 | mt, _ = Typify(m, utc) 62 | if mt != Meta { 63 | t.Errorf("Message is wrongly typified, expected Meta, got %s", mt) 64 | } 65 | 66 | m.Opcode = dns.OpcodeQuery 67 | m.SetQuestion("miek.nl.", dns.TypeAXFR) 68 | mt, _ = Typify(m, utc) 69 | if mt != Meta { 70 | t.Errorf("Message is wrongly typified, expected Meta, got %s", mt) 71 | } 72 | 73 | m.SetQuestion("miek.nl.", dns.TypeA) 74 | m.Ns = append(m.Ns, makeRR("nl. 3599 IN SOA ns1.dns.nl. hostmaster.domain-registry.nl. 2018111539 3600 600 2419200 600")) 75 | mt, _ = Typify(m, utc) 76 | if mt != NoData { 77 | t.Errorf("Message is wrongly typified, expected NoData, got %s", mt) 78 | } 79 | 80 | m.Rcode = dns.RcodeNameError 81 | mt, _ = Typify(m, utc) 82 | if mt != NameError { 83 | t.Errorf("Message is wrongly typified, expected NameError, got %s", mt) 84 | } 85 | 86 | m.Rcode = dns.RcodeServerFailure 87 | mt, _ = Typify(m, utc) 88 | if mt != OtherError { 89 | t.Errorf("Message is wrongly typified, expected OtherError, got %s", mt) 90 | } 91 | 92 | m.SetEdns0(4096, true) 93 | 94 | m.Rcode = dns.RcodeSuccess 95 | m.Answer = append(m.Answer, makeRR("miek.nl. 3600 IN A 127.0.0.1")) 96 | mt, _ = Typify(m, utc) 97 | if mt != NoError { 98 | t.Errorf("Message is wrongly typified, expected NoError, got %s", mt) 99 | } 100 | 101 | m.Extra = append(m.Extra, 102 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), 103 | ) 104 | mt, _ = Typify(m, utc) 105 | if mt != Expired { 106 | t.Errorf("Message is wrongly typified, expected Expired, got %s", mt) 107 | } 108 | 109 | m.Answer = append(m.Answer, 110 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), 111 | ) 112 | mt, _ = Typify(m, utc) 113 | if mt != Expired { 114 | t.Errorf("Message is wrongly typified, expected Expired, got %s", mt) 115 | } 116 | } 117 | 118 | func TestTypifyDelegation(t *testing.T) { 119 | m := delegationMsg() 120 | mt, _ := Typify(m, time.Now().UTC()) 121 | if mt != Delegation { 122 | t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt) 123 | } 124 | } 125 | 126 | func TestTypifyRRSIG(t *testing.T) { 127 | utc := time.Now().UTC() 128 | 129 | m := delegationMsgRRSIGOK() 130 | if mt, _ := Typify(m, utc); mt != Delegation { 131 | t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt) 132 | } 133 | 134 | // Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs. 135 | m = delegationMsgRRSIGFail() 136 | if mt, _ := Typify(m, utc); mt != Delegation { 137 | t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt) 138 | } 139 | 140 | m = delegationMsgRRSIGFail() 141 | m = addOpt(m) 142 | if mt, _ := Typify(m, utc); mt != Expired { 143 | t.Errorf("Message is wrongly typified, expected Expired, got %s", mt) 144 | } 145 | } 146 | 147 | func delegationMsg() *dns.Msg { 148 | return &dns.Msg{ 149 | Ns: []dns.RR{ 150 | makeRR("miek.nl. 3600 IN NS linode.atoom.net."), 151 | makeRR("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."), 152 | makeRR("miek.nl. 3600 IN NS omval.tednet.nl."), 153 | }, 154 | Extra: []dns.RR{ 155 | makeRR("omval.tednet.nl. 3600 IN A 185.49.141.42"), 156 | makeRR("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"), 157 | }, 158 | } 159 | } 160 | 161 | func delegationMsgRRSIGOK() *dns.Msg { 162 | del := delegationMsg() 163 | del.Ns = append(del.Ns, 164 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), 165 | ) 166 | return del 167 | } 168 | 169 | func delegationMsgRRSIGFail() *dns.Msg { 170 | del := delegationMsg() 171 | del.Ns = append(del.Ns, 172 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), 173 | ) 174 | return del 175 | } 176 | 177 | func addOpt(m *dns.Msg) *dns.Msg { 178 | return m.SetEdns0(4096, true) 179 | } 180 | -------------------------------------------------------------------------------- /sdns.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | //go:generate go run gen.go 4 | 5 | import ( 6 | "context" 7 | "flag" 8 | "fmt" 9 | "os" 10 | "os/signal" 11 | "runtime/debug" 12 | "syscall" 13 | "time" 14 | 15 | "github.com/semihalev/log" 16 | "github.com/semihalev/sdns/api" 17 | "github.com/semihalev/sdns/config" 18 | "github.com/semihalev/sdns/middleware" 19 | "github.com/semihalev/sdns/server" 20 | ) 21 | 22 | const version = "1.4.0" 23 | 24 | var ( 25 | flagcfgpath string 26 | flagprintver bool 27 | 28 | cfg *config.Config 29 | ) 30 | 31 | func init() { 32 | flag.StringVar(&flagcfgpath, "config", "sdns.conf", "Location of the config file. If it doesn't exist, a new one will be generated.") 33 | flag.StringVar(&flagcfgpath, "c", "sdns.conf", "Location of the config file. If it doesn't exist, a new one will be generated.") 34 | 35 | flag.BoolVar(&flagprintver, "version", false, "Show the version of the sdns.") 36 | flag.BoolVar(&flagprintver, "v", false, "Show the version of the sdns.") 37 | 38 | flag.Usage = func() { 39 | fmt.Fprintf(os.Stderr, "Usage:\n sdns [OPTIONS]\n\n") 40 | fmt.Fprintf(os.Stderr, "Options:\n") 41 | fmt.Fprintf(os.Stderr, " -c, --config PATH\tLocation of the config file. If it doesn't exist, a new one will be generated.\n") 42 | fmt.Fprintf(os.Stderr, " -v, --version\t\tShow the version of the sdns and exit.\n") 43 | fmt.Fprintf(os.Stderr, " -h, --help\t\tShow this help and exit.\n\n") 44 | fmt.Fprintf(os.Stderr, "Example:\n") 45 | fmt.Fprintf(os.Stderr, " sdns -c sdns.conf\n\n") 46 | } 47 | } 48 | 49 | func setup() { 50 | var err error 51 | 52 | if cfg, err = config.Load(flagcfgpath, version); err != nil { 53 | log.Crit("Config loading failed", "error", err.Error()) 54 | } 55 | 56 | if cfg.LogLevel == "" { 57 | cfg.LogLevel = "info" 58 | } 59 | 60 | lvl, err := log.LvlFromString(cfg.LogLevel) 61 | if err != nil { 62 | log.Crit("Log verbosity level unknown") 63 | } 64 | 65 | log.Root().SetLevel(lvl) 66 | log.Root().SetHandler(log.LvlFilterHandler(lvl, log.StdoutHandler)) 67 | 68 | middleware.Setup(cfg) 69 | } 70 | 71 | func run(ctx context.Context) *server.Server { 72 | srv := server.New(cfg) 73 | srv.Run(ctx) 74 | 75 | api := api.New(cfg) 76 | api.Run(ctx) 77 | 78 | return srv 79 | } 80 | 81 | func printver() { 82 | buildInfo, _ := debug.ReadBuildInfo() 83 | 84 | settings := make(map[string]string) 85 | for _, s := range buildInfo.Settings { 86 | settings[s.Key] = s.Value 87 | } 88 | 89 | fmt.Fprintf(os.Stderr, "sdns v%s rev %.7s\nbuilt by %s (%s %s)\n", version, 90 | settings["vcs.revision"], buildInfo.GoVersion, settings["GOOS"], settings["GOARCH"]) 91 | 92 | os.Exit(0) 93 | } 94 | 95 | func main() { 96 | flag.Parse() 97 | 98 | if flagprintver { 99 | printver() 100 | } 101 | 102 | log.Info("Starting sdns...", "version", version) 103 | 104 | setup() 105 | 106 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) 107 | srv := run(ctx) 108 | 109 | <-ctx.Done() 110 | 111 | log.Info("Stopping sdns...") 112 | 113 | stop() 114 | 115 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 116 | defer cancel() 117 | 118 | for !srv.Stopped() { 119 | select { 120 | case <-time.After(100 * time.Millisecond): 121 | continue 122 | case <-ctx.Done(): 123 | return 124 | } 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /server/doh/doh.go: -------------------------------------------------------------------------------- 1 | package doh 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "io" 7 | "net/http" 8 | "strings" 9 | 10 | "github.com/miekg/dns" 11 | ) 12 | 13 | const minMsgHeaderSize = 12 14 | 15 | // HandleWireFormat handle wire format 16 | func HandleWireFormat(handle func(*dns.Msg) *dns.Msg) func(http.ResponseWriter, *http.Request) { 17 | return func(w http.ResponseWriter, r *http.Request) { 18 | var ( 19 | buf []byte 20 | err error 21 | ) 22 | 23 | switch r.Method { 24 | case http.MethodGet: 25 | buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns")) 26 | if len(buf) == 0 || err != nil { 27 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 28 | return 29 | } 30 | case http.MethodPost: 31 | if r.Header.Get("Content-Type") != "application/dns-message" { 32 | http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) 33 | return 34 | } 35 | 36 | buf, err = io.ReadAll(r.Body) 37 | if err != nil { 38 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 39 | return 40 | } 41 | defer r.Body.Close() 42 | default: 43 | http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) 44 | return 45 | } 46 | 47 | if len(buf) < minMsgHeaderSize { 48 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 49 | return 50 | } 51 | 52 | req := new(dns.Msg) 53 | if err := req.Unpack(buf); err != nil { 54 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 55 | return 56 | } 57 | 58 | msg := handle(req) 59 | if msg == nil { 60 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 61 | return 62 | } 63 | 64 | packed, err := msg.Pack() 65 | if err != nil { 66 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 67 | return 68 | } 69 | 70 | w.Header().Set("Content-Type", "application/dns-message") 71 | 72 | _, _ = w.Write(packed) 73 | } 74 | } 75 | 76 | // HandleJSON handle json format 77 | func HandleJSON(handle func(*dns.Msg) *dns.Msg) func(http.ResponseWriter, *http.Request) { 78 | return func(w http.ResponseWriter, r *http.Request) { 79 | name := r.URL.Query().Get("name") 80 | if name == "" { 81 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 82 | return 83 | } 84 | name = dns.Fqdn(name) 85 | 86 | qtype := ParseQTYPE(r.URL.Query().Get("type")) 87 | if qtype == dns.TypeNone { 88 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 89 | return 90 | } 91 | 92 | req := new(dns.Msg) 93 | req.SetQuestion(name, qtype) 94 | req.AuthenticatedData = true 95 | 96 | if r.URL.Query().Get("cd") == "true" { 97 | req.CheckingDisabled = true 98 | } 99 | 100 | opt := &dns.OPT{ 101 | Hdr: dns.RR_Header{ 102 | Name: ".", 103 | Class: dns.DefaultMsgSize, 104 | Rrtype: dns.TypeOPT, 105 | }, 106 | } 107 | 108 | if r.URL.Query().Get("do") == "true" { 109 | opt.SetDo() 110 | } 111 | 112 | req.Extra = append(req.Extra, opt) 113 | 114 | msg := handle(req) 115 | if msg == nil { 116 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 117 | return 118 | } 119 | 120 | json, err := json.Marshal(NewMsg(msg)) 121 | if err != nil { 122 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 123 | return 124 | } 125 | 126 | if strings.Contains(r.Header.Get("Accept"), "text/html") { 127 | w.Header().Set("Content-Type", "application/x-javascript") 128 | } else { 129 | w.Header().Set("Content-Type", "application/dns-json") 130 | } 131 | 132 | _, _ = w.Write(json) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /server/doh/doh_test.go: -------------------------------------------------------------------------------- 1 | package doh 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | 13 | "github.com/miekg/dns" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func handleTest(w http.ResponseWriter, r *http.Request) { 18 | handle := func(req *dns.Msg) *dns.Msg { 19 | msg, _ := dns.Exchange(req, "8.8.8.8:53") 20 | 21 | return msg 22 | } 23 | 24 | var handleFn func(http.ResponseWriter, *http.Request) 25 | if r.Method == http.MethodGet && r.URL.Query().Get("dns") == "" { 26 | handleFn = HandleJSON(handle) 27 | } else { 28 | handleFn = HandleWireFormat(handle) 29 | } 30 | 31 | handleFn(w, r) 32 | } 33 | func Test_dohJSON(t *testing.T) { 34 | t.Parallel() 35 | 36 | w := httptest.NewRecorder() 37 | 38 | request, err := http.NewRequest("GET", "/dns-query?name=www.google.com&type=a&do=true&cd=true", nil) 39 | assert.NoError(t, err) 40 | 41 | request.RemoteAddr = "127.0.0.1:0" 42 | 43 | handleTest(w, request) 44 | 45 | assert.Equal(t, w.Code, http.StatusOK) 46 | 47 | data, err := io.ReadAll(w.Body) 48 | assert.NoError(t, err) 49 | 50 | var dm Msg 51 | err = json.Unmarshal(data, &dm) 52 | assert.NoError(t, err) 53 | 54 | assert.Equal(t, len(dm.Answer) > 0, true) 55 | } 56 | 57 | func Test_dohJSONerror(t *testing.T) { 58 | t.Parallel() 59 | 60 | w := httptest.NewRecorder() 61 | 62 | request, err := http.NewRequest("GET", "/dns-query?name=", nil) 63 | assert.NoError(t, err) 64 | 65 | request.RemoteAddr = "127.0.0.1:0" 66 | 67 | handleTest(w, request) 68 | 69 | assert.Equal(t, w.Code, http.StatusBadRequest) 70 | } 71 | 72 | func Test_dohJSONaccepthtml(t *testing.T) { 73 | t.Parallel() 74 | 75 | w := httptest.NewRecorder() 76 | 77 | request, err := http.NewRequest("GET", "/dns-query?name=www.google.com", nil) 78 | assert.NoError(t, err) 79 | 80 | request.RemoteAddr = "127.0.0.1:0" 81 | 82 | request.Header.Add("Accept", "text/html") 83 | handleTest(w, request) 84 | 85 | assert.Equal(t, w.Code, http.StatusOK) 86 | assert.Equal(t, w.Header().Get("Content-Type"), "application/x-javascript") 87 | } 88 | 89 | func Test_dohWireGET(t *testing.T) { 90 | t.Parallel() 91 | 92 | w := httptest.NewRecorder() 93 | 94 | req := new(dns.Msg) 95 | req.SetQuestion("www.google.com.", dns.TypeA) 96 | 97 | data, err := req.Pack() 98 | assert.NoError(t, err) 99 | 100 | dq := base64.RawURLEncoding.EncodeToString(data) 101 | 102 | request, err := http.NewRequest("GET", fmt.Sprintf("/dns-query?dns=%s", dq), nil) 103 | assert.NoError(t, err) 104 | 105 | request.RemoteAddr = "127.0.0.1:0" 106 | 107 | handleTest(w, request) 108 | 109 | assert.Equal(t, w.Code, http.StatusOK) 110 | 111 | data, err = io.ReadAll(w.Body) 112 | assert.NoError(t, err) 113 | 114 | msg := new(dns.Msg) 115 | err = msg.Unpack(data) 116 | assert.NoError(t, err) 117 | 118 | assert.Equal(t, msg.Rcode, dns.RcodeSuccess) 119 | 120 | assert.Equal(t, len(msg.Answer) > 0, true) 121 | } 122 | 123 | func Test_dohWireGETerror(t *testing.T) { 124 | t.Parallel() 125 | 126 | w := httptest.NewRecorder() 127 | 128 | request, err := http.NewRequest("GET", "/dns-query?dns=", nil) 129 | assert.NoError(t, err) 130 | 131 | request.RemoteAddr = "127.0.0.1:0" 132 | 133 | handleTest(w, request) 134 | 135 | assert.Equal(t, w.Code, http.StatusBadRequest) 136 | } 137 | 138 | func Test_dohWireGETbadquery(t *testing.T) { 139 | t.Parallel() 140 | 141 | w := httptest.NewRecorder() 142 | 143 | request, err := http.NewRequest("GET", "/dns-query?dns=Df4", nil) 144 | assert.NoError(t, err) 145 | 146 | request.RemoteAddr = "127.0.0.1:0" 147 | 148 | handleTest(w, request) 149 | 150 | assert.Equal(t, w.Code, http.StatusBadRequest) 151 | } 152 | 153 | func Test_dohWireHEAD(t *testing.T) { 154 | t.Parallel() 155 | 156 | w := httptest.NewRecorder() 157 | 158 | request, err := http.NewRequest("HEAD", "/dns-query?dns=", nil) 159 | assert.NoError(t, err) 160 | 161 | request.RemoteAddr = "127.0.0.1:0" 162 | 163 | handleTest(w, request) 164 | 165 | assert.Equal(t, w.Code, http.StatusMethodNotAllowed) 166 | } 167 | 168 | func Test_dohWirePOST(t *testing.T) { 169 | t.Parallel() 170 | 171 | w := httptest.NewRecorder() 172 | 173 | req := new(dns.Msg) 174 | req.SetQuestion("www.google.com.", dns.TypeA) 175 | 176 | data, err := req.Pack() 177 | assert.NoError(t, err) 178 | 179 | request, err := http.NewRequest("POST", "/dns-query", bytes.NewReader(data)) 180 | assert.NoError(t, err) 181 | 182 | request.RemoteAddr = "127.0.0.1:0" 183 | request.Header.Add("Content-Type", "application/dns-message") 184 | 185 | handleTest(w, request) 186 | 187 | assert.Equal(t, w.Code, http.StatusOK) 188 | 189 | data, err = io.ReadAll(w.Body) 190 | assert.NoError(t, err) 191 | 192 | msg := new(dns.Msg) 193 | err = msg.Unpack(data) 194 | assert.NoError(t, err) 195 | 196 | assert.Equal(t, msg.Rcode, dns.RcodeSuccess) 197 | 198 | assert.Equal(t, len(msg.Answer) > 0, true) 199 | } 200 | 201 | func Test_dohWirePOSTError(t *testing.T) { 202 | t.Parallel() 203 | 204 | w := httptest.NewRecorder() 205 | 206 | request, err := http.NewRequest("POST", "/dns-query", bytes.NewReader([]byte{})) 207 | assert.NoError(t, err) 208 | 209 | request.RemoteAddr = "127.0.0.1:0" 210 | request.Header.Add("Content-Type", "text/html") 211 | 212 | handleTest(w, request) 213 | 214 | assert.Equal(t, w.Code, http.StatusUnsupportedMediaType) 215 | } 216 | -------------------------------------------------------------------------------- /server/doh/msg.go: -------------------------------------------------------------------------------- 1 | package doh 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // Question struct 10 | type Question struct { 11 | Name string `json:"name"` 12 | Qtype uint16 `json:"type"` 13 | Qclass uint16 `json:"-"` 14 | } 15 | 16 | // RR struct 17 | type RR struct { 18 | Name string `json:"name"` 19 | Type uint16 `json:"type"` 20 | TTL uint32 `json:"TTL"` 21 | Data string `json:"data"` 22 | } 23 | 24 | // Msg struct 25 | type Msg struct { 26 | Status int 27 | TC bool 28 | RD bool 29 | RA bool 30 | AD bool 31 | CD bool 32 | Question []Question 33 | Answer []RR `json:",omitempty"` 34 | Authority []RR `json:",omitempty"` 35 | } 36 | 37 | // NewMsg function 38 | func NewMsg(m *dns.Msg) *Msg { 39 | if m == nil { 40 | return nil 41 | } 42 | 43 | msg := &Msg{ 44 | Status: m.Rcode, 45 | TC: m.Truncated, 46 | RD: m.RecursionDesired, 47 | RA: m.RecursionAvailable, 48 | AD: m.AuthenticatedData, 49 | CD: m.CheckingDisabled, 50 | Question: make([]Question, len(m.Question)), 51 | Answer: make([]RR, len(m.Answer)), 52 | Authority: make([]RR, len(m.Ns)), 53 | } 54 | 55 | for i, q := range m.Question { 56 | msg.Question[i] = Question(q) 57 | } 58 | 59 | for i, a := range m.Answer { 60 | msg.Answer[i] = RR{ 61 | Name: a.Header().Name, 62 | Type: a.Header().Rrtype, 63 | TTL: a.Header().Ttl, 64 | Data: strings.TrimPrefix(a.String(), a.Header().String()), 65 | } 66 | } 67 | 68 | for i, a := range m.Ns { 69 | msg.Authority[i] = RR{ 70 | Name: a.Header().Name, 71 | Type: a.Header().Rrtype, 72 | TTL: a.Header().Ttl, 73 | Data: strings.TrimPrefix(a.String(), a.Header().String()), 74 | } 75 | } 76 | 77 | return msg 78 | } 79 | -------------------------------------------------------------------------------- /server/doh/msg_test.go: -------------------------------------------------------------------------------- 1 | package doh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_Msg(t *testing.T) { 11 | m := NewMsg(nil) 12 | assert.Nil(t, m) 13 | 14 | msg := new(dns.Msg) 15 | msg.SetQuestion(".", dns.TypeNS) 16 | 17 | rr, err := dns.NewRR(". 518400 IN NS a.root-servers.net.") 18 | assert.NoError(t, err) 19 | 20 | msg.Answer = append(msg.Answer, rr) 21 | 22 | rr, err = dns.NewRR("a.gtld-servers.net. 172800 IN A 192.5.6.30") 23 | assert.NoError(t, err) 24 | 25 | msg.Ns = append(msg.Ns, rr) 26 | 27 | m = NewMsg(msg) 28 | 29 | assert.Equal(t, m.Answer[0].Data, msg.Answer[0].(*dns.NS).Ns) 30 | assert.Equal(t, m.Authority[0].Data, msg.Ns[0].(*dns.A).A.String()) 31 | } 32 | -------------------------------------------------------------------------------- /server/doh/qtype.go: -------------------------------------------------------------------------------- 1 | package doh 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | var qtype = map[string]uint16{ 11 | "A": dns.TypeA, 12 | "AAAA": dns.TypeAAAA, 13 | "AFSDB": dns.TypeAFSDB, 14 | "AMTRELAY": dns.TypeAMTRELAY, 15 | "ANY": dns.TypeANY, 16 | "APL": dns.TypeAPL, 17 | "ATMA": dns.TypeATMA, 18 | "AVC": dns.TypeAVC, 19 | "AXFR": dns.TypeAXFR, 20 | "CAA": dns.TypeCAA, 21 | "CDNSKEY": dns.TypeCDNSKEY, 22 | "CDS": dns.TypeCDS, 23 | "CERT": dns.TypeCERT, 24 | "CNAME": dns.TypeCNAME, 25 | "CSYNC": dns.TypeCSYNC, 26 | "DHCID": dns.TypeDHCID, 27 | "DLV": dns.TypeDLV, 28 | "DNAME": dns.TypeDNAME, 29 | "DNSKEY": dns.TypeDNSKEY, 30 | "DS": dns.TypeDS, 31 | "EID": dns.TypeEID, 32 | "EUI48": dns.TypeEUI48, 33 | "EUI64": dns.TypeEUI64, 34 | "GID": dns.TypeGID, 35 | "GPOS": dns.TypeGPOS, 36 | "HINFO": dns.TypeHINFO, 37 | "HIP": dns.TypeHIP, 38 | "HTTPS": dns.TypeHTTPS, 39 | "IPSECKEY": dns.TypeIPSECKEY, 40 | "ISDN": dns.TypeISDN, 41 | "IXFR": dns.TypeIXFR, 42 | "KEY": dns.TypeKEY, 43 | "KX": dns.TypeKX, 44 | "L32": dns.TypeL32, 45 | "L64": dns.TypeL64, 46 | "LOC": dns.TypeLOC, 47 | "LP": dns.TypeLP, 48 | "MAILA": dns.TypeMAILA, 49 | "MAILB": dns.TypeMAILB, 50 | "MB": dns.TypeMB, 51 | "MD": dns.TypeMD, 52 | "MF": dns.TypeMF, 53 | "MG": dns.TypeMG, 54 | "MINFO": dns.TypeMINFO, 55 | "MR": dns.TypeMR, 56 | "MX": dns.TypeMX, 57 | "NAPTR": dns.TypeNAPTR, 58 | "NID": dns.TypeNID, 59 | "NIMLOC": dns.TypeNIMLOC, 60 | "NINFO": dns.TypeNINFO, 61 | "NS": dns.TypeNS, 62 | "NSAP-PTR": dns.TypeNSAPPTR, 63 | "NSEC": dns.TypeNSEC, 64 | "NSEC3": dns.TypeNSEC3, 65 | "NSEC3PARAM": dns.TypeNSEC3PARAM, 66 | "NULL": dns.TypeNULL, 67 | "NXT": dns.TypeNXT, 68 | "None": dns.TypeNone, 69 | "OPENPGPKEY": dns.TypeOPENPGPKEY, 70 | "OPT": dns.TypeOPT, 71 | "PTR": dns.TypePTR, 72 | "PX": dns.TypePX, 73 | "RKEY": dns.TypeRKEY, 74 | "RP": dns.TypeRP, 75 | "RRSIG": dns.TypeRRSIG, 76 | "RT": dns.TypeRT, 77 | "Reserved": dns.TypeReserved, 78 | "SIG": dns.TypeSIG, 79 | "SMIMEA": dns.TypeSMIMEA, 80 | "SOA": dns.TypeSOA, 81 | "SPF": dns.TypeSPF, 82 | "SRV": dns.TypeSRV, 83 | "SSHFP": dns.TypeSSHFP, 84 | "SVCB": dns.TypeSVCB, 85 | "TA": dns.TypeTA, 86 | "TALINK": dns.TypeTALINK, 87 | "TKEY": dns.TypeTKEY, 88 | "TLSA": dns.TypeTLSA, 89 | "TSIG": dns.TypeTSIG, 90 | "TXT": dns.TypeTXT, 91 | "UID": dns.TypeUID, 92 | "UINFO": dns.TypeUINFO, 93 | "UNSPEC": dns.TypeUNSPEC, 94 | "URI": dns.TypeURI, 95 | "X25": dns.TypeX25, 96 | "ZONEMD": dns.TypeZONEMD, 97 | } 98 | 99 | // ParseQTYPE function 100 | func ParseQTYPE(s string) uint16 { 101 | if s == "" { 102 | return dns.TypeA 103 | } 104 | 105 | if v, err := strconv.ParseUint(s, 10, 16); err == nil { 106 | return uint16(v) 107 | } 108 | 109 | if v, ok := qtype[strings.ToUpper(s)]; ok { 110 | return v 111 | } 112 | 113 | return dns.TypeNone 114 | } 115 | -------------------------------------------------------------------------------- /server/doh/qtype_test.go: -------------------------------------------------------------------------------- 1 | package doh 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_ParseQTYPE(t *testing.T) { 11 | qtype := ParseQTYPE("") 12 | assert.Equal(t, qtype, dns.TypeA) 13 | 14 | qtype = ParseQTYPE("1") 15 | assert.Equal(t, qtype, dns.TypeA) 16 | 17 | qtype = ParseQTYPE("CNAME") 18 | assert.Equal(t, qtype, dns.TypeCNAME) 19 | 20 | qtype = ParseQTYPE("TEST") 21 | assert.Equal(t, qtype, dns.TypeNone) 22 | } 23 | -------------------------------------------------------------------------------- /server/doq/doq.go: -------------------------------------------------------------------------------- 1 | package doq 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "io" 7 | "time" 8 | 9 | "github.com/miekg/dns" 10 | "github.com/quic-go/quic-go" 11 | ) 12 | 13 | var doqProtos = []string{"doq", "doq-i02", "dq", "doq-i00", "doq-i01", "doq-i11"} 14 | 15 | const ( 16 | minMsgHeaderSize = 14 // fixed msg header size 12 + quic prefix size 2 17 | ProtocolError = 0x2 18 | NoError = 0x0 19 | ) 20 | 21 | type Server struct { 22 | Addr string 23 | Handler dns.Handler 24 | 25 | ln *quic.Listener 26 | } 27 | 28 | func (s *Server) ListenAndServeQUIC(tlsCert, tlsKey string) error { 29 | cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey) 30 | if err != nil { 31 | return err 32 | } 33 | 34 | tlsConfig := &tls.Config{ 35 | Certificates: []tls.Certificate{cert}, 36 | NextProtos: doqProtos, 37 | } 38 | 39 | quicConfig := &quic.Config{ 40 | MaxIdleTimeout: 5 * time.Second, 41 | MaxStreamReceiveWindow: dns.MaxMsgSize, 42 | } 43 | 44 | listener, err := quic.ListenAddr(s.Addr, tlsConfig, quicConfig) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | s.ln = listener 50 | 51 | for { 52 | conn, err := listener.Accept(context.Background()) 53 | if err != nil { 54 | return err 55 | } 56 | 57 | go s.handleConnection(conn) 58 | } 59 | } 60 | 61 | func (s *Server) Shutdown() error { 62 | if s.ln == nil { 63 | return nil 64 | } 65 | 66 | err := s.ln.Close() 67 | 68 | if err == quic.ErrServerClosed { 69 | return nil 70 | } 71 | 72 | return err 73 | } 74 | 75 | func (s *Server) handleConnection(conn quic.Connection) { 76 | var ( 77 | stream quic.Stream 78 | buf []byte 79 | err error 80 | ) 81 | 82 | for { 83 | stream, err = conn.AcceptStream(context.Background()) 84 | if err != nil { 85 | _ = conn.CloseWithError(NoError, "") 86 | return 87 | } 88 | 89 | go func() { 90 | defer stream.Close() 91 | 92 | buf, err = io.ReadAll(stream) 93 | if err != nil { 94 | _ = conn.CloseWithError(ProtocolError, err.Error()) 95 | return 96 | } 97 | 98 | if len(buf) < minMsgHeaderSize { 99 | _ = conn.CloseWithError(ProtocolError, "dns msg size too small") 100 | return 101 | } 102 | 103 | req := new(dns.Msg) 104 | if err := req.Unpack(buf[2:]); err != nil { 105 | _ = conn.CloseWithError(ProtocolError, err.Error()) 106 | return 107 | } 108 | req.Id = dns.Id() 109 | 110 | w := &ResponseWriter{Conn: conn, Stream: stream} 111 | 112 | s.Handler.ServeDNS(w, req) 113 | }() 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /server/doq/doq_test.go: -------------------------------------------------------------------------------- 1 | package doq 2 | 3 | import ( 4 | "context" 5 | "crypto/ecdsa" 6 | "crypto/rand" 7 | "crypto/rsa" 8 | "crypto/tls" 9 | "crypto/x509" 10 | "crypto/x509/pkix" 11 | "encoding/pem" 12 | "fmt" 13 | "io" 14 | "math/big" 15 | "os" 16 | "path/filepath" 17 | "testing" 18 | "time" 19 | 20 | "github.com/miekg/dns" 21 | "github.com/quic-go/quic-go" 22 | "github.com/stretchr/testify/assert" 23 | ) 24 | 25 | type dummyHandler struct { 26 | dns.Handler 27 | } 28 | 29 | func makeRR(data string) dns.RR { 30 | r, _ := dns.NewRR(data) 31 | 32 | return r 33 | } 34 | 35 | func (h *dummyHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 36 | msg := new(dns.Msg) 37 | msg.SetReply(r) 38 | msg.Answer = append(msg.Answer, makeRR("example.com. 1800 IN A 0.0.0.0")) 39 | 40 | _ = w.WriteMsg(msg) 41 | } 42 | 43 | func publicKey(priv interface{}) interface{} { 44 | switch k := priv.(type) { 45 | case *rsa.PrivateKey: 46 | return &k.PublicKey 47 | case *ecdsa.PrivateKey: 48 | return &k.PublicKey 49 | default: 50 | return nil 51 | } 52 | } 53 | 54 | func pemBlockForKey(priv interface{}) *pem.Block { 55 | switch k := priv.(type) { 56 | case *rsa.PrivateKey: 57 | return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)} 58 | case *ecdsa.PrivateKey: 59 | b, err := x509.MarshalECPrivateKey(k) 60 | if err != nil { 61 | fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err) 62 | os.Exit(2) 63 | } 64 | return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} 65 | default: 66 | return nil 67 | } 68 | } 69 | 70 | func generateCertificate() error { 71 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | template := x509.Certificate{ 77 | SerialNumber: big.NewInt(1), 78 | Subject: pkix.Name{ 79 | Organization: []string{"Acme Co"}, 80 | }, 81 | NotBefore: time.Now(), 82 | NotAfter: time.Now().Add(time.Hour * 24 * 365 * 3), 83 | 84 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 85 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 86 | BasicConstraintsValid: true, 87 | } 88 | 89 | template.DNSNames = append(template.DNSNames, "localhost") 90 | 91 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) 92 | if err != nil { 93 | return err 94 | } 95 | 96 | certOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.cert"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 97 | if err != nil { 98 | return err 99 | } 100 | 101 | err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 102 | if err != nil { 103 | return err 104 | } 105 | 106 | certOut.Close() 107 | 108 | keyOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | err = pem.Encode(keyOut, pemBlockForKey(priv)) 114 | if err != nil { 115 | return err 116 | } 117 | 118 | return keyOut.Close() 119 | } 120 | 121 | func Test_doq(t *testing.T) { 122 | err := generateCertificate() 123 | assert.NoError(t, err) 124 | 125 | cert := filepath.Join(os.TempDir(), "test.cert") 126 | privkey := filepath.Join(os.TempDir(), "test.key") 127 | 128 | h := &dummyHandler{} 129 | 130 | s := &Server{ 131 | Addr: "127.0.0.1:45853", 132 | Handler: h, 133 | } 134 | 135 | go func() { 136 | err := s.ListenAndServeQUIC(cert, privkey) 137 | if err == quic.ErrServerClosed { 138 | return 139 | } 140 | assert.NoError(t, err) 141 | }() 142 | 143 | time.Sleep(time.Second) 144 | 145 | tlsConf := &tls.Config{ 146 | InsecureSkipVerify: true, 147 | NextProtos: []string{"doq"}, 148 | } 149 | conn, err := quic.DialAddr(context.Background(), s.Addr, tlsConf, nil) 150 | assert.NoError(t, err) 151 | 152 | stream, err := conn.OpenStreamSync(context.Background()) 153 | assert.NoError(t, err) 154 | 155 | req := new(dns.Msg) 156 | req.SetQuestion("example.com.", dns.TypeA) 157 | req.Id = 0 158 | 159 | buf, err := req.Pack() 160 | assert.NoError(t, err) 161 | 162 | n, err := stream.Write(addPrefixLen(buf)) 163 | assert.NoError(t, err) 164 | assert.Greater(t, n, 17) 165 | 166 | err = stream.Close() 167 | assert.NoError(t, err) 168 | 169 | data, err := io.ReadAll(stream) 170 | assert.NoError(t, err) 171 | 172 | msg := new(dns.Msg) 173 | err = msg.Unpack(data[2:]) 174 | assert.NoError(t, err) 175 | 176 | stream, err = conn.OpenStreamSync(context.Background()) 177 | assert.NoError(t, err) 178 | 179 | time.Sleep(6 * time.Second) 180 | 181 | _, err = stream.Write([]byte{0, 0}) 182 | assert.Error(t, err) 183 | 184 | conn, err = quic.DialAddr(context.Background(), s.Addr, tlsConf, nil) 185 | assert.NoError(t, err) 186 | 187 | stream, err = conn.OpenStreamSync(context.Background()) 188 | assert.NoError(t, err) 189 | 190 | _, err = stream.Write([]byte{0, 0}) 191 | assert.NoError(t, err) 192 | 193 | err = stream.Close() 194 | assert.NoError(t, err) 195 | 196 | _, err = io.ReadAll(stream) 197 | assert.Error(t, err) 198 | 199 | conn, err = quic.DialAddr(context.Background(), s.Addr, tlsConf, nil) 200 | assert.NoError(t, err) 201 | 202 | stream, err = conn.OpenStreamSync(context.Background()) 203 | assert.NoError(t, err) 204 | 205 | msg = new(dns.Msg) 206 | msg.SetEdns0(512, true) 207 | buf, _ = msg.Pack() 208 | 209 | _, err = stream.Write(buf) 210 | assert.NoError(t, err) 211 | 212 | err = stream.Close() 213 | assert.NoError(t, err) 214 | 215 | _, err = io.ReadAll(stream) 216 | assert.Error(t, err) 217 | 218 | err = s.Shutdown() 219 | assert.NoError(t, err) 220 | } 221 | -------------------------------------------------------------------------------- /server/doq/response_writer.go: -------------------------------------------------------------------------------- 1 | package doq 2 | 3 | import ( 4 | "encoding/binary" 5 | "net" 6 | 7 | "github.com/miekg/dns" 8 | "github.com/quic-go/quic-go" 9 | ) 10 | 11 | type ResponseWriter struct { 12 | dns.ResponseWriter 13 | 14 | Conn quic.Connection 15 | Stream quic.Stream 16 | } 17 | 18 | func (w *ResponseWriter) LocalAddr() net.Addr { 19 | return w.Conn.LocalAddr() 20 | } 21 | 22 | func (w *ResponseWriter) RemoteAddr() net.Addr { 23 | return w.Conn.RemoteAddr() 24 | } 25 | 26 | func (w *ResponseWriter) Close() error { 27 | return w.Stream.Close() 28 | } 29 | 30 | func (w *ResponseWriter) Write(m []byte) (int, error) { 31 | return w.Stream.Write(addPrefixLen(m)) 32 | } 33 | 34 | func (w *ResponseWriter) WriteMsg(m *dns.Msg) error { 35 | m.Id = 0 36 | 37 | packed, err := m.Pack() 38 | if err != nil { 39 | _ = w.Conn.CloseWithError(0x1, err.Error()) 40 | return err 41 | } 42 | 43 | _, err = w.Stream.Write(addPrefixLen(packed)) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | return nil 49 | } 50 | 51 | func addPrefixLen(msg []byte) (buf []byte) { 52 | buf = make([]byte, 2+len(msg)) 53 | binary.BigEndian.PutUint16(buf, uint16(len(msg))) 54 | copy(buf[2:], msg) 55 | 56 | return buf 57 | } 58 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "crypto/ecdsa" 6 | "crypto/rand" 7 | "crypto/rsa" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "encoding/base64" 11 | "encoding/pem" 12 | "fmt" 13 | "io" 14 | "math/big" 15 | "net/http" 16 | "net/http/httptest" 17 | "os" 18 | "path/filepath" 19 | "testing" 20 | "time" 21 | 22 | "github.com/semihalev/sdns/middleware" 23 | 24 | "github.com/miekg/dns" 25 | "github.com/semihalev/log" 26 | "github.com/semihalev/sdns/config" 27 | "github.com/semihalev/sdns/middleware/blocklist" 28 | "github.com/semihalev/sdns/mock" 29 | "github.com/stretchr/testify/assert" 30 | ) 31 | 32 | func TestMain(m *testing.M) { 33 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler)) 34 | m.Run() 35 | 36 | os.Exit(0) 37 | } 38 | 39 | func publicKey(priv interface{}) interface{} { 40 | switch k := priv.(type) { 41 | case *rsa.PrivateKey: 42 | return &k.PublicKey 43 | case *ecdsa.PrivateKey: 44 | return &k.PublicKey 45 | default: 46 | return nil 47 | } 48 | } 49 | 50 | func pemBlockForKey(priv interface{}) *pem.Block { 51 | switch k := priv.(type) { 52 | case *rsa.PrivateKey: 53 | return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)} 54 | case *ecdsa.PrivateKey: 55 | b, err := x509.MarshalECPrivateKey(k) 56 | if err != nil { 57 | fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err) 58 | os.Exit(2) 59 | } 60 | return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} 61 | default: 62 | return nil 63 | } 64 | } 65 | 66 | func generateCertificate() error { 67 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 68 | if err != nil { 69 | return err 70 | } 71 | 72 | template := x509.Certificate{ 73 | SerialNumber: big.NewInt(1), 74 | Subject: pkix.Name{ 75 | Organization: []string{"Acme Co"}, 76 | }, 77 | NotBefore: time.Now(), 78 | NotAfter: time.Now().Add(time.Hour * 24 * 365 * 3), 79 | 80 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 81 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 82 | BasicConstraintsValid: true, 83 | } 84 | 85 | template.DNSNames = append(template.DNSNames, "localhost") 86 | 87 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | certOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.cert"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | certOut.Close() 103 | 104 | keyOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | err = pem.Encode(keyOut, pemBlockForKey(priv)) 110 | if err != nil { 111 | return err 112 | } 113 | 114 | return keyOut.Close() 115 | } 116 | 117 | func Test_logPipe(t *testing.T) { 118 | logReader, logWriter := io.Pipe() 119 | go readlogs(logReader) 120 | _, _ = logWriter.Write([]byte("test test test test test test\n")) 121 | } 122 | 123 | func Test_ServerNoBind(t *testing.T) { 124 | cfg := &config.Config{} 125 | 126 | s := New(cfg) 127 | s.Run(context.Background()) 128 | } 129 | 130 | func Test_ServerBindFail(t *testing.T) { 131 | cfg := &config.Config{} 132 | 133 | cfg.TLSCertificate = "cert" 134 | cfg.TLSPrivateKey = "key" 135 | cfg.LogLevel = "crit" 136 | cfg.Bind = "1:1" 137 | cfg.BindTLS = "1:2" 138 | cfg.BindDOH = "1:3" 139 | cfg.BindDOQ = "1:4" 140 | 141 | s := New(cfg) 142 | s.Run(context.Background()) 143 | } 144 | 145 | func Test_Server(t *testing.T) { 146 | cfg := &config.Config{} 147 | err := generateCertificate() 148 | assert.NoError(t, err) 149 | 150 | cert := filepath.Join(os.TempDir(), "test.cert") 151 | privkey := filepath.Join(os.TempDir(), "test.key") 152 | 153 | cfg.TLSCertificate = cert 154 | cfg.TLSPrivateKey = privkey 155 | cfg.LogLevel = "crit" 156 | cfg.Bind = "127.0.0.1:0" 157 | cfg.BindTLS = "127.0.0.1:23222" 158 | cfg.BindDOH = "127.0.0.1:23223" 159 | cfg.BindDOQ = "127.0.0.1:23224" 160 | cfg.BlockListDir = filepath.Join(os.TempDir(), "sdns_temp") 161 | 162 | middleware.Register("blocklist", func(cfg *config.Config) middleware.Handler { return blocklist.New(cfg) }) 163 | middleware.Setup(cfg) 164 | 165 | blocklist := middleware.Get("blocklist").(*blocklist.BlockList) 166 | blocklist.Set("test.com.") 167 | 168 | s := New(cfg) 169 | s.Run(context.Background()) 170 | 171 | req := new(dns.Msg) 172 | req.SetQuestion("test.com.", dns.TypeA) 173 | 174 | mw := mock.NewWriter("udp", "127.0.0.1:0") 175 | s.ServeDNS(mw, req) 176 | 177 | assert.True(t, mw.Written()) 178 | if assert.NotNil(t, mw.Msg()) { 179 | assert.Equal(t, true, len(mw.Msg().Answer) > 0) 180 | } 181 | 182 | request, err := http.NewRequest("GET", "/dns-query?name=test.com", nil) 183 | assert.NoError(t, err) 184 | 185 | hw := httptest.NewRecorder() 186 | 187 | s.ServeHTTP(hw, request) 188 | assert.Equal(t, 200, hw.Code) 189 | 190 | data, err := req.Pack() 191 | assert.NoError(t, err) 192 | 193 | dq := base64.RawURLEncoding.EncodeToString(data) 194 | 195 | request, err = http.NewRequest("GET", fmt.Sprintf("/dns-query?dns=%s", dq), nil) 196 | assert.NoError(t, err) 197 | 198 | hw = httptest.NewRecorder() 199 | 200 | s.ServeHTTP(hw, request) 201 | assert.Equal(t, 200, hw.Code) 202 | 203 | request, err = http.NewRequest("GET", "/dns-query?name=example.com", nil) 204 | assert.NoError(t, err) 205 | 206 | hw = httptest.NewRecorder() 207 | 208 | s.ServeHTTP(hw, request) 209 | assert.Equal(t, 400, hw.Code) 210 | 211 | time.Sleep(2 * time.Second) 212 | 213 | os.Remove(cert) 214 | os.Remove(privkey) 215 | } 216 | -------------------------------------------------------------------------------- /snap/snapcraft.yaml: -------------------------------------------------------------------------------- 1 | name: sdns 2 | adopt-info: sdns 3 | summary: A DNS Resolver Server. 4 | description: | 5 | A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy. 6 | 7 | architectures: 8 | - build-on: armhf 9 | - build-on: arm64 10 | - build-on: amd64 11 | 12 | grade: stable 13 | confinement: strict 14 | base: core18 15 | 16 | parts: 17 | sdns: 18 | plugin: go 19 | source: https://github.com/semihalev/sdns.git 20 | go-importpath: github.com/semihalev/sdns 21 | source-type: git 22 | override-pull: | 23 | snapcraftctl pull 24 | last_committed_tag="$(git describe --tags --abbrev=0)" 25 | last_committed_tag_ver="$(echo ${last_committed_tag} | sed 's/v//')" 26 | last_released_tag="$(snap info sdns | awk '$1 == "latest/stable:" { print $2 }')" 27 | if [ "${last_committed_tag_ver}" != "${last_released_tag}" ]; then 28 | git fetch 29 | git checkout "${last_committed_tag}" 30 | fi 31 | snapcraftctl set-version "$(git describe --tags | sed 's/v//')" 32 | snapcraftctl set-grade stable 33 | stage: 34 | - bin/sdns 35 | apps: 36 | sdns: 37 | command: sdns 38 | plugs: [network-bind] 39 | daemon: simple 40 | -------------------------------------------------------------------------------- /waitgroup/waitgroup.go: -------------------------------------------------------------------------------- 1 | package waitgroup 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // WaitGroup waits for other same processes based key with timeout. 10 | type WaitGroup struct { 11 | mu sync.RWMutex 12 | 13 | groups map[uint64]*call 14 | 15 | timeout time.Duration 16 | } 17 | 18 | type call struct { 19 | ctx context.Context 20 | dups int 21 | cancel func() 22 | } 23 | 24 | // New return a new WaitGroup with timeout. 25 | func New(timeout time.Duration) *WaitGroup { 26 | return &WaitGroup{ 27 | groups: make(map[uint64]*call), 28 | 29 | timeout: timeout, 30 | } 31 | } 32 | 33 | // Get return count of dups with key. 34 | func (wg *WaitGroup) Get(key uint64) int { 35 | wg.mu.RLock() 36 | defer wg.mu.RUnlock() 37 | 38 | if c, ok := wg.groups[key]; ok { 39 | return c.dups 40 | } 41 | 42 | return 0 43 | } 44 | 45 | // Wait blocks until WaitGroup context cancelled or timedout with key. 46 | func (wg *WaitGroup) Wait(key uint64) { 47 | wg.mu.RLock() 48 | 49 | if c, ok := wg.groups[key]; ok { 50 | wg.mu.RUnlock() 51 | <-c.ctx.Done() 52 | return 53 | } 54 | 55 | wg.mu.RUnlock() 56 | } 57 | 58 | // Add adds a new caller or if the caller exists increment dups with key. 59 | func (wg *WaitGroup) Add(key uint64) { 60 | wg.mu.Lock() 61 | defer wg.mu.Unlock() 62 | 63 | if c, ok := wg.groups[key]; ok { 64 | c.dups++ 65 | return 66 | } 67 | 68 | c := new(call) 69 | c.dups++ 70 | c.ctx, c.cancel = context.WithTimeout(context.Background(), wg.timeout) 71 | wg.groups[key] = c 72 | } 73 | 74 | // Done cancels the group context or if the caller dups more then zero, decrements the dups with key. 75 | func (wg *WaitGroup) Done(key uint64) { 76 | wg.mu.Lock() 77 | defer wg.mu.Unlock() 78 | 79 | if c, ok := wg.groups[key]; ok { 80 | if c.dups > 1 { 81 | c.dups-- 82 | return 83 | } 84 | c.cancel() 85 | } 86 | 87 | delete(wg.groups, key) 88 | } 89 | -------------------------------------------------------------------------------- /waitgroup/waitgroup_test.go: -------------------------------------------------------------------------------- 1 | package waitgroup 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/miekg/dns" 9 | "github.com/semihalev/sdns/cache" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_WaitGroupWait(t *testing.T) { 14 | wg := New(5 * time.Second) 15 | mu := sync.RWMutex{} 16 | 17 | m := new(dns.Msg) 18 | m.SetQuestion(dns.Fqdn("example.com."), dns.TypeA) 19 | key := cache.Hash(m.Question[0]) 20 | 21 | wg.Add(key) 22 | 23 | count := wg.Get(key) 24 | assert.Equal(t, 1, count) 25 | 26 | key2 := cache.Hash(dns.Question{Name: "none.", Qtype: dns.TypeA, Qclass: dns.ClassINET}) 27 | 28 | count = wg.Get(key2) 29 | assert.Equal(t, 0, count) 30 | 31 | wg.Wait(key2) 32 | 33 | var workers []*string 34 | 35 | for i := 0; i < 5; i++ { 36 | go func() { 37 | w := new(string) 38 | *w = "running" 39 | 40 | mu.Lock() 41 | workers = append(workers, w) 42 | mu.Unlock() 43 | 44 | wg.Wait(key) 45 | 46 | mu.Lock() 47 | *w = "stopped" 48 | mu.Unlock() 49 | }() 50 | } 51 | 52 | time.Sleep(time.Second) 53 | 54 | wg.Done(key) 55 | 56 | time.Sleep(100 * time.Millisecond) 57 | 58 | mu.RLock() 59 | defer mu.RUnlock() 60 | for _, w := range workers { 61 | assert.Equal(t, *w, "stopped") 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /zregister.go: -------------------------------------------------------------------------------- 1 | // Code generated by gen.go DO NOT EDIT. 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/semihalev/sdns/config" 7 | "github.com/semihalev/sdns/middleware" 8 | "github.com/semihalev/sdns/middleware/accesslist" 9 | "github.com/semihalev/sdns/middleware/accesslog" 10 | "github.com/semihalev/sdns/middleware/as112" 11 | "github.com/semihalev/sdns/middleware/blocklist" 12 | "github.com/semihalev/sdns/middleware/cache" 13 | "github.com/semihalev/sdns/middleware/chaos" 14 | "github.com/semihalev/sdns/middleware/edns" 15 | "github.com/semihalev/sdns/middleware/failover" 16 | "github.com/semihalev/sdns/middleware/forwarder" 17 | "github.com/semihalev/sdns/middleware/hostsfile" 18 | "github.com/semihalev/sdns/middleware/loop" 19 | "github.com/semihalev/sdns/middleware/metrics" 20 | "github.com/semihalev/sdns/middleware/ratelimit" 21 | "github.com/semihalev/sdns/middleware/recovery" 22 | "github.com/semihalev/sdns/middleware/resolver" 23 | ) 24 | 25 | func init() { 26 | middleware.Register("recovery", func(cfg *config.Config) middleware.Handler { return recovery.New(cfg) }) 27 | middleware.Register("loop", func(cfg *config.Config) middleware.Handler { return loop.New(cfg) }) 28 | middleware.Register("metrics", func(cfg *config.Config) middleware.Handler { return metrics.New(cfg) }) 29 | middleware.Register("accesslist", func(cfg *config.Config) middleware.Handler { return accesslist.New(cfg) }) 30 | middleware.Register("ratelimit", func(cfg *config.Config) middleware.Handler { return ratelimit.New(cfg) }) 31 | middleware.Register("edns", func(cfg *config.Config) middleware.Handler { return edns.New(cfg) }) 32 | middleware.Register("accesslog", func(cfg *config.Config) middleware.Handler { return accesslog.New(cfg) }) 33 | middleware.Register("chaos", func(cfg *config.Config) middleware.Handler { return chaos.New(cfg) }) 34 | middleware.Register("hostsfile", func(cfg *config.Config) middleware.Handler { return hostsfile.New(cfg) }) 35 | middleware.Register("blocklist", func(cfg *config.Config) middleware.Handler { return blocklist.New(cfg) }) 36 | middleware.Register("as112", func(cfg *config.Config) middleware.Handler { return as112.New(cfg) }) 37 | middleware.Register("cache", func(cfg *config.Config) middleware.Handler { return cache.New(cfg) }) 38 | middleware.Register("failover", func(cfg *config.Config) middleware.Handler { return failover.New(cfg) }) 39 | middleware.Register("resolver", func(cfg *config.Config) middleware.Handler { return resolver.New(cfg) }) 40 | middleware.Register("forwarder", func(cfg *config.Config) middleware.Handler { return forwarder.New(cfg) }) 41 | } 42 | --------------------------------------------------------------------------------