├── .codecov.yml
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── config.yml
├── dependabot.yml
└── workflows
│ ├── codecov.yml
│ ├── codeql-action.yml
│ ├── docker-hub.yml
│ ├── docker-publish.yml
│ ├── go.yml
│ ├── golangci-lint.yml
│ └── goreleaser.yml
├── .gitignore
├── .goreleaser.yml
├── CNAME
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── SECURITY.md
├── _config.yml
├── api
├── README.md
├── api.go
├── api_test.go
├── context.go
├── group.go
├── router.go
└── tree.go
├── authcache
├── authserver.go
├── authserver_test.go
├── ns_cache.go
└── ns_cache_test.go
├── cache
├── cache.go
├── cache_test.go
├── hash.go
├── hash_test.go
├── shard.go
└── shard_test.go
├── config
├── config.go
└── config_test.go
├── contrib
└── linux
│ ├── adduser.sh
│ ├── sdns.conf
│ └── sdns.service
├── dnsutil
├── dnsutil.go
├── dnsutil_test.go
├── ttl.go
└── ttl_test.go
├── doc.go
├── docker-compose.yml
├── gen.go
├── go.mod
├── go.sum
├── logo.png
├── middleware
├── accesslist
│ ├── accesslist.go
│ └── accesslist_test.go
├── accesslog
│ ├── accesslog.go
│ └── accesslog_test.go
├── as112
│ ├── as112.go
│ └── as112_test.go
├── blocklist
│ ├── blocklist.go
│ ├── blocklist_test.go
│ ├── updater.go
│ └── updater_test.go
├── cache
│ ├── cache.go
│ ├── cache_test.go
│ └── item.go
├── chain.go
├── chain_test.go
├── chaos
│ ├── chaos.go
│ └── chaos_test.go
├── edns
│ ├── edns.go
│ └── edns_test.go
├── failover
│ ├── failover.go
│ └── failover_test.go
├── forwarder
│ ├── forwarder.go
│ └── forwarder_test.go
├── hostsfile
│ ├── hostsfile.go
│ └── hostsfile_test.go
├── loop
│ ├── loop.go
│ └── loop_test.go
├── metrics
│ ├── metrics.go
│ └── metrics_test.go
├── middleware.go
├── middleware_test.go
├── ratelimit
│ ├── ratelimit.go
│ └── ratelimit_test.go
├── recovery
│ ├── recovery.go
│ └── recovery_test.go
├── resolver
│ ├── auto_trust_anchor.go
│ ├── auto_trust_anchor_test.go
│ ├── client.go
│ ├── client_test.go
│ ├── handler.go
│ ├── handler_test.go
│ ├── nsec3.go
│ ├── nsec3_test.go
│ ├── resolver.go
│ ├── resolver_test.go
│ ├── singleinflight.go
│ ├── utils.go
│ └── utils_test.go
└── response_writer.go
├── mock
├── writer.go
└── writer_test.go
├── response
├── typify.go
└── typify_test.go
├── sdns.go
├── server
├── doh
│ ├── doh.go
│ ├── doh_test.go
│ ├── msg.go
│ ├── msg_test.go
│ ├── qtype.go
│ └── qtype_test.go
├── doq
│ ├── doq.go
│ ├── doq_test.go
│ └── response_writer.go
├── server.go
└── server_test.go
├── snap
└── snapcraft.yaml
├── waitgroup
├── waitgroup.go
└── waitgroup_test.go
└── zregister.go
/.codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | target: 40%
6 | threshold: null
7 | patch: false
8 | changes: false
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: "Bug Report"
2 | description: Create a new issue for a bug.
3 | title: "[BUG] -
"
4 | labels: [
5 | "bug"
6 | ]
7 | assignees:
8 | - semihalev
9 | body:
10 | - type: markdown
11 | attributes:
12 | value: |
13 | Thank you for dedicating time to report this issue. We appreciate your effort. Kindly complete the form below.
14 | - type: textarea
15 | id: description
16 | attributes:
17 | label: "Description"
18 | description: Please enter an explicit description of your issue
19 | placeholder: Short and explicit description of your incident...
20 | validations:
21 | required: true
22 | - type: textarea
23 | id: reprod
24 | attributes:
25 | label: "Reproduction steps"
26 | description: Please enter an explicit description of your issue
27 | render: bash
28 | validations:
29 | required: true
30 | - type: textarea
31 | id: sdns-version
32 | attributes:
33 | label: sdns version
34 | description: "`sdns --version` show the version information"
35 | render: bash
36 | validations:
37 | required: true
38 | - type: textarea
39 | id: logs
40 | attributes:
41 | label: "Logs"
42 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
43 | render: bash
44 | validations:
45 | required: false
46 | - type: dropdown
47 | id: os
48 | attributes:
49 | label: "OS"
50 | description: What is the impacted environment ?
51 | multiple: true
52 | options:
53 | - Windows
54 | - Linux
55 | - MacOS
56 | - FreeBSD
57 | - Other
58 | - type: textarea
59 | id: ctx
60 | attributes:
61 | label: Additional context
62 | description: Anything else you would like to add
63 | validations:
64 | required: false
65 | validations:
66 | required: false
67 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | contact_links:
2 | - name: Feature Request
3 | url: https://github.com/semihalev/sdns/discussions/new?category=ideas
4 | about: Discuss a new feture for SDNS
5 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "gomod"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 | time: "15:30"
8 | labels:
9 | - "dependencies"
10 | - package-ecosystem: "github-actions"
11 | directory: "/"
12 | schedule:
13 | interval: "daily"
14 | time: "15:30"
15 | labels:
16 | - "dependencies"
17 |
--------------------------------------------------------------------------------
/.github/workflows/codecov.yml:
--------------------------------------------------------------------------------
1 | name: Codecov
2 |
3 | on:
4 | push:
5 | pull_request:
6 |
7 | jobs:
8 | run:
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - uses: actions/checkout@master
13 |
14 | - name: Set up Go 1.x
15 | uses: actions/setup-go@v5
16 | with:
17 | go-version: ^1.22
18 |
19 | - name: Generate coverage report
20 | run: make test
21 |
22 | - name: Upload coverage to Codecov
23 | uses: codecov/codecov-action@v5.3.1
24 | with:
25 | fail_ci_if_error: true
26 | token: ${{ secrets.CODECOV_TOKEN }}
27 | file: ./coverage.out
28 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-action.yml:
--------------------------------------------------------------------------------
1 | name: "Code Scanning - Action"
2 |
3 | on:
4 | push:
5 | pull_request:
6 | schedule:
7 | - cron: '0 0 * * 0'
8 |
9 | jobs:
10 | CodeQL-Build:
11 | # CodeQL runs on ubuntu-latest, windows-latest, and macos-latest
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v4
17 | with:
18 | # Must fetch at least the immediate parents so that if this is
19 | # a pull request then we can checkout the head of the pull request.
20 | # Only include this option if you are running this workflow on pull requests.
21 | fetch-depth: 2
22 |
23 | # If this run was triggered by a pull request event then checkout
24 | # the head of the pull request instead of the merge commit.
25 | # Only include this step if you are running this workflow on pull requests.
26 | - run: git checkout HEAD^2
27 | if: ${{ github.event_name == 'pull_request' }}
28 |
29 | # Initializes the CodeQL tools for scanning.
30 | - name: Initialize CodeQL
31 | uses: github/codeql-action/init@v3
32 | # Override language selection by uncommenting this and choosing your languages
33 | # with:
34 | # languages: go, javascript, csharp, python, cpp, java
35 |
36 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
37 | # If this step fails, then you should remove it and run the build manually (see below).
38 | - name: Autobuild
39 | uses: github/codeql-action/autobuild@v3
40 |
41 | # ℹ️ Command-line programs to run using the OS shell.
42 | # 📚 https://git.io/JvXDl
43 |
44 | # ✏️ If the Autobuild fails above, remove it and uncomment the following
45 | # three lines and modify them (or add more) to build your code if your
46 | # project uses a compiled language
47 |
48 | #- run: |
49 | # make bootstrap
50 | # make release
51 |
52 | - name: Perform CodeQL Analysis
53 | uses: github/codeql-action/analyze@v3
54 |
--------------------------------------------------------------------------------
/.github/workflows/docker-hub.yml:
--------------------------------------------------------------------------------
1 | name: "Docker Hub"
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches:
7 | - "master"
8 | tags:
9 | - v*
10 | jobs:
11 | docker:
12 | runs-on: ubuntu-latest
13 | if: github.event_name == 'push'
14 |
15 | steps:
16 | -
17 | name: Checkout
18 | uses: actions/checkout@v4
19 | -
20 | name: Docker meta
21 | id: meta
22 | uses: docker/metadata-action@v5
23 | with:
24 | images: |
25 | c1982/sdns
26 | tags: |
27 | type=semver,pattern={{version}}
28 | type=raw,value=latest,enable={{is_default_branch}}
29 | -
30 | name: Set up QEMU
31 | uses: docker/setup-qemu-action@v3
32 | -
33 | name: Set up Docker Buildx
34 | uses: docker/setup-buildx-action@v3
35 | -
36 | name: Login to Docker Hub
37 | uses: docker/login-action@v3
38 | with:
39 | username: ${{ secrets.DOCKERHUB_USERNAME }}
40 | password: ${{ secrets.DOCKERHUB_TOKEN }}
41 | -
42 | name: Build and push
43 | uses: docker/build-push-action@v6
44 | with:
45 | context: .
46 | platforms: linux/amd64,linux/arm64
47 | push: true
48 | tags: ${{ steps.meta.outputs.tags }}
49 | labels: ${{ steps.meta.outputs.labels }}
50 |
--------------------------------------------------------------------------------
/.github/workflows/docker-publish.yml:
--------------------------------------------------------------------------------
1 | name: "Docker Package"
2 |
3 | on:
4 | push:
5 | # Publish `master` as Docker `latest` image.
6 | branches:
7 | - master
8 |
9 | # Publish `v1.2.3` tags as releases.
10 | tags:
11 | - v*
12 |
13 | env:
14 | # TODO: Change variable to your image's name.
15 | IMAGE_NAME: sdns
16 |
17 | jobs:
18 | # Push image to GitHub Packages.
19 | # See also https://docs.docker.com/docker-hub/builds/
20 | push:
21 | runs-on: ubuntu-latest
22 | if: github.event_name == 'push'
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 |
27 | - name: Version
28 | id: get_version
29 | run: |
30 | VERSION=$(echo "${{ github.ref }}" | sed -e 's,.*/\(.*\),\1,')
31 | [ "$VERSION" == "master" ] && VERSION=latest
32 | echo ::set-output name=VERSION::$VERSION
33 |
34 | - name: Build image
35 | run: docker build -t docker.pkg.github.com/${{ github.repository }}/sdns:${{ steps.get_version.outputs.VERSION }} .
36 |
37 | - name: Log into registry
38 | run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login docker.pkg.github.com -u ${{ github.actor }} --password-stdin
39 |
40 | - name: Push image
41 | run: docker push docker.pkg.github.com/${{ github.repository }}/sdns:${{ steps.get_version.outputs.VERSION }}
42 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | name: Go
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 |
11 | build:
12 | name: Build
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | os: [ubuntu-latest, macos-latest, windows-latest]
17 |
18 | steps:
19 | - name: Set up Go 1.x
20 | uses: actions/setup-go@v5
21 | with:
22 | go-version: ^1.22
23 | id: go
24 |
25 | - name: Check out code into the Go module directory
26 | uses: actions/checkout@v4
27 |
28 | - name: Get dependencies
29 | run: |
30 | go get -v -t -d ./...
31 |
32 | - name: Build
33 | run: go build -v .
34 |
35 | - name: Test
36 | run: make test
37 |
--------------------------------------------------------------------------------
/.github/workflows/golangci-lint.yml:
--------------------------------------------------------------------------------
1 | name: Linter
2 | on:
3 | push:
4 | tags:
5 | - v*
6 | branches:
7 | - master
8 | pull_request:
9 | jobs:
10 | golangci:
11 | name: lint
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: actions/setup-go@v5
16 | with:
17 | cache: false
18 | go-version: ^1.22
19 | - name: golangci-lint
20 | uses: golangci/golangci-lint-action@v6.4.1
21 | with:
22 | # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
23 | version: v1.62.2
24 |
25 | # Optional: working directory, useful for monorepos
26 | # working-directory: somedir
27 |
28 | # Optional: golangci-lint command line arguments.
29 | # args: --disable typecheck
30 |
31 | # Optional: show only new issues if it's a pull request. The default value is `false`.
32 | # only-new-issues: true
33 |
--------------------------------------------------------------------------------
/.github/workflows/goreleaser.yml:
--------------------------------------------------------------------------------
1 | name: Releaser
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v[0-9]+.[0-9]+.[0-9]+*'
7 |
8 | jobs:
9 | goreleaser:
10 | runs-on: ubuntu-latest
11 | steps:
12 | -
13 | name: Checkout
14 | uses: actions/checkout@v4
15 | with:
16 | fetch-depth: 0
17 | -
18 | name: Set up Go 1.x
19 | uses: actions/setup-go@v5
20 | with:
21 | go-version: ^1.22
22 | -
23 | name: Run GoReleaser
24 | uses: goreleaser/goreleaser-action@v6.2.1
25 | with:
26 | version: latest
27 | args: release --clean
28 | env:
29 | GITHUB_TOKEN: ${{ secrets.GH_API_TOKEN }}
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pem
2 | *.crt
3 | *.key
4 | sdns
5 | sdns.conf
6 | coverage.out
7 | db
8 | access.log
9 | vendor
10 |
--------------------------------------------------------------------------------
/.goreleaser.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | env_files:
4 | github_token: ~/.GH_TOKEN.txt
5 |
6 | env:
7 | - CGO_ENABLED=0
8 |
9 | before:
10 | hooks:
11 | - go mod download
12 | - go generate ./...
13 |
14 | builds:
15 | - id: nonf
16 | targets:
17 | - darwin_amd64
18 | - darwin_arm64
19 | - windows_amd64
20 | - freebsd_amd64
21 | - openbsd_amd64
22 | - netbsd_amd64
23 | flags:
24 | - -trimpath
25 | ldflags:
26 | - -s -w
27 | - id: linux
28 | goos:
29 | - linux
30 | goarch:
31 | - arm
32 | - arm64
33 | - mips
34 | - mipsle
35 | - mips64
36 | - mips64le
37 | goarm:
38 | - 5
39 | - 6
40 | - 7
41 | gomips:
42 | - softfloat
43 | flags:
44 | - -trimpath
45 | ldflags:
46 | - -s -w
47 | - id: nf
48 | goos:
49 | - linux
50 | goarch:
51 | - amd64
52 | flags:
53 | - -trimpath
54 | ldflags:
55 | - -s -w
56 |
57 | release:
58 | github:
59 | owner: semihalev
60 | name: sdns
61 | prerelease: false
62 | draft: true
63 |
64 | archives:
65 | -
66 | name_template: "{{ .ProjectName }}-{{ .Version }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
67 | format_overrides:
68 | - goos: windows
69 | format: zip
70 | wrap_in_directory: true
71 | files:
72 | - README.md
73 | - LICENSE
74 |
75 | checksum:
76 | name_template: '{{ .ProjectName }}-{{ .Version }}_sha256sums.txt'
77 | algorithm: sha256
78 |
79 | changelog:
80 | disable: true
81 |
82 | nfpms:
83 | - file_name_template: '{{ .ProjectName }}_{{ .Version }}_{{- if eq .Arch "amd64" }}x86_64{{- else }}{{ .Arch }}{{ end }}'
84 | builds:
85 | - nf
86 | homepage: https://sdns.dev
87 | description: A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy
88 | maintainer: Yasar Alev
89 | license: MIT
90 | bindir: /usr/bin
91 | contents:
92 | - src: "./contrib/linux/sdns.service"
93 | dst: "/lib/systemd/system/sdns.service"
94 | - src: "./contrib/linux/sdns.conf"
95 | dst: "/etc/sdns.conf"
96 | type: config
97 | - dst: "/var/lib/sdns"
98 | type: dir
99 | scripts:
100 | postinstall: "contrib/linux/adduser.sh"
101 | release: 1
102 | formats:
103 | - deb
104 | - rpm
105 | overrides:
106 | deb:
107 | dependencies:
108 | - systemd-sysv
109 | rpm:
110 | dependencies:
111 | - systemd
112 |
113 | brews:
114 | -
115 | repository:
116 | owner: semihalev
117 | name: homebrew-tap
118 | directory: Formula
119 | homepage: https://sdns.dev
120 | description: A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy
121 | dependencies:
122 | - name: go
123 | type: build
124 | commit_author:
125 | name: semihalev
126 | email: semihalev@gmail.com
127 | service: |
128 | run [opt_bin/"sdns", "-config", etc/"sdns.conf"]
129 | keep_alive true
130 | require_root true
131 | error_log_path var/"log/sdns.log"
132 | log_path var/"log/sdns.log"
133 | working_dir opt_prefix
134 | test: |
135 | fork do
136 | exec bin/"sdns", "-config", testpath/"sdns.conf"
137 | end
138 | sleep(2)
139 | assert_predicate testpath/"sdns.conf", :exist?
140 |
--------------------------------------------------------------------------------
/CNAME:
--------------------------------------------------------------------------------
1 | sdns.dev
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | murat3ok@gmail.com.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code\_of\_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to SDNS
2 |
3 | First and foremost, thank you for considering contributing to SDNS! It's people like you that make SDNS such a great tool.
4 |
5 | ## Getting Started
6 |
7 | * Make sure you have a [GitHub account](https://github.com/signup/free).
8 | * Fork the repository on GitHub.
9 | * Decide if you want to work on an existing issue or if you want to propose a new feature or bug fix.
10 |
11 | ## Making Changes
12 |
13 | 1. Create a new branch in your fork from the main branch. Name your branch something descriptive.
14 | 2. Make the changes in your fork.
15 | 3. If you're adding a feature or fixing a bug, please add or modify existing tests if applicable.
16 | 4. Run all tests to ensure your changes don't negatively impact existing code.
17 | 5. Commit your changes to your branch. Keep commit messages clear and concise, stating what you did and why.
18 |
19 | ## Submitting Changes
20 |
21 | 1. Push your changes to your fork on GitHub.
22 | 2. Open a pull request against the main branch of the original repository.
23 | 3. Please ensure your pull request description clearly describes the problem and solution and relates to any issues it addresses.
24 |
25 | ## Additional Resources
26 |
27 | * [Issue tracker](https://github.com/semihalev/sdns/issues)
28 | * [General GitHub documentation](https://docs.github.com/)
29 | * [GitHub pull request documentation](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests)
30 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:alpine AS builder
2 |
3 | COPY . /go/src/github.com/semihalev/sdns/
4 |
5 | WORKDIR /go/src/github.com/semihalev/sdns
6 | RUN apk --no-cache add \
7 | ca-certificates \
8 | gcc \
9 | binutils-gold \
10 | git \
11 | musl-dev
12 |
13 | RUN go build -ldflags "-linkmode external -extldflags -static -s -w" -o /tmp/sdns \
14 | && strip --strip-all /tmp/sdns
15 |
16 | FROM scratch
17 |
18 | COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
19 | COPY --from=builder /tmp/sdns /sdns
20 |
21 | EXPOSE 53/tcp
22 | EXPOSE 53/udp
23 | EXPOSE 853
24 | EXPOSE 8053
25 | EXPOSE 8080
26 |
27 | ENTRYPOINT ["/sdns"]
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2018 semihalev, https://github.com/semihalev
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | GO ?= go
2 | TESTFOLDER := $(shell $(GO) list ./...)
3 | BIN = sdns
4 |
5 | all: generate tidy test build
6 |
7 | .PHONY: test
8 | test:
9 | echo "mode: atomic" > coverage.out
10 | for d in $(TESTFOLDER); do \
11 | $(GO) test -v -covermode=atomic -race -coverprofile=profile.out $$d > profiles.out; \
12 | cat profiles.out; \
13 | if grep -q "^--- FAIL" profiles.out; then \
14 | rm -rf profiles.out; \
15 | rm -rf profile.out; \
16 | exit 1; \
17 | fi; \
18 | if [ -f profile.out ]; then \
19 | cat profile.out | grep -v "mode:" >> coverage.out; \
20 | rm -rf profile.out; \
21 | fi; \
22 | rm -rf profiles.out; \
23 | done
24 |
25 | .PHONY: generate
26 | generate:
27 | $(GO) generate
28 |
29 | .PHONY: tidy
30 | tidy:
31 | $(GO) mod tidy
32 |
33 | .PHONY: build
34 | build:
35 | $(GO) build
36 |
37 | .PHONY: clean
38 | clean:
39 | rm -rf $(BIN)
40 | rm -rf zregister.go
41 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Supported Versions
4 |
5 | Only the following version is currently being supported with security updates:
6 |
7 | | Version | Supported |
8 | | ------- | ------------------ |
9 | | 1.3.x | :white\_check\_mark: |
10 | | < 1.3 | :x: |
11 |
12 | ## Reporting a Vulnerability
13 |
14 | We take security issues seriously. If you discover a security vulnerability in this project, please follow these steps:
15 |
16 | 1. **Open an Issue**: Once you've made sure you're on the latest version and the vulnerability still exists, open an issue on our GitHub repository. Describe the vulnerability in detail, including the steps to reproduce if possible.
17 | 2. **Discussion**: After you report the vulnerability, we'll engage in a discussion with you on the issue to understand it better and evaluate its impact.
18 | 3. **Resolution**: We will address the security issue and release a new version with the necessary patches as soon as possible.
19 |
20 | Your efforts to responsibly disclose your findings are sincerely appreciated and will be acknowledged.
21 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-dinky
2 | title: SDNS
3 | show_downloads: true
4 | plugins:
5 | - jemoji
6 |
--------------------------------------------------------------------------------
/api/README.md:
--------------------------------------------------------------------------------
1 | # HTTP API
2 |
3 | You can manage all blocks with basic HTTP API functions.
4 |
5 | ## Authentication
6 |
7 | API bearer token can be set on sdns config. If the token set, Authorization header should be send on API requests.
8 | ### Example Header
9 | `Authorization: Bearer my_very_long_token`
10 |
11 | ## Actions
12 |
13 | ### GET /api/v1/block/set/:key
14 |
15 | It is used to create a new block.
16 |
17 | __request__
18 |
19 | > curl http://localhost:8080/api/v1/block/set/domain.com
20 |
21 | __response__
22 |
23 | ```json
24 | {"success":true}
25 | ```
26 |
27 | ### GET /api/v1/block/get/:key
28 |
29 | Used to request an existing block
30 |
31 | __request__
32 |
33 | > curl http://localhost:8080/api/v1/block/get/domain.com
34 |
35 | __response__
36 |
37 | ```json
38 | {"success":true}
39 | ```
40 | or
41 |
42 | ```json
43 | {"error":"domain.com not found"}
44 | ```
45 |
46 | ### GET /api/v1/block/exists/:key
47 |
48 | It queries whether it has a block.
49 |
50 | __request__
51 |
52 | > curl http://localhost:8080/api/v1/block/exists/domain.com
53 |
54 | __response__
55 |
56 | ```json
57 | {"success":true}
58 | ```
59 |
60 | ### GET /api/v1/block/remove/:key
61 |
62 | Deletes the block.
63 |
64 | __request__
65 |
66 | > curl http://localhost:8080/api/v1/block/remove/domain.com
67 |
68 | __response__
69 |
70 | ```json
71 | {"success":true}
72 | ```
73 |
74 | ### GET /api/v1/purge/domain/type
75 |
76 | Purge a cached query.
77 |
78 | __request__
79 |
80 | > curl http://localhost:8080/api/v1/purge/example.com/MX
81 |
82 | __response__
83 |
84 | ```json
85 | {"success":true}
86 | ```
87 |
88 | ### GET /metrics
89 |
90 | Export the prometheus metrics.
91 |
--------------------------------------------------------------------------------
/api/api.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "net/http"
7 | "net/http/pprof"
8 | "os"
9 | "strings"
10 | "time"
11 |
12 | "github.com/miekg/dns"
13 | "github.com/prometheus/client_golang/prometheus/promhttp"
14 | "github.com/semihalev/log"
15 | "github.com/semihalev/sdns/config"
16 | "github.com/semihalev/sdns/dnsutil"
17 | "github.com/semihalev/sdns/middleware"
18 | "github.com/semihalev/sdns/middleware/blocklist"
19 | )
20 |
21 | // API type
22 | type API struct {
23 | addr string
24 | bearerToken string
25 | router *Router
26 | blocklist *blocklist.BlockList
27 | }
28 |
29 | var debugpprof bool
30 |
31 | func init() {
32 | _, debugpprof = os.LookupEnv("SDNS_PPROF")
33 | }
34 |
35 | // New return new api
36 | func New(cfg *config.Config) *API {
37 | var bl *blocklist.BlockList
38 |
39 | b := middleware.Get("blocklist")
40 | if b != nil {
41 | bl = b.(*blocklist.BlockList)
42 | }
43 |
44 | a := &API{
45 | addr: cfg.API,
46 | blocklist: bl,
47 | router: NewRouter(),
48 | bearerToken: cfg.BearerToken,
49 | }
50 |
51 | return a
52 | }
53 |
54 | func (a *API) checkToken(ctx *Context) bool {
55 | if a.bearerToken == "" {
56 | return true
57 | }
58 |
59 | authHeader := ctx.Request.Header.Get("Authorization")
60 | if authHeader == "" {
61 | ctx.JSON(http.StatusUnauthorized, Json{"error": "unauthorized"})
62 | return false
63 | }
64 |
65 | tokenSplit := strings.Split(authHeader, " ")
66 | if len(tokenSplit) != 2 {
67 | ctx.JSON(http.StatusUnauthorized, Json{"error": "unauthorized"})
68 | return false
69 | }
70 |
71 | if tokenSplit[0] == "Bearer" && a.bearerToken == tokenSplit[1] {
72 | return true
73 | }
74 |
75 | ctx.JSON(http.StatusUnauthorized, Json{"error": "unauthorized"})
76 | return false
77 | }
78 |
79 | func (a *API) existsBlock(ctx *Context) {
80 | if !a.checkToken(ctx) {
81 | return
82 | }
83 |
84 | ctx.JSON(http.StatusOK, Json{"exists": a.blocklist.Exists(ctx.Param("key"))})
85 | }
86 |
87 | func (a *API) getBlock(ctx *Context) {
88 | if !a.checkToken(ctx) {
89 | return
90 | }
91 |
92 | if ok, _ := a.blocklist.Get(ctx.Param("key")); !ok {
93 | ctx.JSON(http.StatusNotFound, Json{"error": ctx.Param("key") + " not found"})
94 | } else {
95 | ctx.JSON(http.StatusOK, Json{"success": ok})
96 | }
97 | }
98 |
99 | func (a *API) removeBlock(ctx *Context) {
100 | if !a.checkToken(ctx) {
101 | return
102 | }
103 |
104 | ctx.JSON(http.StatusOK, Json{"success": a.blocklist.Remove(ctx.Param("key"))})
105 | }
106 |
107 | func (a *API) setBlock(ctx *Context) {
108 | if !a.checkToken(ctx) {
109 | return
110 | }
111 |
112 | ctx.JSON(http.StatusOK, Json{"success": a.blocklist.Set(ctx.Param("key"))})
113 | }
114 |
115 | func (a *API) metrics(ctx *Context) {
116 | if !a.checkToken(ctx) {
117 | return
118 | }
119 |
120 | promhttp.Handler().ServeHTTP(ctx.Writer, ctx.Request)
121 | }
122 |
123 | func (a *API) purge(ctx *Context) {
124 | if !a.checkToken(ctx) {
125 | return
126 | }
127 |
128 | qtype := strings.ToUpper(ctx.Param("qtype"))
129 | qname := dns.Fqdn(ctx.Param("qname"))
130 |
131 | bqname := base64.StdEncoding.EncodeToString([]byte(qtype + ":" + qname))
132 |
133 | req := new(dns.Msg)
134 | req.SetQuestion(dns.Fqdn(bqname), dns.TypeNULL)
135 | req.Question[0].Qclass = dns.ClassCHAOS
136 |
137 | _, _ = dnsutil.ExchangeInternal(context.Background(), req)
138 |
139 | ctx.JSON(http.StatusOK, Json{"success": true})
140 | }
141 |
142 | // Run API server
143 | func (a *API) Run(ctx context.Context) {
144 | if a.addr == "" {
145 | return
146 | }
147 |
148 | if debugpprof {
149 | profiler := a.router.Group("/debug")
150 | {
151 | profiler.GET("/", func(ctx *Context) {
152 | http.Redirect(ctx.Writer, ctx.Request, profiler.path+"/pprof/", http.StatusMovedPermanently)
153 | })
154 | profiler.GET("/pprof/", func(ctx *Context) { pprof.Index(ctx.Writer, ctx.Request) })
155 | profiler.GET("/pprof/*", func(ctx *Context) { pprof.Index(ctx.Writer, ctx.Request) })
156 | profiler.GET("/pprof/cmdline", func(ctx *Context) { pprof.Cmdline(ctx.Writer, ctx.Request) })
157 | profiler.GET("/pprof/profile", func(ctx *Context) { pprof.Profile(ctx.Writer, ctx.Request) })
158 | profiler.GET("/pprof/symbol", func(ctx *Context) { pprof.Symbol(ctx.Writer, ctx.Request) })
159 | profiler.GET("/pprof/trace", func(ctx *Context) { pprof.Trace(ctx.Writer, ctx.Request) })
160 | }
161 | }
162 |
163 | if a.blocklist != nil {
164 | block := a.router.Group("/api/v1/block")
165 | {
166 | block.GET("/exists/:key", a.existsBlock)
167 | block.GET("/get/:key", a.getBlock)
168 | block.GET("/remove/:key", a.removeBlock)
169 | block.GET("/set/:key", a.setBlock)
170 | }
171 | }
172 |
173 | a.router.GET("/api/v1/purge/:qname/:qtype", a.purge)
174 |
175 | a.router.GET("/metrics", a.metrics)
176 |
177 | srv := &http.Server{
178 | Addr: a.addr,
179 | Handler: a.router,
180 | }
181 |
182 | go func() {
183 | if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
184 | log.Error("Start API server failed", "error", err.Error())
185 | }
186 | }()
187 |
188 | log.Info("API server listening...", "addr", a.addr)
189 | if a.bearerToken != "" {
190 | log.Info("API authorization bearer token", "token", a.bearerToken)
191 | }
192 |
193 | go func() {
194 | <-ctx.Done()
195 |
196 | log.Info("API server stopping...", "addr", a.addr)
197 |
198 | apiCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
199 | defer cancel()
200 |
201 | if err := srv.Shutdown(apiCtx); err != nil {
202 | log.Error("Shutdown API server failed:", "error", err.Error())
203 | }
204 | }()
205 | }
206 |
--------------------------------------------------------------------------------
/api/context.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import (
4 | "encoding/json"
5 | "net/http"
6 | )
7 |
8 | type (
9 | Context struct {
10 | Request *http.Request
11 | Writer http.ResponseWriter
12 | Handler Handler
13 | Params *Params
14 | }
15 |
16 | Handler func(ctx *Context)
17 |
18 | Param struct {
19 | Key string
20 | Value string
21 | }
22 |
23 | Params []Param
24 |
25 | Json map[string]any
26 | )
27 |
28 | func (ctx *Context) JSON(code int, data any) {
29 | buf, err := json.Marshal(data)
30 | if err != nil {
31 | ctx.Writer.WriteHeader(http.StatusInternalServerError)
32 | return
33 | }
34 |
35 | ctx.Writer.WriteHeader(code)
36 | ctx.Writer.Header().Set("Content-Type", "application/json")
37 |
38 | _, _ = ctx.Writer.Write(buf)
39 | }
40 |
41 | func (ctx *Context) Param(key string) string {
42 | params := *ctx.Params
43 | for _, p := range params {
44 | if p.Key == key {
45 | return p.Value
46 | }
47 | }
48 |
49 | return ""
50 | }
51 |
52 | func (ctx *Context) addParameter(key, value string) {
53 | i := len(*ctx.Params)
54 | *ctx.Params = (*ctx.Params)[:i+1]
55 | (*ctx.Params)[i] = Param{
56 | Key: key,
57 | Value: value,
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/api/group.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | type Group struct {
4 | parent *Router
5 | path string
6 | }
7 |
8 | func (g *Group) Handle(method, path string, handle Handler) {
9 | g.parent.Handle(method, g.path+path, handle)
10 | }
11 |
12 | func (g *Group) GET(path string, handle Handler) {
13 | g.parent.GET(g.path+path, handle)
14 | }
15 |
16 | func (g *Group) POST(path string, handle Handler) {
17 | g.parent.POST(g.path+path, handle)
18 | }
19 |
--------------------------------------------------------------------------------
/api/router.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "os"
7 | "runtime/debug"
8 | "sync"
9 |
10 | "github.com/semihalev/log"
11 | )
12 |
13 | type Router struct {
14 | get Tree
15 | post Tree
16 | delete Tree
17 | put Tree
18 | patch Tree
19 | head Tree
20 | connect Tree
21 | trace Tree
22 | options Tree
23 |
24 | ctxPool sync.Pool
25 | }
26 |
27 | var extraHeaders = map[string]string{
28 | "Server": "sdns",
29 | "Access-Control-Allow-Origin": "*",
30 | "Access-Control-Allow-Methods": "GET,POST",
31 | "Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
32 | "Pragma": "no-cache",
33 | }
34 |
35 | func NewRouter() *Router {
36 | r := &Router{}
37 |
38 | r.ctxPool.New = func() any {
39 | params := make(Params, 0, 20)
40 | return &Context{Params: ¶ms}
41 | }
42 |
43 | return r
44 | }
45 |
46 | func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
47 | defer func() {
48 | if r := recover(); r != nil {
49 | http.Error(w, "Internal Server Error", http.StatusInternalServerError)
50 | log.Error("Recovered in API", "recover", r)
51 |
52 | _, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n\n", r))
53 | debug.PrintStack()
54 | }
55 | }()
56 |
57 | for k, v := range extraHeaders {
58 | w.Header().Set(k, v)
59 | }
60 |
61 | if len(r.URL.Path) > 2048 {
62 | http.Error(w, "Bad Request", http.StatusBadRequest)
63 | return
64 | }
65 |
66 | ctx := rt.getContext(w, r)
67 |
68 | if r.Method[0] == 'G' {
69 | rt.get.Lookup(ctx)
70 | } else {
71 | tree := rt.selectTree(r.Method)
72 | if tree != nil {
73 | tree.Lookup(ctx)
74 | }
75 | }
76 |
77 | if ctx.Handler == nil {
78 | http.NotFound(w, r)
79 | rt.putContext(ctx)
80 | return
81 | }
82 |
83 | ctx.Handler(ctx)
84 |
85 | rt.putContext(ctx)
86 | }
87 |
88 | func (rt *Router) Handle(method, path string, handle Handler) {
89 | tree := rt.selectTree(method)
90 | tree.Add(path, handle)
91 | }
92 |
93 | func (rt *Router) GET(path string, handle Handler) {
94 | rt.get.Add(path, handle)
95 | }
96 |
97 | func (rt *Router) POST(path string, handle Handler) {
98 | rt.post.Add(path, handle)
99 | }
100 |
101 | func (rt *Router) Group(rp string) *Group {
102 | return &Group{parent: rt, path: rp}
103 | }
104 |
105 | func (rt *Router) getContext(w http.ResponseWriter, r *http.Request) *Context {
106 | ctx := rt.ctxPool.Get().(*Context)
107 |
108 | ctx.Request = r
109 | ctx.Writer = w
110 | ctx.Handler = nil
111 | (*ctx.Params) = (*ctx.Params)[:0]
112 |
113 | return ctx
114 | }
115 |
116 | func (rt *Router) putContext(ctx *Context) {
117 | rt.ctxPool.Put(ctx)
118 | }
119 |
120 | func (rt *Router) selectTree(method string) *Tree {
121 | switch method {
122 | case http.MethodGet:
123 | return &rt.get
124 | case http.MethodPost:
125 | return &rt.post
126 | case http.MethodDelete:
127 | return &rt.delete
128 | case http.MethodPut:
129 | return &rt.put
130 | case http.MethodPatch:
131 | return &rt.patch
132 | case http.MethodHead:
133 | return &rt.head
134 | case http.MethodConnect:
135 | return &rt.connect
136 | case http.MethodTrace:
137 | return &rt.trace
138 | case http.MethodOptions:
139 | return &rt.options
140 | default:
141 | return nil
142 | }
143 | }
144 |
--------------------------------------------------------------------------------
/authcache/authserver.go:
--------------------------------------------------------------------------------
1 | package authcache
2 |
3 | import (
4 | "sort"
5 | "sync"
6 | "sync/atomic"
7 | "time"
8 | )
9 |
10 | // AuthServer type
11 | type AuthServer struct {
12 | // place atomic members at the start to fix alignment for ARM32
13 | Rtt int64
14 | Count int64
15 | Addr string
16 | Version Version
17 | }
18 |
19 | // Version type
20 | type Version byte
21 |
22 | const (
23 | // IPv4 mode
24 | IPv4 Version = 0x1
25 |
26 | // IPv6 mode
27 | IPv6 Version = 0x2
28 | )
29 |
30 | // NewAuthServer return a new server
31 | func NewAuthServer(addr string, version Version) *AuthServer {
32 | return &AuthServer{
33 | Addr: addr,
34 | Version: version,
35 | }
36 | }
37 |
38 | func (v Version) String() string {
39 | switch v {
40 | case IPv4:
41 | return "IPv4"
42 | case IPv6:
43 | return "IPv6"
44 | default:
45 | return "Unknown"
46 | }
47 | }
48 |
49 | func (a *AuthServer) String() string {
50 | count := atomic.LoadInt64(&a.Count)
51 | rn := atomic.LoadInt64(&a.Rtt)
52 |
53 | if count == 0 {
54 | count = 1
55 | }
56 |
57 | var health string
58 | if rn >= int64(time.Second) {
59 | health = "POOR"
60 | } else if rn > 0 {
61 | health = "GOOD"
62 | } else {
63 | health = "UNKNOWN"
64 | }
65 |
66 | rtt := (time.Duration(rn) / time.Duration(count)).Round(time.Millisecond)
67 |
68 | return a.Version.String() + ":" + a.Addr + " rtt:" + rtt.String() + " health:[" + health + "]"
69 | }
70 |
71 | // AuthServers type
72 | type AuthServers struct {
73 | sync.RWMutex
74 | // place atomic members at the start to fix alignment for ARM32
75 | Called uint64
76 | ErrorCount uint32
77 |
78 | Zone string
79 |
80 | List []*AuthServer
81 | Nss []string
82 |
83 | CheckingDisable bool
84 | Checked bool
85 | }
86 |
87 | // Sort sort servers by rtt
88 | func Sort(serversList []*AuthServer, called uint64) {
89 | for _, s := range serversList {
90 | //clear stats and re-start again
91 | if called%1e3 == 0 {
92 | atomic.StoreInt64(&s.Rtt, 0)
93 | atomic.StoreInt64(&s.Count, 0)
94 |
95 | continue
96 | }
97 |
98 | rtt := atomic.LoadInt64(&s.Rtt)
99 | count := atomic.LoadInt64(&s.Count)
100 |
101 | if count > 0 {
102 | // average rtt
103 | atomic.StoreInt64(&s.Rtt, rtt/count)
104 | atomic.StoreInt64(&s.Count, 1)
105 | }
106 | }
107 | sort.Slice(serversList, func(i, j int) bool {
108 | return atomic.LoadInt64(&serversList[i].Rtt) < atomic.LoadInt64(&serversList[j].Rtt)
109 | })
110 | }
111 |
--------------------------------------------------------------------------------
/authcache/authserver_test.go:
--------------------------------------------------------------------------------
1 | package authcache
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "testing"
7 | "time"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func Test_TrySort(t *testing.T) {
13 | s := &AuthServers{
14 | List: []*AuthServer{},
15 | }
16 |
17 | for i := 0; i < 10; i++ {
18 | s.List = append(s.List, NewAuthServer(fmt.Sprintf("0.0.0.%d:53", i), IPv4))
19 | s.List = append(s.List, NewAuthServer(fmt.Sprintf("[::%d]:53", i), IPv6))
20 | }
21 |
22 | r := rand.New(rand.NewSource(time.Now().UnixNano()))
23 | for i := 0; i < 2000; i++ {
24 | for j := range s.List {
25 | s.List[j].Count++
26 | s.List[j].Rtt += (time.Duration(r.Intn(2000-0)+0) * time.Millisecond).Nanoseconds()
27 | Sort(s.List, uint64(i))
28 | }
29 | }
30 |
31 | assert.Equal(t, int64(1), s.List[0].Count)
32 | }
33 |
--------------------------------------------------------------------------------
/authcache/ns_cache.go:
--------------------------------------------------------------------------------
1 | package authcache
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/miekg/dns"
7 | "github.com/semihalev/sdns/cache"
8 | )
9 |
10 | // NS represents a cache entry
11 | type NS struct {
12 | Servers *AuthServers
13 | DSRR []dns.RR
14 | TTL time.Duration
15 |
16 | ut time.Time
17 | }
18 |
19 | // NSCache type
20 | type NSCache struct {
21 | cache *cache.Cache
22 |
23 | now func() time.Time
24 | }
25 |
26 | // NewNSCache return new cache
27 | func NewNSCache() *NSCache {
28 | n := &NSCache{
29 | cache: cache.New(defaultCap),
30 | now: time.Now,
31 | }
32 |
33 | return n
34 | }
35 |
36 | // Get returns the entry for a key or an error
37 | func (n *NSCache) Get(key uint64) (*NS, error) {
38 | el, ok := n.cache.Get(key)
39 |
40 | if !ok {
41 | return nil, cache.ErrCacheNotFound
42 | }
43 |
44 | ns := el.(*NS)
45 |
46 | elapsed := n.now().UTC().Sub(ns.ut)
47 |
48 | if elapsed >= ns.TTL {
49 | return nil, cache.ErrCacheExpired
50 | }
51 |
52 | return ns, nil
53 | }
54 |
55 | // Set sets a keys value to a NS
56 | func (n *NSCache) Set(key uint64, dsRR []dns.RR, servers *AuthServers, ttl time.Duration) {
57 | if ttl > maximumTTL {
58 | ttl = maximumTTL
59 | } else if ttl < minimumTTL {
60 | ttl = minimumTTL
61 | }
62 |
63 | n.cache.Add(key, &NS{
64 | Servers: servers,
65 | DSRR: dsRR,
66 | TTL: ttl,
67 | ut: n.now().UTC().Round(time.Second),
68 | })
69 | }
70 |
71 | // Remove remove a cache
72 | func (n *NSCache) Remove(key uint64) {
73 | n.cache.Remove(key)
74 | }
75 |
76 | const (
77 | maximumTTL = 12 * time.Hour
78 | minimumTTL = 1 * time.Hour
79 | defaultCap = 1024 * 256
80 | )
81 |
--------------------------------------------------------------------------------
/authcache/ns_cache_test.go:
--------------------------------------------------------------------------------
1 | package authcache
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/cache"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func Test_NSCache(t *testing.T) {
13 | nscache := NewNSCache()
14 |
15 | m := new(dns.Msg)
16 | m.SetQuestion(dns.Fqdn("example.com."), dns.TypeA)
17 | key := cache.Hash(m.Question[0])
18 |
19 | a := NewAuthServer("0.0.0.0:53", IPv4)
20 | _ = a.String()
21 |
22 | servers := &AuthServers{List: []*AuthServer{a}}
23 |
24 | _, err := nscache.Get(key)
25 | assert.Error(t, err)
26 | assert.Equal(t, err.Error(), "cache not found")
27 |
28 | nscache.Set(key, nil, servers, time.Hour)
29 |
30 | _, err = nscache.Get(key)
31 | assert.NoError(t, err)
32 |
33 | nscache.now = func() time.Time {
34 | return time.Now().Add(30 * time.Minute)
35 | }
36 | _, err = nscache.Get(key)
37 | assert.NoError(t, err)
38 |
39 | nscache.now = func() time.Time {
40 | return time.Now().Add(2 * time.Hour)
41 | }
42 | _, err = nscache.Get(key)
43 | assert.Error(t, err)
44 | assert.Equal(t, err.Error(), "cache expired")
45 |
46 | _, err = nscache.Get(key)
47 | assert.Error(t, err)
48 |
49 | nscache.Remove(key)
50 | }
51 |
--------------------------------------------------------------------------------
/cache/cache.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import (
7 | "errors"
8 | )
9 |
10 | var (
11 | // ErrCacheNotFound error
12 | ErrCacheNotFound = errors.New("cache not found")
13 | // ErrCacheExpired error
14 | ErrCacheExpired = errors.New("cache expired")
15 | )
16 |
17 | // Cache is cache.
18 | type Cache struct {
19 | shards [shardSize]*shard
20 | }
21 |
22 | // New returns a new cache.
23 | func New(size int) *Cache {
24 | ssize := size / shardSize
25 | if ssize < 4 {
26 | ssize = 4
27 | }
28 |
29 | c := &Cache{}
30 |
31 | // Initialize all the shards
32 | for i := 0; i < shardSize; i++ {
33 | c.shards[i] = newShard(ssize)
34 | }
35 | return c
36 | }
37 |
38 | // Get looks up element index under key.
39 | func (c *Cache) Get(key uint64) (interface{}, bool) {
40 | shard := key & (shardSize - 1)
41 | return c.shards[shard].Get(key)
42 | }
43 |
44 | // Add adds a new element to the cache. If the element already exists it is overwritten.
45 | func (c *Cache) Add(key uint64, el interface{}) {
46 | shard := key & (shardSize - 1)
47 | c.shards[shard].Add(key, el)
48 | }
49 |
50 | // Remove removes the element indexed with key.
51 | func (c *Cache) Remove(key uint64) {
52 | shard := key & (shardSize - 1)
53 | c.shards[shard].Remove(key)
54 | }
55 |
56 | // Len returns the number of elements in the cache.
57 | func (c *Cache) Len() int {
58 | l := 0
59 | for _, s := range c.shards {
60 | l += s.Len()
61 | }
62 | return l
63 | }
64 |
--------------------------------------------------------------------------------
/cache/cache_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import (
7 | "fmt"
8 | "testing"
9 | )
10 |
11 | func TestCacheAddAndGet(t *testing.T) {
12 | c := New(4)
13 | c.Add(1, 1)
14 |
15 | if _, found := c.Get(1); !found {
16 | t.Fatal("Failed to find inserted record")
17 | }
18 | }
19 |
20 | func TestCacheLen(t *testing.T) {
21 | c := New(4)
22 |
23 | c.Add(1, 1)
24 | if l := c.Len(); l != 1 {
25 | t.Fatalf("Cache size should %d, got %d", 1, l)
26 | }
27 |
28 | c.Add(1, 1)
29 | if l := c.Len(); l != 1 {
30 | t.Fatalf("Cache size should %d, got %d", 1, l)
31 | }
32 |
33 | c.Add(2, 2)
34 | if l := c.Len(); l != 2 {
35 | t.Fatalf("Cache size should %d, got %d", 2, l)
36 | }
37 | }
38 |
39 | func TestCacheRemove(t *testing.T) {
40 | c := New(4)
41 |
42 | c.Add(1, 1)
43 | if l := c.Len(); l != 1 {
44 | t.Fatalf("Cache size should %d, got %d", 1, l)
45 | }
46 |
47 | c.Remove(1)
48 | if l := c.Len(); l != 0 {
49 | t.Fatalf("Cache size should %d, got %d", 1, l)
50 | }
51 | }
52 |
53 | func BenchmarkCacheGet(b *testing.B) {
54 | const items = 1 << 16
55 | c := New(12 * items)
56 | v := []byte("xyza")
57 | for i := 0; i < items; i++ {
58 | c.Add(uint64(i), v)
59 | }
60 |
61 | b.ReportAllocs()
62 | b.SetBytes(items)
63 | b.RunParallel(func(pb *testing.PB) {
64 | for pb.Next() {
65 | for i := 0; i < items; i++ {
66 | b, _ := c.Get(uint64(i))
67 | if string(b.([]byte)) != string(v) {
68 | panic(fmt.Errorf("BUG: invalid value obtained; got %q; want %q", b, v))
69 | }
70 | }
71 | }
72 | })
73 | }
74 |
75 | func BenchmarkCacheSet(b *testing.B) {
76 | const items = 1 << 16
77 | c := New(12 * items)
78 | b.ReportAllocs()
79 | b.SetBytes(items)
80 | b.RunParallel(func(pb *testing.PB) {
81 | v := []byte("xyza")
82 | for pb.Next() {
83 | for i := 0; i < items; i++ {
84 | c.Add(uint64(i), v)
85 | }
86 | }
87 | })
88 | }
89 |
90 | func BenchmarkCacheSetGet(b *testing.B) {
91 | const items = 1 << 16
92 | c := New(12 * items)
93 | b.ReportAllocs()
94 | b.SetBytes(2 * items)
95 | b.RunParallel(func(pb *testing.PB) {
96 | v := []byte("xyza")
97 | for pb.Next() {
98 | for i := 0; i < items; i++ {
99 | c.Add(uint64(i), v)
100 | }
101 | for i := 0; i < items; i++ {
102 | b, _ := c.Get(uint64(i))
103 | if string(b.([]byte)) != string(v) {
104 | panic(fmt.Errorf("BUG: invalid value obtained; got %q; want %q", b, v))
105 | }
106 | }
107 | }
108 | })
109 | }
110 |
--------------------------------------------------------------------------------
/cache/hash.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import (
7 | "bytes"
8 | "hash"
9 | "sync"
10 |
11 | "github.com/cespare/xxhash/v2"
12 | "github.com/miekg/dns"
13 | )
14 |
15 | // Hash returns a hash for cache
16 | func Hash(q dns.Question, cd ...bool) uint64 {
17 | h := AcquireHash()
18 | defer ReleaseHash(h)
19 |
20 | buf := AcquireBuf()
21 | defer ReleaseBuf(buf)
22 |
23 | buf.Write([]byte{uint8(q.Qclass >> 8), uint8(q.Qclass & 0xff)})
24 | buf.Write([]byte{uint8(q.Qtype >> 8), uint8(q.Qtype & 0xff)})
25 |
26 | if len(cd) > 0 && cd[0] {
27 | buf.WriteByte(1)
28 | }
29 |
30 | for i := range q.Name {
31 | c := q.Name[i]
32 | if c >= 'A' && c <= 'Z' {
33 | c += 'a' - 'A'
34 | }
35 | buf.WriteByte(c)
36 | }
37 |
38 | _, _ = h.Write(buf.Bytes())
39 |
40 | return h.Sum64()
41 | }
42 |
43 | var bufferPool sync.Pool
44 | var hashPool sync.Pool
45 |
46 | // AcquireHash returns a hash from pool
47 | func AcquireHash() hash.Hash64 {
48 | v := hashPool.Get()
49 | if v == nil {
50 | return xxhash.New()
51 | }
52 | return v.(hash.Hash64)
53 | }
54 |
55 | // ReleaseHash returns hash to pool
56 | func ReleaseHash(h hash.Hash64) {
57 | h.Reset()
58 | hashPool.Put(h)
59 | }
60 |
61 | // AcquireBuf returns a buf from pool
62 | func AcquireBuf() *bytes.Buffer {
63 | v := bufferPool.Get()
64 | if v == nil {
65 | return &bytes.Buffer{}
66 | }
67 | return v.(*bytes.Buffer)
68 | }
69 |
70 | // ReleaseBuf returns buf to pool
71 | func ReleaseBuf(buf *bytes.Buffer) {
72 | buf.Reset()
73 | bufferPool.Put(buf)
74 | }
75 |
--------------------------------------------------------------------------------
/cache/hash_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import (
7 | "testing"
8 |
9 | "github.com/miekg/dns"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func Test_Hash(t *testing.T) {
14 |
15 | q := dns.Question{Name: "goOgle.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
16 |
17 | asset := Hash(q)
18 |
19 | assert.Equal(t, uint64(13726664550454464700), asset)
20 |
21 | asset = Hash(q, true)
22 |
23 | assert.Equal(t, uint64(8882204296994448420), asset)
24 | }
25 |
26 | func Benchmark_Hash(b *testing.B) {
27 | q := dns.Question{Name: "goOgle.com.", Qtype: dns.TypeA, Qclass: dns.ClassANY}
28 |
29 | b.ResetTimer()
30 |
31 | for i := 0; i < b.N; i++ {
32 | Hash(q)
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/cache/shard.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import "sync"
7 |
8 | // shard is a cache with random eviction.
9 | type shard struct {
10 | items map[uint64]interface{}
11 | size int
12 |
13 | sync.RWMutex
14 | }
15 |
16 | // newShard returns a new shard with size.
17 | func newShard(size int) *shard { return &shard{items: make(map[uint64]interface{}), size: size} }
18 |
19 | // Add adds element indexed by key into the cache. Any existing element is overwritten
20 | func (s *shard) Add(key uint64, el interface{}) {
21 | l := s.Len()
22 | if l+1 > s.size {
23 | s.Evict()
24 | }
25 |
26 | s.Lock()
27 | s.items[key] = el
28 | s.Unlock()
29 | }
30 |
31 | // Remove removes the element indexed by key from the cache.
32 | func (s *shard) Remove(key uint64) {
33 | s.Lock()
34 | delete(s.items, key)
35 | s.Unlock()
36 | }
37 |
38 | // Evict removes a random element from the cache.
39 | func (s *shard) Evict() {
40 | hasKey := false
41 | var key uint64
42 |
43 | s.RLock()
44 | for k := range s.items {
45 | key = k
46 | hasKey = true
47 | break
48 | }
49 | s.RUnlock()
50 |
51 | if !hasKey {
52 | // empty cache
53 | return
54 | }
55 |
56 | // If this item is gone between the RUnlock and Lock race we don't care.
57 | s.Remove(key)
58 | }
59 |
60 | // Get looks up the element indexed under key.
61 | func (s *shard) Get(key uint64) (interface{}, bool) {
62 | s.RLock()
63 | el, found := s.items[key]
64 | s.RUnlock()
65 | return el, found
66 | }
67 |
68 | // Len returns the current length of the cache.
69 | func (s *shard) Len() int {
70 | s.RLock()
71 | l := len(s.items)
72 | s.RUnlock()
73 | return l
74 | }
75 |
76 | const shardSize = 256
77 |
--------------------------------------------------------------------------------
/cache/shard_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import "testing"
7 |
8 | func TestShardAddAndGet(t *testing.T) {
9 | s := newShard(4)
10 | s.Add(1, 1)
11 |
12 | if _, found := s.Get(1); !found {
13 | t.Fatal("Failed to find inserted record")
14 | }
15 | }
16 |
17 | func TestShardLen(t *testing.T) {
18 | s := newShard(4)
19 |
20 | s.Add(1, 1)
21 | if l := s.Len(); l != 1 {
22 | t.Fatalf("Shard size should %d, got %d", 1, l)
23 | }
24 |
25 | s.Add(1, 1)
26 | if l := s.Len(); l != 1 {
27 | t.Fatalf("Shard size should %d, got %d", 1, l)
28 | }
29 |
30 | s.Add(2, 2)
31 | if l := s.Len(); l != 2 {
32 | t.Fatalf("Shard size should %d, got %d", 2, l)
33 | }
34 | }
35 |
36 | func TestShardEvict(t *testing.T) {
37 | s := newShard(1)
38 | s.Evict() // empty cache
39 |
40 | s.Add(1, 1)
41 | s.Add(2, 2)
42 | // 1 should be gone
43 |
44 | if _, found := s.Get(1); found {
45 | t.Fatal("Found item that should have been evicted")
46 | }
47 | }
48 |
49 | func TestShardLenEvict(t *testing.T) {
50 | s := newShard(4)
51 | s.Add(1, 1)
52 | s.Add(2, 1)
53 | s.Add(3, 1)
54 | s.Add(4, 1)
55 |
56 | if l := s.Len(); l != 4 {
57 | t.Fatalf("Shard size should %d, got %d", 4, l)
58 | }
59 |
60 | // This should evict one element
61 | s.Add(5, 1)
62 | if l := s.Len(); l != 4 {
63 | t.Fatalf("Shard size should %d, got %d", 4, l)
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "os"
5 | "testing"
6 |
7 | "github.com/semihalev/log"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func Test_config(t *testing.T) {
12 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
13 |
14 | const configFile = "example.conf"
15 |
16 | err := generateConfig(configFile)
17 | assert.NoError(t, err)
18 |
19 | _, err = Load(configFile, "0.0.0")
20 | assert.NoError(t, err)
21 |
22 | os.Remove(configFile)
23 | os.Remove("db")
24 | }
25 |
26 | func Test_configError(t *testing.T) {
27 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
28 |
29 | const configFile = ""
30 |
31 | _, err := Load(configFile, "0.0.0")
32 | assert.Error(t, err)
33 |
34 | os.Remove("db")
35 | }
36 |
--------------------------------------------------------------------------------
/contrib/linux/adduser.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | groupadd --system sdns
4 | useradd --system -d /var/lib/sdns -s /usr/sbin/nologin -g sdns sdns
5 | mkdir -p /var/lib/sdns
6 | chown sdns:sdns /var/lib/sdns
7 |
--------------------------------------------------------------------------------
/contrib/linux/sdns.service:
--------------------------------------------------------------------------------
1 | [Unit]
2 | Description=SDNS - Fast DNS Resolver
3 | ConditionPathExists=/var/lib/sdns
4 | Wants=network.target
5 | After=network.target
6 |
7 | [Service]
8 | Type=simple
9 | User=sdns
10 | Group=sdns
11 | LimitNOFILE=131072
12 | Restart=on-failure
13 | RestartSec=10
14 | Environment="SDNS_DEBUGNS=false"
15 | Environment="SDNS_PPROF=false"
16 | WorkingDirectory=/var/lib/sdns
17 | ExecStart=/usr/bin/sdns --config=/etc/sdns.conf
18 | PermissionsStartOnly=true
19 | StandardOutput=syslog
20 | StandardError=journal
21 | SyslogIdentifier=sdns
22 | AmbientCapabilities=CAP_NET_BIND_SERVICE
23 |
24 | [Install]
25 | WantedBy=multi-user.target
26 |
--------------------------------------------------------------------------------
/dnsutil/ttl.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package dnsutil
5 |
6 | import (
7 | "time"
8 |
9 | "github.com/miekg/dns"
10 | "github.com/semihalev/sdns/response"
11 | )
12 |
13 | // MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message.
14 | func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration {
15 | if mt != response.NoError && mt != response.NameError && mt != response.NoData {
16 | return MinimalDefaultTTL
17 | }
18 |
19 | // No records or OPT is the only record, return a short ttl as a fail safe.
20 | if len(m.Answer)+len(m.Ns) == 0 &&
21 | (len(m.Extra) == 0 || (len(m.Extra) == 1 && m.Extra[0].Header().Rrtype == dns.TypeOPT)) {
22 | return MinimalDefaultTTL
23 | }
24 |
25 | minTTL := MaximumDefaulTTL
26 | for _, r := range m.Answer {
27 | if r.Header().Ttl < uint32(minTTL.Seconds()) {
28 | minTTL = time.Duration(r.Header().Ttl) * time.Second
29 | }
30 | }
31 | for _, r := range m.Ns {
32 | if r.Header().Ttl < uint32(minTTL.Seconds()) {
33 | minTTL = time.Duration(r.Header().Ttl) * time.Second
34 | }
35 | }
36 |
37 | for _, r := range m.Extra {
38 | if r.Header().Rrtype == dns.TypeOPT {
39 | // OPT records use TTL field for extended rcode and flags
40 | continue
41 | }
42 | if r.Header().Ttl < uint32(minTTL.Seconds()) {
43 | minTTL = time.Duration(r.Header().Ttl) * time.Second
44 | }
45 | }
46 | return minTTL
47 | }
48 |
49 | const (
50 | // MinimalDefaultTTL is the absolute lowest TTL.
51 | MinimalDefaultTTL = 5 * time.Second
52 | // MaximumDefaulTTL is the maximum TTL.
53 | MaximumDefaulTTL = 24 * time.Hour
54 | )
55 |
--------------------------------------------------------------------------------
/dnsutil/ttl_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package dnsutil
5 |
6 | import (
7 | "testing"
8 | "time"
9 |
10 | "github.com/miekg/dns"
11 | "github.com/semihalev/sdns/response"
12 | )
13 |
14 | // See https://github.com/kubernetes/dns/issues/121, add some specific tests for those use cases.
15 |
16 | func makeRR(data string) dns.RR {
17 | r, _ := dns.NewRR(data)
18 |
19 | return r
20 | }
21 |
22 | func TestMinimalTTL(t *testing.T) {
23 | utc := time.Now().UTC()
24 |
25 | mt, _ := response.Typify(nil, utc)
26 | if mt != response.OtherError {
27 | t.Fatalf("Expected type to be response.NoData, got %s", mt)
28 | }
29 |
30 | dur := MinimalTTL(nil, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
31 | if dur != time.Duration(MinimalDefaultTTL) {
32 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur)
33 | }
34 |
35 | m := new(dns.Msg)
36 | m.SetQuestion("z.alm.im.", dns.TypeA)
37 | m.SetEdns0(dns.DefaultMsgSize, true)
38 |
39 | mt, _ = response.Typify(m, utc)
40 | if mt != response.NoError {
41 | t.Fatalf("Expected type to be response.NoData, got %s", mt)
42 | }
43 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
44 | if dur != time.Duration(MinimalDefaultTTL) {
45 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur)
46 | }
47 |
48 | m.Ns = []dns.RR{
49 | makeRR("alm.im. 1800 IN SOA ivan.ns.cloudflare.com. dns.cloudflare.com. 2025042470 10000 2400 604800 3600"),
50 | }
51 |
52 | mt, _ = response.Typify(m, utc)
53 | if mt != response.NoData {
54 | t.Fatalf("Expected type to be response.NoData, got %s", mt)
55 | }
56 |
57 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
58 | if dur != time.Duration(1800*time.Second) {
59 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur)
60 | }
61 |
62 | m.Extra = []dns.RR{
63 | makeRR("alm.im. 1200 IN A 127.0.0.1"),
64 | }
65 |
66 | m.Rcode = dns.RcodeNameError
67 | mt, _ = response.Typify(m, utc)
68 | if mt != response.NameError {
69 | t.Fatalf("Expected type to be response.NameError, got %s", mt)
70 | }
71 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
72 | if dur != time.Duration(1200*time.Second) {
73 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur)
74 | }
75 |
76 | m.Answer = []dns.RR{
77 | makeRR("z.alm.im. 600 IN A 127.0.0.1"),
78 | }
79 | dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
80 | if dur != time.Duration(600*time.Second) {
81 | t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur)
82 | }
83 | }
84 |
85 | func BenchmarkMinimalTTL(b *testing.B) {
86 | m := new(dns.Msg)
87 | m.SetQuestion("example.org.", dns.TypeA)
88 | m.Ns = []dns.RR{
89 | makeRR("a.example.org. 1800 IN A 127.0.0.53"),
90 | makeRR("b.example.org. 1900 IN A 127.0.0.53"),
91 | makeRR("c.example.org. 1600 IN A 127.0.0.53"),
92 | makeRR("d.example.org. 1100 IN A 127.0.0.53"),
93 | makeRR("e.example.org. 1000 IN A 127.0.0.53"),
94 | }
95 | m.Extra = []dns.RR{
96 | makeRR("a.example.org. 1800 IN A 127.0.0.53"),
97 | makeRR("b.example.org. 1600 IN A 127.0.0.53"),
98 | makeRR("c.example.org. 1400 IN A 127.0.0.53"),
99 | makeRR("d.example.org. 1200 IN A 127.0.0.53"),
100 | makeRR("e.example.org. 1100 IN A 127.0.0.53"),
101 | }
102 |
103 | utc := time.Now().UTC()
104 | mt, _ := response.Typify(m, utc)
105 |
106 | b.ResetTimer()
107 | for i := 0; i < b.N; i++ {
108 | dur := MinimalTTL(m, mt)
109 | if dur != 1000*time.Second {
110 | b.Fatalf("Wrong MinimalTTL %d, expected %d", dur, 1000*time.Second)
111 | }
112 | }
113 | }
114 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy.
3 | https://sdns.dev for more information.
4 | */
5 | package main // import "github.com/semihalev/sdns"
6 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.8"
2 |
3 | services:
4 | sdns:
5 | image: c1982/sdns
6 | container_name: sdns
7 | restart: unless-stopped
8 | ports:
9 | - 127.0.0.1:53:53
10 | - 127.0.0.1:53:53/udp
11 | read_only: false
12 |
--------------------------------------------------------------------------------
/gen.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 |
3 | package main
4 |
5 | import (
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 | "sort"
10 | )
11 |
12 | // middleware list order very important, handlers call via this order.
13 | var middlewareList = []string{
14 | "recovery",
15 | "loop",
16 | "metrics",
17 | "accesslist",
18 | "ratelimit",
19 | "edns",
20 | "accesslog",
21 | "chaos",
22 | "hostsfile",
23 | "blocklist",
24 | "as112",
25 | "cache",
26 | "failover",
27 | "resolver",
28 | "forwarder",
29 | }
30 |
31 | func main() {
32 | var pathlist []string
33 | for _, name := range middlewareList {
34 | stat, err := os.Stat(filepath.Join(middlewareDir, name))
35 | if err != nil {
36 | fmt.Println(err)
37 | os.Exit(1)
38 | }
39 | if !stat.IsDir() {
40 | fmt.Println("path is not directory")
41 | os.Exit(1)
42 | }
43 | pathlist = append(pathlist, filepath.Join(prefixDir, middlewareDir, name))
44 | }
45 |
46 | file, err := os.Create(filename)
47 | if err != nil {
48 | fmt.Println(err)
49 | os.Exit(1)
50 | }
51 |
52 | defer file.Close()
53 |
54 | file.WriteString("// Code generated by gen.go DO NOT EDIT.\n")
55 |
56 | file.WriteString("\npackage main\n\nimport (\n")
57 | file.WriteString("\t\"github.com/semihalev/sdns/config\"\n")
58 | file.WriteString("\t\"github.com/semihalev/sdns/middleware\"\n")
59 |
60 | sort.StringSlice(pathlist).Sort()
61 |
62 | for _, path := range pathlist {
63 | file.WriteString("\t\"" + path + "\"\n")
64 | }
65 |
66 | file.WriteString(")\n\n")
67 |
68 | file.WriteString("func init() {\n")
69 | for _, name := range middlewareList {
70 | file.WriteString("\tmiddleware.Register(\"" + name + "\", func(cfg *config.Config) middleware.Handler { return " + name + ".New(cfg) })\n")
71 | }
72 |
73 | file.WriteString("}\n")
74 | }
75 |
76 | const (
77 | filename = "zregister.go"
78 | prefixDir = "github.com/semihalev/sdns"
79 | middlewareDir = "middleware"
80 | )
81 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/semihalev/sdns
2 |
3 | require (
4 | github.com/BurntSushi/toml v1.4.0
5 | github.com/cespare/xxhash/v2 v2.3.0
6 | github.com/miekg/dns v1.1.63
7 | github.com/prometheus/client_golang v1.20.5
8 | github.com/quic-go/quic-go v0.49.0
9 | github.com/semihalev/log v0.1.1
10 | github.com/stretchr/testify v1.10.0
11 | github.com/yl2chen/cidranger v1.0.2
12 | golang.org/x/time v0.10.0
13 | )
14 |
15 | require (
16 | github.com/beorn7/perks v1.0.1 // indirect
17 | github.com/davecgh/go-spew v1.1.1 // indirect
18 | github.com/go-stack/stack v1.8.0 // indirect
19 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
20 | github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
21 | github.com/klauspost/compress v1.17.9 // indirect
22 | github.com/kr/text v0.2.0 // indirect
23 | github.com/mattn/go-colorable v0.1.7 // indirect
24 | github.com/mattn/go-isatty v0.0.19 // indirect
25 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
26 | github.com/onsi/ginkgo/v2 v2.9.5 // indirect
27 | github.com/pmezard/go-difflib v1.0.0 // indirect
28 | github.com/prometheus/client_model v0.6.1 // indirect
29 | github.com/prometheus/common v0.55.0 // indirect
30 | github.com/prometheus/procfs v0.15.1 // indirect
31 | github.com/quic-go/qpack v0.5.1 // indirect
32 | go.uber.org/mock v0.5.0 // indirect
33 | golang.org/x/crypto v0.31.0 // indirect
34 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
35 | golang.org/x/mod v0.18.0 // indirect
36 | golang.org/x/net v0.33.0 // indirect
37 | golang.org/x/sync v0.10.0 // indirect
38 | golang.org/x/sys v0.28.0 // indirect
39 | golang.org/x/text v0.21.0 // indirect
40 | golang.org/x/tools v0.22.0 // indirect
41 | google.golang.org/protobuf v1.34.2 // indirect
42 | gopkg.in/yaml.v3 v3.0.1 // indirect
43 | )
44 |
45 | go 1.22
46 |
47 | toolchain go1.22.5
48 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/semihalev/sdns/86830209f5550c03071fde5adf33dea0ff6e086b/logo.png
--------------------------------------------------------------------------------
/middleware/accesslist/accesslist.go:
--------------------------------------------------------------------------------
1 | package accesslist
2 |
3 | import (
4 | "context"
5 | "net"
6 |
7 | "github.com/semihalev/log"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | "github.com/yl2chen/cidranger"
11 | )
12 |
13 | // AccessList type
14 | type AccessList struct {
15 | ranger cidranger.Ranger
16 | }
17 |
18 | // New return accesslist
19 | func New(cfg *config.Config) *AccessList {
20 | if len(cfg.AccessList) == 0 {
21 | cfg.AccessList = append(cfg.AccessList, "0.0.0.0/0")
22 | cfg.AccessList = append(cfg.AccessList, "::0/0")
23 | }
24 |
25 | a := new(AccessList)
26 | a.ranger = cidranger.NewPCTrieRanger()
27 | for _, cidr := range cfg.AccessList {
28 | _, ipnet, err := net.ParseCIDR(cidr)
29 | if err != nil {
30 | log.Error("Access list parse cidr failed", "error", err.Error())
31 | continue
32 | }
33 |
34 | _ = a.ranger.Insert(cidranger.NewBasicRangerEntry(*ipnet))
35 |
36 | }
37 |
38 | return a
39 | }
40 |
41 | // Name return middleware name
42 | func (a *AccessList) Name() string { return name }
43 |
44 | // ServeDNS implements the Handle interface.
45 | func (a *AccessList) ServeDNS(ctx context.Context, ch *middleware.Chain) {
46 | if ch.Writer.Internal() {
47 | ch.Next(ctx)
48 | return
49 | }
50 |
51 | allowed, _ := a.ranger.Contains(ch.Writer.RemoteIP())
52 |
53 | if !allowed {
54 | //no reply to client
55 | ch.Cancel()
56 | return
57 | }
58 |
59 | ch.Next(ctx)
60 | }
61 |
62 | const name = "accesslist"
63 |
--------------------------------------------------------------------------------
/middleware/accesslist/accesslist_test.go:
--------------------------------------------------------------------------------
1 | package accesslist
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/semihalev/log"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | "github.com/semihalev/sdns/mock"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func Test_AccesslistDefaults(t *testing.T) {
15 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
16 |
17 | cfg := new(config.Config)
18 | cfg.AccessList = []string{}
19 |
20 | a := New(cfg)
21 |
22 | ch := middleware.NewChain([]middleware.Handler{a})
23 |
24 | mw := mock.NewWriter("udp", "8.8.8.8:0")
25 | ch.Writer = mw
26 | a.ServeDNS(context.Background(), ch)
27 | }
28 |
29 | func Test_Accesslist(t *testing.T) {
30 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
31 |
32 | cfg := new(config.Config)
33 | cfg.AccessList = []string{"127.0.0.1/32", "1"}
34 |
35 | middleware.Register("accesslist", func(cfg *config.Config) middleware.Handler { return New(cfg) })
36 | middleware.Setup(cfg)
37 |
38 | a := middleware.Get("accesslist").(*AccessList)
39 | assert.Equal(t, "accesslist", a.Name())
40 |
41 | ch := middleware.NewChain([]middleware.Handler{})
42 |
43 | mw := mock.NewWriter("udp", "127.0.0.255:0")
44 | ch.Writer = mw
45 | a.ServeDNS(context.Background(), ch)
46 |
47 | mw = mock.NewWriter("udp", "0.0.0.0:0")
48 | ch.Writer = mw
49 | a.ServeDNS(context.Background(), ch)
50 | }
51 |
--------------------------------------------------------------------------------
/middleware/accesslog/accesslog.go:
--------------------------------------------------------------------------------
1 | package accesslog
2 |
3 | import (
4 | "context"
5 | "os"
6 | "strconv"
7 | "strings"
8 | "time"
9 |
10 | "github.com/miekg/dns"
11 | "github.com/semihalev/log"
12 | "github.com/semihalev/sdns/config"
13 | "github.com/semihalev/sdns/middleware"
14 | )
15 |
16 | // AccessLog type
17 | type AccessLog struct {
18 | cfg *config.Config
19 | logFile *os.File
20 | }
21 |
22 | // New returns a new AccessLog
23 | func New(cfg *config.Config) *AccessLog {
24 | var logFile *os.File
25 | var err error
26 |
27 | if cfg.AccessLog != "" {
28 | logFile, err = os.OpenFile(cfg.AccessLog, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600)
29 | if err != nil {
30 | log.Error("Access log file open failed", "error", strings.Trim(err.Error(), "\n"))
31 | }
32 | }
33 |
34 | return &AccessLog{
35 | cfg: cfg,
36 | logFile: logFile,
37 | }
38 | }
39 |
40 | // Name return middleware name
41 | func (a *AccessLog) Name() string { return name }
42 |
43 | // ServeDNS implements the Handle interface.
44 | func (a *AccessLog) ServeDNS(ctx context.Context, ch *middleware.Chain) {
45 | ch.Next(ctx)
46 |
47 | w := ch.Writer
48 |
49 | if a.logFile != nil && w.Written() && !w.Internal() {
50 | resp := w.Msg()
51 |
52 | cd := "-cd"
53 | if resp.CheckingDisabled {
54 | cd = "+cd"
55 | }
56 |
57 | record := []string{
58 | w.RemoteIP().String() + " -",
59 | "[" + time.Now().Format("02/Jan/2006:15:04:05 -0700") + "]",
60 | formatQuestion(resp.Question[0]),
61 | w.Proto(),
62 | cd,
63 | dns.RcodeToString[resp.Rcode],
64 | strconv.Itoa(resp.Len()),
65 | }
66 |
67 | _, err := a.logFile.WriteString(strings.Join(record, " ") + "\n")
68 | if err != nil {
69 | log.Error("Access log write failed", "error", strings.Trim(err.Error(), "\n"))
70 | }
71 | }
72 | }
73 |
74 | func formatQuestion(q dns.Question) string {
75 | return "\"" + strings.ToLower(q.Name) + " " + dns.ClassToString[q.Qclass] + " " + dns.TypeToString[q.Qtype] + "\""
76 | }
77 |
78 | const name = "accesslog"
79 |
--------------------------------------------------------------------------------
/middleware/accesslog/accesslog_test.go:
--------------------------------------------------------------------------------
1 | package accesslog
2 |
3 | import (
4 | "context"
5 | "os"
6 | "testing"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/log"
10 | "github.com/semihalev/sdns/config"
11 | "github.com/semihalev/sdns/middleware"
12 | "github.com/semihalev/sdns/mock"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func Test_accesslog(t *testing.T) {
17 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
18 |
19 | cfg := &config.Config{
20 | AccessLog: "access_test.log",
21 | }
22 |
23 | middleware.Register("accesslog", func(cfg *config.Config) middleware.Handler { return New(cfg) })
24 | middleware.Setup(cfg)
25 | a := middleware.Get("accesslog").(*AccessLog)
26 |
27 | assert.Equal(t, "accesslog", a.Name())
28 | assert.NotNil(t, a.logFile)
29 |
30 | ch := middleware.NewChain([]middleware.Handler{a})
31 |
32 | mw := mock.NewWriter("udp", "127.0.0.1:0")
33 | req := new(dns.Msg)
34 | req.SetQuestion("test.com.", dns.TypeA)
35 |
36 | ch.Reset(mw, req)
37 |
38 | resp := new(dns.Msg)
39 | resp.SetRcode(req, dns.RcodeServerFailure)
40 | resp.Question = req.Copy().Question
41 |
42 | _ = ch.Writer.WriteMsg(resp)
43 |
44 | a.ServeDNS(context.Background(), ch)
45 |
46 | assert.Equal(t, dns.RcodeServerFailure, mw.Msg().Rcode)
47 |
48 | resp.CheckingDisabled = true
49 | a.ServeDNS(context.Background(), ch)
50 |
51 | assert.True(t, resp.CheckingDisabled)
52 |
53 | assert.NoError(t, a.logFile.Close())
54 |
55 | a.ServeDNS(context.Background(), ch)
56 |
57 | assert.NoError(t, os.Remove(cfg.AccessLog))
58 | }
59 |
--------------------------------------------------------------------------------
/middleware/as112/as112_test.go:
--------------------------------------------------------------------------------
1 | package as112
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/log"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/semihalev/sdns/middleware"
11 | "github.com/semihalev/sdns/mock"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func Test_AS112(t *testing.T) {
16 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
17 |
18 | cfg := new(config.Config)
19 | cfg.EmptyZones = []string{
20 | "10.in-addr.arpa.",
21 | "example.arpa",
22 | }
23 |
24 | middleware.Register("as112", func(cfg *config.Config) middleware.Handler { return New(cfg) })
25 | middleware.Setup(cfg)
26 |
27 | a := middleware.Get("as112").(*AS112)
28 |
29 | assert.Equal(t, "as112", a.Name())
30 |
31 | ch := middleware.NewChain([]middleware.Handler{})
32 |
33 | req := new(dns.Msg)
34 | req.SetQuestion("10.in-addr.arpa.", dns.TypeSOA)
35 | ch.Request = req
36 |
37 | mw := mock.NewWriter("udp", "127.0.0.1:0")
38 | ch.Writer = mw
39 |
40 | a.ServeDNS(context.Background(), ch)
41 | assert.Equal(t, true, len(mw.Msg().Answer) > 0)
42 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
43 |
44 | req.SetQuestion("10.in-addr.arpa.", dns.TypeNS)
45 |
46 | mw = mock.NewWriter("udp", "127.0.0.1:0")
47 | ch.Writer = mw
48 | a.ServeDNS(context.Background(), ch)
49 | assert.Equal(t, true, len(mw.Msg().Answer) > 0)
50 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
51 |
52 | req.SetQuestion("10.in-addr.arpa.", dns.TypeSOA)
53 |
54 | mw = mock.NewWriter("udp", "127.0.0.1:0")
55 | ch.Writer = mw
56 | a.ServeDNS(context.Background(), ch)
57 | assert.Equal(t, true, len(mw.Msg().Answer) > 0)
58 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
59 |
60 | req.SetQuestion("10.in-addr.arpa.", dns.TypeDS)
61 |
62 | mw = mock.NewWriter("udp", "127.0.0.1:0")
63 | ch.Writer = mw
64 | a.ServeDNS(context.Background(), ch)
65 | assert.False(t, mw.Written())
66 |
67 | req.SetQuestion("20.in-addr.arpa.", dns.TypeNS)
68 |
69 | mw = mock.NewWriter("udp", "127.0.0.1:0")
70 | ch.Writer = mw
71 | a.ServeDNS(context.Background(), ch)
72 | assert.False(t, mw.Written())
73 |
74 | req.SetQuestion("example.com.", dns.TypeNS)
75 |
76 | mw = mock.NewWriter("udp", "127.0.0.1:0")
77 | ch.Writer = mw
78 | a.ServeDNS(context.Background(), ch)
79 | assert.False(t, mw.Written())
80 |
81 | req.SetQuestion("10.10.in-addr.arpa.", dns.TypeSOA)
82 |
83 | mw = mock.NewWriter("udp", "127.0.0.1:0")
84 | ch.Writer = mw
85 | a.ServeDNS(context.Background(), ch)
86 | assert.Equal(t, true, len(mw.Msg().Ns) > 0)
87 | assert.Equal(t, dns.RcodeNameError, mw.Rcode())
88 |
89 | req.SetQuestion("10.10.in-addr.arpa.", dns.TypeA)
90 |
91 | mw = mock.NewWriter("udp", "127.0.0.1:0")
92 | ch.Writer = mw
93 | a.ServeDNS(context.Background(), ch)
94 | assert.Equal(t, true, len(mw.Msg().Ns) > 0)
95 | assert.Equal(t, dns.RcodeNameError, mw.Rcode())
96 |
97 | req.SetQuestion("10.10.in-addr.arpa.", dns.TypeNS)
98 |
99 | mw = mock.NewWriter("udp", "127.0.0.1:0")
100 | ch.Writer = mw
101 | a.ServeDNS(context.Background(), ch)
102 | assert.Equal(t, true, len(mw.Msg().Ns) > 0)
103 | assert.Equal(t, dns.RcodeNameError, mw.Rcode())
104 | }
105 |
--------------------------------------------------------------------------------
/middleware/blocklist/blocklist.go:
--------------------------------------------------------------------------------
1 | package blocklist
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "path/filepath"
9 | "sync"
10 |
11 | "github.com/miekg/dns"
12 | "github.com/semihalev/sdns/config"
13 | "github.com/semihalev/sdns/middleware"
14 | )
15 |
16 | // BlockList type
17 | type BlockList struct {
18 | mu sync.RWMutex
19 |
20 | nullroute net.IP
21 | null6route net.IP
22 |
23 | m map[string]bool
24 | w map[string]bool
25 |
26 | cfg *config.Config
27 | }
28 |
29 | // New returns a new BlockList
30 | func New(cfg *config.Config) *BlockList {
31 | b := &BlockList{
32 | nullroute: net.ParseIP(cfg.Nullroute),
33 | null6route: net.ParseIP(cfg.Nullroutev6),
34 |
35 | m: make(map[string]bool),
36 | w: make(map[string]bool),
37 |
38 | cfg: cfg,
39 | }
40 |
41 | go b.fetchBlocklists()
42 |
43 | return b
44 | }
45 |
46 | // Name return middleware name
47 | func (b *BlockList) Name() string { return name }
48 |
49 | // ServeDNS implements the Handle interface.
50 | func (b *BlockList) ServeDNS(ctx context.Context, ch *middleware.Chain) {
51 | w, req := ch.Writer, ch.Request
52 |
53 | q := req.Question[0]
54 |
55 | if !b.Exists(q.Name) {
56 | ch.Next(ctx)
57 | return
58 | }
59 |
60 | msg := new(dns.Msg)
61 | msg.SetReply(req)
62 | msg.Authoritative, msg.RecursionAvailable = true, true
63 |
64 | switch q.Qtype {
65 | case dns.TypeA:
66 | rrHeader := dns.RR_Header{
67 | Name: q.Name,
68 | Rrtype: dns.TypeA,
69 | Class: dns.ClassINET,
70 | Ttl: 3600,
71 | }
72 | a := &dns.A{Hdr: rrHeader, A: b.nullroute}
73 | msg.Answer = append(msg.Answer, a)
74 | case dns.TypeAAAA:
75 | rrHeader := dns.RR_Header{
76 | Name: q.Name,
77 | Rrtype: dns.TypeAAAA,
78 | Class: dns.ClassINET,
79 | Ttl: 3600,
80 | }
81 | a := &dns.AAAA{Hdr: rrHeader, AAAA: b.null6route}
82 | msg.Answer = append(msg.Answer, a)
83 | default:
84 | rrHeader := dns.RR_Header{
85 | Name: q.Name,
86 | Rrtype: dns.TypeSOA,
87 | Class: dns.ClassINET,
88 | Ttl: 86400,
89 | }
90 | soa := &dns.SOA{
91 | Hdr: rrHeader,
92 | Ns: q.Name,
93 | Mbox: ".",
94 | Serial: 0,
95 | Refresh: 28800,
96 | Retry: 7200,
97 | Expire: 604800,
98 | Minttl: 86400,
99 | }
100 | msg.Extra = append(msg.Answer, soa)
101 | }
102 |
103 | _ = w.WriteMsg(msg)
104 |
105 | ch.Cancel()
106 | }
107 |
108 | // Get returns the entry for a key or an error
109 | func (b *BlockList) Get(key string) (bool, error) {
110 | b.mu.RLock()
111 | defer b.mu.RUnlock()
112 |
113 | key = dns.CanonicalName(key)
114 | val, ok := b.m[key]
115 |
116 | if !ok {
117 | return false, errors.New("block not found")
118 | }
119 |
120 | return val, nil
121 | }
122 |
123 | // Remove removes an entry from the cache
124 | func (b *BlockList) Remove(key string) bool {
125 | if !b.Exists(key) {
126 | return false
127 | }
128 |
129 | b.mu.Lock()
130 | defer b.mu.Unlock()
131 |
132 | key = dns.CanonicalName(key)
133 | delete(b.m, key)
134 | b.save()
135 |
136 | return true
137 | }
138 |
139 | // Set sets a value in the BlockList
140 | func (b *BlockList) Set(key string) bool {
141 | b.mu.Lock()
142 | defer b.mu.Unlock()
143 |
144 | key = dns.CanonicalName(key)
145 |
146 | if b.w[key] {
147 | return false
148 | }
149 |
150 | b.m[key] = true
151 | b.save()
152 |
153 | return true
154 | }
155 |
156 | func (b *BlockList) set(key string) bool {
157 | b.mu.Lock()
158 | defer b.mu.Unlock()
159 |
160 | key = dns.CanonicalName(key)
161 |
162 | if b.w[key] {
163 | return false
164 | }
165 |
166 | b.m[key] = true
167 |
168 | return true
169 | }
170 |
171 | // Exists returns whether or not a key exists in the cache
172 | func (b *BlockList) Exists(key string) bool {
173 | b.mu.RLock()
174 | defer b.mu.RUnlock()
175 |
176 | key = dns.CanonicalName(key)
177 | _, ok := b.m[key]
178 |
179 | return ok
180 | }
181 |
182 | // Length returns the caches length
183 | func (b *BlockList) Length() int {
184 | b.mu.RLock()
185 | defer b.mu.RUnlock()
186 |
187 | return len(b.m)
188 | }
189 |
190 | func (b *BlockList) save() {
191 | path := filepath.Join(b.cfg.BlockListDir, "local")
192 |
193 | file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
194 | if err != nil {
195 | return
196 | }
197 |
198 | _, _ = file.WriteString("# The file generated by auto. DO NOT EDIT\n")
199 | for d := range b.m {
200 | _, _ = file.WriteString(d + "\n")
201 | }
202 |
203 | _ = file.Close()
204 | }
205 |
206 | const name = "blocklist"
207 |
--------------------------------------------------------------------------------
/middleware/blocklist/blocklist_test.go:
--------------------------------------------------------------------------------
1 | package blocklist
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 | "path/filepath"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/miekg/dns"
12 | "github.com/semihalev/sdns/config"
13 | "github.com/semihalev/sdns/middleware"
14 | "github.com/semihalev/sdns/mock"
15 | "github.com/stretchr/testify/assert"
16 | )
17 |
18 | func Test_BlockList(t *testing.T) {
19 | testDomain := "test.com."
20 |
21 | cfg := new(config.Config)
22 | cfg.Nullroute = "0.0.0.0"
23 | cfg.Nullroutev6 = "::0"
24 | cfg.BlockListDir = filepath.Join(os.TempDir(), "sdns_temp")
25 |
26 | middleware.Register("blocklist", func(cfg *config.Config) middleware.Handler { return New(cfg) })
27 | middleware.Setup(cfg)
28 |
29 | blocklist := middleware.Get("blocklist").(*BlockList)
30 |
31 | assert.Equal(t, "blocklist", blocklist.Name())
32 | blocklist.Set(testDomain)
33 |
34 | ch := middleware.NewChain([]middleware.Handler{})
35 |
36 | req := new(dns.Msg)
37 | req.SetQuestion("test.com.", dns.TypeA)
38 | ch.Request = req
39 |
40 | mw := mock.NewWriter("udp", "127.0.0.1:0")
41 | ch.Writer = mw
42 |
43 | blocklist.ServeDNS(context.Background(), ch)
44 | assert.Equal(t, true, len(mw.Msg().Answer) > 0)
45 |
46 | req.SetQuestion("test.com.", dns.TypeAAAA)
47 | ch.Request = req
48 |
49 | blocklist.ServeDNS(context.Background(), ch)
50 | assert.Equal(t, true, len(mw.Msg().Answer) > 0)
51 |
52 | req.SetQuestion("test.com.", dns.TypeNS)
53 | ch.Request = req
54 |
55 | blocklist.ServeDNS(context.Background(), ch)
56 | assert.Equal(t, true, len(mw.Msg().Extra) > 0)
57 |
58 | mw = mock.NewWriter("udp", "127.0.0.1:0")
59 | ch.Writer = mw
60 | req.SetQuestion("test2.com.", dns.TypeA)
61 | blocklist.ServeDNS(context.Background(), ch)
62 | assert.Nil(t, mw.Msg())
63 |
64 | assert.Equal(t, blocklist.Exists(testDomain), true)
65 | assert.Equal(t, blocklist.Exists(strings.ToUpper(testDomain)), true)
66 |
67 | _, err := blocklist.Get(testDomain)
68 | assert.NoError(t, err)
69 |
70 | assert.Equal(t, blocklist.Length(), 1)
71 |
72 | if exists := blocklist.Exists(fmt.Sprintf("%sfuzz", testDomain)); exists {
73 | t.Error("fuzz existed in block blocklist")
74 | }
75 |
76 | if blocklistLen := blocklist.Length(); blocklistLen != 1 {
77 | t.Error("invalid length: ", blocklistLen)
78 | }
79 |
80 | blocklist.Remove(testDomain)
81 | assert.Equal(t, blocklist.Exists(testDomain), false)
82 |
83 | _, err = blocklist.Get(testDomain)
84 | assert.Error(t, err)
85 |
86 | blocklist.Set(testDomain)
87 | }
88 |
--------------------------------------------------------------------------------
/middleware/blocklist/updater.go:
--------------------------------------------------------------------------------
1 | package blocklist
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "net/url"
9 | "os"
10 | "path/filepath"
11 | "strings"
12 | "sync"
13 | "time"
14 |
15 | "github.com/miekg/dns"
16 | "github.com/semihalev/log"
17 | )
18 |
19 | var timesSeen = make(map[string]int)
20 |
21 | func (b *BlockList) fetchBlocklists() {
22 | if b.cfg.BlockListDir == "" {
23 | b.cfg.BlockListDir = filepath.Join(b.cfg.Directory, "blacklists")
24 | }
25 |
26 | <-time.After(time.Second)
27 |
28 | if err := b.updateBlocklists(); err != nil {
29 | log.Error("Update blocklists failed", "error", err.Error())
30 | }
31 |
32 | if err := b.readBlocklists(); err != nil {
33 | log.Error("Read blocklists failed", "dir", b.cfg.BlockListDir, "error", err.Error())
34 | }
35 | }
36 |
37 | func (b *BlockList) updateBlocklists() error {
38 | if _, err := os.Stat(b.cfg.BlockListDir); os.IsNotExist(err) {
39 | if err := os.Mkdir(b.cfg.BlockListDir, 0750); err != nil {
40 | return fmt.Errorf("error creating blacklist directory: %s", err)
41 | }
42 | }
43 |
44 | b.mu.Lock()
45 | for _, entry := range b.cfg.Whitelist {
46 | b.w[dns.CanonicalName(entry)] = true
47 | }
48 | b.mu.Unlock()
49 |
50 | for _, entry := range b.cfg.Blocklist {
51 | b.set(entry)
52 | }
53 |
54 | b.fetchBlocklist()
55 |
56 | return nil
57 | }
58 |
59 | func (b *BlockList) downloadBlocklist(uri, name string) error {
60 | filePath := filepath.FromSlash(fmt.Sprintf("%s/%s", b.cfg.BlockListDir, name))
61 |
62 | output, err := os.Create(filePath)
63 | if err != nil {
64 | return fmt.Errorf("error creating file: %s", err)
65 | }
66 |
67 | defer func() {
68 | err := output.Close()
69 | if err != nil {
70 | log.Warn("Blocklist file close failed", "name", name, "error", err.Error())
71 | }
72 | }()
73 |
74 | response, err := http.Get(uri)
75 | if err != nil {
76 | return fmt.Errorf("error downloading source: %s", err)
77 | }
78 | defer response.Body.Close()
79 |
80 | if _, err := io.Copy(output, response.Body); err != nil {
81 | return fmt.Errorf("error copying output: %s", err)
82 | }
83 |
84 | return nil
85 | }
86 |
87 | func (b *BlockList) fetchBlocklist() {
88 | var wg sync.WaitGroup
89 |
90 | for _, uri := range b.cfg.BlockLists {
91 | wg.Add(1)
92 |
93 | u, _ := url.Parse(uri)
94 | host := u.Host
95 | timesSeen[host] = timesSeen[host] + 1
96 | fileName := fmt.Sprintf("%s.%d.tmp", host, timesSeen[host])
97 |
98 | go func(uri string, name string) {
99 | log.Info("Fetching blacklist", "uri", uri)
100 | if err := b.downloadBlocklist(uri, name); err != nil {
101 | log.Error("Fetching blacklist", "uri", uri, "error", err.Error())
102 | }
103 |
104 | wg.Done()
105 | }(uri, fileName)
106 | }
107 |
108 | wg.Wait()
109 | }
110 |
111 | func (b *BlockList) readBlocklists() error {
112 | log.Info("Loading blocked domains...", "path", b.cfg.BlockListDir)
113 |
114 | if _, err := os.Stat(b.cfg.BlockListDir); os.IsNotExist(err) {
115 | log.Warn("Path not found, skipping...", "path", b.cfg.BlockListDir)
116 | return nil
117 | }
118 |
119 | err := filepath.Walk(b.cfg.BlockListDir, func(path string, f os.FileInfo, _ error) error {
120 | if !f.IsDir() {
121 | file, err := os.Open(filepath.FromSlash(path))
122 | if err != nil {
123 | return fmt.Errorf("error opening file: %s", err)
124 | }
125 |
126 | if err = b.parseHostFile(file); err != nil {
127 | _ = file.Close()
128 | return fmt.Errorf("error parsing hostfile %s", err)
129 | }
130 |
131 | _ = file.Close()
132 |
133 | if filepath.Ext(path) == ".tmp" {
134 | _ = os.Remove(filepath.FromSlash(path))
135 | }
136 | }
137 |
138 | return nil
139 | })
140 |
141 | if err != nil {
142 | return fmt.Errorf("error walking location %s", err)
143 | }
144 |
145 | log.Info("Blocked domains loaded", "total", b.Length())
146 |
147 | return nil
148 | }
149 |
150 | func (b *BlockList) parseHostFile(file *os.File) error {
151 | scanner := bufio.NewScanner(file)
152 | for scanner.Scan() {
153 | line := scanner.Text()
154 | line = strings.TrimSpace(line)
155 | isComment := strings.HasPrefix(line, "#")
156 |
157 | if !isComment && line != "" {
158 | fields := strings.Fields(line)
159 |
160 | if len(fields) > 1 && !strings.HasPrefix(fields[1], "#") {
161 | line = fields[1]
162 | } else {
163 | line = fields[0]
164 | }
165 |
166 | line = dns.CanonicalName(line)
167 |
168 | if !b.Exists(line) {
169 | b.set(line)
170 | }
171 | }
172 | }
173 |
174 | if err := scanner.Err(); err != nil {
175 | return fmt.Errorf("error scanning hostfile: %s", err)
176 | }
177 |
178 | return nil
179 | }
180 |
--------------------------------------------------------------------------------
/middleware/blocklist/updater_test.go:
--------------------------------------------------------------------------------
1 | package blocklist
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "testing"
7 |
8 | "github.com/semihalev/log"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | const (
14 | testDomain = "www.google.com"
15 | )
16 |
17 | func Test_UpdateBlocklists(t *testing.T) {
18 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
19 |
20 | tempDir := filepath.Join(os.TempDir(), "sdns_temp")
21 |
22 | cfg := new(config.Config)
23 | cfg.BlockListDir = tempDir
24 | cfg.Whitelist = append(cfg.Whitelist, testDomain)
25 | cfg.Blocklist = append(cfg.Blocklist, testDomain)
26 |
27 | cfg.BlockLists = []string{}
28 | cfg.BlockLists = append(cfg.BlockLists, "https://raw.githubusercontent.com/quidsup/notrack/master/trackers.txt")
29 | cfg.BlockLists = append(cfg.BlockLists, "https://test.dev/hosts")
30 |
31 | b := New(cfg)
32 |
33 | err := b.updateBlocklists()
34 | assert.NoError(t, err)
35 |
36 | err = b.readBlocklists()
37 | assert.NoError(t, err)
38 | }
39 |
--------------------------------------------------------------------------------
/middleware/cache/item.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package cache
5 |
6 | import (
7 | "time"
8 |
9 | "github.com/miekg/dns"
10 | "golang.org/x/time/rate"
11 | )
12 |
13 | type item struct {
14 | Rcode int
15 | Authoritative bool
16 | AuthenticatedData bool
17 | RecursionAvailable bool
18 | Answer []dns.RR
19 | Ns []dns.RR
20 | Extra []dns.RR
21 |
22 | Limiter *rate.Limiter
23 |
24 | origTTL uint32
25 | stored time.Time
26 |
27 | prefetching bool
28 | }
29 |
30 | func newItem(m *dns.Msg, now time.Time, d time.Duration, queryRate int) *item {
31 | i := new(item)
32 | i.Rcode = m.Rcode
33 | i.Authoritative = m.Authoritative
34 | i.AuthenticatedData = m.AuthenticatedData
35 | i.RecursionAvailable = m.RecursionAvailable
36 | i.Answer = m.Answer
37 | i.Ns = m.Ns
38 | i.Extra = make([]dns.RR, len(m.Extra))
39 | // Don't copy OPT records as these are hop-by-hop.
40 | j := 0
41 | for _, e := range m.Extra {
42 | if e.Header().Rrtype == dns.TypeOPT {
43 | continue
44 | }
45 | i.Extra[j] = e
46 | j++
47 | }
48 | i.Extra = i.Extra[:j]
49 |
50 | i.origTTL = uint32(d.Seconds())
51 | i.stored = now.UTC()
52 |
53 | limit := rate.Limit(0)
54 | if queryRate > 0 {
55 | limit = rate.Every(time.Second / time.Duration(queryRate))
56 | }
57 |
58 | i.Limiter = rate.NewLimiter(limit, queryRate)
59 |
60 | return i
61 | }
62 |
63 | // toMsg turns i into a message, it tailors the reply to m.
64 | // The Authoritative bit is always set to 0, because the answer is from the cache.
65 | func (i *item) toMsg(m *dns.Msg, now time.Time) *dns.Msg {
66 | m1 := new(dns.Msg)
67 | m1.SetReply(m)
68 |
69 | m1.Authoritative = false
70 | m1.AuthenticatedData = i.AuthenticatedData
71 | m1.RecursionAvailable = i.RecursionAvailable
72 | m1.Rcode = i.Rcode
73 |
74 | m1.Answer = i.Answer
75 | m1.Ns = i.Ns
76 | m1.Extra = i.Extra
77 |
78 | m1.Answer = make([]dns.RR, len(i.Answer))
79 | m1.Ns = make([]dns.RR, len(i.Ns))
80 | m1.Extra = make([]dns.RR, len(i.Extra))
81 |
82 | ttl := uint32(i.ttl(now))
83 | for j, r := range i.Answer {
84 | m1.Answer[j] = dns.Copy(r)
85 | m1.Answer[j].Header().Ttl = ttl
86 | }
87 | for j, r := range i.Ns {
88 | m1.Ns[j] = dns.Copy(r)
89 | m1.Ns[j].Header().Ttl = ttl
90 | }
91 | // newItem skips OPT records, so we can just use i.Extra as is.
92 | for j, r := range i.Extra {
93 | m1.Extra[j] = dns.Copy(r)
94 | m1.Extra[j].Header().Ttl = ttl
95 | }
96 | return m1
97 | }
98 |
99 | func (i *item) ttl(now time.Time) int {
100 | ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
101 | return ttl
102 | }
103 |
--------------------------------------------------------------------------------
/middleware/chain.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/miekg/dns"
7 | )
8 |
9 | // Chain type
10 | type Chain struct {
11 | Writer ResponseWriter
12 | Request *dns.Msg
13 |
14 | handlers []Handler
15 |
16 | head int
17 | tail int
18 | count int
19 | }
20 |
21 | // NewChain return new fresh chain
22 | func NewChain(handlers []Handler) *Chain {
23 | return &Chain{
24 | Writer: &responseWriter{},
25 | handlers: handlers,
26 | count: len(handlers),
27 | }
28 | }
29 |
30 | // Next call next dns handler in the chain
31 | func (ch *Chain) Next(ctx context.Context) {
32 | if ch.count == 0 {
33 | return
34 | }
35 |
36 | handler := ch.handlers[ch.head]
37 | ch.head = (ch.head + 1) % len(ch.handlers)
38 | ch.count--
39 |
40 | handler.ServeDNS(ctx, ch)
41 | }
42 |
43 | // Cancel next calls
44 | func (ch *Chain) Cancel() {
45 | ch.count = 0
46 | }
47 |
48 | // CancelWithRcode next calls with rcode
49 | func (ch *Chain) CancelWithRcode(rcode int, do bool) {
50 | m := new(dns.Msg)
51 | m.Extra = ch.Request.Extra
52 | m.SetRcode(ch.Request, rcode)
53 |
54 | m.RecursionAvailable = true
55 | m.RecursionDesired = true
56 |
57 | if opt := m.IsEdns0(); opt != nil {
58 | opt.SetDo(do)
59 | }
60 |
61 | _ = ch.Writer.WriteMsg(m)
62 |
63 | ch.count = 0
64 | }
65 |
66 | // Reset the chain variables
67 | func (ch *Chain) Reset(w dns.ResponseWriter, r *dns.Msg) {
68 | ch.Writer.Reset(w)
69 | ch.Request = r
70 | ch.count = len(ch.handlers)
71 | ch.head, ch.tail = 0, 0
72 | }
73 |
--------------------------------------------------------------------------------
/middleware/chain_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/mock"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func Test_Chain(t *testing.T) {
13 | w := mock.NewWriter("tcp", "127.0.0.1:0")
14 | ch := NewChain([]Handler{&dummy{}})
15 | req := new(dns.Msg)
16 | req.SetQuestion("test.com.", dns.TypeA)
17 | req.SetEdns0(512, true)
18 | ch.Reset(w, req)
19 |
20 | ch.Next(context.Background())
21 |
22 | req.Rcode = dns.RcodeSuccess
23 | err := ch.Writer.WriteMsg(req)
24 | assert.NoError(t, err)
25 |
26 | data, err := req.Pack()
27 | assert.NoError(t, err)
28 |
29 | assert.Equal(t, true, ch.Writer.Written())
30 | assert.Equal(t, dns.RcodeSuccess, ch.Writer.Rcode())
31 |
32 | _, err = ch.Writer.Write(data)
33 | assert.Equal(t, errAlreadyWritten, err)
34 |
35 | ch.Reset(mock.NewWriter("tcp", "127.0.0.1:0"), req)
36 | size, err := ch.Writer.Write(data)
37 | assert.NoError(t, err)
38 | assert.Equal(t, len(data), size)
39 | assert.NotNil(t, ch.Writer.Msg())
40 |
41 | err = ch.Writer.WriteMsg(req)
42 | assert.Equal(t, errAlreadyWritten, err)
43 |
44 | ch.Reset(mock.NewWriter("tcp", "127.0.0.1:0"), req)
45 | _, err = ch.Writer.Write([]byte{})
46 | assert.Error(t, err)
47 |
48 | assert.Equal(t, "tcp", ch.Writer.Proto())
49 | assert.Equal(t, "127.0.0.1", ch.Writer.RemoteIP().String())
50 |
51 | ch.Cancel()
52 | assert.Equal(t, 0, ch.count)
53 |
54 | ch.Reset(mock.NewWriter("tcp", "127.0.0.1:0"), req)
55 |
56 | ch.CancelWithRcode(dns.RcodeServerFailure, true)
57 | assert.True(t, ch.Writer.Written())
58 | assert.Equal(t, dns.RcodeServerFailure, ch.Writer.Rcode())
59 | assert.Equal(t, 0, ch.count)
60 | }
61 |
--------------------------------------------------------------------------------
/middleware/chaos/chaos.go:
--------------------------------------------------------------------------------
1 | package chaos
2 |
3 | import (
4 | "context"
5 | "os"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | )
11 |
12 | // Chaos type
13 | type Chaos struct {
14 | chaos bool
15 | version string
16 | }
17 |
18 | // New return accesslist
19 | func New(cfg *config.Config) *Chaos {
20 | return &Chaos{
21 | version: "SDNS v" + cfg.ServerVersion() + " (github.com/semihalev/sdns)",
22 | chaos: cfg.Chaos,
23 | }
24 | }
25 |
26 | // Name return middleware name
27 | func (c *Chaos) Name() string { return name }
28 |
29 | // ServeDNS implements the Handle interface.
30 | func (c *Chaos) ServeDNS(ctx context.Context, ch *middleware.Chain) {
31 | w, req := ch.Writer, ch.Request
32 |
33 | q := req.Question[0]
34 |
35 | if q.Qclass != dns.ClassCHAOS || q.Qtype != dns.TypeTXT || !c.chaos {
36 | ch.Next(ctx)
37 | return
38 | }
39 |
40 | resp := new(dns.Msg)
41 | resp.SetReply(req)
42 |
43 | switch q.Name {
44 | case "version.bind.", "version.server.":
45 | resp.Answer = []dns.RR{
46 | &dns.TXT{
47 | Hdr: dns.RR_Header{
48 | Name: q.Name,
49 | Rrtype: dns.TypeTXT,
50 | Class: q.Qclass,
51 | },
52 | Txt: []string{c.version},
53 | }}
54 | case "hostname.bind.", "id.server.":
55 | hostname, err := os.Hostname()
56 | if err != nil {
57 | hostname = "unknown"
58 | }
59 |
60 | resp.Answer = []dns.RR{
61 | &dns.TXT{
62 | Hdr: dns.RR_Header{
63 | Name: q.Name,
64 | Rrtype: dns.TypeTXT,
65 | Class: q.Qclass,
66 | },
67 | Txt: []string{limitTXTLength(hostname)},
68 | }}
69 | default:
70 | ch.Next(ctx)
71 | return
72 | }
73 |
74 | _ = w.WriteMsg(resp)
75 | ch.Cancel()
76 | }
77 |
78 | func limitTXTLength(s string) string {
79 | if len(s) < 256 {
80 | return s
81 | }
82 | return s[:255]
83 | }
84 |
85 | const name = "chaos"
86 |
--------------------------------------------------------------------------------
/middleware/chaos/chaos_test.go:
--------------------------------------------------------------------------------
1 | package chaos
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | "github.com/semihalev/sdns/mock"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func Test_Chaos(t *testing.T) {
15 | cfg := new(config.Config)
16 | cfg.Chaos = true
17 |
18 | middleware.Register("chaos", func(cfg *config.Config) middleware.Handler { return New(cfg) })
19 | middleware.Setup(cfg)
20 |
21 | c := middleware.Get("chaos").(*Chaos)
22 | assert.Equal(t, "chaos", c.Name())
23 |
24 | ch := middleware.NewChain([]middleware.Handler{})
25 |
26 | mw := mock.NewWriter("udp", "127.0.0.1:0")
27 | req := new(dns.Msg)
28 | req.SetQuestion("version.bind.", dns.TypeTXT)
29 | ch.Reset(mw, req)
30 | c.ServeDNS(context.Background(), ch)
31 |
32 | assert.False(t, mw.Written())
33 |
34 | mw = mock.NewWriter("udp", "127.0.0.1:0")
35 | req.Question[0].Qclass = dns.ClassCHAOS
36 | ch.Reset(mw, req)
37 | c.ServeDNS(context.Background(), ch)
38 |
39 | assert.True(t, mw.Written())
40 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
41 |
42 | mw = mock.NewWriter("udp", "127.0.0.1:0")
43 | req.Question[0].Name = "hostname.bind."
44 | ch.Reset(mw, req)
45 | c.ServeDNS(context.Background(), ch)
46 |
47 | assert.True(t, mw.Written())
48 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
49 |
50 | mw = mock.NewWriter("udp", "127.0.0.1:0")
51 | req.Question[0].Name = "unknown.bind."
52 | ch.Reset(mw, req)
53 | c.ServeDNS(context.Background(), ch)
54 |
55 | assert.False(t, mw.Written())
56 | }
57 |
--------------------------------------------------------------------------------
/middleware/edns/edns.go:
--------------------------------------------------------------------------------
1 | package edns
2 |
3 | import (
4 | "context"
5 | "encoding/hex"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/dnsutil"
10 | "github.com/semihalev/sdns/middleware"
11 | )
12 |
13 | // EDNS type
14 | type EDNS struct {
15 | cookiesecret string
16 | nsidstr string
17 | }
18 |
19 | // New return edns
20 | func New(cfg *config.Config) *EDNS {
21 | return &EDNS{cookiesecret: cfg.CookieSecret, nsidstr: cfg.NSID}
22 | }
23 |
24 | // Name return middleware name
25 | func (e *EDNS) Name() string { return name }
26 |
27 | // ResponseWriter implement of ctx.ResponseWriter
28 | type ResponseWriter struct {
29 | middleware.ResponseWriter
30 | *EDNS
31 |
32 | opt *dns.OPT
33 | size int
34 | do bool
35 | cookie string
36 | nsid bool
37 | noedns bool
38 | noad bool
39 | }
40 |
41 | // ServeDNS implements the Handle interface.
42 | func (e *EDNS) ServeDNS(ctx context.Context, ch *middleware.Chain) {
43 | w, req := ch.Writer, ch.Request
44 |
45 | if req.Opcode > 0 {
46 | _ = dnsutil.NotSupported(w, req)
47 |
48 | ch.Cancel()
49 | return
50 | }
51 |
52 | noedns := req.IsEdns0() == nil
53 |
54 | opt, size, cookie, nsid, do := dnsutil.SetEdns0(req)
55 | if opt.Version() != 0 {
56 | opt.SetVersion(0)
57 |
58 | ch.CancelWithRcode(dns.RcodeBadVers, do)
59 |
60 | return
61 | }
62 |
63 | switch w.Proto() {
64 | case "tcp", "doq", "doh":
65 | size = dns.MaxMsgSize
66 | }
67 |
68 | if noedns {
69 | size = dns.MinMsgSize
70 | }
71 |
72 | ch.Writer = &ResponseWriter{
73 | ResponseWriter: w,
74 | EDNS: e,
75 |
76 | opt: opt,
77 | size: size,
78 | do: do,
79 | cookie: cookie,
80 | noedns: noedns,
81 | nsid: nsid,
82 | noad: !req.AuthenticatedData && !do,
83 | }
84 |
85 | ch.Next(ctx)
86 |
87 | ch.Writer = w
88 | }
89 |
90 | // WriteMsg implements the ctx.ResponseWriter interface
91 | func (w *ResponseWriter) WriteMsg(m *dns.Msg) error {
92 | m.Compress = true
93 |
94 | if !w.do {
95 | m = dnsutil.ClearDNSSEC(m)
96 | }
97 | m = dnsutil.ClearOPT(m)
98 |
99 | if !w.noedns {
100 | w.opt.SetDo(w.do)
101 | w.setCookie()
102 | w.setNSID()
103 | m.Extra = append(m.Extra, w.opt)
104 | }
105 |
106 | if w.noad {
107 | m.AuthenticatedData = false
108 | }
109 |
110 | if w.Proto() == "udp" && m.Len() > w.size {
111 | m.Truncated = true
112 | m.Answer = []dns.RR{}
113 | m.Ns = []dns.RR{}
114 | m.AuthenticatedData = false
115 | }
116 |
117 | return w.ResponseWriter.WriteMsg(m)
118 | }
119 |
120 | func (w *ResponseWriter) setCookie() {
121 | if w.cookie == "" {
122 | return
123 | }
124 |
125 | w.opt.Option = append(w.opt.Option, &dns.EDNS0_COOKIE{
126 | Code: dns.EDNS0COOKIE,
127 | Cookie: dnsutil.GenerateServerCookie(w.cookiesecret, w.RemoteIP().String(), w.cookie),
128 | })
129 | }
130 |
131 | func (w *ResponseWriter) setNSID() {
132 | if w.nsidstr == "" || !w.nsid {
133 | return
134 | }
135 |
136 | w.opt.Option = append(w.opt.Option, &dns.EDNS0_NSID{
137 | Code: dns.EDNS0NSID,
138 | Nsid: hex.EncodeToString([]byte(w.nsidstr)),
139 | })
140 | }
141 |
142 | const name = "edns"
143 |
--------------------------------------------------------------------------------
/middleware/edns/edns_test.go:
--------------------------------------------------------------------------------
1 | package edns
2 |
3 | import (
4 | "context"
5 | "net"
6 | "testing"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/semihalev/sdns/middleware"
11 | "github.com/semihalev/sdns/mock"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | type dummy struct{}
16 |
17 | func (d *dummy) ServeDNS(ctx context.Context, ch *middleware.Chain) {
18 | w, req := ch.Writer, ch.Request
19 |
20 | m := new(dns.Msg)
21 | m.SetReply(req)
22 |
23 | rrHeader := dns.RR_Header{
24 | Name: req.Question[0].Name,
25 | Rrtype: dns.TypeA,
26 | Class: dns.ClassINET,
27 | Ttl: 3600,
28 | }
29 | a := &dns.A{Hdr: rrHeader, A: net.ParseIP("127.0.0.1")}
30 |
31 | for i := 0; i < 100; i++ {
32 | m.Answer = append(m.Answer, a)
33 | }
34 |
35 | _ = w.WriteMsg(m)
36 | }
37 |
38 | func (d *dummy) Name() string { return "dummy" }
39 |
40 | func Test_EDNS(t *testing.T) {
41 | testDomain := "example.com."
42 |
43 | cfg := new(config.Config)
44 |
45 | middleware.Register("edns", func(cfg *config.Config) middleware.Handler { return New(cfg) })
46 | middleware.Setup(cfg)
47 |
48 | edns := middleware.Get("edns").(*EDNS)
49 | assert.Equal(t, "edns", edns.Name())
50 |
51 | ch := middleware.NewChain([]middleware.Handler{edns, &dummy{}})
52 |
53 | req := new(dns.Msg)
54 | req.SetQuestion(testDomain, dns.TypeA)
55 |
56 | mw := mock.NewWriter("tcp", "127.0.0.1:0")
57 | ch.Reset(mw, req)
58 | ch.Next(context.Background())
59 |
60 | assert.True(t, ch.Writer.Written())
61 | assert.Equal(t, dns.RcodeSuccess, ch.Writer.Rcode())
62 | assert.Nil(t, ch.Writer.Msg().IsEdns0())
63 |
64 | req.SetEdns0(4096, true)
65 | opt := req.IsEdns0()
66 | opt.SetVersion(100)
67 |
68 | mw = mock.NewWriter("udp", "127.0.0.1:0")
69 | ch.Reset(mw, req)
70 | ch.Next(context.Background())
71 |
72 | assert.True(t, ch.Writer.Written())
73 | assert.Equal(t, dns.RcodeBadVers, ch.Writer.Rcode())
74 |
75 | opt = req.IsEdns0()
76 | opt.SetVersion(0)
77 | opt.SetUDPSize(512)
78 |
79 | mw = mock.NewWriter("tcp", "127.0.0.1:0")
80 | ch.Reset(mw, req)
81 | ch.Next(context.Background())
82 |
83 | if assert.True(t, ch.Writer.Written()) {
84 | assert.False(t, ch.Writer.Msg().Truncated)
85 | }
86 |
87 | mw = mock.NewWriter("udp", "127.0.0.1:0")
88 | ch.Reset(mw, req)
89 | ch.Next(context.Background())
90 |
91 | if assert.True(t, ch.Writer.Written()) {
92 | assert.True(t, ch.Writer.Msg().Truncated)
93 | }
94 |
95 | opt.Option = append(opt.Option, &dns.EDNS0_COOKIE{
96 | Code: dns.EDNS0COOKIE,
97 | Cookie: "testtesttesttest",
98 | })
99 | opt.SetUDPSize(4096)
100 | mw = mock.NewWriter("udp", "127.0.0.1:0")
101 | ch.Reset(mw, req)
102 | ch.Next(context.Background())
103 | }
104 |
--------------------------------------------------------------------------------
/middleware/failover/failover.go:
--------------------------------------------------------------------------------
1 | package failover
2 |
3 | import (
4 | "context"
5 | "net"
6 | "strings"
7 | "time"
8 |
9 | "github.com/miekg/dns"
10 | "github.com/semihalev/log"
11 | "github.com/semihalev/sdns/config"
12 | "github.com/semihalev/sdns/dnsutil"
13 | "github.com/semihalev/sdns/middleware"
14 | )
15 |
16 | // Failover type
17 | type Failover struct {
18 | servers []string
19 | }
20 |
21 | // ResponseWriter implement of ctx.ResponseWriter
22 | type ResponseWriter struct {
23 | middleware.ResponseWriter
24 |
25 | f *Failover
26 | }
27 |
28 | // New return failover
29 | func New(cfg *config.Config) *Failover {
30 | fallbackservers := []string{}
31 | for _, s := range cfg.FallbackServers {
32 | host, _, _ := net.SplitHostPort(s)
33 |
34 | if ip := net.ParseIP(host); ip != nil && ip.To4() != nil {
35 | fallbackservers = append(fallbackservers, s)
36 | } else if ip != nil && ip.To16() != nil {
37 | fallbackservers = append(fallbackservers, s)
38 | } else {
39 | log.Error("Fallback server is not correct. Check your config.", "server", s)
40 | }
41 | }
42 |
43 | return &Failover{servers: fallbackservers}
44 | }
45 |
46 | // Name return middleware name
47 | func (f *Failover) Name() string { return name }
48 |
49 | // ServeDNS implements the Handle interface.
50 | func (f *Failover) ServeDNS(ctx context.Context, ch *middleware.Chain) {
51 | w := ch.Writer
52 |
53 | ch.Writer = &ResponseWriter{ResponseWriter: w, f: f}
54 |
55 | ch.Next(ctx)
56 |
57 | ch.Writer = w
58 | }
59 |
60 | // WriteMsg implements the ctx.ResponseWriter interface
61 | func (w *ResponseWriter) WriteMsg(m *dns.Msg) error {
62 | if len(m.Question) == 0 || len(w.f.servers) == 0 {
63 | return w.ResponseWriter.WriteMsg(m)
64 | }
65 |
66 | if m.Rcode != dns.RcodeServerFailure || !m.RecursionDesired {
67 | return w.ResponseWriter.WriteMsg(m)
68 | }
69 |
70 | req := new(dns.Msg)
71 | req.SetQuestion(m.Question[0].Name, m.Question[0].Qtype)
72 | req.Question[0].Qclass = m.Question[0].Qclass
73 | req.SetEdns0(dnsutil.DefaultMsgSize, true)
74 | req.CheckingDisabled = m.CheckingDisabled
75 |
76 | ctx := context.Background()
77 |
78 | for _, server := range w.f.servers {
79 | ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
80 | defer cancel()
81 | resp, err := dnsutil.Exchange(ctx, req, server, "udp")
82 | if err != nil {
83 | log.Info("Failover query failed", "query", formatQuestion(req.Question[0]), "error", err.Error())
84 | continue
85 | }
86 |
87 | resp.Id = m.Id
88 |
89 | return w.ResponseWriter.WriteMsg(resp)
90 | }
91 |
92 | return w.ResponseWriter.WriteMsg(m)
93 | }
94 |
95 | func formatQuestion(q dns.Question) string {
96 | return strings.ToLower(q.Name) + " " + dns.ClassToString[q.Qclass] + " " + dns.TypeToString[q.Qtype]
97 | }
98 |
99 | const name = "failover"
100 |
--------------------------------------------------------------------------------
/middleware/failover/failover_test.go:
--------------------------------------------------------------------------------
1 | package failover
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/log"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/semihalev/sdns/middleware"
11 | "github.com/semihalev/sdns/mock"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | type dummy struct{}
16 |
17 | func (d *dummy) ServeDNS(ctx context.Context, ch *middleware.Chain) {
18 | w, req := ch.Writer, ch.Request
19 |
20 | m := new(dns.Msg)
21 | m.SetRcode(req, dns.RcodeServerFailure)
22 |
23 | _ = w.WriteMsg(m)
24 | }
25 |
26 | func (d *dummy) Name() string { return "dummy" }
27 |
28 | func Test_Failover(t *testing.T) {
29 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
30 |
31 | cfg := new(config.Config)
32 | cfg.FallbackServers = []string{"[::255]:53", "8.8.8.8:53", "1"}
33 |
34 | middleware.Register("failover", func(cfg *config.Config) middleware.Handler { return New(cfg) })
35 | middleware.Setup(cfg)
36 |
37 | f := middleware.Get("failover").(*Failover)
38 | assert.Equal(t, "failover", f.Name())
39 |
40 | ch := middleware.NewChain([]middleware.Handler{f, &dummy{}})
41 |
42 | ctx := context.Background()
43 |
44 | req := new(dns.Msg)
45 | req.SetQuestion("example.com.", dns.TypeA)
46 | req.RecursionDesired = false
47 |
48 | mw := mock.NewWriter("udp", "127.0.0.1:0")
49 | ch.Writer = mw
50 | ch.Request = req
51 |
52 | ch.Reset(mw, req)
53 | ch.Next(ctx)
54 |
55 | assert.Equal(t, dns.RcodeServerFailure, mw.Rcode())
56 |
57 | req.RecursionDesired = true
58 |
59 | ch.Reset(mw, req)
60 | ch.Next(ctx)
61 |
62 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess)
63 |
64 | f.servers = []string{}
65 |
66 | ch.Reset(mw, req)
67 | ch.Next(ctx)
68 |
69 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure)
70 |
71 | f.servers = []string{"[::255]:53"}
72 |
73 | ch.Reset(mw, req)
74 | ch.Next(ctx)
75 |
76 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure)
77 | }
78 |
--------------------------------------------------------------------------------
/middleware/forwarder/forwarder.go:
--------------------------------------------------------------------------------
1 | package forwarder
2 |
3 | import (
4 | "context"
5 | "net"
6 | "strings"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/log"
10 | "github.com/semihalev/sdns/config"
11 | "github.com/semihalev/sdns/dnsutil"
12 | "github.com/semihalev/sdns/middleware"
13 | )
14 |
15 | type server struct {
16 | Addr string
17 | Proto string
18 | }
19 |
20 | // Forwarder type
21 | type Forwarder struct {
22 | servers []*server
23 | dnssec bool
24 | }
25 |
26 | // New return forwarder
27 | func New(cfg *config.Config) *Forwarder {
28 | forwarderservers := []*server{}
29 | for _, s := range cfg.ForwarderServers {
30 | srv := &server{Proto: "udp"}
31 |
32 | if strings.HasPrefix(s, "tls://") {
33 | s = strings.TrimPrefix(s, "tls://")
34 | srv.Proto = "tcp-tls"
35 | }
36 |
37 | host, _, _ := net.SplitHostPort(s)
38 |
39 | if ip := net.ParseIP(host); ip != nil && ip.To4() != nil {
40 | srv.Addr = s
41 | forwarderservers = append(forwarderservers, srv)
42 | } else if ip != nil && ip.To16() != nil {
43 | srv.Addr = s
44 | forwarderservers = append(forwarderservers, srv)
45 | } else {
46 | log.Error("Forwarder server is not correct. Check your config.", "server", s)
47 | }
48 | }
49 |
50 | return &Forwarder{servers: forwarderservers, dnssec: cfg.DNSSEC == "on"}
51 | }
52 |
53 | // Name return middleware name
54 | func (f *Forwarder) Name() string { return name }
55 |
56 | // ServeDNS implements the Handle interface.
57 | func (f *Forwarder) ServeDNS(ctx context.Context, ch *middleware.Chain) {
58 | w, req := ch.Writer, ch.Request
59 |
60 | if len(req.Question) == 0 || len(f.servers) == 0 {
61 | ch.CancelWithRcode(dns.RcodeServerFailure, true)
62 | return
63 | }
64 |
65 | if !req.CheckingDisabled {
66 | req.CheckingDisabled = !f.dnssec
67 | }
68 |
69 | for _, server := range f.servers {
70 | resp, err := dnsutil.Exchange(ctx, req, server.Addr, server.Proto)
71 | if err != nil {
72 | log.Info("forwarder query failed", "query", formatQuestion(req.Question[0]), "error", err.Error())
73 | continue
74 | }
75 |
76 | resp.Id = req.Id
77 | if !f.dnssec {
78 | resp.CheckingDisabled = false
79 | }
80 |
81 | _ = w.WriteMsg(resp)
82 | return
83 | }
84 |
85 | ch.CancelWithRcode(dns.RcodeServerFailure, true)
86 | }
87 |
88 | func formatQuestion(q dns.Question) string {
89 | return strings.ToLower(q.Name) + " " + dns.ClassToString[q.Qclass] + " " + dns.TypeToString[q.Qtype]
90 | }
91 |
92 | const name = "forwarder"
93 |
--------------------------------------------------------------------------------
/middleware/forwarder/forwarder_test.go:
--------------------------------------------------------------------------------
1 | package forwarder
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/log"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/semihalev/sdns/middleware"
11 | "github.com/semihalev/sdns/mock"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func Test_Forwarder(t *testing.T) {
16 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
17 |
18 | cfg := new(config.Config)
19 | cfg.ForwarderServers = []string{"[::255]:53", "8.8.8.8:53", "1", "tls://8.8.8.8:853"}
20 |
21 | middleware.Register("forwarder", func(cfg *config.Config) middleware.Handler { return New(cfg) })
22 | middleware.Setup(cfg)
23 |
24 | f := middleware.Get("forwarder").(*Forwarder)
25 | assert.Equal(t, "forwarder", f.Name())
26 |
27 | ch := middleware.NewChain([]middleware.Handler{f})
28 |
29 | ctx := context.Background()
30 |
31 | req := new(dns.Msg)
32 | req.SetQuestion("example.com.", dns.TypeA)
33 | req.RecursionDesired = false
34 |
35 | mw := mock.NewWriter("udp", "127.0.0.1:0")
36 | ch.Writer = mw
37 | ch.Request = req
38 |
39 | ch.Reset(mw, req)
40 | ch.Next(ctx)
41 |
42 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
43 |
44 | req.RecursionDesired = true
45 |
46 | ch.Reset(mw, req)
47 | ch.Next(ctx)
48 |
49 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess)
50 |
51 | f.servers = []*server{}
52 |
53 | ch.Reset(mw, req)
54 | ch.Next(ctx)
55 |
56 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure)
57 |
58 | srv := &server{Addr: "[::255]:53", Proto: "udp"}
59 | f.servers = []*server{srv}
60 |
61 | ch.Reset(mw, req)
62 | ch.Next(ctx)
63 |
64 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure)
65 |
66 | srv = &server{Addr: "8.8.8.8:853", Proto: "tcp-tls"}
67 | f.servers = []*server{srv}
68 |
69 | ch.Reset(mw, req)
70 | ch.Next(ctx)
71 |
72 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess)
73 | }
74 |
--------------------------------------------------------------------------------
/middleware/loop/loop.go:
--------------------------------------------------------------------------------
1 | package loop
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/miekg/dns"
7 | "github.com/semihalev/log"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | )
11 |
12 | // Loop dummy type
13 | type Loop struct{}
14 |
15 | type ctxKey string
16 |
17 | // New return loop
18 | func New(cfg *config.Config) *Loop {
19 | return &Loop{}
20 | }
21 |
22 | // Name return middleware name
23 | func (l *Loop) Name() string { return name }
24 |
25 | // ServeDNS implements the Handle interface.
26 | func (l *Loop) ServeDNS(ctx context.Context, ch *middleware.Chain) {
27 | req := ch.Request
28 |
29 | if len(req.Question) == 0 {
30 | ch.Cancel()
31 | return
32 | }
33 |
34 | qKey := req.Question[0].Name + ":" + dns.TypeToString[req.Question[0].Qtype]
35 |
36 | key := ctxKey("loopcheck:" + qKey)
37 |
38 | if v := ctx.Value(key); v != nil {
39 | count := v.(uint64)
40 |
41 | if count > 10 {
42 | log.Warn("Loop detected", "query", qKey)
43 | ch.CancelWithRcode(dns.RcodeServerFailure, false)
44 | return
45 | }
46 |
47 | count++
48 | ctx = context.WithValue(ctx, key, count)
49 | } else {
50 | ctx = context.WithValue(ctx, key, uint64(1))
51 | }
52 |
53 | ch.Next(ctx)
54 | }
55 |
56 | const name = "loop"
57 |
--------------------------------------------------------------------------------
/middleware/loop/loop_test.go:
--------------------------------------------------------------------------------
1 | package loop
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/log"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/semihalev/sdns/middleware"
11 | "github.com/semihalev/sdns/mock"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func Test_loop(t *testing.T) {
16 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
17 |
18 | middleware.Register("loop", func(cfg *config.Config) middleware.Handler { return New(cfg) })
19 | middleware.Setup(&config.Config{})
20 |
21 | l := middleware.Get("loop").(*Loop)
22 |
23 | assert.Equal(t, "loop", l.Name())
24 |
25 | ch := middleware.NewChain([]middleware.Handler{l, l, l, l, l, l, l, l, l, l, l})
26 |
27 | ctx := context.Background()
28 |
29 | mw := mock.NewWriter("udp", "127.0.0.1:0")
30 | req := new(dns.Msg)
31 | req.SetQuestion("example.com.", dns.TypeA)
32 | ch.Reset(mw, req)
33 | l.ServeDNS(ctx, ch)
34 | assert.Equal(t, dns.RcodeServerFailure, mw.Msg().Rcode)
35 |
36 | mw = mock.NewWriter("udp", "127.0.0.1:0")
37 | req.Question = []dns.Question{}
38 | ch.Reset(mw, req)
39 | l.ServeDNS(ctx, ch)
40 | assert.Nil(t, mw.Msg())
41 | }
42 |
--------------------------------------------------------------------------------
/middleware/metrics/metrics.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "context"
5 | "sync"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/prometheus/client_golang/prometheus"
9 | "github.com/semihalev/sdns/config"
10 | "github.com/semihalev/sdns/middleware"
11 | )
12 |
13 | // Metrics type
14 | type Metrics struct {
15 | queries *prometheus.CounterVec
16 | }
17 |
18 | // New return new metrics
19 | func New(cfg *config.Config) *Metrics {
20 | m := &Metrics{
21 | queries: prometheus.NewCounterVec(
22 | prometheus.CounterOpts{
23 | Name: "dns_queries_total",
24 | Help: "How many DNS queries processed",
25 | },
26 | []string{"qtype", "rcode"},
27 | ),
28 | }
29 | _ = prometheus.Register(m.queries)
30 |
31 | return m
32 | }
33 |
34 | // Name return middleware name
35 | func (m *Metrics) Name() string { return name }
36 |
37 | // ServeDNS implements the Handle interface.
38 | func (m *Metrics) ServeDNS(ctx context.Context, ch *middleware.Chain) {
39 | ch.Next(ctx)
40 |
41 | if !ch.Writer.Written() {
42 | return
43 | }
44 |
45 | labels := AcquireLabels()
46 | defer ReleaseLabels(labels)
47 |
48 | labels["qtype"] = dns.TypeToString[ch.Request.Question[0].Qtype]
49 | labels["rcode"] = dns.RcodeToString[ch.Writer.Rcode()]
50 |
51 | m.queries.With(labels).Inc()
52 | }
53 |
54 | var labelsPool sync.Pool
55 |
56 | // AcquireLabels returns a label from pool
57 | func AcquireLabels() prometheus.Labels {
58 | x := labelsPool.Get()
59 | if x == nil {
60 | return prometheus.Labels{"qtype": "", "rcode": ""}
61 | }
62 |
63 | return x.(prometheus.Labels)
64 | }
65 |
66 | // ReleaseLabels returns labels to pool
67 | func ReleaseLabels(labels prometheus.Labels) {
68 | labelsPool.Put(labels)
69 | }
70 |
71 | const name = "metrics"
72 |
--------------------------------------------------------------------------------
/middleware/metrics/metrics_test.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | "github.com/semihalev/sdns/mock"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func Test_Metrics(t *testing.T) {
15 | middleware.Register("metrics", func(cfg *config.Config) middleware.Handler { return New(cfg) })
16 | middleware.Setup(&config.Config{})
17 |
18 | m := middleware.Get("metrics").(*Metrics)
19 |
20 | assert.Equal(t, "metrics", m.Name())
21 |
22 | ch := middleware.NewChain([]middleware.Handler{})
23 |
24 | mw := mock.NewWriter("udp", "127.0.0.1:0")
25 | req := new(dns.Msg)
26 | req.SetQuestion("test.com.", dns.TypeA)
27 |
28 | ch.Reset(mw, req)
29 |
30 | m.ServeDNS(context.Background(), ch)
31 | assert.Equal(t, dns.RcodeServerFailure, mw.Rcode())
32 |
33 | _ = ch.Writer.WriteMsg(req)
34 | assert.Equal(t, true, ch.Writer.Written())
35 |
36 | m.ServeDNS(context.Background(), ch)
37 | assert.Equal(t, dns.RcodeSuccess, mw.Rcode())
38 | }
39 |
--------------------------------------------------------------------------------
/middleware/middleware.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "plugin"
7 | "sync"
8 |
9 | "github.com/semihalev/log"
10 | "github.com/semihalev/sdns/config"
11 | )
12 |
13 | // Handler interface
14 | type Handler interface {
15 | Name() string
16 | ServeDNS(context.Context, *Chain)
17 | }
18 |
19 | type middleware struct {
20 | mu sync.RWMutex
21 |
22 | cfg *config.Config
23 | handlers []handler
24 | }
25 |
26 | type handler struct {
27 | name string
28 | new func(*config.Config) Handler
29 | }
30 |
31 | var (
32 | chainHandlers []Handler
33 | setup bool
34 | m middleware
35 | )
36 |
37 | // Register a middleware
38 | func Register(name string, new func(*config.Config) Handler) {
39 | RegisterAt(name, new, len(m.handlers))
40 | }
41 |
42 | // RegisterAt a middleware at an index
43 | func RegisterAt(name string, new func(*config.Config) Handler, idx int) {
44 | log.Debug("Register middleware", "name", name, "index", idx)
45 |
46 | m.mu.Lock()
47 | defer m.mu.Unlock()
48 |
49 | m.handlers = append(m.handlers, handler{})
50 | copy(m.handlers[idx+1:], m.handlers[idx:])
51 | m.handlers[idx] = handler{name: name, new: new}
52 | }
53 |
54 | // RegisterBefore a middleware before another middleware
55 | func RegisterBefore(name string, new func(*config.Config) Handler, before string) {
56 | log.Debug("Register middleware", "name", name, "before", before)
57 |
58 | m.mu.Lock()
59 | defer m.mu.Unlock()
60 |
61 | for idx, v := range m.handlers {
62 | if v.name == before {
63 | m.handlers = append(m.handlers, handler{})
64 | copy(m.handlers[idx+1:], m.handlers[idx:])
65 | m.handlers[idx] = handler{name: name, new: new}
66 | return
67 | }
68 | }
69 |
70 | panic(fmt.Sprintf("Middleware %s not found", before))
71 | }
72 |
73 | // Setup handlers
74 | func Setup(cfg *config.Config) {
75 | if setup {
76 | panic("middleware setup already done")
77 | }
78 |
79 | m.cfg = cfg
80 |
81 | LoadExternalPlugins()
82 |
83 | m.mu.Lock()
84 | defer m.mu.Unlock()
85 |
86 | for i, handler := range m.handlers {
87 | h := handler.new(m.cfg)
88 | chainHandlers = append(chainHandlers, h)
89 |
90 | log.Debug("Middleware registered", "name", h.Name(), "index", i)
91 | }
92 |
93 | setup = true
94 | }
95 |
96 | // LoadExternalPlugins load external plugins into chain
97 | func LoadExternalPlugins() {
98 | for name, pcfg := range m.cfg.Plugins {
99 | pl, err := plugin.Open(pcfg.Path)
100 | if err != nil {
101 | log.Error("Plugin open failed", "plugin", name, "error", err.Error())
102 | continue
103 | }
104 |
105 | newFuncSym, err := pl.Lookup("New")
106 | if err != nil {
107 | log.Error("Plugin new function lookup failed", "plugin", name, "error", err.Error())
108 | continue
109 | }
110 |
111 | newFn, ok := newFuncSym.(func(cfg *config.Config) Handler)
112 |
113 | if !ok {
114 | log.Error("Plugin new function assert failed", "plugin", name)
115 | continue
116 | }
117 |
118 | RegisterBefore(name, newFn, "cache")
119 | log.Info("Plugin successfully loaded", "plugin", name, "path", pcfg.Path)
120 | }
121 | }
122 |
123 | // Handlers return registered handlers
124 | func Handlers() []Handler {
125 | handlers := chainHandlers
126 | return handlers
127 | }
128 |
129 | // List return names of handlers
130 | func List() (list []string) {
131 | m.mu.RLock()
132 | defer m.mu.RUnlock()
133 |
134 | for _, handler := range m.handlers {
135 | list = append(list, handler.name)
136 | }
137 |
138 | return list
139 | }
140 |
141 | // Get return a handler by name
142 | func Get(name string) Handler {
143 | if !setup {
144 | return nil
145 | }
146 |
147 | m.mu.RLock()
148 | defer m.mu.RUnlock()
149 |
150 | for i, handler := range m.handlers {
151 | if handler.name == name {
152 | if len(chainHandlers) <= i {
153 | return nil
154 | }
155 | return chainHandlers[i]
156 | }
157 | }
158 |
159 | return nil
160 | }
161 |
162 | // Ready return true if middleware setup was done
163 | func Ready() bool {
164 | m.mu.RLock()
165 | defer m.mu.RUnlock()
166 |
167 | return setup
168 | }
169 |
--------------------------------------------------------------------------------
/middleware/middleware_test.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/semihalev/sdns/config"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | type dummy struct{}
12 |
13 | func (d *dummy) ServeDNS(ctx context.Context, ch *Chain) { ch.Next(ctx) }
14 | func (d *dummy) Name() string { return "dummy" }
15 |
16 | func Test_Middleware(t *testing.T) {
17 | Register("dummy", func(*config.Config) Handler {
18 | return &dummy{}
19 | })
20 |
21 | cfg := &config.Config{}
22 |
23 | d := Get("dummy")
24 | assert.Nil(t, d)
25 |
26 | assert.NotPanics(t, func() {
27 | Setup(cfg)
28 | })
29 |
30 | assert.True(t, Ready())
31 |
32 | assert.Panics(t, func() {
33 | Setup(cfg)
34 | })
35 | assert.True(t, len(List()) == 1)
36 | assert.True(t, len(Handlers()) == 1)
37 |
38 | d = Get("dummy")
39 | assert.NotNil(t, d)
40 |
41 | d = Get("none")
42 | assert.Nil(t, d)
43 |
44 | chainHandlers = []Handler{}
45 | d = Get("dummy")
46 | assert.Nil(t, d)
47 | }
48 |
49 | func Test_RegisterAt(t *testing.T) {
50 | m.handlers = []handler{}
51 |
52 | Register("dummy", func(*config.Config) Handler {
53 | return &dummy{}
54 | })
55 | RegisterAt("dummy2", func(*config.Config) Handler {
56 | return &dummy{}
57 | }, 0)
58 |
59 | assert.True(t, len(m.handlers) == 2)
60 | assert.True(t, m.handlers[0].name == "dummy2")
61 | assert.True(t, m.handlers[1].name == "dummy")
62 |
63 | RegisterBefore("dummy3", func(*config.Config) Handler {
64 | return &dummy{}
65 | }, "dummy")
66 | assert.True(t, len(m.handlers) == 3)
67 | assert.True(t, m.handlers[0].name == "dummy2")
68 | assert.True(t, m.handlers[1].name == "dummy3")
69 | assert.True(t, m.handlers[2].name == "dummy")
70 |
71 | assert.Panics(t, func() {
72 | RegisterAt("dummy4", func(*config.Config) Handler {
73 | return &dummy{}
74 | }, 4)
75 | })
76 | assert.Panics(t, func() {
77 | RegisterAt("dummy5", func(*config.Config) Handler {
78 | return &dummy{}
79 | }, -1)
80 | })
81 | assert.Panics(t, func() {
82 | RegisterBefore("dummy6", func(*config.Config) Handler {
83 | return &dummy{}
84 | }, "noexist")
85 | })
86 |
87 | m.handlers = []handler{}
88 | }
89 |
--------------------------------------------------------------------------------
/middleware/ratelimit/ratelimit.go:
--------------------------------------------------------------------------------
1 | package ratelimit
2 |
3 | import (
4 | "context"
5 | "net"
6 | "sync/atomic"
7 | "time"
8 |
9 | "github.com/cespare/xxhash/v2"
10 | "github.com/miekg/dns"
11 | "github.com/semihalev/sdns/cache"
12 | "github.com/semihalev/sdns/config"
13 | "github.com/semihalev/sdns/dnsutil"
14 | "github.com/semihalev/sdns/middleware"
15 | "golang.org/x/time/rate"
16 | )
17 |
18 | type limiter struct {
19 | rl *rate.Limiter
20 | cookie atomic.Value
21 | }
22 |
23 | // RateLimit type
24 | type RateLimit struct {
25 | cookiesecret string
26 |
27 | cache *cache.Cache
28 | rate int
29 | }
30 |
31 | // New return accesslist
32 | func New(cfg *config.Config) *RateLimit {
33 | r := &RateLimit{
34 | cache: cache.New(cacheSize),
35 | cookiesecret: cfg.CookieSecret,
36 | rate: cfg.ClientRateLimit,
37 | }
38 |
39 | return r
40 | }
41 |
42 | // Name return middleware name
43 | func (r *RateLimit) Name() string { return name }
44 |
45 | // ServeDNS implements the Handle interface.
46 | func (r *RateLimit) ServeDNS(ctx context.Context, ch *middleware.Chain) {
47 | w, req := ch.Writer, ch.Request
48 |
49 | if w.Internal() {
50 | ch.Next(ctx)
51 | return
52 | }
53 |
54 | if r.rate == 0 {
55 | ch.Next(ctx)
56 | return
57 | }
58 |
59 | if w.RemoteIP() == nil {
60 | ch.Next(ctx)
61 | return
62 | } else if w.RemoteIP().IsLoopback() {
63 | ch.Next(ctx)
64 | return
65 | }
66 |
67 | var cachedcookie, clientcookie, servercookie string
68 |
69 | l := r.getLimiter(w.RemoteIP())
70 | cachedcookie = l.cookie.Load().(string)
71 |
72 | if opt := req.IsEdns0(); opt != nil {
73 | for _, option := range opt.Option {
74 | switch option.Option() {
75 | case dns.EDNS0COOKIE:
76 | if len(option.String()) >= cookieSize {
77 | clientcookie = option.String()[:cookieSize]
78 | servercookie = dnsutil.GenerateServerCookie(r.cookiesecret, w.RemoteIP().String(), clientcookie)
79 |
80 | if cachedcookie == "" || cachedcookie == option.String() {
81 | ch.Next(ctx)
82 |
83 | l.cookie.Store(servercookie)
84 | return
85 | }
86 |
87 | if w.Proto() == "udp" {
88 | if !l.rl.Allow() {
89 | ch.Cancel()
90 | return
91 | }
92 |
93 | l.cookie.Store(servercookie)
94 | option.(*dns.EDNS0_COOKIE).Cookie = servercookie
95 |
96 | ch.CancelWithRcode(dns.RcodeBadCookie, false)
97 |
98 | return
99 | }
100 | }
101 | }
102 | }
103 | }
104 |
105 | if !l.rl.Allow() {
106 | //no reply to client
107 | ch.Cancel()
108 | return
109 | }
110 |
111 | ch.Next(ctx)
112 |
113 | if servercookie != "" {
114 | l.cookie.Store(servercookie)
115 | }
116 | }
117 |
118 | func (r *RateLimit) getLimiter(remoteip net.IP) *limiter {
119 | xxhash := xxhash.New()
120 | _, _ = xxhash.Write(remoteip)
121 | key := xxhash.Sum64()
122 |
123 | if v, ok := r.cache.Get(key); ok {
124 | return v.(*limiter)
125 | }
126 |
127 | limit := rate.Limit(0)
128 | if r.rate > 0 {
129 | limit = rate.Every(time.Minute / time.Duration(r.rate))
130 | }
131 |
132 | rl := rate.NewLimiter(limit, r.rate)
133 |
134 | l := &limiter{rl: rl}
135 | l.cookie.Store("")
136 |
137 | r.cache.Add(key, l)
138 |
139 | return l
140 | }
141 |
142 | const (
143 | cacheSize = 256 * 100
144 | cookieSize = 16
145 |
146 | name = "ratelimit"
147 | )
148 |
--------------------------------------------------------------------------------
/middleware/ratelimit/ratelimit_test.go:
--------------------------------------------------------------------------------
1 | package ratelimit
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/config"
9 | "github.com/semihalev/sdns/middleware"
10 | "github.com/semihalev/sdns/mock"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func Test_RateLimit(t *testing.T) {
15 | cfg := new(config.Config)
16 | cfg.ClientRateLimit = 1
17 |
18 | middleware.Register("ratelimit", func(cfg *config.Config) middleware.Handler { return New(cfg) })
19 | middleware.Setup(cfg)
20 |
21 | r := middleware.Get("ratelimit").(*RateLimit)
22 |
23 | assert.Equal(t, "ratelimit", r.Name())
24 |
25 | ch := middleware.NewChain([]middleware.Handler{})
26 |
27 | req := new(dns.Msg)
28 | req.SetQuestion("example.com.", dns.TypeA)
29 | req.SetEdns0(4096, true)
30 |
31 | opt := req.IsEdns0()
32 | opt.Option = append(opt.Option, &dns.EDNS0_COOKIE{
33 | Code: dns.EDNS0COOKIE,
34 | Cookie: "testtesttesttest",
35 | })
36 |
37 | mw := mock.NewWriter("udp", "")
38 | ch.Reset(mw, req)
39 | r.ServeDNS(context.Background(), ch)
40 |
41 | mw = mock.NewWriter("udp", "10.0.0.1:0")
42 | ch.Reset(mw, req)
43 | r.ServeDNS(context.Background(), ch)
44 | r.ServeDNS(context.Background(), ch)
45 | if assert.True(t, mw.Written()) {
46 | assert.Equal(t, dns.RcodeBadCookie, mw.Rcode())
47 | }
48 |
49 | opt.Option = nil
50 | opt.Option = append(opt.Option, &dns.EDNS0_COOKIE{
51 | Code: dns.EDNS0COOKIE,
52 | Cookie: "testtesttesttest",
53 | })
54 |
55 | mw = mock.NewWriter("udp", "10.0.0.1:0")
56 | ch.Reset(mw, req)
57 | r.ServeDNS(context.Background(), ch)
58 | assert.False(t, mw.Written())
59 |
60 | mw = mock.NewWriter("tcp", "10.0.0.2:0")
61 | ch.Reset(mw, req)
62 | r.ServeDNS(context.Background(), ch)
63 | r.ServeDNS(context.Background(), ch)
64 | assert.False(t, mw.Written())
65 |
66 | opt.Option = nil
67 | mw = mock.NewWriter("udp", "10.0.0.1:0")
68 | ch.Reset(mw, req)
69 | r.ServeDNS(context.Background(), ch)
70 | r.ServeDNS(context.Background(), ch)
71 | assert.False(t, mw.Written())
72 |
73 | mw = mock.NewWriter("udp", "0.0.0.0:0")
74 | ch.Reset(mw, req)
75 | r.ServeDNS(context.Background(), ch)
76 |
77 | mw = mock.NewWriter("udp", "127.0.0.1:0")
78 | ch.Reset(mw, req)
79 | r.ServeDNS(context.Background(), ch)
80 |
81 | r.rate = 0
82 |
83 | r.ServeDNS(context.Background(), ch)
84 | }
85 |
--------------------------------------------------------------------------------
/middleware/recovery/recovery.go:
--------------------------------------------------------------------------------
1 | package recovery
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 | "runtime/debug"
8 |
9 | "github.com/miekg/dns"
10 | "github.com/semihalev/log"
11 | "github.com/semihalev/sdns/config"
12 | "github.com/semihalev/sdns/middleware"
13 | )
14 |
15 | // Recovery dummy type
16 | type Recovery struct{}
17 |
18 | // New return recovery
19 | func New(cfg *config.Config) *Recovery {
20 | return &Recovery{}
21 | }
22 |
23 | // Name return middleware name
24 | func (r *Recovery) Name() string { return name }
25 |
26 | // ServeDNS implements the Handle interface.
27 | func (r *Recovery) ServeDNS(ctx context.Context, ch *middleware.Chain) {
28 | defer func() {
29 | if r := recover(); r != nil {
30 | ch.CancelWithRcode(dns.RcodeServerFailure, false)
31 |
32 | log.Error("Recovered in ServeDNS", "recover", r)
33 |
34 | _, _ = os.Stderr.WriteString(fmt.Sprintf("panic: %v\n\n", r))
35 | debug.PrintStack()
36 | }
37 | }()
38 |
39 | ch.Next(ctx)
40 | }
41 |
42 | const name = "recovery"
43 |
--------------------------------------------------------------------------------
/middleware/recovery/recovery_test.go:
--------------------------------------------------------------------------------
1 | package recovery
2 |
3 | import (
4 | "context"
5 | "os"
6 | "testing"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/log"
10 | "github.com/semihalev/sdns/config"
11 | "github.com/semihalev/sdns/middleware"
12 | "github.com/semihalev/sdns/mock"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func Test_Recovery(t *testing.T) {
17 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
18 |
19 | stderr := os.Stderr
20 | os.Stderr, _ = os.Open(os.DevNull)
21 |
22 | middleware.Register("recovery", func(cfg *config.Config) middleware.Handler { return New(cfg) })
23 | middleware.Setup(&config.Config{})
24 |
25 | r := middleware.Get("recovery").(*Recovery)
26 |
27 | assert.Equal(t, "recovery", r.Name())
28 |
29 | ch := middleware.NewChain([]middleware.Handler{r, nil})
30 |
31 | mw := mock.NewWriter("udp", "127.0.0.1:0")
32 | req := new(dns.Msg)
33 | req.SetQuestion("test.com.", dns.TypeA)
34 |
35 | ch.Reset(mw, req)
36 |
37 | r.ServeDNS(context.Background(), ch)
38 |
39 | assert.Equal(t, dns.RcodeServerFailure, mw.Msg().Rcode)
40 |
41 | ch = middleware.NewChain([]middleware.Handler{r})
42 | ch.Reset(mw, req)
43 | r.ServeDNS(context.Background(), ch)
44 |
45 | os.Stderr = stderr
46 | }
47 |
--------------------------------------------------------------------------------
/middleware/resolver/auto_trust_anchor_test.go:
--------------------------------------------------------------------------------
1 | package resolver_test
2 |
3 | import (
4 | "context"
5 | "crypto/x509"
6 | "encoding/base64"
7 | "net"
8 | "os"
9 | "path/filepath"
10 | "testing"
11 | "time"
12 |
13 | "github.com/miekg/dns"
14 | "github.com/semihalev/log"
15 | "github.com/semihalev/sdns/authcache"
16 | "github.com/semihalev/sdns/config"
17 | "github.com/semihalev/sdns/middleware/resolver"
18 | "github.com/stretchr/testify/assert"
19 | )
20 |
21 | var (
22 | privateKey = "MIICXAIBAAKBgQCWIKiOFx/LqVppaUSfW2a9hEnfUS+Qb752/fL3odiGQxCrxcmcEXvn+APSN3ipRetdLdHeB7FSZQ4eIhBtgKjBuFlqQj8pnZOWhV16w80HFjYg/ea9nhG8IziTzK/lsSIk2cDTe1k9kD5WUaLRijLJEEy7gLkOOFmt3Ho675dw2QIDAQABAoGANYEsMX/iUBZqZ5kh4N2Vb0O/hDyOBB8fNY9qUYE4BxnNzjpukRXWICVPT1N/yGxn5syWuFfrhZ8IegrP6gbpnZ1ViYRONOkrfoGOm/U71IL8mlr/NCrxAd/ifB4Db1HOEvlewwQ3G8+HE7HBAjYpup+w4Yw/Du2Cw6dtlJ9MmWUCQQDDHj0MCxWus38EHBwueVjmKq/gE1oZpuLCGmjVZIXlA7yw4IlQU27Y+XlEdVJIMRIUQ1K7Zdw/KFU+aKfBmYy3AkEAxPijYGdWPSZDZn/9tPMfdBtipz1wXHREHLBNOPOcgP4TjVoqBY8Yl6ZUwwjTA8C2JZ1ZU4oSUyLuecbH/N1+7wJAVb2w99zbH1UTSLwNikKa1TIG7UGzwzf5x3ARh0xQJk4ZGeThkmHHgSNHrdScXsrpdewLq/vb6AkSRIV6ynFuSwJAHLk5cfR/0fkDeS4O/FU77/2SXFsMSJ8304suJ7D20KS8iy9r01Wzu2GpGKvvwatXpJKWlSUcWP1OE3oWbdyLBwJBAMCcuKf9EIw9Wgkt9KKhJXKpSqUr1xN+3WZf4bmQl4nT1mITMPcmnQla/JYepnspYrt06L16Ed8vf4u8AbEW68I="
23 | )
24 |
25 | type dummyHandler struct {
26 | DNSKEY []dns.RR
27 | }
28 |
29 | func (d *dummyHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
30 | resp := new(dns.Msg)
31 | resp.Question = r.Question
32 | resp.Authoritative = true
33 | resp.SetRcode(r, dns.RcodeSuccess)
34 |
35 | if r.Question[0].Qtype == dns.TypeDNSKEY {
36 | resp.Answer = append(resp.Answer, d.DNSKEY...)
37 | }
38 |
39 | _ = w.WriteMsg(resp)
40 | }
41 |
42 | func makeRootKeysConfig() *config.Config {
43 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
44 |
45 | cfg := new(config.Config)
46 | cfg.RootServers = []string{"127.0.0.1:44302"}
47 | cfg.RootKeys = []string{
48 | ". 86400 IN DNSKEY 257 3 8 AwEAAZYgqI4XH8upWmlpRJ9bZr2ESd9RL5Bvvnb98veh2IZDEKvFyZwRe+f4A9I3eKlF610t0d4HsVJlDh4iEG2AqMG4WWpCPymdk5aFXXrDzQcWNiD95r2eEbwjOJPMr+WxIiTZwNN7WT2QPlZRotGKMskQTLuAuQ44Wa3cejrvl3DZ",
49 | ". 86400 IN DNSKEY 257 3 8 AwEAAcjLi71oV55rThFidre9DgnEgJwOmPPg0XwWmFkz3uNoT3+SaT6hErHuJS2I8+vc4rZIoGaNdlOrsNBEqyfaikDniq6+PwdNFK8Adt8xBCh9YOZkexdb8i59MABbv1TtJ130O9L8OQ9MOfJfyLm9UknV4D5y8HDDOBcjkJ2U4DRx",
50 | ". 86400 IN DNSKEY 257 3 8 AwEAAc78m2ldR7iPdjdFZlGheNgdUclcSrPSx+E5s0XWiW6nBaDDawTICkwWI7m7Uzuva1myKkZKgidtwmmxS1P6/xjsRCn1xEjPXvim5Xzr0gjsp16KFQsR8IALGu6dxJYn7WHB+UdT3yiV0x6FwAVb/ilYsOMmn3S/oaaTx4Oh7OEL",
51 | ". 86400 IN DNSKEY 385 3 8 AwEAAcZMCRaHx5n6GWwAWFrwhnfheNafQfaoBcn1IfwmQ8RD/0V+WAeFU+CVH+zinmqamv/V+zF2FF03WZjuq5HpOtqFVAKsQmC3Wb6DttjwzgNs0Iywgy/Ae8QZp03WApVmzcr4hDvxXeP5ABMwf8vR7gF/JtArb2Mlnekh/7sWs/wr",
52 | }
53 | cfg.Maxdepth = 30
54 | cfg.Expire = 600
55 | cfg.CacheSize = 1024
56 | cfg.Timeout.Duration = 2 * time.Second
57 | cfg.Directory = filepath.Join(os.TempDir(), "sdns_temp_autota")
58 | cfg.IPv6Access = false
59 |
60 | _ = os.Mkdir(cfg.Directory, 0777)
61 |
62 | return cfg
63 | }
64 |
65 | func runTestServer() {
66 | buf, _ := base64.StdEncoding.DecodeString(privateKey)
67 | privKey, _ := x509.ParsePKCS1PrivateKey(buf)
68 |
69 | dnskey1 := &dns.DNSKEY{
70 | Hdr: dns.RR_Header{
71 | Name: ".",
72 | Rrtype: dns.TypeDNSKEY,
73 | Class: dns.ClassINET,
74 | Ttl: 86400,
75 | },
76 | Algorithm: 8,
77 | Flags: 257,
78 | Protocol: 3,
79 | PublicKey: "AwEAAZYgqI4XH8upWmlpRJ9bZr2ESd9RL5Bvvnb98veh2IZDEKvFyZwRe+f4A9I3eKlF610t0d4HsVJlDh4iEG2AqMG4WWpCPymdk5aFXXrDzQcWNiD95r2eEbwjOJPMr+WxIiTZwNN7WT2QPlZRotGKMskQTLuAuQ44Wa3cejrvl3DZ",
80 | }
81 |
82 | dnskey2 := &dns.DNSKEY{
83 | Hdr: dns.RR_Header{
84 | Name: ".",
85 | Rrtype: dns.TypeDNSKEY,
86 | Class: dns.ClassINET,
87 | Ttl: 86400,
88 | },
89 | Algorithm: 8,
90 | Flags: 385,
91 | Protocol: 3,
92 | PublicKey: "AwEAAcjLi71oV55rThFidre9DgnEgJwOmPPg0XwWmFkz3uNoT3+SaT6hErHuJS2I8+vc4rZIoGaNdlOrsNBEqyfaikDniq6+PwdNFK8Adt8xBCh9YOZkexdb8i59MABbv1TtJ130O9L8OQ9MOfJfyLm9UknV4D5y8HDDOBcjkJ2U4DRx",
93 | }
94 |
95 | rrsigdk := &dns.RRSIG{
96 | Hdr: dns.RR_Header{
97 | Name: ".",
98 | Rrtype: dns.TypeRRSIG,
99 | Class: dns.ClassINET,
100 | Ttl: 86400,
101 | },
102 | TypeCovered: dns.TypeDNSKEY,
103 | Algorithm: 8,
104 | SignerName: ".",
105 | KeyTag: dnskey1.KeyTag(),
106 | Inception: uint32(time.Now().UTC().Unix()),
107 | Expiration: uint32(time.Now().UTC().Add(15 * 24 * time.Hour).Unix()),
108 | OrigTtl: 3600,
109 | }
110 |
111 | dnskey3 := &dns.DNSKEY{
112 | Hdr: dns.RR_Header{
113 | Name: ".",
114 | Rrtype: dns.TypeDNSKEY,
115 | Class: dns.ClassINET,
116 | Ttl: 86400,
117 | },
118 | Algorithm: 8,
119 | Flags: 257,
120 | Protocol: 3,
121 | }
122 |
123 | _, _ = dnskey3.Generate(1024)
124 |
125 | _ = rrsigdk.Sign(privKey, []dns.RR{dnskey1, dnskey2, dnskey3})
126 |
127 | /*if s, ok := privkey.(*rsa.PrivateKey); ok {
128 | buf := x509.MarshalPKCS1PrivateKey(s)
129 | fmt.Println(base64.StdEncoding.EncodeToString(buf))
130 | rrsig.Sign(s, []dns.RR{dnskey})
131 | }*/
132 |
133 | go func() {
134 | _ = dns.ListenAndServe("127.0.0.1:44302", "udp", &dummyHandler{DNSKEY: []dns.RR{dnskey1, dnskey2, dnskey3, rrsigdk}})
135 | }()
136 | }
137 |
138 | func Test_autota(t *testing.T) {
139 | runTestServer()
140 |
141 | cfg := makeRootKeysConfig()
142 |
143 | r := resolver.NewResolver(cfg)
144 |
145 | time.Sleep(time.Second)
146 |
147 | req := new(dns.Msg)
148 | req.SetQuestion(".", dns.TypeDNSKEY)
149 | req.SetEdns0(1400, true)
150 |
151 | rootservers := &authcache.AuthServers{}
152 | rootservers.Zone = "."
153 |
154 | for _, s := range cfg.RootServers {
155 | host, _, _ := net.SplitHostPort(s)
156 | if ip := net.ParseIP(host); ip != nil && ip.To4() != nil {
157 | rootservers.List = append(rootservers.List, authcache.NewAuthServer(s, authcache.IPv4))
158 | }
159 | }
160 |
161 | resp, err := r.Resolve(context.Background(), req, rootservers, true, 30, 0, false, nil)
162 |
163 | assert.True(t, resp.AuthenticatedData)
164 | assert.NoError(t, err)
165 | assert.Len(t, resp.Answer, 4)
166 |
167 | os.Remove(filepath.Join(cfg.Directory, "trust-anchor.db"))
168 | }
169 |
--------------------------------------------------------------------------------
/middleware/resolver/client.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | // Originally this Client from github.com/miekg/dns
4 | // Adapted for resolver package usage by Semih Alev.
5 |
6 | import (
7 | "encoding/binary"
8 | "errors"
9 | "io"
10 | "net"
11 | "sync"
12 | "time"
13 |
14 | "github.com/miekg/dns"
15 | )
16 |
17 | const (
18 | headerSize = 12
19 | )
20 |
21 | // A Conn represents a connection to a DNS server.
22 | type Conn struct {
23 | net.Conn // a net.Conn holding the connection
24 | UDPSize uint16 // minimum receive buffer for UDP messages
25 | }
26 |
27 | // Exchange performs a synchronous query
28 | func (co *Conn) Exchange(m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
29 |
30 | opt := m.IsEdns0()
31 | // If EDNS0 is used use that for size.
32 | if opt != nil && opt.UDPSize() >= dns.MinMsgSize {
33 | co.UDPSize = opt.UDPSize()
34 | }
35 |
36 | if opt == nil && co.UDPSize < dns.MinMsgSize {
37 | co.UDPSize = dns.MinMsgSize
38 | }
39 |
40 | t := time.Now()
41 |
42 | if err = co.WriteMsg(m); err != nil {
43 | return nil, 0, err
44 | }
45 |
46 | r, err = co.ReadMsg()
47 | if err == nil && r.Id != m.Id {
48 | err = dns.ErrId
49 | }
50 |
51 | rtt = time.Since(t)
52 |
53 | return r, rtt, err
54 | }
55 |
56 | // ReadMsg reads a message from the connection co.
57 | // If the received message contains a TSIG record the transaction signature
58 | // is verified. This method always tries to return the message, however if an
59 | // error is returned there are no guarantees that the returned message is a
60 | // valid representation of the packet read.
61 | func (co *Conn) ReadMsg() (*dns.Msg, error) {
62 | var (
63 | p []byte
64 | n int
65 | err error
66 | )
67 |
68 | if _, ok := co.Conn.(net.PacketConn); ok {
69 | p = AcquireBuf(co.UDPSize)
70 | n, err = co.Read(p)
71 | } else {
72 | var length uint16
73 | if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
74 | return nil, err
75 | }
76 |
77 | p = AcquireBuf(length)
78 | n, err = io.ReadFull(co.Conn, p)
79 | }
80 |
81 | if err != nil {
82 | return nil, err
83 | } else if n < headerSize {
84 | return nil, dns.ErrShortRead
85 | }
86 |
87 | defer ReleaseBuf(p)
88 |
89 | m := new(dns.Msg)
90 | if err := m.Unpack(p); err != nil {
91 | // If an error was returned, we still want to allow the user to use
92 | // the message, but naively they can just check err if they don't want
93 | // to use an erroneous message
94 | return m, err
95 | }
96 | return m, err
97 | }
98 |
99 | // Read implements the net.Conn read method.
100 | func (co *Conn) Read(p []byte) (n int, err error) {
101 | if co.Conn == nil {
102 | return 0, dns.ErrConnEmpty
103 | }
104 |
105 | if _, ok := co.Conn.(net.PacketConn); ok {
106 | // UDP connection
107 | return co.Conn.Read(p)
108 | }
109 |
110 | var length uint16
111 | if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
112 | return 0, err
113 | }
114 | if int(length) > len(p) {
115 | return 0, io.ErrShortBuffer
116 | }
117 |
118 | return io.ReadFull(co.Conn, p[:length])
119 | }
120 |
121 | // WriteMsg sends a message through the connection co.
122 | // If the message m contains a TSIG record the transaction
123 | // signature is calculated.
124 | func (co *Conn) WriteMsg(m *dns.Msg) (err error) {
125 | size := uint16(m.Len()) + 1
126 |
127 | out := AcquireBuf(size)
128 | defer ReleaseBuf(out)
129 |
130 | out, err = m.PackBuffer(out)
131 | if err != nil {
132 | return err
133 | }
134 | _, err = co.Write(out)
135 | return err
136 | }
137 |
138 | // Write implements the net.Conn Write method.
139 | func (co *Conn) Write(p []byte) (int, error) {
140 | if len(p) > dns.MaxMsgSize {
141 | return 0, errors.New("message too large")
142 | }
143 |
144 | if _, ok := co.Conn.(net.PacketConn); ok {
145 | return co.Conn.Write(p)
146 | }
147 |
148 | l := make([]byte, 2)
149 | binary.BigEndian.PutUint16(l, uint16(len(p)))
150 |
151 | n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
152 | return int(n), err
153 | }
154 |
155 | var bufferPool sync.Pool
156 |
157 | // AcquireBuf returns an buf from pool
158 | func AcquireBuf(size uint16) []byte {
159 | x := bufferPool.Get()
160 | if x == nil {
161 | return make([]byte, size)
162 | }
163 | buf := *(x.(*[]byte))
164 | if cap(buf) < int(size) {
165 | return make([]byte, size)
166 | }
167 | return buf[:size]
168 | }
169 |
170 | // ReleaseBuf returns buf to pool
171 | func ReleaseBuf(buf []byte) {
172 | bufferPool.Put(&buf)
173 | }
174 |
--------------------------------------------------------------------------------
/middleware/resolver/client_test.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "net"
5 | "testing"
6 | "time"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/sdns/dnsutil"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func Test_ClientTimeout(t *testing.T) {
14 | req := new(dns.Msg)
15 | req.SetQuestion(".", dns.TypeNS)
16 | req.SetEdns0(dnsutil.DefaultMsgSize, true)
17 |
18 | dialer := &net.Dialer{Deadline: time.Now().Add(2 * time.Second)}
19 | co := &Conn{}
20 |
21 | var err error
22 | co.Conn, err = dialer.Dial("udp4", "127.1.0.255:53")
23 | assert.NoError(t, err)
24 |
25 | err = co.SetDeadline(time.Now().Add(2 * time.Second))
26 | assert.NoError(t, err)
27 |
28 | _, _, err = co.Exchange(req)
29 | assert.Error(t, err)
30 | assert.NoError(t, co.Close())
31 | }
32 |
33 | func Test_Client(t *testing.T) {
34 | req := new(dns.Msg)
35 | req.SetQuestion(".", dns.TypeNS)
36 | req.SetEdns0(dnsutil.DefaultMsgSize, true)
37 |
38 | dialer := &net.Dialer{Deadline: time.Now().Add(2 * time.Second)}
39 | co := &Conn{}
40 |
41 | var err error
42 | co.Conn, err = dialer.Dial("udp4", "198.41.0.4:53")
43 | assert.NoError(t, err)
44 |
45 | err = co.SetDeadline(time.Now().Add(2 * time.Second))
46 | assert.NoError(t, err)
47 |
48 | r, _, err := co.Exchange(req)
49 | assert.NoError(t, err)
50 | assert.NotNil(t, r)
51 | }
52 |
--------------------------------------------------------------------------------
/middleware/resolver/handler.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "context"
5 | "os"
6 | "time"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/log"
10 | "github.com/semihalev/sdns/authcache"
11 | "github.com/semihalev/sdns/cache"
12 | "github.com/semihalev/sdns/config"
13 | "github.com/semihalev/sdns/dnsutil"
14 | "github.com/semihalev/sdns/middleware"
15 | )
16 |
17 | // DNSHandler type
18 | type DNSHandler struct {
19 | resolver *Resolver
20 | cfg *config.Config
21 | }
22 |
23 | type ctxKey string
24 |
25 | var debugns bool
26 |
27 | func init() {
28 | _, debugns = os.LookupEnv("SDNS_DEBUGNS")
29 | }
30 |
31 | // New returns a new Handler
32 | func New(cfg *config.Config) *DNSHandler {
33 | if cfg.Maxdepth == 0 {
34 | cfg.Maxdepth = 30
35 | }
36 |
37 | if cfg.QueryTimeout.Duration == 0 {
38 | cfg.QueryTimeout.Duration = 10 * time.Second
39 | }
40 |
41 | return &DNSHandler{
42 | resolver: NewResolver(cfg),
43 | cfg: cfg,
44 | }
45 | }
46 |
47 | // Name return middleware name
48 | func (h *DNSHandler) Name() string { return name }
49 |
50 | // ServeDNS implements the Handle interface.
51 | func (h *DNSHandler) ServeDNS(ctx context.Context, ch *middleware.Chain) {
52 | if len(h.cfg.ForwarderServers) > 0 {
53 | ch.Next(ctx)
54 | return
55 | }
56 |
57 | w, req := ch.Writer, ch.Request
58 |
59 | if v := ctx.Value(ctxKey("reqid")); v == nil {
60 | ctx = context.WithValue(ctx, ctxKey("reqid"), req.Id)
61 | }
62 | msg := h.handle(ctx, req)
63 |
64 | _ = w.WriteMsg(msg)
65 | }
66 |
67 | func (h *DNSHandler) handle(ctx context.Context, req *dns.Msg) *dns.Msg {
68 | q := req.Question[0]
69 |
70 | do := false
71 | opt := req.IsEdns0()
72 | if opt != nil {
73 | do = opt.Do()
74 | }
75 |
76 | if q.Qtype == dns.TypeANY {
77 | return dnsutil.SetRcode(req, dns.RcodeNotImplemented, do)
78 | }
79 |
80 | // debug ns stats
81 | if debugns && q.Qclass == dns.ClassCHAOS && q.Qtype == dns.TypeHINFO {
82 | return h.nsStats(req)
83 | }
84 |
85 | // check purge query
86 | if q.Qclass == dns.ClassCHAOS && q.Qtype == dns.TypeNULL {
87 | if qname, qtype, ok := dnsutil.ParsePurgeQuestion(req); ok {
88 | if qtype == dns.TypeNS {
89 | h.purge(qname)
90 | }
91 |
92 | resp := dnsutil.SetRcode(req, dns.RcodeSuccess, do)
93 | txt, _ := dns.NewRR(q.Name + ` 20 IN TXT "cache purged"`)
94 |
95 | resp.Extra = append(resp.Extra, txt)
96 |
97 | return resp
98 | }
99 | }
100 |
101 | if q.Name != rootzone && !req.RecursionDesired {
102 | return dnsutil.SetRcode(req, dns.RcodeServerFailure, do)
103 | }
104 |
105 | // we shouldn't send rd and ad flag to aa servers
106 | req.RecursionDesired = false
107 | req.AuthenticatedData = false
108 |
109 | if !req.CheckingDisabled {
110 | req.CheckingDisabled = !h.resolver.dnssec
111 | }
112 |
113 | ctx, cancel := context.WithDeadline(ctx, time.Now().Add(h.cfg.QueryTimeout.Duration))
114 | defer cancel()
115 |
116 | depth := h.cfg.Maxdepth
117 | resp, err := h.resolver.Resolve(ctx, req, h.resolver.rootservers, true, depth, 0, false, nil, q.Name == rootzone)
118 |
119 | if !h.resolver.dnssec {
120 | req.CheckingDisabled = false
121 | if resp != nil {
122 | resp.CheckingDisabled = false
123 | }
124 | }
125 |
126 | if err != nil {
127 | log.Info("Resolve query failed", "query", formatQuestion(q), "error", err.Error())
128 |
129 | return dnsutil.SetRcode(req, dns.RcodeServerFailure, do)
130 | }
131 |
132 | if resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeNotZone {
133 | return dnsutil.SetRcode(req, dns.RcodeServerFailure, do)
134 | }
135 |
136 | return resp
137 | }
138 |
139 | func (h *DNSHandler) nsStats(req *dns.Msg) *dns.Msg {
140 | q := req.Question[0]
141 |
142 | msg := new(dns.Msg)
143 | msg.SetReply(req)
144 |
145 | msg.Authoritative = false
146 | msg.RecursionAvailable = true
147 |
148 | servers := h.resolver.rootservers
149 | ttl := uint32(0)
150 | name := rootzone
151 |
152 | if q.Name != rootzone {
153 | nsKey := cache.Hash(dns.Question{Name: q.Name, Qtype: dns.TypeNS, Qclass: dns.ClassINET}, msg.CheckingDisabled)
154 | ns, err := h.resolver.ncache.Get(nsKey)
155 | if err != nil {
156 | nsKey = cache.Hash(dns.Question{Name: q.Name, Qtype: dns.TypeNS, Qclass: dns.ClassINET}, !msg.CheckingDisabled)
157 | ns, err := h.resolver.ncache.Get(nsKey)
158 | if err == nil {
159 | servers = ns.Servers
160 | name = q.Name
161 | }
162 | } else {
163 | servers = ns.Servers
164 | name = q.Name
165 | }
166 | }
167 |
168 | var serversList []*authcache.AuthServer
169 |
170 | servers.RLock()
171 | serversList = append(serversList, servers.List...)
172 | servers.RUnlock()
173 |
174 | authcache.Sort(serversList, 1)
175 |
176 | rrHeader := dns.RR_Header{
177 | Name: name,
178 | Rrtype: dns.TypeHINFO,
179 | Class: dns.ClassCHAOS,
180 | Ttl: ttl,
181 | }
182 |
183 | for _, server := range serversList {
184 | hinfo := &dns.HINFO{Hdr: rrHeader, Cpu: "Host", Os: server.String()}
185 | msg.Ns = append(msg.Ns, hinfo)
186 | }
187 |
188 | return msg
189 | }
190 |
191 | func (h *DNSHandler) purge(qname string) {
192 | q := dns.Question{Name: qname, Qtype: dns.TypeNS, Qclass: dns.ClassINET}
193 |
194 | key := cache.Hash(q, false)
195 | h.resolver.ncache.Remove(key)
196 |
197 | key = cache.Hash(q, true)
198 | h.resolver.ncache.Remove(key)
199 | }
200 |
201 | const name = "resolver"
202 |
--------------------------------------------------------------------------------
/middleware/resolver/handler_test.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "context"
5 | "encoding/base64"
6 | "os"
7 | "path/filepath"
8 | "testing"
9 | "time"
10 |
11 | "github.com/miekg/dns"
12 | "github.com/semihalev/log"
13 | "github.com/semihalev/sdns/config"
14 | "github.com/semihalev/sdns/dnsutil"
15 | "github.com/semihalev/sdns/middleware"
16 | "github.com/semihalev/sdns/middleware/edns"
17 | "github.com/semihalev/sdns/mock"
18 | "github.com/stretchr/testify/assert"
19 | )
20 |
21 | func makeTestConfig() *config.Config {
22 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
23 |
24 | cfg := new(config.Config)
25 | cfg.RootServers = []string{"192.5.5.241:53"}
26 | cfg.Root6Servers = []string{"[2001:500:2f::f]:53"}
27 | cfg.RootKeys = []string{
28 | ". 172800 IN DNSKEY 257 3 8 AwEAAaz/tAm8yTn4Mfeh5eyI96WSVexTBAvkMgJzkKTOiW1vkIbzxeF3+/4RgWOq7HrxRixHlFlExOLAJr5emLvN7SWXgnLh4+B5xQlNVz8Og8kvArMtNROxVQuCaSnIDdD5LKyWbRd2n9WGe2R8PzgCmr3EgVLrjyBxWezF0jLHwVN8efS3rCj/EWgvIWgb9tarpVUDK/b58Da+sqqls3eNbuv7pr+eoZG+SrDK6nWeL3c6H5Apxz7LjVc1uTIdsIXxuOLYA4/ilBmSVIzuDWfdRUfhHdY6+cn8HFRm+2hM8AnXGXws9555KrUB5qihylGa8subX2Nn6UwNR1AkUTV74bU=",
29 | }
30 | cfg.Maxdepth = 30
31 | cfg.Expire = 600
32 | cfg.CacheSize = 1024
33 | cfg.Timeout.Duration = 2 * time.Second
34 | cfg.Directory = filepath.Join(os.TempDir(), "sdns_temp")
35 | cfg.IPv6Access = true
36 | cfg.DNSSEC = "on"
37 |
38 | if !middleware.Ready() {
39 | middleware.Register("edns", func(cfg *config.Config) middleware.Handler { return edns.New(cfg) })
40 | middleware.Register("resolver", func(cfg *config.Config) middleware.Handler { return New(cfg) })
41 | middleware.Setup(cfg)
42 | }
43 |
44 | return cfg
45 | }
46 |
47 | func Test_handler(t *testing.T) {
48 | makeTestConfig()
49 |
50 | ctx := context.Background()
51 |
52 | handler := middleware.Get("resolver").(*DNSHandler)
53 |
54 | time.Sleep(2 * time.Second)
55 |
56 | assert.Equal(t, "resolver", handler.Name())
57 |
58 | m := new(dns.Msg)
59 | m.SetQuestion("www.apple.com.", dns.TypeA)
60 | r := handler.handle(ctx, m)
61 | assert.Equal(t, len(r.Answer) > 0, true)
62 |
63 | m = new(dns.Msg)
64 | // test again for caches
65 | m.SetQuestion("www.apple.com.", dns.TypeA)
66 | r = handler.handle(ctx, m)
67 | assert.Equal(t, len(r.Answer) > 0, true)
68 |
69 | m = new(dns.Msg)
70 | m.SetEdns0(dnsutil.DefaultMsgSize, true)
71 | m.SetQuestion("dnssec-failed.org.", dns.TypeA)
72 | r = handler.handle(ctx, m)
73 | assert.Equal(t, len(r.Answer) == 0, true)
74 |
75 | m = new(dns.Msg)
76 | m.SetQuestion("example.com.", dns.TypeA)
77 | r = handler.handle(ctx, m)
78 | assert.Equal(t, len(r.Answer) > 0, true)
79 |
80 | m = new(dns.Msg)
81 | m.SetQuestion(".", dns.TypeANY)
82 | r = handler.handle(ctx, m)
83 | assert.Equal(t, r.Rcode, dns.RcodeNotImplemented)
84 |
85 | m = new(dns.Msg)
86 | m.SetQuestion(".", dns.TypeNS)
87 | m.RecursionDesired = false
88 | r = handler.handle(ctx, m)
89 | assert.NotEqual(t, r.Rcode, dns.RcodeServerFailure)
90 | }
91 |
92 | func Test_HandlerHINFO(t *testing.T) {
93 | ctx := context.Background()
94 | cfg := makeTestConfig()
95 | handler := New(cfg)
96 |
97 | m := new(dns.Msg)
98 | m.SetQuestion(".", dns.TypeHINFO)
99 | m.Question[0].Qclass = dns.ClassCHAOS
100 |
101 | debugns = true
102 | resp := handler.handle(ctx, m)
103 |
104 | assert.Equal(t, true, len(resp.Ns) > 0)
105 | }
106 |
107 | func Test_HandlerPurge(t *testing.T) {
108 | ctx := context.Background()
109 | cfg := makeTestConfig()
110 | handler := New(cfg)
111 |
112 | bqname := base64.StdEncoding.EncodeToString([]byte("NS:."))
113 |
114 | req := new(dns.Msg)
115 | req.SetQuestion(dns.Fqdn(bqname), dns.TypeNULL)
116 | req.Question[0].Qclass = dns.ClassCHAOS
117 |
118 | resp := handler.handle(ctx, req)
119 |
120 | assert.Equal(t, true, len(resp.Extra) > 0)
121 | }
122 |
123 | func Test_HandlerServe(t *testing.T) {
124 | cfg := makeTestConfig()
125 | h := New(cfg)
126 |
127 | ch := middleware.NewChain([]middleware.Handler{})
128 | mw := mock.NewWriter("tcp", "127.0.0.1:0")
129 |
130 | req := new(dns.Msg)
131 | req.SetQuestion(".", dns.TypeNS)
132 |
133 | ch.Reset(mw, req)
134 |
135 | h.ServeDNS(context.Background(), ch)
136 | assert.Equal(t, true, ch.Writer.Written())
137 | }
138 |
--------------------------------------------------------------------------------
/middleware/resolver/nsec3.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "errors"
5 |
6 | "github.com/miekg/dns"
7 | )
8 |
9 | var (
10 | errNSECTypeExists = errors.New("NSEC3 record shows question type exists")
11 | errNSECMissingCoverage = errors.New("NSEC3 record missing for expected encloser")
12 | errNSECBadDelegation = errors.New("DS or SOA bit set in NSEC3 type map")
13 | errNSECNSMissing = errors.New("NS bit not set in NSEC3 type map")
14 | errNSECOptOut = errors.New("Opt-Out bit not set for NSEC3 record covering next closer")
15 | )
16 |
17 | func typesSet(set []uint16, types ...uint16) bool {
18 | tm := make(map[uint16]struct{}, len(types))
19 | for _, t := range types {
20 | tm[t] = struct{}{}
21 | }
22 | for _, t := range set {
23 | if _, ok := tm[t]; ok {
24 | return true
25 | }
26 | }
27 | return false
28 | }
29 |
30 | func findClosestEncloser(name string, nsec []dns.RR) (string, string) {
31 | labelIndices := dns.Split(name)
32 | nc := name
33 | for i := 0; i < len(labelIndices); i++ {
34 | z := name[labelIndices[i]:]
35 | _, err := findMatching(z, nsec)
36 | if err != nil {
37 | continue
38 | }
39 | if i != 0 {
40 | nc = name[labelIndices[i-1]:]
41 | }
42 | return z, nc
43 | }
44 | return "", ""
45 | }
46 |
47 | func findMatching(name string, nsec []dns.RR) ([]uint16, error) {
48 | for _, rr := range nsec {
49 | n := rr.(*dns.NSEC3)
50 | if n.Match(name) {
51 | return n.TypeBitMap, nil
52 | }
53 | }
54 | return nil, errNSECMissingCoverage
55 | }
56 |
57 | func findCoverer(name string, nsec []dns.RR) ([]uint16, bool, error) {
58 | for _, rr := range nsec {
59 | n := rr.(*dns.NSEC3)
60 | if n.Cover(name) {
61 | return n.TypeBitMap, (n.Flags & 1) == 1, nil
62 | }
63 | }
64 | return nil, false, errNSECMissingCoverage
65 | }
66 |
67 | func verifyNameError(msg *dns.Msg, nsec []dns.RR) error {
68 | q := msg.Question[0]
69 | qname := q.Name
70 |
71 | if dname := getDnameTarget(msg); dname != "" {
72 | qname = dname
73 | }
74 |
75 | ce, _ := findClosestEncloser(qname, nsec)
76 | if ce == "" {
77 | return errNSECMissingCoverage
78 | }
79 | _, _, err := findCoverer("*."+ce, nsec)
80 | if err != nil {
81 | return err
82 | }
83 | return nil
84 | }
85 |
86 | func verifyNODATA(msg *dns.Msg, nsec []dns.RR) error {
87 | q := msg.Question[0]
88 | qname := q.Name
89 |
90 | if dname := getDnameTarget(msg); dname != "" {
91 | qname = dname
92 | }
93 |
94 | types, err := findMatching(qname, nsec)
95 | if err != nil {
96 | if q.Qtype != dns.TypeDS {
97 | return err
98 | }
99 |
100 | ce, nc := findClosestEncloser(qname, nsec)
101 | if ce == "" {
102 | return errNSECMissingCoverage
103 | }
104 | _, _, err := findCoverer(nc, nsec)
105 | if err != nil {
106 | return err
107 | }
108 | return nil
109 | }
110 |
111 | if typesSet(types, q.Qtype, dns.TypeCNAME) {
112 | return errNSECTypeExists
113 | }
114 |
115 | return nil
116 | }
117 |
118 | func verifyDelegation(delegation string, nsec []dns.RR) error {
119 | types, err := findMatching(delegation, nsec)
120 | if err != nil {
121 | ce, nc := findClosestEncloser(delegation, nsec)
122 | if ce == "" {
123 | return errNSECMissingCoverage
124 | }
125 | _, optOut, err := findCoverer(nc, nsec)
126 | if err != nil {
127 | return err
128 | }
129 | if !optOut {
130 | return errNSECOptOut
131 | }
132 | return nil
133 | }
134 | if !typesSet(types, dns.TypeNS) {
135 | return errNSECNSMissing
136 | }
137 | if typesSet(types, dns.TypeDS, dns.TypeSOA) {
138 | return errNSECBadDelegation
139 | }
140 | return nil
141 | }
142 |
--------------------------------------------------------------------------------
/middleware/resolver/singleinflight.go:
--------------------------------------------------------------------------------
1 | // Copyright 2013 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // Adapted for resolver package usage by Semih Alev.
6 |
7 | package resolver
8 |
9 | import (
10 | "sync"
11 |
12 | "github.com/miekg/dns"
13 | )
14 |
15 | // call is an in-flight or completed singleflight.Do call
16 | type call struct {
17 | wg sync.WaitGroup
18 | val *dns.Msg
19 | err error
20 | dups int
21 | }
22 |
23 | // singleflight represents a class of work and forms a namespace in
24 | // which units of work can be executed with duplicate suppression.
25 | type singleflight struct {
26 | sync.RWMutex // protects m
27 | m map[uint64]*call // lazily initialized
28 | }
29 |
30 | // Do executes and returns the results of the given function, making
31 | // sure that only one execution is in-flight for a given key at a
32 | // time. If a duplicate comes in, the duplicate caller waits for the
33 | // original to complete and receives the same results.
34 | // The return value shared indicates whether v was given to multiple callers.
35 | func (g *singleflight) Do(key uint64, fn func() (*dns.Msg, error)) (v *dns.Msg, shared bool, err error) {
36 | g.Lock()
37 | if g.m == nil {
38 | g.m = make(map[uint64]*call)
39 | }
40 | if c, ok := g.m[key]; ok {
41 | c.dups++
42 | g.Unlock()
43 | c.wg.Wait()
44 | return c.val, true, c.err
45 | }
46 | c := new(call)
47 | c.wg.Add(1)
48 | g.m[key] = c
49 | g.Unlock()
50 |
51 | c.val, c.err = fn()
52 | c.wg.Done()
53 |
54 | g.Lock()
55 | delete(g.m, key)
56 | g.Unlock()
57 |
58 | return c.val, c.dups > 0, c.err
59 | }
60 |
--------------------------------------------------------------------------------
/middleware/resolver/utils_test.go:
--------------------------------------------------------------------------------
1 | package resolver
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "testing"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func Test_shuffleStr(t *testing.T) {
13 |
14 | vals := make([]string, 1)
15 |
16 | rr := shuffleStr(vals)
17 |
18 | if len(rr) != 1 {
19 | t.Error("invalid array length")
20 | }
21 | }
22 |
23 | func Test_searchAddr(t *testing.T) {
24 | testDomain := "google.com."
25 |
26 | m := new(dns.Msg)
27 | m.SetQuestion(testDomain, dns.TypeA)
28 |
29 | m.SetEdns0(512, true)
30 | assert.Equal(t, isDO(m), true)
31 |
32 | m.Extra = []dns.RR{}
33 | assert.Equal(t, isDO(m), false)
34 |
35 | a1 := &dns.A{
36 | Hdr: dns.RR_Header{
37 | Name: testDomain,
38 | Rrtype: dns.TypeA,
39 | Class: dns.ClassINET,
40 | Ttl: 10,
41 | },
42 | A: net.ParseIP("127.0.0.1")}
43 |
44 | m.Answer = append(m.Answer, a1)
45 |
46 | a2 := &dns.A{
47 | Hdr: dns.RR_Header{
48 | Name: testDomain,
49 | Rrtype: dns.TypeA,
50 | Class: dns.ClassINET,
51 | Ttl: 10,
52 | },
53 | A: net.ParseIP("192.0.2.1")}
54 |
55 | m.Answer = append(m.Answer, a2)
56 |
57 | addrs, found := searchAddrs(m)
58 | assert.Equal(t, len(addrs), 1)
59 | assert.NotEqual(t, addrs[0], "127.0.0.1")
60 | assert.Equal(t, addrs[0], "192.0.2.1")
61 | assert.Equal(t, found, true)
62 | }
63 |
64 | func Test_extractRRSet(t *testing.T) {
65 | var rr []dns.RR
66 | for i := 0; i < 3; i++ {
67 | a, _ := dns.NewRR(fmt.Sprintf("test.com. 5 IN A 127.0.0.%d", i))
68 | rr = append(rr, a)
69 | }
70 |
71 | rre := extractRRSet(rr, "test.com.", dns.TypeA)
72 | assert.Len(t, rre, 3)
73 | }
74 |
--------------------------------------------------------------------------------
/middleware/response_writer.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "errors"
5 | "net"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/semihalev/sdns/mock"
9 | "github.com/semihalev/sdns/server/doq"
10 | )
11 |
12 | // ResponseWriter implement of dns.ResponseWriter
13 | type ResponseWriter interface {
14 | dns.ResponseWriter
15 | Msg() *dns.Msg
16 | Rcode() int
17 | Written() bool
18 | Reset(dns.ResponseWriter)
19 | Proto() string
20 | RemoteIP() net.IP
21 | Internal() bool
22 | }
23 |
24 | type responseWriter struct {
25 | dns.ResponseWriter
26 | msg *dns.Msg
27 | size int
28 | rcode int
29 | proto string
30 | remoteip net.IP
31 | internal bool
32 | }
33 |
34 | var _ ResponseWriter = &responseWriter{}
35 | var errAlreadyWritten = errors.New("msg already written")
36 |
37 | func (w *responseWriter) Msg() *dns.Msg {
38 | return w.msg
39 | }
40 |
41 | func (w *responseWriter) Reset(rw dns.ResponseWriter) {
42 | w.ResponseWriter = rw
43 | w.size = -1
44 | w.msg = nil
45 | w.rcode = dns.RcodeSuccess
46 |
47 | switch rw.LocalAddr().(type) {
48 | case (*net.TCPAddr):
49 | w.proto = "tcp"
50 | w.remoteip = w.RemoteAddr().(*net.TCPAddr).IP
51 | case (*net.UDPAddr):
52 | w.proto = "udp"
53 | w.remoteip = w.RemoteAddr().(*net.UDPAddr).IP
54 | }
55 |
56 | switch writer := rw.(type) {
57 | case (*mock.Writer):
58 | w.proto = writer.Proto()
59 | case (*doq.ResponseWriter):
60 | w.proto = "doq"
61 | }
62 |
63 | w.internal = w.RemoteAddr().String() == "127.0.0.255:0"
64 | }
65 |
66 | func (w *responseWriter) RemoteIP() net.IP {
67 | return w.remoteip
68 | }
69 |
70 | func (w *responseWriter) Proto() string {
71 | return w.proto
72 | }
73 |
74 | func (w *responseWriter) Rcode() int {
75 | return w.rcode
76 | }
77 |
78 | func (w *responseWriter) Written() bool {
79 | return w.size != -1
80 | }
81 |
82 | func (w *responseWriter) Write(m []byte) (int, error) {
83 | if w.Written() {
84 | return 0, errAlreadyWritten
85 | }
86 |
87 | w.msg = new(dns.Msg)
88 | err := w.msg.Unpack(m)
89 | if err != nil {
90 | return 0, err
91 | }
92 | w.rcode = w.msg.Rcode
93 |
94 | n, err := w.ResponseWriter.Write(m)
95 | w.size = n
96 | return n, err
97 | }
98 |
99 | func (w *responseWriter) WriteMsg(m *dns.Msg) error {
100 | if w.Written() {
101 | return errAlreadyWritten
102 | }
103 |
104 | w.msg = m
105 | w.rcode = m.Rcode
106 | w.size = 0
107 |
108 | return w.ResponseWriter.WriteMsg(m)
109 | }
110 |
111 | // Internal func
112 | func (w *responseWriter) Internal() bool { return w.internal }
113 |
--------------------------------------------------------------------------------
/mock/writer.go:
--------------------------------------------------------------------------------
1 | package mock
2 |
3 | import (
4 | "net"
5 |
6 | "github.com/miekg/dns"
7 | )
8 |
9 | // Writer type
10 | type Writer struct {
11 | msg *dns.Msg
12 |
13 | proto string
14 |
15 | localAddr net.Addr
16 | remoteAddr net.Addr
17 |
18 | remoteip net.IP
19 |
20 | internal bool
21 | }
22 |
23 | // NewWriter return writer
24 | func NewWriter(proto, addr string) *Writer {
25 | w := &Writer{}
26 |
27 | switch proto {
28 | case "tcp", "doh":
29 | w.localAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 53}
30 | w.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr)
31 | w.remoteip = w.remoteAddr.(*net.TCPAddr).IP
32 |
33 | case "udp":
34 | w.localAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 53}
35 | w.remoteAddr, _ = net.ResolveUDPAddr("udp", addr)
36 | w.remoteip = w.remoteAddr.(*net.UDPAddr).IP
37 | }
38 |
39 | w.internal = w.RemoteAddr().String() == "127.0.0.255:0"
40 |
41 | w.proto = proto
42 |
43 | return w
44 | }
45 |
46 | // Rcode return message response code
47 | func (w *Writer) Rcode() int {
48 | if w.msg == nil {
49 | return dns.RcodeServerFailure
50 | }
51 |
52 | return w.msg.Rcode
53 | }
54 |
55 | // Msg return current dns message
56 | func (w *Writer) Msg() *dns.Msg {
57 | return w.msg
58 | }
59 |
60 | // Write func
61 | func (w *Writer) Write(b []byte) (int, error) {
62 | w.msg = new(dns.Msg)
63 | err := w.msg.Unpack(b)
64 | if err != nil {
65 | return 0, err
66 | }
67 | return len(b), nil
68 | }
69 |
70 | // WriteMsg func
71 | func (w *Writer) WriteMsg(msg *dns.Msg) error {
72 | w.msg = msg
73 | return nil
74 | }
75 |
76 | // Written func
77 | func (w *Writer) Written() bool {
78 | return w.msg != nil
79 | }
80 |
81 | // RemoteIP func
82 | func (w *Writer) RemoteIP() net.IP { return w.remoteip }
83 |
84 | // Proto func
85 | func (w *Writer) Proto() string { return w.proto }
86 |
87 | // Reset func
88 | func (w *Writer) Reset(rw dns.ResponseWriter) {}
89 |
90 | // Close func
91 | func (w *Writer) Close() error { return nil }
92 |
93 | // Hijack func
94 | func (w *Writer) Hijack() {}
95 |
96 | // LocalAddr func
97 | func (w *Writer) LocalAddr() net.Addr { return w.localAddr }
98 |
99 | // RemoteAddr func
100 | func (w *Writer) RemoteAddr() net.Addr { return w.remoteAddr }
101 |
102 | // TsigStatus func
103 | func (w *Writer) TsigStatus() error { return nil }
104 |
105 | // TsigTimersOnly func
106 | func (w *Writer) TsigTimersOnly(ok bool) {}
107 |
108 | // Internal func
109 | func (w *Writer) Internal() bool { return w.internal }
110 |
--------------------------------------------------------------------------------
/mock/writer_test.go:
--------------------------------------------------------------------------------
1 | package mock
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/miekg/dns"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func Test_Writer(t *testing.T) {
11 | mw := NewWriter("udp", "127.0.0.1:0")
12 |
13 | m := new(dns.Msg)
14 | m.SetQuestion("example.com.", dns.TypeA)
15 | err := mw.WriteMsg(m)
16 |
17 | assert.NoError(t, err)
18 | assert.True(t, mw.Written())
19 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess)
20 | assert.NotNil(t, mw.Msg())
21 | assert.Equal(t, mw.LocalAddr().String(), "127.0.0.1:53")
22 | assert.Equal(t, mw.RemoteAddr().String(), "127.0.0.1:0")
23 | assert.Nil(t, mw.Close())
24 | assert.Nil(t, mw.TsigStatus())
25 |
26 | mw = NewWriter("tcp", "127.0.0.255:0")
27 | assert.False(t, mw.Written())
28 | assert.Equal(t, mw.Rcode(), dns.RcodeServerFailure)
29 |
30 | assert.Equal(t, "tcp", mw.Proto())
31 | assert.Equal(t, "127.0.0.255", mw.RemoteIP().String())
32 |
33 | _, err = mw.Write([]byte{})
34 | assert.Error(t, err)
35 |
36 | data, err := m.Pack()
37 | assert.NoError(t, err)
38 | _, err = mw.Write(data)
39 | assert.NoError(t, err)
40 | assert.True(t, mw.Written())
41 | assert.Equal(t, mw.Rcode(), dns.RcodeSuccess)
42 | assert.True(t, mw.Internal())
43 | }
44 |
--------------------------------------------------------------------------------
/response/typify.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package response
5 |
6 | import (
7 | "fmt"
8 | "time"
9 |
10 | "github.com/miekg/dns"
11 | )
12 |
13 | // Type is the type of the message.
14 | type Type int
15 |
16 | const (
17 | // NoError indicates a positive reply
18 | NoError Type = iota
19 | // NameError is a NXDOMAIN in header, SOA in auth.
20 | NameError
21 | // NoData indicates name found, but not the type: NOERROR in header, SOA in auth.
22 | NoData
23 | // Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked).
24 | Delegation
25 | // Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR.
26 | Meta
27 | // Update is an dynamic update message.
28 | Update
29 | // OtherError indicates any other error.
30 | OtherError
31 | // Expired if sigs expired: don't cache these
32 | Expired
33 | // NoCache indicates a no cache reply
34 | NoCache
35 | )
36 |
37 | var toString = map[Type]string{
38 | NoError: "NOERROR",
39 | NameError: "NXDOMAIN",
40 | NoData: "NODATA",
41 | Delegation: "DELEGATION",
42 | Meta: "META",
43 | Update: "UPDATE",
44 | OtherError: "OTHERERROR",
45 | Expired: "EXPIRED",
46 | NoCache: "NOCACHE",
47 | }
48 |
49 | func (t Type) String() string { return toString[t] }
50 |
51 | // TypeFromString returns the type from the string s. If not type matches
52 | // the OtherError type and an error are returned.
53 | func TypeFromString(s string) (Type, error) {
54 | for t, str := range toString {
55 | if s == str {
56 | return t, nil
57 | }
58 | }
59 | return NoError, fmt.Errorf("invalid Type: %s", s)
60 | }
61 |
62 | // Typify classifies a message, it returns the Type.
63 | func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) {
64 | if m == nil {
65 | return OtherError, nil
66 | }
67 | opt := m.IsEdns0()
68 | do := false
69 | if opt != nil {
70 | do = opt.Do()
71 | }
72 |
73 | if m.Opcode == dns.OpcodeUpdate {
74 | return Update, opt
75 | }
76 |
77 | // Check transfer and update first
78 | if m.Opcode == dns.OpcodeNotify {
79 | return Meta, opt
80 | }
81 |
82 | if len(m.Question) > 0 {
83 | if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR {
84 | return Meta, opt
85 | }
86 | }
87 |
88 | // If our message contains any expired sigs and we care about that, we should return expired
89 | if do {
90 | if expired := typifyExpired(m, t); expired {
91 | return Expired, opt
92 | }
93 | }
94 |
95 | if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
96 | return NoError, opt
97 | }
98 |
99 | soa := false
100 | ns := 0
101 | for _, r := range m.Ns {
102 | if r.Header().Rrtype == dns.TypeSOA {
103 | soa = true
104 | continue
105 | }
106 | if r.Header().Rrtype == dns.TypeNS {
107 | ns++
108 | }
109 | }
110 |
111 | if !soa && len(m.Question) > 0 {
112 | if len(m.Answer) == 0 && m.Question[0].Qtype == dns.TypeDNSKEY {
113 | return NoCache, opt
114 | }
115 | }
116 |
117 | if soa && m.Rcode == dns.RcodeSuccess {
118 | return NoData, opt
119 | }
120 | if soa && m.Rcode == dns.RcodeNameError {
121 | return NameError, opt
122 | }
123 |
124 | if ns > 0 && m.Rcode == dns.RcodeSuccess {
125 | return Delegation, opt
126 | }
127 |
128 | if m.Rcode == dns.RcodeSuccess {
129 | return NoError, opt
130 | }
131 |
132 | return OtherError, opt
133 | }
134 |
135 | func typifyExpired(m *dns.Msg, t time.Time) bool {
136 | if expired := typifyExpiredRRSIG(m.Answer, t); expired {
137 | return true
138 | }
139 | if expired := typifyExpiredRRSIG(m.Ns, t); expired {
140 | return true
141 | }
142 | if expired := typifyExpiredRRSIG(m.Extra, t); expired {
143 | return true
144 | }
145 | return false
146 | }
147 |
148 | func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool {
149 | for _, r := range rrs {
150 | if r.Header().Rrtype != dns.TypeRRSIG {
151 | continue
152 | }
153 | ok := r.(*dns.RRSIG).ValidityPeriod(t)
154 | if !ok {
155 | return true
156 | }
157 | }
158 | return false
159 | }
160 |
--------------------------------------------------------------------------------
/response/typify_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2016-2020 The CoreDNS authors and contributors
2 | // Adapted for SDNS usage by Semih Alev.
3 |
4 | package response
5 |
6 | import (
7 | "testing"
8 | "time"
9 |
10 | "github.com/miekg/dns"
11 | )
12 |
13 | func makeRR(data string) dns.RR {
14 | r, _ := dns.NewRR(data)
15 |
16 | return r
17 | }
18 |
19 | func TestTypifyNilMsg(t *testing.T) {
20 | var m *dns.Msg
21 |
22 | ty, _ := Typify(m, time.Now().UTC())
23 | if ty != OtherError {
24 | t.Errorf("Message wrongly typified, expected OtherError, got %s", ty)
25 | }
26 |
27 | ty, _ = TypeFromString("")
28 | if ty != NoError {
29 | t.Errorf("Message wrongly typified, expected NoError, got %s", ty)
30 | }
31 |
32 | ty, _ = TypeFromString("NOERROR")
33 | if ty != NoError {
34 | t.Errorf("Message wrongly typified, expected NoError, got %s", ty)
35 | }
36 |
37 | ts := ty.String()
38 | if ts != "NOERROR" {
39 | t.Errorf("Type to string wrong, expected NOERROR, got %s", ty)
40 | }
41 | }
42 |
43 | func TestTypify(t *testing.T) {
44 | m := new(dns.Msg)
45 | m.SetQuestion("miek.nl.", dns.TypeA)
46 |
47 | utc := time.Now().UTC()
48 |
49 | mt, _ := Typify(m, utc)
50 | if mt != NoError {
51 | t.Errorf("Message is wrongly typified, expected NoError, got %s", mt)
52 | }
53 |
54 | m.Opcode = dns.OpcodeUpdate
55 | mt, _ = Typify(m, utc)
56 | if mt != Update {
57 | t.Errorf("Message is wrongly typified, expected Update, got %s", mt)
58 | }
59 |
60 | m.Opcode = dns.OpcodeNotify
61 | mt, _ = Typify(m, utc)
62 | if mt != Meta {
63 | t.Errorf("Message is wrongly typified, expected Meta, got %s", mt)
64 | }
65 |
66 | m.Opcode = dns.OpcodeQuery
67 | m.SetQuestion("miek.nl.", dns.TypeAXFR)
68 | mt, _ = Typify(m, utc)
69 | if mt != Meta {
70 | t.Errorf("Message is wrongly typified, expected Meta, got %s", mt)
71 | }
72 |
73 | m.SetQuestion("miek.nl.", dns.TypeA)
74 | m.Ns = append(m.Ns, makeRR("nl. 3599 IN SOA ns1.dns.nl. hostmaster.domain-registry.nl. 2018111539 3600 600 2419200 600"))
75 | mt, _ = Typify(m, utc)
76 | if mt != NoData {
77 | t.Errorf("Message is wrongly typified, expected NoData, got %s", mt)
78 | }
79 |
80 | m.Rcode = dns.RcodeNameError
81 | mt, _ = Typify(m, utc)
82 | if mt != NameError {
83 | t.Errorf("Message is wrongly typified, expected NameError, got %s", mt)
84 | }
85 |
86 | m.Rcode = dns.RcodeServerFailure
87 | mt, _ = Typify(m, utc)
88 | if mt != OtherError {
89 | t.Errorf("Message is wrongly typified, expected OtherError, got %s", mt)
90 | }
91 |
92 | m.SetEdns0(4096, true)
93 |
94 | m.Rcode = dns.RcodeSuccess
95 | m.Answer = append(m.Answer, makeRR("miek.nl. 3600 IN A 127.0.0.1"))
96 | mt, _ = Typify(m, utc)
97 | if mt != NoError {
98 | t.Errorf("Message is wrongly typified, expected NoError, got %s", mt)
99 | }
100 |
101 | m.Extra = append(m.Extra,
102 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
103 | )
104 | mt, _ = Typify(m, utc)
105 | if mt != Expired {
106 | t.Errorf("Message is wrongly typified, expected Expired, got %s", mt)
107 | }
108 |
109 | m.Answer = append(m.Answer,
110 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
111 | )
112 | mt, _ = Typify(m, utc)
113 | if mt != Expired {
114 | t.Errorf("Message is wrongly typified, expected Expired, got %s", mt)
115 | }
116 | }
117 |
118 | func TestTypifyDelegation(t *testing.T) {
119 | m := delegationMsg()
120 | mt, _ := Typify(m, time.Now().UTC())
121 | if mt != Delegation {
122 | t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt)
123 | }
124 | }
125 |
126 | func TestTypifyRRSIG(t *testing.T) {
127 | utc := time.Now().UTC()
128 |
129 | m := delegationMsgRRSIGOK()
130 | if mt, _ := Typify(m, utc); mt != Delegation {
131 | t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt)
132 | }
133 |
134 | // Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs.
135 | m = delegationMsgRRSIGFail()
136 | if mt, _ := Typify(m, utc); mt != Delegation {
137 | t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt)
138 | }
139 |
140 | m = delegationMsgRRSIGFail()
141 | m = addOpt(m)
142 | if mt, _ := Typify(m, utc); mt != Expired {
143 | t.Errorf("Message is wrongly typified, expected Expired, got %s", mt)
144 | }
145 | }
146 |
147 | func delegationMsg() *dns.Msg {
148 | return &dns.Msg{
149 | Ns: []dns.RR{
150 | makeRR("miek.nl. 3600 IN NS linode.atoom.net."),
151 | makeRR("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."),
152 | makeRR("miek.nl. 3600 IN NS omval.tednet.nl."),
153 | },
154 | Extra: []dns.RR{
155 | makeRR("omval.tednet.nl. 3600 IN A 185.49.141.42"),
156 | makeRR("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"),
157 | },
158 | }
159 | }
160 |
161 | func delegationMsgRRSIGOK() *dns.Msg {
162 | del := delegationMsg()
163 | del.Ns = append(del.Ns,
164 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
165 | )
166 | return del
167 | }
168 |
169 | func delegationMsgRRSIGFail() *dns.Msg {
170 | del := delegationMsg()
171 | del.Ns = append(del.Ns,
172 | makeRR("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="),
173 | )
174 | return del
175 | }
176 |
177 | func addOpt(m *dns.Msg) *dns.Msg {
178 | return m.SetEdns0(4096, true)
179 | }
180 |
--------------------------------------------------------------------------------
/sdns.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | //go:generate go run gen.go
4 |
5 | import (
6 | "context"
7 | "flag"
8 | "fmt"
9 | "os"
10 | "os/signal"
11 | "runtime/debug"
12 | "syscall"
13 | "time"
14 |
15 | "github.com/semihalev/log"
16 | "github.com/semihalev/sdns/api"
17 | "github.com/semihalev/sdns/config"
18 | "github.com/semihalev/sdns/middleware"
19 | "github.com/semihalev/sdns/server"
20 | )
21 |
22 | const version = "1.4.0"
23 |
24 | var (
25 | flagcfgpath string
26 | flagprintver bool
27 |
28 | cfg *config.Config
29 | )
30 |
31 | func init() {
32 | flag.StringVar(&flagcfgpath, "config", "sdns.conf", "Location of the config file. If it doesn't exist, a new one will be generated.")
33 | flag.StringVar(&flagcfgpath, "c", "sdns.conf", "Location of the config file. If it doesn't exist, a new one will be generated.")
34 |
35 | flag.BoolVar(&flagprintver, "version", false, "Show the version of the sdns.")
36 | flag.BoolVar(&flagprintver, "v", false, "Show the version of the sdns.")
37 |
38 | flag.Usage = func() {
39 | fmt.Fprintf(os.Stderr, "Usage:\n sdns [OPTIONS]\n\n")
40 | fmt.Fprintf(os.Stderr, "Options:\n")
41 | fmt.Fprintf(os.Stderr, " -c, --config PATH\tLocation of the config file. If it doesn't exist, a new one will be generated.\n")
42 | fmt.Fprintf(os.Stderr, " -v, --version\t\tShow the version of the sdns and exit.\n")
43 | fmt.Fprintf(os.Stderr, " -h, --help\t\tShow this help and exit.\n\n")
44 | fmt.Fprintf(os.Stderr, "Example:\n")
45 | fmt.Fprintf(os.Stderr, " sdns -c sdns.conf\n\n")
46 | }
47 | }
48 |
49 | func setup() {
50 | var err error
51 |
52 | if cfg, err = config.Load(flagcfgpath, version); err != nil {
53 | log.Crit("Config loading failed", "error", err.Error())
54 | }
55 |
56 | if cfg.LogLevel == "" {
57 | cfg.LogLevel = "info"
58 | }
59 |
60 | lvl, err := log.LvlFromString(cfg.LogLevel)
61 | if err != nil {
62 | log.Crit("Log verbosity level unknown")
63 | }
64 |
65 | log.Root().SetLevel(lvl)
66 | log.Root().SetHandler(log.LvlFilterHandler(lvl, log.StdoutHandler))
67 |
68 | middleware.Setup(cfg)
69 | }
70 |
71 | func run(ctx context.Context) *server.Server {
72 | srv := server.New(cfg)
73 | srv.Run(ctx)
74 |
75 | api := api.New(cfg)
76 | api.Run(ctx)
77 |
78 | return srv
79 | }
80 |
81 | func printver() {
82 | buildInfo, _ := debug.ReadBuildInfo()
83 |
84 | settings := make(map[string]string)
85 | for _, s := range buildInfo.Settings {
86 | settings[s.Key] = s.Value
87 | }
88 |
89 | fmt.Fprintf(os.Stderr, "sdns v%s rev %.7s\nbuilt by %s (%s %s)\n", version,
90 | settings["vcs.revision"], buildInfo.GoVersion, settings["GOOS"], settings["GOARCH"])
91 |
92 | os.Exit(0)
93 | }
94 |
95 | func main() {
96 | flag.Parse()
97 |
98 | if flagprintver {
99 | printver()
100 | }
101 |
102 | log.Info("Starting sdns...", "version", version)
103 |
104 | setup()
105 |
106 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
107 | srv := run(ctx)
108 |
109 | <-ctx.Done()
110 |
111 | log.Info("Stopping sdns...")
112 |
113 | stop()
114 |
115 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
116 | defer cancel()
117 |
118 | for !srv.Stopped() {
119 | select {
120 | case <-time.After(100 * time.Millisecond):
121 | continue
122 | case <-ctx.Done():
123 | return
124 | }
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/server/doh/doh.go:
--------------------------------------------------------------------------------
1 | package doh
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "io"
7 | "net/http"
8 | "strings"
9 |
10 | "github.com/miekg/dns"
11 | )
12 |
13 | const minMsgHeaderSize = 12
14 |
15 | // HandleWireFormat handle wire format
16 | func HandleWireFormat(handle func(*dns.Msg) *dns.Msg) func(http.ResponseWriter, *http.Request) {
17 | return func(w http.ResponseWriter, r *http.Request) {
18 | var (
19 | buf []byte
20 | err error
21 | )
22 |
23 | switch r.Method {
24 | case http.MethodGet:
25 | buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns"))
26 | if len(buf) == 0 || err != nil {
27 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
28 | return
29 | }
30 | case http.MethodPost:
31 | if r.Header.Get("Content-Type") != "application/dns-message" {
32 | http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType)
33 | return
34 | }
35 |
36 | buf, err = io.ReadAll(r.Body)
37 | if err != nil {
38 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
39 | return
40 | }
41 | defer r.Body.Close()
42 | default:
43 | http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
44 | return
45 | }
46 |
47 | if len(buf) < minMsgHeaderSize {
48 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
49 | return
50 | }
51 |
52 | req := new(dns.Msg)
53 | if err := req.Unpack(buf); err != nil {
54 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
55 | return
56 | }
57 |
58 | msg := handle(req)
59 | if msg == nil {
60 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
61 | return
62 | }
63 |
64 | packed, err := msg.Pack()
65 | if err != nil {
66 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
67 | return
68 | }
69 |
70 | w.Header().Set("Content-Type", "application/dns-message")
71 |
72 | _, _ = w.Write(packed)
73 | }
74 | }
75 |
76 | // HandleJSON handle json format
77 | func HandleJSON(handle func(*dns.Msg) *dns.Msg) func(http.ResponseWriter, *http.Request) {
78 | return func(w http.ResponseWriter, r *http.Request) {
79 | name := r.URL.Query().Get("name")
80 | if name == "" {
81 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
82 | return
83 | }
84 | name = dns.Fqdn(name)
85 |
86 | qtype := ParseQTYPE(r.URL.Query().Get("type"))
87 | if qtype == dns.TypeNone {
88 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
89 | return
90 | }
91 |
92 | req := new(dns.Msg)
93 | req.SetQuestion(name, qtype)
94 | req.AuthenticatedData = true
95 |
96 | if r.URL.Query().Get("cd") == "true" {
97 | req.CheckingDisabled = true
98 | }
99 |
100 | opt := &dns.OPT{
101 | Hdr: dns.RR_Header{
102 | Name: ".",
103 | Class: dns.DefaultMsgSize,
104 | Rrtype: dns.TypeOPT,
105 | },
106 | }
107 |
108 | if r.URL.Query().Get("do") == "true" {
109 | opt.SetDo()
110 | }
111 |
112 | req.Extra = append(req.Extra, opt)
113 |
114 | msg := handle(req)
115 | if msg == nil {
116 | http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
117 | return
118 | }
119 |
120 | json, err := json.Marshal(NewMsg(msg))
121 | if err != nil {
122 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
123 | return
124 | }
125 |
126 | if strings.Contains(r.Header.Get("Accept"), "text/html") {
127 | w.Header().Set("Content-Type", "application/x-javascript")
128 | } else {
129 | w.Header().Set("Content-Type", "application/dns-json")
130 | }
131 |
132 | _, _ = w.Write(json)
133 | }
134 | }
135 |
--------------------------------------------------------------------------------
/server/doh/doh_test.go:
--------------------------------------------------------------------------------
1 | package doh
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "net/http"
10 | "net/http/httptest"
11 | "testing"
12 |
13 | "github.com/miekg/dns"
14 | "github.com/stretchr/testify/assert"
15 | )
16 |
17 | func handleTest(w http.ResponseWriter, r *http.Request) {
18 | handle := func(req *dns.Msg) *dns.Msg {
19 | msg, _ := dns.Exchange(req, "8.8.8.8:53")
20 |
21 | return msg
22 | }
23 |
24 | var handleFn func(http.ResponseWriter, *http.Request)
25 | if r.Method == http.MethodGet && r.URL.Query().Get("dns") == "" {
26 | handleFn = HandleJSON(handle)
27 | } else {
28 | handleFn = HandleWireFormat(handle)
29 | }
30 |
31 | handleFn(w, r)
32 | }
33 | func Test_dohJSON(t *testing.T) {
34 | t.Parallel()
35 |
36 | w := httptest.NewRecorder()
37 |
38 | request, err := http.NewRequest("GET", "/dns-query?name=www.google.com&type=a&do=true&cd=true", nil)
39 | assert.NoError(t, err)
40 |
41 | request.RemoteAddr = "127.0.0.1:0"
42 |
43 | handleTest(w, request)
44 |
45 | assert.Equal(t, w.Code, http.StatusOK)
46 |
47 | data, err := io.ReadAll(w.Body)
48 | assert.NoError(t, err)
49 |
50 | var dm Msg
51 | err = json.Unmarshal(data, &dm)
52 | assert.NoError(t, err)
53 |
54 | assert.Equal(t, len(dm.Answer) > 0, true)
55 | }
56 |
57 | func Test_dohJSONerror(t *testing.T) {
58 | t.Parallel()
59 |
60 | w := httptest.NewRecorder()
61 |
62 | request, err := http.NewRequest("GET", "/dns-query?name=", nil)
63 | assert.NoError(t, err)
64 |
65 | request.RemoteAddr = "127.0.0.1:0"
66 |
67 | handleTest(w, request)
68 |
69 | assert.Equal(t, w.Code, http.StatusBadRequest)
70 | }
71 |
72 | func Test_dohJSONaccepthtml(t *testing.T) {
73 | t.Parallel()
74 |
75 | w := httptest.NewRecorder()
76 |
77 | request, err := http.NewRequest("GET", "/dns-query?name=www.google.com", nil)
78 | assert.NoError(t, err)
79 |
80 | request.RemoteAddr = "127.0.0.1:0"
81 |
82 | request.Header.Add("Accept", "text/html")
83 | handleTest(w, request)
84 |
85 | assert.Equal(t, w.Code, http.StatusOK)
86 | assert.Equal(t, w.Header().Get("Content-Type"), "application/x-javascript")
87 | }
88 |
89 | func Test_dohWireGET(t *testing.T) {
90 | t.Parallel()
91 |
92 | w := httptest.NewRecorder()
93 |
94 | req := new(dns.Msg)
95 | req.SetQuestion("www.google.com.", dns.TypeA)
96 |
97 | data, err := req.Pack()
98 | assert.NoError(t, err)
99 |
100 | dq := base64.RawURLEncoding.EncodeToString(data)
101 |
102 | request, err := http.NewRequest("GET", fmt.Sprintf("/dns-query?dns=%s", dq), nil)
103 | assert.NoError(t, err)
104 |
105 | request.RemoteAddr = "127.0.0.1:0"
106 |
107 | handleTest(w, request)
108 |
109 | assert.Equal(t, w.Code, http.StatusOK)
110 |
111 | data, err = io.ReadAll(w.Body)
112 | assert.NoError(t, err)
113 |
114 | msg := new(dns.Msg)
115 | err = msg.Unpack(data)
116 | assert.NoError(t, err)
117 |
118 | assert.Equal(t, msg.Rcode, dns.RcodeSuccess)
119 |
120 | assert.Equal(t, len(msg.Answer) > 0, true)
121 | }
122 |
123 | func Test_dohWireGETerror(t *testing.T) {
124 | t.Parallel()
125 |
126 | w := httptest.NewRecorder()
127 |
128 | request, err := http.NewRequest("GET", "/dns-query?dns=", nil)
129 | assert.NoError(t, err)
130 |
131 | request.RemoteAddr = "127.0.0.1:0"
132 |
133 | handleTest(w, request)
134 |
135 | assert.Equal(t, w.Code, http.StatusBadRequest)
136 | }
137 |
138 | func Test_dohWireGETbadquery(t *testing.T) {
139 | t.Parallel()
140 |
141 | w := httptest.NewRecorder()
142 |
143 | request, err := http.NewRequest("GET", "/dns-query?dns=Df4", nil)
144 | assert.NoError(t, err)
145 |
146 | request.RemoteAddr = "127.0.0.1:0"
147 |
148 | handleTest(w, request)
149 |
150 | assert.Equal(t, w.Code, http.StatusBadRequest)
151 | }
152 |
153 | func Test_dohWireHEAD(t *testing.T) {
154 | t.Parallel()
155 |
156 | w := httptest.NewRecorder()
157 |
158 | request, err := http.NewRequest("HEAD", "/dns-query?dns=", nil)
159 | assert.NoError(t, err)
160 |
161 | request.RemoteAddr = "127.0.0.1:0"
162 |
163 | handleTest(w, request)
164 |
165 | assert.Equal(t, w.Code, http.StatusMethodNotAllowed)
166 | }
167 |
168 | func Test_dohWirePOST(t *testing.T) {
169 | t.Parallel()
170 |
171 | w := httptest.NewRecorder()
172 |
173 | req := new(dns.Msg)
174 | req.SetQuestion("www.google.com.", dns.TypeA)
175 |
176 | data, err := req.Pack()
177 | assert.NoError(t, err)
178 |
179 | request, err := http.NewRequest("POST", "/dns-query", bytes.NewReader(data))
180 | assert.NoError(t, err)
181 |
182 | request.RemoteAddr = "127.0.0.1:0"
183 | request.Header.Add("Content-Type", "application/dns-message")
184 |
185 | handleTest(w, request)
186 |
187 | assert.Equal(t, w.Code, http.StatusOK)
188 |
189 | data, err = io.ReadAll(w.Body)
190 | assert.NoError(t, err)
191 |
192 | msg := new(dns.Msg)
193 | err = msg.Unpack(data)
194 | assert.NoError(t, err)
195 |
196 | assert.Equal(t, msg.Rcode, dns.RcodeSuccess)
197 |
198 | assert.Equal(t, len(msg.Answer) > 0, true)
199 | }
200 |
201 | func Test_dohWirePOSTError(t *testing.T) {
202 | t.Parallel()
203 |
204 | w := httptest.NewRecorder()
205 |
206 | request, err := http.NewRequest("POST", "/dns-query", bytes.NewReader([]byte{}))
207 | assert.NoError(t, err)
208 |
209 | request.RemoteAddr = "127.0.0.1:0"
210 | request.Header.Add("Content-Type", "text/html")
211 |
212 | handleTest(w, request)
213 |
214 | assert.Equal(t, w.Code, http.StatusUnsupportedMediaType)
215 | }
216 |
--------------------------------------------------------------------------------
/server/doh/msg.go:
--------------------------------------------------------------------------------
1 | package doh
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/miekg/dns"
7 | )
8 |
9 | // Question struct
10 | type Question struct {
11 | Name string `json:"name"`
12 | Qtype uint16 `json:"type"`
13 | Qclass uint16 `json:"-"`
14 | }
15 |
16 | // RR struct
17 | type RR struct {
18 | Name string `json:"name"`
19 | Type uint16 `json:"type"`
20 | TTL uint32 `json:"TTL"`
21 | Data string `json:"data"`
22 | }
23 |
24 | // Msg struct
25 | type Msg struct {
26 | Status int
27 | TC bool
28 | RD bool
29 | RA bool
30 | AD bool
31 | CD bool
32 | Question []Question
33 | Answer []RR `json:",omitempty"`
34 | Authority []RR `json:",omitempty"`
35 | }
36 |
37 | // NewMsg function
38 | func NewMsg(m *dns.Msg) *Msg {
39 | if m == nil {
40 | return nil
41 | }
42 |
43 | msg := &Msg{
44 | Status: m.Rcode,
45 | TC: m.Truncated,
46 | RD: m.RecursionDesired,
47 | RA: m.RecursionAvailable,
48 | AD: m.AuthenticatedData,
49 | CD: m.CheckingDisabled,
50 | Question: make([]Question, len(m.Question)),
51 | Answer: make([]RR, len(m.Answer)),
52 | Authority: make([]RR, len(m.Ns)),
53 | }
54 |
55 | for i, q := range m.Question {
56 | msg.Question[i] = Question(q)
57 | }
58 |
59 | for i, a := range m.Answer {
60 | msg.Answer[i] = RR{
61 | Name: a.Header().Name,
62 | Type: a.Header().Rrtype,
63 | TTL: a.Header().Ttl,
64 | Data: strings.TrimPrefix(a.String(), a.Header().String()),
65 | }
66 | }
67 |
68 | for i, a := range m.Ns {
69 | msg.Authority[i] = RR{
70 | Name: a.Header().Name,
71 | Type: a.Header().Rrtype,
72 | TTL: a.Header().Ttl,
73 | Data: strings.TrimPrefix(a.String(), a.Header().String()),
74 | }
75 | }
76 |
77 | return msg
78 | }
79 |
--------------------------------------------------------------------------------
/server/doh/msg_test.go:
--------------------------------------------------------------------------------
1 | package doh
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/miekg/dns"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func Test_Msg(t *testing.T) {
11 | m := NewMsg(nil)
12 | assert.Nil(t, m)
13 |
14 | msg := new(dns.Msg)
15 | msg.SetQuestion(".", dns.TypeNS)
16 |
17 | rr, err := dns.NewRR(". 518400 IN NS a.root-servers.net.")
18 | assert.NoError(t, err)
19 |
20 | msg.Answer = append(msg.Answer, rr)
21 |
22 | rr, err = dns.NewRR("a.gtld-servers.net. 172800 IN A 192.5.6.30")
23 | assert.NoError(t, err)
24 |
25 | msg.Ns = append(msg.Ns, rr)
26 |
27 | m = NewMsg(msg)
28 |
29 | assert.Equal(t, m.Answer[0].Data, msg.Answer[0].(*dns.NS).Ns)
30 | assert.Equal(t, m.Authority[0].Data, msg.Ns[0].(*dns.A).A.String())
31 | }
32 |
--------------------------------------------------------------------------------
/server/doh/qtype.go:
--------------------------------------------------------------------------------
1 | package doh
2 |
3 | import (
4 | "strconv"
5 | "strings"
6 |
7 | "github.com/miekg/dns"
8 | )
9 |
10 | var qtype = map[string]uint16{
11 | "A": dns.TypeA,
12 | "AAAA": dns.TypeAAAA,
13 | "AFSDB": dns.TypeAFSDB,
14 | "AMTRELAY": dns.TypeAMTRELAY,
15 | "ANY": dns.TypeANY,
16 | "APL": dns.TypeAPL,
17 | "ATMA": dns.TypeATMA,
18 | "AVC": dns.TypeAVC,
19 | "AXFR": dns.TypeAXFR,
20 | "CAA": dns.TypeCAA,
21 | "CDNSKEY": dns.TypeCDNSKEY,
22 | "CDS": dns.TypeCDS,
23 | "CERT": dns.TypeCERT,
24 | "CNAME": dns.TypeCNAME,
25 | "CSYNC": dns.TypeCSYNC,
26 | "DHCID": dns.TypeDHCID,
27 | "DLV": dns.TypeDLV,
28 | "DNAME": dns.TypeDNAME,
29 | "DNSKEY": dns.TypeDNSKEY,
30 | "DS": dns.TypeDS,
31 | "EID": dns.TypeEID,
32 | "EUI48": dns.TypeEUI48,
33 | "EUI64": dns.TypeEUI64,
34 | "GID": dns.TypeGID,
35 | "GPOS": dns.TypeGPOS,
36 | "HINFO": dns.TypeHINFO,
37 | "HIP": dns.TypeHIP,
38 | "HTTPS": dns.TypeHTTPS,
39 | "IPSECKEY": dns.TypeIPSECKEY,
40 | "ISDN": dns.TypeISDN,
41 | "IXFR": dns.TypeIXFR,
42 | "KEY": dns.TypeKEY,
43 | "KX": dns.TypeKX,
44 | "L32": dns.TypeL32,
45 | "L64": dns.TypeL64,
46 | "LOC": dns.TypeLOC,
47 | "LP": dns.TypeLP,
48 | "MAILA": dns.TypeMAILA,
49 | "MAILB": dns.TypeMAILB,
50 | "MB": dns.TypeMB,
51 | "MD": dns.TypeMD,
52 | "MF": dns.TypeMF,
53 | "MG": dns.TypeMG,
54 | "MINFO": dns.TypeMINFO,
55 | "MR": dns.TypeMR,
56 | "MX": dns.TypeMX,
57 | "NAPTR": dns.TypeNAPTR,
58 | "NID": dns.TypeNID,
59 | "NIMLOC": dns.TypeNIMLOC,
60 | "NINFO": dns.TypeNINFO,
61 | "NS": dns.TypeNS,
62 | "NSAP-PTR": dns.TypeNSAPPTR,
63 | "NSEC": dns.TypeNSEC,
64 | "NSEC3": dns.TypeNSEC3,
65 | "NSEC3PARAM": dns.TypeNSEC3PARAM,
66 | "NULL": dns.TypeNULL,
67 | "NXT": dns.TypeNXT,
68 | "None": dns.TypeNone,
69 | "OPENPGPKEY": dns.TypeOPENPGPKEY,
70 | "OPT": dns.TypeOPT,
71 | "PTR": dns.TypePTR,
72 | "PX": dns.TypePX,
73 | "RKEY": dns.TypeRKEY,
74 | "RP": dns.TypeRP,
75 | "RRSIG": dns.TypeRRSIG,
76 | "RT": dns.TypeRT,
77 | "Reserved": dns.TypeReserved,
78 | "SIG": dns.TypeSIG,
79 | "SMIMEA": dns.TypeSMIMEA,
80 | "SOA": dns.TypeSOA,
81 | "SPF": dns.TypeSPF,
82 | "SRV": dns.TypeSRV,
83 | "SSHFP": dns.TypeSSHFP,
84 | "SVCB": dns.TypeSVCB,
85 | "TA": dns.TypeTA,
86 | "TALINK": dns.TypeTALINK,
87 | "TKEY": dns.TypeTKEY,
88 | "TLSA": dns.TypeTLSA,
89 | "TSIG": dns.TypeTSIG,
90 | "TXT": dns.TypeTXT,
91 | "UID": dns.TypeUID,
92 | "UINFO": dns.TypeUINFO,
93 | "UNSPEC": dns.TypeUNSPEC,
94 | "URI": dns.TypeURI,
95 | "X25": dns.TypeX25,
96 | "ZONEMD": dns.TypeZONEMD,
97 | }
98 |
99 | // ParseQTYPE function
100 | func ParseQTYPE(s string) uint16 {
101 | if s == "" {
102 | return dns.TypeA
103 | }
104 |
105 | if v, err := strconv.ParseUint(s, 10, 16); err == nil {
106 | return uint16(v)
107 | }
108 |
109 | if v, ok := qtype[strings.ToUpper(s)]; ok {
110 | return v
111 | }
112 |
113 | return dns.TypeNone
114 | }
115 |
--------------------------------------------------------------------------------
/server/doh/qtype_test.go:
--------------------------------------------------------------------------------
1 | package doh
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/miekg/dns"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func Test_ParseQTYPE(t *testing.T) {
11 | qtype := ParseQTYPE("")
12 | assert.Equal(t, qtype, dns.TypeA)
13 |
14 | qtype = ParseQTYPE("1")
15 | assert.Equal(t, qtype, dns.TypeA)
16 |
17 | qtype = ParseQTYPE("CNAME")
18 | assert.Equal(t, qtype, dns.TypeCNAME)
19 |
20 | qtype = ParseQTYPE("TEST")
21 | assert.Equal(t, qtype, dns.TypeNone)
22 | }
23 |
--------------------------------------------------------------------------------
/server/doq/doq.go:
--------------------------------------------------------------------------------
1 | package doq
2 |
3 | import (
4 | "context"
5 | "crypto/tls"
6 | "io"
7 | "time"
8 |
9 | "github.com/miekg/dns"
10 | "github.com/quic-go/quic-go"
11 | )
12 |
13 | var doqProtos = []string{"doq", "doq-i02", "dq", "doq-i00", "doq-i01", "doq-i11"}
14 |
15 | const (
16 | minMsgHeaderSize = 14 // fixed msg header size 12 + quic prefix size 2
17 | ProtocolError = 0x2
18 | NoError = 0x0
19 | )
20 |
21 | type Server struct {
22 | Addr string
23 | Handler dns.Handler
24 |
25 | ln *quic.Listener
26 | }
27 |
28 | func (s *Server) ListenAndServeQUIC(tlsCert, tlsKey string) error {
29 | cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
30 | if err != nil {
31 | return err
32 | }
33 |
34 | tlsConfig := &tls.Config{
35 | Certificates: []tls.Certificate{cert},
36 | NextProtos: doqProtos,
37 | }
38 |
39 | quicConfig := &quic.Config{
40 | MaxIdleTimeout: 5 * time.Second,
41 | MaxStreamReceiveWindow: dns.MaxMsgSize,
42 | }
43 |
44 | listener, err := quic.ListenAddr(s.Addr, tlsConfig, quicConfig)
45 | if err != nil {
46 | return err
47 | }
48 |
49 | s.ln = listener
50 |
51 | for {
52 | conn, err := listener.Accept(context.Background())
53 | if err != nil {
54 | return err
55 | }
56 |
57 | go s.handleConnection(conn)
58 | }
59 | }
60 |
61 | func (s *Server) Shutdown() error {
62 | if s.ln == nil {
63 | return nil
64 | }
65 |
66 | err := s.ln.Close()
67 |
68 | if err == quic.ErrServerClosed {
69 | return nil
70 | }
71 |
72 | return err
73 | }
74 |
75 | func (s *Server) handleConnection(conn quic.Connection) {
76 | var (
77 | stream quic.Stream
78 | buf []byte
79 | err error
80 | )
81 |
82 | for {
83 | stream, err = conn.AcceptStream(context.Background())
84 | if err != nil {
85 | _ = conn.CloseWithError(NoError, "")
86 | return
87 | }
88 |
89 | go func() {
90 | defer stream.Close()
91 |
92 | buf, err = io.ReadAll(stream)
93 | if err != nil {
94 | _ = conn.CloseWithError(ProtocolError, err.Error())
95 | return
96 | }
97 |
98 | if len(buf) < minMsgHeaderSize {
99 | _ = conn.CloseWithError(ProtocolError, "dns msg size too small")
100 | return
101 | }
102 |
103 | req := new(dns.Msg)
104 | if err := req.Unpack(buf[2:]); err != nil {
105 | _ = conn.CloseWithError(ProtocolError, err.Error())
106 | return
107 | }
108 | req.Id = dns.Id()
109 |
110 | w := &ResponseWriter{Conn: conn, Stream: stream}
111 |
112 | s.Handler.ServeDNS(w, req)
113 | }()
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/server/doq/doq_test.go:
--------------------------------------------------------------------------------
1 | package doq
2 |
3 | import (
4 | "context"
5 | "crypto/ecdsa"
6 | "crypto/rand"
7 | "crypto/rsa"
8 | "crypto/tls"
9 | "crypto/x509"
10 | "crypto/x509/pkix"
11 | "encoding/pem"
12 | "fmt"
13 | "io"
14 | "math/big"
15 | "os"
16 | "path/filepath"
17 | "testing"
18 | "time"
19 |
20 | "github.com/miekg/dns"
21 | "github.com/quic-go/quic-go"
22 | "github.com/stretchr/testify/assert"
23 | )
24 |
25 | type dummyHandler struct {
26 | dns.Handler
27 | }
28 |
29 | func makeRR(data string) dns.RR {
30 | r, _ := dns.NewRR(data)
31 |
32 | return r
33 | }
34 |
35 | func (h *dummyHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
36 | msg := new(dns.Msg)
37 | msg.SetReply(r)
38 | msg.Answer = append(msg.Answer, makeRR("example.com. 1800 IN A 0.0.0.0"))
39 |
40 | _ = w.WriteMsg(msg)
41 | }
42 |
43 | func publicKey(priv interface{}) interface{} {
44 | switch k := priv.(type) {
45 | case *rsa.PrivateKey:
46 | return &k.PublicKey
47 | case *ecdsa.PrivateKey:
48 | return &k.PublicKey
49 | default:
50 | return nil
51 | }
52 | }
53 |
54 | func pemBlockForKey(priv interface{}) *pem.Block {
55 | switch k := priv.(type) {
56 | case *rsa.PrivateKey:
57 | return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
58 | case *ecdsa.PrivateKey:
59 | b, err := x509.MarshalECPrivateKey(k)
60 | if err != nil {
61 | fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err)
62 | os.Exit(2)
63 | }
64 | return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
65 | default:
66 | return nil
67 | }
68 | }
69 |
70 | func generateCertificate() error {
71 | priv, err := rsa.GenerateKey(rand.Reader, 2048)
72 | if err != nil {
73 | return err
74 | }
75 |
76 | template := x509.Certificate{
77 | SerialNumber: big.NewInt(1),
78 | Subject: pkix.Name{
79 | Organization: []string{"Acme Co"},
80 | },
81 | NotBefore: time.Now(),
82 | NotAfter: time.Now().Add(time.Hour * 24 * 365 * 3),
83 |
84 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
85 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
86 | BasicConstraintsValid: true,
87 | }
88 |
89 | template.DNSNames = append(template.DNSNames, "localhost")
90 |
91 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
92 | if err != nil {
93 | return err
94 | }
95 |
96 | certOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.cert"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
97 | if err != nil {
98 | return err
99 | }
100 |
101 | err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
102 | if err != nil {
103 | return err
104 | }
105 |
106 | certOut.Close()
107 |
108 | keyOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
109 | if err != nil {
110 | return err
111 | }
112 |
113 | err = pem.Encode(keyOut, pemBlockForKey(priv))
114 | if err != nil {
115 | return err
116 | }
117 |
118 | return keyOut.Close()
119 | }
120 |
121 | func Test_doq(t *testing.T) {
122 | err := generateCertificate()
123 | assert.NoError(t, err)
124 |
125 | cert := filepath.Join(os.TempDir(), "test.cert")
126 | privkey := filepath.Join(os.TempDir(), "test.key")
127 |
128 | h := &dummyHandler{}
129 |
130 | s := &Server{
131 | Addr: "127.0.0.1:45853",
132 | Handler: h,
133 | }
134 |
135 | go func() {
136 | err := s.ListenAndServeQUIC(cert, privkey)
137 | if err == quic.ErrServerClosed {
138 | return
139 | }
140 | assert.NoError(t, err)
141 | }()
142 |
143 | time.Sleep(time.Second)
144 |
145 | tlsConf := &tls.Config{
146 | InsecureSkipVerify: true,
147 | NextProtos: []string{"doq"},
148 | }
149 | conn, err := quic.DialAddr(context.Background(), s.Addr, tlsConf, nil)
150 | assert.NoError(t, err)
151 |
152 | stream, err := conn.OpenStreamSync(context.Background())
153 | assert.NoError(t, err)
154 |
155 | req := new(dns.Msg)
156 | req.SetQuestion("example.com.", dns.TypeA)
157 | req.Id = 0
158 |
159 | buf, err := req.Pack()
160 | assert.NoError(t, err)
161 |
162 | n, err := stream.Write(addPrefixLen(buf))
163 | assert.NoError(t, err)
164 | assert.Greater(t, n, 17)
165 |
166 | err = stream.Close()
167 | assert.NoError(t, err)
168 |
169 | data, err := io.ReadAll(stream)
170 | assert.NoError(t, err)
171 |
172 | msg := new(dns.Msg)
173 | err = msg.Unpack(data[2:])
174 | assert.NoError(t, err)
175 |
176 | stream, err = conn.OpenStreamSync(context.Background())
177 | assert.NoError(t, err)
178 |
179 | time.Sleep(6 * time.Second)
180 |
181 | _, err = stream.Write([]byte{0, 0})
182 | assert.Error(t, err)
183 |
184 | conn, err = quic.DialAddr(context.Background(), s.Addr, tlsConf, nil)
185 | assert.NoError(t, err)
186 |
187 | stream, err = conn.OpenStreamSync(context.Background())
188 | assert.NoError(t, err)
189 |
190 | _, err = stream.Write([]byte{0, 0})
191 | assert.NoError(t, err)
192 |
193 | err = stream.Close()
194 | assert.NoError(t, err)
195 |
196 | _, err = io.ReadAll(stream)
197 | assert.Error(t, err)
198 |
199 | conn, err = quic.DialAddr(context.Background(), s.Addr, tlsConf, nil)
200 | assert.NoError(t, err)
201 |
202 | stream, err = conn.OpenStreamSync(context.Background())
203 | assert.NoError(t, err)
204 |
205 | msg = new(dns.Msg)
206 | msg.SetEdns0(512, true)
207 | buf, _ = msg.Pack()
208 |
209 | _, err = stream.Write(buf)
210 | assert.NoError(t, err)
211 |
212 | err = stream.Close()
213 | assert.NoError(t, err)
214 |
215 | _, err = io.ReadAll(stream)
216 | assert.Error(t, err)
217 |
218 | err = s.Shutdown()
219 | assert.NoError(t, err)
220 | }
221 |
--------------------------------------------------------------------------------
/server/doq/response_writer.go:
--------------------------------------------------------------------------------
1 | package doq
2 |
3 | import (
4 | "encoding/binary"
5 | "net"
6 |
7 | "github.com/miekg/dns"
8 | "github.com/quic-go/quic-go"
9 | )
10 |
11 | type ResponseWriter struct {
12 | dns.ResponseWriter
13 |
14 | Conn quic.Connection
15 | Stream quic.Stream
16 | }
17 |
18 | func (w *ResponseWriter) LocalAddr() net.Addr {
19 | return w.Conn.LocalAddr()
20 | }
21 |
22 | func (w *ResponseWriter) RemoteAddr() net.Addr {
23 | return w.Conn.RemoteAddr()
24 | }
25 |
26 | func (w *ResponseWriter) Close() error {
27 | return w.Stream.Close()
28 | }
29 |
30 | func (w *ResponseWriter) Write(m []byte) (int, error) {
31 | return w.Stream.Write(addPrefixLen(m))
32 | }
33 |
34 | func (w *ResponseWriter) WriteMsg(m *dns.Msg) error {
35 | m.Id = 0
36 |
37 | packed, err := m.Pack()
38 | if err != nil {
39 | _ = w.Conn.CloseWithError(0x1, err.Error())
40 | return err
41 | }
42 |
43 | _, err = w.Stream.Write(addPrefixLen(packed))
44 | if err != nil {
45 | return err
46 | }
47 |
48 | return nil
49 | }
50 |
51 | func addPrefixLen(msg []byte) (buf []byte) {
52 | buf = make([]byte, 2+len(msg))
53 | binary.BigEndian.PutUint16(buf, uint16(len(msg)))
54 | copy(buf[2:], msg)
55 |
56 | return buf
57 | }
58 |
--------------------------------------------------------------------------------
/server/server_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "crypto/ecdsa"
6 | "crypto/rand"
7 | "crypto/rsa"
8 | "crypto/x509"
9 | "crypto/x509/pkix"
10 | "encoding/base64"
11 | "encoding/pem"
12 | "fmt"
13 | "io"
14 | "math/big"
15 | "net/http"
16 | "net/http/httptest"
17 | "os"
18 | "path/filepath"
19 | "testing"
20 | "time"
21 |
22 | "github.com/semihalev/sdns/middleware"
23 |
24 | "github.com/miekg/dns"
25 | "github.com/semihalev/log"
26 | "github.com/semihalev/sdns/config"
27 | "github.com/semihalev/sdns/middleware/blocklist"
28 | "github.com/semihalev/sdns/mock"
29 | "github.com/stretchr/testify/assert"
30 | )
31 |
32 | func TestMain(m *testing.M) {
33 | log.Root().SetHandler(log.LvlFilterHandler(0, log.StdoutHandler))
34 | m.Run()
35 |
36 | os.Exit(0)
37 | }
38 |
39 | func publicKey(priv interface{}) interface{} {
40 | switch k := priv.(type) {
41 | case *rsa.PrivateKey:
42 | return &k.PublicKey
43 | case *ecdsa.PrivateKey:
44 | return &k.PublicKey
45 | default:
46 | return nil
47 | }
48 | }
49 |
50 | func pemBlockForKey(priv interface{}) *pem.Block {
51 | switch k := priv.(type) {
52 | case *rsa.PrivateKey:
53 | return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}
54 | case *ecdsa.PrivateKey:
55 | b, err := x509.MarshalECPrivateKey(k)
56 | if err != nil {
57 | fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err)
58 | os.Exit(2)
59 | }
60 | return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
61 | default:
62 | return nil
63 | }
64 | }
65 |
66 | func generateCertificate() error {
67 | priv, err := rsa.GenerateKey(rand.Reader, 2048)
68 | if err != nil {
69 | return err
70 | }
71 |
72 | template := x509.Certificate{
73 | SerialNumber: big.NewInt(1),
74 | Subject: pkix.Name{
75 | Organization: []string{"Acme Co"},
76 | },
77 | NotBefore: time.Now(),
78 | NotAfter: time.Now().Add(time.Hour * 24 * 365 * 3),
79 |
80 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
81 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
82 | BasicConstraintsValid: true,
83 | }
84 |
85 | template.DNSNames = append(template.DNSNames, "localhost")
86 |
87 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
88 | if err != nil {
89 | return err
90 | }
91 |
92 | certOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.cert"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
93 | if err != nil {
94 | return err
95 | }
96 |
97 | err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
98 | if err != nil {
99 | return err
100 | }
101 |
102 | certOut.Close()
103 |
104 | keyOut, err := os.OpenFile(filepath.Join(os.TempDir(), "test.key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
105 | if err != nil {
106 | return err
107 | }
108 |
109 | err = pem.Encode(keyOut, pemBlockForKey(priv))
110 | if err != nil {
111 | return err
112 | }
113 |
114 | return keyOut.Close()
115 | }
116 |
117 | func Test_logPipe(t *testing.T) {
118 | logReader, logWriter := io.Pipe()
119 | go readlogs(logReader)
120 | _, _ = logWriter.Write([]byte("test test test test test test\n"))
121 | }
122 |
123 | func Test_ServerNoBind(t *testing.T) {
124 | cfg := &config.Config{}
125 |
126 | s := New(cfg)
127 | s.Run(context.Background())
128 | }
129 |
130 | func Test_ServerBindFail(t *testing.T) {
131 | cfg := &config.Config{}
132 |
133 | cfg.TLSCertificate = "cert"
134 | cfg.TLSPrivateKey = "key"
135 | cfg.LogLevel = "crit"
136 | cfg.Bind = "1:1"
137 | cfg.BindTLS = "1:2"
138 | cfg.BindDOH = "1:3"
139 | cfg.BindDOQ = "1:4"
140 |
141 | s := New(cfg)
142 | s.Run(context.Background())
143 | }
144 |
145 | func Test_Server(t *testing.T) {
146 | cfg := &config.Config{}
147 | err := generateCertificate()
148 | assert.NoError(t, err)
149 |
150 | cert := filepath.Join(os.TempDir(), "test.cert")
151 | privkey := filepath.Join(os.TempDir(), "test.key")
152 |
153 | cfg.TLSCertificate = cert
154 | cfg.TLSPrivateKey = privkey
155 | cfg.LogLevel = "crit"
156 | cfg.Bind = "127.0.0.1:0"
157 | cfg.BindTLS = "127.0.0.1:23222"
158 | cfg.BindDOH = "127.0.0.1:23223"
159 | cfg.BindDOQ = "127.0.0.1:23224"
160 | cfg.BlockListDir = filepath.Join(os.TempDir(), "sdns_temp")
161 |
162 | middleware.Register("blocklist", func(cfg *config.Config) middleware.Handler { return blocklist.New(cfg) })
163 | middleware.Setup(cfg)
164 |
165 | blocklist := middleware.Get("blocklist").(*blocklist.BlockList)
166 | blocklist.Set("test.com.")
167 |
168 | s := New(cfg)
169 | s.Run(context.Background())
170 |
171 | req := new(dns.Msg)
172 | req.SetQuestion("test.com.", dns.TypeA)
173 |
174 | mw := mock.NewWriter("udp", "127.0.0.1:0")
175 | s.ServeDNS(mw, req)
176 |
177 | assert.True(t, mw.Written())
178 | if assert.NotNil(t, mw.Msg()) {
179 | assert.Equal(t, true, len(mw.Msg().Answer) > 0)
180 | }
181 |
182 | request, err := http.NewRequest("GET", "/dns-query?name=test.com", nil)
183 | assert.NoError(t, err)
184 |
185 | hw := httptest.NewRecorder()
186 |
187 | s.ServeHTTP(hw, request)
188 | assert.Equal(t, 200, hw.Code)
189 |
190 | data, err := req.Pack()
191 | assert.NoError(t, err)
192 |
193 | dq := base64.RawURLEncoding.EncodeToString(data)
194 |
195 | request, err = http.NewRequest("GET", fmt.Sprintf("/dns-query?dns=%s", dq), nil)
196 | assert.NoError(t, err)
197 |
198 | hw = httptest.NewRecorder()
199 |
200 | s.ServeHTTP(hw, request)
201 | assert.Equal(t, 200, hw.Code)
202 |
203 | request, err = http.NewRequest("GET", "/dns-query?name=example.com", nil)
204 | assert.NoError(t, err)
205 |
206 | hw = httptest.NewRecorder()
207 |
208 | s.ServeHTTP(hw, request)
209 | assert.Equal(t, 400, hw.Code)
210 |
211 | time.Sleep(2 * time.Second)
212 |
213 | os.Remove(cert)
214 | os.Remove(privkey)
215 | }
216 |
--------------------------------------------------------------------------------
/snap/snapcraft.yaml:
--------------------------------------------------------------------------------
1 | name: sdns
2 | adopt-info: sdns
3 | summary: A DNS Resolver Server.
4 | description: |
5 | A high-performance, recursive DNS resolver server with DNSSEC support, focused on preserving privacy.
6 |
7 | architectures:
8 | - build-on: armhf
9 | - build-on: arm64
10 | - build-on: amd64
11 |
12 | grade: stable
13 | confinement: strict
14 | base: core18
15 |
16 | parts:
17 | sdns:
18 | plugin: go
19 | source: https://github.com/semihalev/sdns.git
20 | go-importpath: github.com/semihalev/sdns
21 | source-type: git
22 | override-pull: |
23 | snapcraftctl pull
24 | last_committed_tag="$(git describe --tags --abbrev=0)"
25 | last_committed_tag_ver="$(echo ${last_committed_tag} | sed 's/v//')"
26 | last_released_tag="$(snap info sdns | awk '$1 == "latest/stable:" { print $2 }')"
27 | if [ "${last_committed_tag_ver}" != "${last_released_tag}" ]; then
28 | git fetch
29 | git checkout "${last_committed_tag}"
30 | fi
31 | snapcraftctl set-version "$(git describe --tags | sed 's/v//')"
32 | snapcraftctl set-grade stable
33 | stage:
34 | - bin/sdns
35 | apps:
36 | sdns:
37 | command: sdns
38 | plugs: [network-bind]
39 | daemon: simple
40 |
--------------------------------------------------------------------------------
/waitgroup/waitgroup.go:
--------------------------------------------------------------------------------
1 | package waitgroup
2 |
3 | import (
4 | "context"
5 | "sync"
6 | "time"
7 | )
8 |
9 | // WaitGroup waits for other same processes based key with timeout.
10 | type WaitGroup struct {
11 | mu sync.RWMutex
12 |
13 | groups map[uint64]*call
14 |
15 | timeout time.Duration
16 | }
17 |
18 | type call struct {
19 | ctx context.Context
20 | dups int
21 | cancel func()
22 | }
23 |
24 | // New return a new WaitGroup with timeout.
25 | func New(timeout time.Duration) *WaitGroup {
26 | return &WaitGroup{
27 | groups: make(map[uint64]*call),
28 |
29 | timeout: timeout,
30 | }
31 | }
32 |
33 | // Get return count of dups with key.
34 | func (wg *WaitGroup) Get(key uint64) int {
35 | wg.mu.RLock()
36 | defer wg.mu.RUnlock()
37 |
38 | if c, ok := wg.groups[key]; ok {
39 | return c.dups
40 | }
41 |
42 | return 0
43 | }
44 |
45 | // Wait blocks until WaitGroup context cancelled or timedout with key.
46 | func (wg *WaitGroup) Wait(key uint64) {
47 | wg.mu.RLock()
48 |
49 | if c, ok := wg.groups[key]; ok {
50 | wg.mu.RUnlock()
51 | <-c.ctx.Done()
52 | return
53 | }
54 |
55 | wg.mu.RUnlock()
56 | }
57 |
58 | // Add adds a new caller or if the caller exists increment dups with key.
59 | func (wg *WaitGroup) Add(key uint64) {
60 | wg.mu.Lock()
61 | defer wg.mu.Unlock()
62 |
63 | if c, ok := wg.groups[key]; ok {
64 | c.dups++
65 | return
66 | }
67 |
68 | c := new(call)
69 | c.dups++
70 | c.ctx, c.cancel = context.WithTimeout(context.Background(), wg.timeout)
71 | wg.groups[key] = c
72 | }
73 |
74 | // Done cancels the group context or if the caller dups more then zero, decrements the dups with key.
75 | func (wg *WaitGroup) Done(key uint64) {
76 | wg.mu.Lock()
77 | defer wg.mu.Unlock()
78 |
79 | if c, ok := wg.groups[key]; ok {
80 | if c.dups > 1 {
81 | c.dups--
82 | return
83 | }
84 | c.cancel()
85 | }
86 |
87 | delete(wg.groups, key)
88 | }
89 |
--------------------------------------------------------------------------------
/waitgroup/waitgroup_test.go:
--------------------------------------------------------------------------------
1 | package waitgroup
2 |
3 | import (
4 | "sync"
5 | "testing"
6 | "time"
7 |
8 | "github.com/miekg/dns"
9 | "github.com/semihalev/sdns/cache"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func Test_WaitGroupWait(t *testing.T) {
14 | wg := New(5 * time.Second)
15 | mu := sync.RWMutex{}
16 |
17 | m := new(dns.Msg)
18 | m.SetQuestion(dns.Fqdn("example.com."), dns.TypeA)
19 | key := cache.Hash(m.Question[0])
20 |
21 | wg.Add(key)
22 |
23 | count := wg.Get(key)
24 | assert.Equal(t, 1, count)
25 |
26 | key2 := cache.Hash(dns.Question{Name: "none.", Qtype: dns.TypeA, Qclass: dns.ClassINET})
27 |
28 | count = wg.Get(key2)
29 | assert.Equal(t, 0, count)
30 |
31 | wg.Wait(key2)
32 |
33 | var workers []*string
34 |
35 | for i := 0; i < 5; i++ {
36 | go func() {
37 | w := new(string)
38 | *w = "running"
39 |
40 | mu.Lock()
41 | workers = append(workers, w)
42 | mu.Unlock()
43 |
44 | wg.Wait(key)
45 |
46 | mu.Lock()
47 | *w = "stopped"
48 | mu.Unlock()
49 | }()
50 | }
51 |
52 | time.Sleep(time.Second)
53 |
54 | wg.Done(key)
55 |
56 | time.Sleep(100 * time.Millisecond)
57 |
58 | mu.RLock()
59 | defer mu.RUnlock()
60 | for _, w := range workers {
61 | assert.Equal(t, *w, "stopped")
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/zregister.go:
--------------------------------------------------------------------------------
1 | // Code generated by gen.go DO NOT EDIT.
2 |
3 | package main
4 |
5 | import (
6 | "github.com/semihalev/sdns/config"
7 | "github.com/semihalev/sdns/middleware"
8 | "github.com/semihalev/sdns/middleware/accesslist"
9 | "github.com/semihalev/sdns/middleware/accesslog"
10 | "github.com/semihalev/sdns/middleware/as112"
11 | "github.com/semihalev/sdns/middleware/blocklist"
12 | "github.com/semihalev/sdns/middleware/cache"
13 | "github.com/semihalev/sdns/middleware/chaos"
14 | "github.com/semihalev/sdns/middleware/edns"
15 | "github.com/semihalev/sdns/middleware/failover"
16 | "github.com/semihalev/sdns/middleware/forwarder"
17 | "github.com/semihalev/sdns/middleware/hostsfile"
18 | "github.com/semihalev/sdns/middleware/loop"
19 | "github.com/semihalev/sdns/middleware/metrics"
20 | "github.com/semihalev/sdns/middleware/ratelimit"
21 | "github.com/semihalev/sdns/middleware/recovery"
22 | "github.com/semihalev/sdns/middleware/resolver"
23 | )
24 |
25 | func init() {
26 | middleware.Register("recovery", func(cfg *config.Config) middleware.Handler { return recovery.New(cfg) })
27 | middleware.Register("loop", func(cfg *config.Config) middleware.Handler { return loop.New(cfg) })
28 | middleware.Register("metrics", func(cfg *config.Config) middleware.Handler { return metrics.New(cfg) })
29 | middleware.Register("accesslist", func(cfg *config.Config) middleware.Handler { return accesslist.New(cfg) })
30 | middleware.Register("ratelimit", func(cfg *config.Config) middleware.Handler { return ratelimit.New(cfg) })
31 | middleware.Register("edns", func(cfg *config.Config) middleware.Handler { return edns.New(cfg) })
32 | middleware.Register("accesslog", func(cfg *config.Config) middleware.Handler { return accesslog.New(cfg) })
33 | middleware.Register("chaos", func(cfg *config.Config) middleware.Handler { return chaos.New(cfg) })
34 | middleware.Register("hostsfile", func(cfg *config.Config) middleware.Handler { return hostsfile.New(cfg) })
35 | middleware.Register("blocklist", func(cfg *config.Config) middleware.Handler { return blocklist.New(cfg) })
36 | middleware.Register("as112", func(cfg *config.Config) middleware.Handler { return as112.New(cfg) })
37 | middleware.Register("cache", func(cfg *config.Config) middleware.Handler { return cache.New(cfg) })
38 | middleware.Register("failover", func(cfg *config.Config) middleware.Handler { return failover.New(cfg) })
39 | middleware.Register("resolver", func(cfg *config.Config) middleware.Handler { return resolver.New(cfg) })
40 | middleware.Register("forwarder", func(cfg *config.Config) middleware.Handler { return forwarder.New(cfg) })
41 | }
42 |
--------------------------------------------------------------------------------