A DNS-based ad blocker with a stylish dual-themed dashboard
3 |
4 |
5 |
6 |
⚠️ This is a work in progress application and may contain bugs or incomplete features ⚠️
7 |
8 |
9 | GoAdBlock is a lightweight, high-performance DNS-based ad blocker written in Go. It intercepts DNS queries for known advertising and tracking domains and prevents them from resolving, effectively blocking ads at the network level before they're downloaded.
10 |
11 | ## ✨ Features
12 |
13 | - DNS-level ad blocking: Blocks ads at the network level for all devices
14 | - Dual-themed dashboard: Choose between TVA (Time Variance Authority) or Cockpit interface
15 | - Real-time statistics: Monitor blocked requests, cache performance, and more
16 | - Client tracking: See which devices are making requests on your network
17 | - Performance optimized: Written in Go for high throughput and low resource usage
18 | - Self-contained binary: Single binary that includes all assets
19 | - Local caching: Improves response times for frequently accessed domains
20 | - Customizable blocklists: Add or remove domains from blocklists
21 | - Cross-platform: Works on Linux, macOS, and Windows
22 |
23 | ## 📸 Screenshots
24 |
25 |
26 |
TVA Theme
27 |
28 |
29 |
30 |
31 | ## 🚀 Installation
32 |
33 | ### Prerequisites
34 |
35 | - Go 1.18 or higher
36 |
37 | ### From Source
38 |
39 | ```sh
40 | # Clone the repository
41 | git clone https://github.com/vivek-pk/GoAdBlock.git
42 |
43 | # Navigate to the project directory
44 | cd GoAdBlock
45 |
46 | # Build the project
47 | go build -o goadblock ./cmd/server/main.go
48 |
49 | # Run the executable
50 | ./goadblock
51 | ```
52 |
53 |
61 |
62 | ## ⚙️ Configuration
63 |
64 | > ⚠️ **TODO**: This section needs to be completed/reviewed
65 |
66 | GoAdBlock can be configured using flags or a configuration file:
67 |
68 | ```sh
69 | # Run with custom DNS port
70 | ./goadblock --dns-port=5353
71 |
72 | # Run with custom web interface port
73 | ./goadblock --http-port=8080
74 |
75 | # Use a config file
76 | ./goadblock --config=config.yaml
77 | ```
78 |
79 | Example config file:
80 |
81 | ```yaml
82 | dns:
83 | port: 53
84 | upstream: '8.8.8.8'
85 | cache_size: 5000
86 | cache_ttl: 3600
87 |
88 | http:
89 | port: 8080
90 | username: 'admin'
91 | password: 'changeme'
92 |
93 | blocklists:
94 | - 'https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts'
95 | - 'https://adaway.org/hosts.txt'
96 | ```
97 |
98 | ## 📊 Usage
99 |
100 | 1. Set your router's DNS server to point to the machine running GoAdBlock
101 | 2. Or configure individual devices to use GoAdBlock as their DNS server
102 | 3. Access the dashboard at http://:8080
103 | 4. Toggle between themes using the theme switcher in the sidebar
104 | 5. Monitor blocking performance through the visual dashboard
105 |
106 |
107 | ## 🛠️ Development
108 |
109 | ### Project Structure
110 |
111 | ```
112 | /
113 | ├── cmd/
114 | │ └── server/ # Application entry point
115 | ├── internal/
116 | │ ├── api/ # Web API and dashboard
117 | │ │ ├── static/ # Static assets (JS, CSS)
118 | │ │ └── templates/ # HTML templates
119 | │ ├── blocklist/ # Blocklist management
120 | │ ├── cache/ # DNS cache implementation
121 | │ ├── config/ # Configuration handling
122 | │ └── dns/ # DNS server implementation
123 | └── pkg/ # Public packages
124 | ```
125 |
126 |
135 |
136 | ## 🤝 Contributing
137 |
138 | Contributions are welcome! Please feel free to submit a Pull Request.
139 |
140 | 1. Fork the repository
141 | 2. Create your feature branch (`git checkout -b feature/amazing-feature`)
142 | 3. Commit your changes (`git commit -m 'Add some amazing feature'`)
143 | 4. Push to the branch (`git push origin feature/amazing-feature`)
144 | 5. Open a Pull Request
145 |
146 | ## 📝 License
147 |
148 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
149 |
150 | ## 🙏 Acknowledgments
151 |
152 | - Special thanks to everyone who contributed to this project
153 | - UI themes inspired by Marvel's Time Variance Authority and aviation cockpit designs
154 | - Built with Go, Alpine.js, Chart.js, and TailwindCSS
155 |
156 |
157 |
If you find this project useful, consider giving it a star! ⭐
158 |
159 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
5 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
6 | github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
7 | github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
8 | github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
9 | github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
10 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
11 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
12 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
13 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
14 | github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
15 | github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
16 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
17 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
18 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
19 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
20 | github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
21 | github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
22 | github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
23 | github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
24 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
25 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
26 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
27 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
28 | github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
29 | github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
30 | github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
31 | github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
32 | github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
33 | github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
34 | github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
35 | github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
36 | github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
37 | github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
38 | github.com/spf13/viper v1.20.0 h1:zrxIyR3RQIOsarIrgL8+sAvALXul9jeEPa06Y0Ph6vY=
39 | github.com/spf13/viper v1.20.0/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
40 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
41 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
42 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
43 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
44 | github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
45 | github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
46 | go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
47 | go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
48 | go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
49 | go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
50 | golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
51 | golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
52 | golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
53 | golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
54 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
55 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
56 | golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
57 | golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
58 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
59 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
60 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
61 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
62 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
63 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
64 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
65 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
66 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
67 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | .
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/internal/dns/server.go:
--------------------------------------------------------------------------------
1 | package dns
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "net"
8 | "sync"
9 | "time"
10 |
11 | "github.com/miekg/dns"
12 | "github.com/vivek-pk/goadblock/internal/blocker"
13 | )
14 |
15 | // Server represents a DNS server
16 | type Server struct {
17 | blocker *blocker.Blocker
18 | notifier BlockNotifier
19 | server *dns.Server
20 | cache *DNSCache
21 | upstreamAddrs []string
22 | currentUpstream int
23 | metrics *Metrics
24 | shutdown chan struct{}
25 | apiNotifier APINotifier
26 | Ready chan struct{}
27 | blockingMode string
28 | blockingIP net.IP
29 | }
30 |
31 | type ServerConfig struct {
32 | UpstreamServers []string
33 | BlockingMode string
34 | BlockingIP string
35 | CacheSize int
36 | }
37 |
38 | type DNSCache struct {
39 | entries map[string]*CacheEntry
40 | mu sync.RWMutex
41 | }
42 |
43 | type CacheEntry struct {
44 | Answer []dns.RR
45 | ExpiresAt time.Time
46 | }
47 |
48 | type Metrics struct {
49 | TotalQueries int64
50 | BlockedQueries int64
51 | CacheHits int64
52 | CacheMisses int64
53 | mu sync.RWMutex
54 | }
55 |
56 | type APINotifier interface {
57 | AddQuery(domain string, clientIP string, blocked bool)
58 | }
59 |
60 | // BlockNotifier is an interface for components that need to be notified of blocked domains
61 | type BlockNotifier interface {
62 | OnDomainBlocked(domain string, clientIP string, reason string)
63 | }
64 |
65 | // Update NewServer function to accept config
66 | func NewServer(blocker *blocker.Blocker, apiNotifier APINotifier, config ServerConfig) *Server {
67 | // Create default config if needed
68 | if len(config.UpstreamServers) == 0 {
69 | config.UpstreamServers = []string{
70 | "8.8.8.8:53", // Google
71 | "1.1.1.1:53", // Cloudflare
72 | }
73 | }
74 | if config.BlockingMode == "" {
75 | config.BlockingMode = "zero_ip"
76 | }
77 | if config.BlockingIP == "" {
78 | config.BlockingIP = "0.0.0.0"
79 | }
80 | if config.CacheSize <= 0 {
81 | config.CacheSize = 10000
82 | }
83 |
84 | return &Server{
85 | blocker: blocker,
86 | apiNotifier: apiNotifier,
87 | cache: &DNSCache{
88 | entries: make(map[string]*CacheEntry, config.CacheSize),
89 | },
90 | upstreamAddrs: config.UpstreamServers,
91 | metrics: &Metrics{},
92 | shutdown: make(chan struct{}),
93 | Ready: make(chan struct{}),
94 | blockingMode: config.BlockingMode,
95 | blockingIP: net.ParseIP(config.BlockingIP),
96 | }
97 | }
98 |
99 | // Backward compatibility wrapper
100 | func NewServerSimple(blocker *blocker.Blocker, apiNotifier APINotifier) *Server {
101 | return NewServer(blocker, apiNotifier, ServerConfig{
102 | UpstreamServers: []string{
103 | "8.8.8.8:53", // Google
104 | "1.1.1.1:53", // Cloudflare
105 | },
106 | BlockingMode: "zero_ip",
107 | BlockingIP: "0.0.0.0",
108 | CacheSize: 10000,
109 | })
110 | }
111 |
112 | func (s *Server) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
113 | s.metrics.incrementTotal()
114 |
115 | m := new(dns.Msg)
116 | m.SetReply(r)
117 | m.Compress = false
118 |
119 | switch r.Opcode {
120 | case dns.OpcodeQuery:
121 | for _, q := range m.Question {
122 | switch q.Qtype {
123 | case dns.TypeA, dns.TypeAAAA:
124 | clientIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
125 | isBlocked, reason := s.blocker.IsBlocked(q.Name)
126 | log.Printf("DNS query: %s, blocked: %v, reason: %s", q.Name, isBlocked, reason)
127 |
128 | // Notify API server of query
129 | if s.apiNotifier != nil {
130 | s.apiNotifier.AddQuery(q.Name, clientIP, isBlocked)
131 | }
132 |
133 | if isBlocked {
134 | // Notify block listeners
135 | if s.notifier != nil {
136 | s.notifier.OnDomainBlocked(q.Name, clientIP, reason)
137 | }
138 |
139 | s.metrics.incrementBlocked()
140 | if q.Qtype == dns.TypeA {
141 | m.Answer = append(m.Answer, &dns.A{
142 | Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
143 | A: net.IPv4(0, 0, 0, 0), // Block by returning 0.0.0.0
144 | })
145 | } else {
146 | m.Answer = append(m.Answer, &dns.AAAA{
147 | Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
148 | AAAA: net.IPv6zero, // Block IPv6 too
149 | })
150 | }
151 |
152 | log.Printf("Blocked domain %s, returning null IP", q.Name)
153 | } else {
154 | // Check cache first
155 | if answer := s.checkCache(q.Name, q.Qtype); answer != nil {
156 | m.Answer = answer
157 | s.metrics.incrementCacheHit()
158 | } else {
159 | s.metrics.incrementCacheMiss()
160 | resp, err := s.queryUpstream(r)
161 | if err == nil && resp != nil {
162 | m.Answer = resp.Answer
163 | s.updateCache(q.Name, q.Qtype, resp.Answer)
164 | }
165 | }
166 | }
167 | }
168 | }
169 | }
170 |
171 | w.WriteMsg(m)
172 | }
173 |
174 | func (s *Server) queryUpstream(r *dns.Msg) (*dns.Msg, error) {
175 | // Round-robin through upstream servers
176 | s.currentUpstream = (s.currentUpstream + 1) % len(s.upstreamAddrs)
177 | return dns.Exchange(r, s.upstreamAddrs[s.currentUpstream])
178 | }
179 |
180 | func (s *Server) checkCache(name string, qtype uint16) []dns.RR {
181 | s.cache.mu.RLock()
182 | defer s.cache.mu.RUnlock()
183 |
184 | key := getCacheKey(name, qtype)
185 | if entry, exists := s.cache.entries[key]; exists && time.Now().Before(entry.ExpiresAt) {
186 | return entry.Answer
187 | }
188 | return nil
189 | }
190 |
191 | func (s *Server) updateCache(name string, qtype uint16, answer []dns.RR) {
192 | if len(answer) == 0 {
193 | return
194 | }
195 |
196 | s.cache.mu.Lock()
197 | defer s.cache.mu.Unlock()
198 |
199 | // Cache for 5 minutes
200 | s.cache.entries[getCacheKey(name, qtype)] = &CacheEntry{
201 | Answer: answer,
202 | ExpiresAt: time.Now().Add(5 * time.Minute),
203 | }
204 | }
205 |
206 | func getCacheKey(name string, qtype uint16) string {
207 | return fmt.Sprintf("%s:%d", name, qtype)
208 | }
209 |
210 | // Metrics methods
211 | func (m *Metrics) incrementTotal() {
212 | m.mu.Lock()
213 | defer m.mu.Unlock()
214 | m.TotalQueries++
215 | log.Printf("Total queries: %d", m.TotalQueries) // Debug log
216 | }
217 |
218 | func (m *Metrics) incrementBlocked() {
219 | m.mu.Lock()
220 | defer m.mu.Unlock()
221 | m.BlockedQueries++
222 | log.Printf("Blocked queries: %d", m.BlockedQueries) // Debug log
223 | }
224 |
225 | func (m *Metrics) incrementCacheHit() {
226 | m.mu.Lock()
227 | defer m.mu.Unlock()
228 | m.CacheHits++
229 | log.Printf("Cache hits: %d", m.CacheHits) // Debug log
230 | }
231 |
232 | func (m *Metrics) incrementCacheMiss() {
233 | m.mu.Lock()
234 | defer m.mu.Unlock()
235 | m.CacheMisses++
236 | log.Printf("Cache misses: %d", m.CacheMisses) // Debug log
237 | }
238 |
239 | func (s *Server) GetMetrics() *Metrics {
240 | return s.metrics
241 | }
242 |
243 | func (s *Server) Start(addr string) error {
244 | s.server = &dns.Server{Addr: addr, Net: "udp"}
245 | dns.HandleFunc(".", s.handleRequest)
246 |
247 | errChan := make(chan error, 1)
248 | go func() {
249 | errChan <- s.server.ListenAndServe()
250 | }()
251 |
252 | // Signal ready after successful bind
253 | close(s.Ready)
254 |
255 | // Wait for either shutdown signal or error
256 | select {
257 | case <-s.shutdown:
258 | return nil
259 | case err := <-errChan:
260 | return err
261 | }
262 | }
263 |
264 | func (s *Server) Shutdown(ctx context.Context) error {
265 | // Signal shutdown
266 | close(s.shutdown)
267 |
268 | // Shutdown the DNS server
269 | if s.server != nil {
270 | return s.server.Shutdown()
271 | }
272 | return nil
273 | }
274 |
275 | func logQuery(domain string, isBlocked bool, clientIP net.IP) {
276 | status := "allowed"
277 | if isBlocked {
278 | status = "blocked"
279 | }
280 | log.Printf("DNS Query from %s: %s - %s", clientIP, domain, status)
281 | }
282 |
283 | // Add this method to your DNS Server struct
284 | func (s *Server) GetBlocker() *blocker.Blocker {
285 | return s.blocker
286 | }
287 |
--------------------------------------------------------------------------------
/internal/blocker/blocker.go:
--------------------------------------------------------------------------------
1 | package blocker
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "io"
7 | "log"
8 | "net/http"
9 | "regexp"
10 | "strings"
11 | "sync"
12 | )
13 |
14 | // BlockList represents a named collection of blocked domains
15 | type BlockList struct {
16 | Name string
17 | Domains map[string]struct{}
18 | Count int
19 | }
20 |
21 | // Blocker holds domain blocking information
22 | type Blocker struct {
23 | blocklists map[string]*BlockList
24 | whitelist map[string]struct{}
25 | blockRegexes []*regexp.Regexp
26 | mu sync.RWMutex
27 | blocklistStats map[string]int // Track blocks per blocklist
28 | }
29 |
30 | // New creates a new Blocker
31 | func New() *Blocker {
32 | return &Blocker{
33 | blocklists: make(map[string]*BlockList),
34 | whitelist: make(map[string]struct{}),
35 | blockRegexes: make([]*regexp.Regexp, 0),
36 | blocklistStats: make(map[string]int),
37 | }
38 | }
39 |
40 | // Update the IsBlocked method to return both a boolean and a reason string
41 | func (b *Blocker) IsBlocked(domain string) (bool, string) {
42 | b.mu.RLock()
43 | defer b.mu.RUnlock()
44 |
45 | domain = strings.ToLower(domain)
46 | domain = strings.TrimSuffix(domain, ".") // Remove trailing dot which DNS queries often have
47 |
48 | // Check whitelist first
49 | if _, ok := b.whitelist[domain]; ok {
50 | log.Printf("Domain %s is whitelisted, allowing", domain)
51 | return false, ""
52 | }
53 |
54 | // Check exact domain match in blocklists
55 | for listName, list := range b.blocklists {
56 | if _, ok := list.Domains[domain]; ok {
57 | log.Printf("Domain %s found in blocklist %s", domain, listName)
58 | b.blocklistStats[listName]++
59 | return true, listName
60 | }
61 |
62 | // Check parent domains (subdomains)
63 | parts := strings.Split(domain, ".")
64 | for i := 1; i < len(parts); i++ {
65 | parentDomain := strings.Join(parts[i:], ".")
66 | if _, ok := list.Domains[parentDomain]; ok {
67 | log.Printf("Domain %s matched parent domain %s in blocklist %s",
68 | domain, parentDomain, listName)
69 | b.blocklistStats[listName]++
70 | return true, listName
71 | }
72 | }
73 | }
74 |
75 | // Check regex patterns
76 | for _, regex := range b.blockRegexes {
77 | if regex.MatchString(domain) {
78 | log.Printf("Domain %s matched regex pattern: %s", domain, regex.String())
79 | return true, "regex:" + regex.String()
80 | }
81 | }
82 |
83 | log.Printf("Domain %s not found in any blocklist, allowing", domain)
84 | return false, ""
85 | }
86 |
87 | // LoadFromURL loads blocked domains from a URL
88 | func (b *Blocker) LoadFromURL(url string, name string) error {
89 | if name == "" {
90 | name = url // Use URL as name if not provided
91 | }
92 |
93 | resp, err := http.Get(url)
94 | if err != nil {
95 | return err
96 | }
97 | defer resp.Body.Close()
98 |
99 | return b.loadFromReader(resp.Body, name)
100 | }
101 |
102 | func (b *Blocker) loadFromReader(reader io.Reader, listName string) error {
103 | b.mu.Lock()
104 | defer b.mu.Unlock()
105 |
106 | // Create new blocklist or get existing one
107 | list, exists := b.blocklists[listName]
108 | if !exists {
109 | list = &BlockList{
110 | Name: listName,
111 | Domains: make(map[string]struct{}),
112 | }
113 | b.blocklists[listName] = list
114 | b.blocklistStats[listName] = 0
115 | }
116 |
117 | scanner := bufio.NewScanner(reader)
118 | for scanner.Scan() {
119 | line := strings.TrimSpace(scanner.Text())
120 |
121 | // Skip empty lines and comments
122 | if line == "" || strings.HasPrefix(line, "#") {
123 | continue
124 | }
125 |
126 | // Parse hosts file format (0.0.0.0 example.com or 127.0.0.1 example.com)
127 | fields := strings.Fields(line)
128 | if len(fields) >= 2 {
129 | domain := strings.ToLower(fields[1])
130 | list.Domains[domain] = struct{}{}
131 | }
132 | }
133 |
134 | // Update count
135 | list.Count = len(list.Domains)
136 |
137 | return scanner.Err()
138 | }
139 |
140 | // LoadMultipleLists loads multiple blocklists
141 | func (b *Blocker) LoadMultipleLists(sources map[string]string) error {
142 | for name, url := range sources {
143 | if err := b.LoadFromURL(url, name); err != nil {
144 | return fmt.Errorf("failed to load blocklist %s: %w", name, err)
145 | }
146 | }
147 | return nil
148 | }
149 |
150 | // AddToWhitelist adds a domain to the whitelist
151 | func (b *Blocker) AddToWhitelist(domain string) {
152 | b.mu.Lock()
153 | defer b.mu.Unlock()
154 |
155 | domain = strings.ToLower(domain)
156 | b.whitelist[domain] = struct{}{}
157 | }
158 |
159 | // RemoveFromWhitelist removes a domain from the whitelist
160 | func (b *Blocker) RemoveFromWhitelist(domain string) {
161 | b.mu.Lock()
162 | defer b.mu.Unlock()
163 |
164 | domain = strings.ToLower(domain)
165 | delete(b.whitelist, domain)
166 | }
167 |
168 | // IsWhitelisted checks if a domain is whitelisted
169 | func (b *Blocker) IsWhitelisted(domain string) bool {
170 | b.mu.RLock()
171 | defer b.mu.RUnlock()
172 |
173 | domain = strings.ToLower(domain)
174 | _, ok := b.whitelist[domain]
175 | return ok
176 | }
177 |
178 | // AddBlockRegex adds a regex pattern for blocking
179 | func (b *Blocker) AddBlockRegex(pattern string) error {
180 | regex, err := regexp.Compile(pattern)
181 | if err != nil {
182 | return err
183 | }
184 |
185 | b.mu.Lock()
186 | defer b.mu.Unlock()
187 |
188 | b.blockRegexes = append(b.blockRegexes, regex)
189 | return nil
190 | }
191 |
192 | // RemoveBlockRegex removes a regex pattern by its string representation
193 | func (b *Blocker) RemoveBlockRegex(pattern string) {
194 | b.mu.Lock()
195 | defer b.mu.Unlock()
196 |
197 | // Find and remove the regex pattern
198 | for i, regex := range b.blockRegexes {
199 | if regex.String() == pattern {
200 | b.blockRegexes = append(b.blockRegexes[:i], b.blockRegexes[i+1:]...)
201 | break
202 | }
203 | }
204 | }
205 |
206 | // GetBlocklistStats returns statistics about blocklists
207 | func (b *Blocker) GetBlocklistStats() map[string]map[string]int {
208 | b.mu.RLock()
209 | defer b.mu.RUnlock()
210 |
211 | stats := make(map[string]map[string]int)
212 |
213 | for name, list := range b.blocklists {
214 | stats[name] = map[string]int{
215 | "domains": list.Count,
216 | "blocks": b.blocklistStats[name],
217 | }
218 | }
219 |
220 | return stats
221 | }
222 |
223 | // GetWhitelist returns the current whitelist
224 | func (b *Blocker) GetWhitelist() []string {
225 | b.mu.RLock()
226 | defer b.mu.RUnlock()
227 |
228 | whitelist := make([]string, 0, len(b.whitelist))
229 | for domain := range b.whitelist {
230 | whitelist = append(whitelist, domain)
231 | }
232 |
233 | return whitelist
234 | }
235 |
236 | // GetRegexPatterns returns the current regex patterns
237 | func (b *Blocker) GetRegexPatterns() []string {
238 | b.mu.RLock()
239 | defer b.mu.RUnlock()
240 |
241 | patterns := make([]string, len(b.blockRegexes))
242 | for i, regex := range b.blockRegexes {
243 | patterns[i] = regex.String()
244 | }
245 |
246 | return patterns
247 | }
248 |
249 | // AddDomainToBlocklist adds a domain to a specific blocklist
250 | func (b *Blocker) AddDomainToBlocklist(domain, listName string) {
251 | b.mu.Lock()
252 | defer b.mu.Unlock()
253 |
254 | domain = strings.ToLower(domain)
255 |
256 | // Create blocklist if it doesn't exist
257 | if _, exists := b.blocklists[listName]; !exists {
258 | b.blocklists[listName] = &BlockList{
259 | Name: listName,
260 | Domains: make(map[string]struct{}),
261 | }
262 | b.blocklistStats[listName] = 0
263 | }
264 |
265 | b.blocklists[listName].Domains[domain] = struct{}{}
266 | b.blocklists[listName].Count = len(b.blocklists[listName].Domains)
267 | }
268 |
269 | // RemoveDomainFromBlocklist removes a domain from a specific blocklist
270 | func (b *Blocker) RemoveDomainFromBlocklist(domain, listName string) bool {
271 | b.mu.Lock()
272 | defer b.mu.Unlock()
273 |
274 | domain = strings.ToLower(domain)
275 |
276 | // Check if blocklist exists
277 | list, exists := b.blocklists[listName]
278 | if !exists {
279 | return false
280 | }
281 |
282 | // Check if domain exists in blocklist
283 | if _, ok := list.Domains[domain]; !ok {
284 | return false
285 | }
286 |
287 | // Remove domain
288 | delete(list.Domains, domain)
289 | list.Count = len(list.Domains)
290 |
291 | return true
292 | }
293 |
--------------------------------------------------------------------------------
/internal/dns/server_test.go:
--------------------------------------------------------------------------------
1 | package dns
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "testing"
8 | "time"
9 |
10 | "github.com/miekg/dns"
11 | "github.com/vivek-pk/goadblock/internal/blocker"
12 | )
13 |
14 | // Add after the existing imports
15 | type mockNotifier struct {
16 | queries []struct {
17 | domain string
18 | clientIP string
19 | blocked bool
20 | }
21 | }
22 |
23 | func (m *mockNotifier) AddQuery(domain string, clientIP string, blocked bool) {
24 | m.queries = append(m.queries, struct {
25 | domain string
26 | clientIP string
27 | blocked bool
28 | }{domain, clientIP, blocked})
29 | }
30 |
31 | // findAvailablePort finds an available UDP port
32 | func findAvailablePort() (int, error) {
33 | addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
34 | if err != nil {
35 | return 0, err
36 | }
37 |
38 | l, err := net.ListenUDP("udp", addr)
39 | if err != nil {
40 | return 0, err
41 | }
42 | defer l.Close()
43 |
44 | return l.LocalAddr().(*net.UDPAddr).Port, nil
45 | }
46 |
47 | // Update setupTestServer function
48 | func setupTestServer(t *testing.T) (*Server, string, func()) {
49 | port, err := findAvailablePort()
50 | if err != nil {
51 | t.Fatalf("Failed to find available port: %v", err)
52 | }
53 |
54 | adblocker := blocker.New()
55 | _ = adblocker.LoadFromURL("https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts", "default")
56 |
57 | // Create mock notifier
58 | notifier := &mockNotifier{}
59 |
60 | // Create server config
61 | config := ServerConfig{
62 | UpstreamServers: []string{"8.8.8.8:53", "1.1.1.1:53"},
63 | BlockingMode: "zero_ip",
64 | BlockingIP: "0.0.0.0",
65 | CacheSize: 1000,
66 | }
67 |
68 | // Pass notifier and config to NewServer
69 | server := NewServer(adblocker, notifier, config)
70 | addr := fmt.Sprintf(":%d", port)
71 | errChan := make(chan error, 1)
72 |
73 | go func() {
74 | if err := server.Start(addr); err != nil {
75 | errChan <- err
76 | }
77 | }()
78 |
79 | // Wait for server to start
80 | startTimeout := time.After(5 * time.Second)
81 | for {
82 | select {
83 | case err := <-errChan:
84 | t.Fatalf("Server failed to start: %v", err)
85 | case <-startTimeout:
86 | t.Fatal("Server startup timed out")
87 | case <-time.After(100 * time.Millisecond):
88 | if isServerReady(port) {
89 | return server, fmt.Sprintf("127.0.0.1:%d", port), func() {
90 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
91 | defer cancel()
92 | server.Shutdown(ctx)
93 | }
94 | }
95 | }
96 | }
97 | }
98 |
99 | func isServerReady(port int) bool {
100 | c := &dns.Client{
101 | Timeout: 500 * time.Millisecond,
102 | }
103 | m := new(dns.Msg)
104 | m.SetQuestion("google.com.", dns.TypeA)
105 |
106 | _, _, err := c.Exchange(m, fmt.Sprintf("127.0.0.1:%d", port))
107 | return err == nil
108 | }
109 |
110 | func TestDNSServer(t *testing.T) {
111 | _, addr, cleanup := setupTestServer(t)
112 | defer cleanup()
113 |
114 | // Configure DNS client with timeout
115 | c := &dns.Client{
116 | Timeout: 2 * time.Second,
117 | }
118 |
119 | tests := []struct {
120 | name string
121 | domain string
122 | qtype uint16
123 | shouldBlock bool
124 | }{
125 | {"Known ad domain A", "doubleclick.net.", dns.TypeA, true},
126 | {"Known ad domain AAAA", "doubleclick.net.", dns.TypeAAAA, true},
127 | {"Google ads domain", "googleadservices.com.", dns.TypeA, true},
128 | {"Regular domain", "google.com.", dns.TypeA, false},
129 | {"Another regular domain", "github.com.", dns.TypeA, false},
130 | }
131 |
132 | for _, tt := range tests {
133 | t.Run(tt.name, func(t *testing.T) {
134 | m := new(dns.Msg)
135 | m.SetQuestion(tt.domain, tt.qtype)
136 |
137 | // Retry logic for DNS queries
138 | var resp *dns.Msg
139 | var err error
140 | for retries := 3; retries > 0; retries-- {
141 | resp, _, err = c.Exchange(m, addr)
142 | if err == nil {
143 | break
144 | }
145 | time.Sleep(100 * time.Millisecond)
146 | }
147 |
148 | if err != nil {
149 | t.Fatalf("Query failed: %v", err)
150 | }
151 |
152 | if len(resp.Answer) == 0 {
153 | t.Fatal("Expected answer section in response")
154 | }
155 |
156 | switch tt.qtype {
157 | case dns.TypeA:
158 | if a, ok := resp.Answer[0].(*dns.A); ok {
159 | isZeroIP := a.A.Equal(net.IPv4(0, 0, 0, 0))
160 | if tt.shouldBlock != isZeroIP {
161 | t.Errorf("Expected blocked=%v for %s, got IP=%v",
162 | tt.shouldBlock, tt.domain, a.A)
163 | }
164 | }
165 | case dns.TypeAAAA:
166 | if aaaa, ok := resp.Answer[0].(*dns.AAAA); ok {
167 | isZeroIP := aaaa.AAAA.Equal(net.IPv6zero)
168 | if tt.shouldBlock != isZeroIP {
169 | t.Errorf("Expected blocked=%v for %s, got IP=%v",
170 | tt.shouldBlock, tt.domain, aaaa.AAAA)
171 | }
172 | }
173 | }
174 | })
175 | }
176 | }
177 |
178 | func TestCaching(t *testing.T) {
179 | server, addr, cleanup := setupTestServer(t)
180 | defer cleanup()
181 |
182 | domain := "example.com."
183 | metrics := server.GetMetrics()
184 | initialMisses := metrics.CacheMisses
185 |
186 | // Make first query
187 | m := new(dns.Msg)
188 | m.SetQuestion(domain, dns.TypeA)
189 | c := new(dns.Client)
190 |
191 | // First query - should miss cache
192 | _, _, err := c.Exchange(m, addr)
193 | if err != nil {
194 | t.Fatalf("First query failed: %v", err)
195 | }
196 |
197 | // Second query - should hit cache
198 | _, _, err = c.Exchange(m, addr)
199 | if err != nil {
200 | t.Fatalf("Second query failed: %v", err)
201 | }
202 |
203 | if metrics.CacheHits != 1 {
204 | t.Errorf("Expected 1 cache hit, got %d", metrics.CacheHits)
205 | }
206 | if metrics.CacheMisses != initialMisses+1 {
207 | t.Errorf("Expected %d cache misses, got %d", initialMisses+1, metrics.CacheMisses)
208 | }
209 | }
210 |
211 | // Update the TestQueryNotifications function too
212 | func TestQueryNotifications(t *testing.T) {
213 | notifier := &mockNotifier{}
214 | adblocker := blocker.New()
215 | _ = adblocker.LoadFromURL("https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts", "default")
216 |
217 | // Create server config
218 | config := ServerConfig{
219 | UpstreamServers: []string{"8.8.8.8:53", "1.1.1.1:53"},
220 | BlockingMode: "zero_ip",
221 | BlockingIP: "0.0.0.0",
222 | CacheSize: 1000,
223 | }
224 |
225 | server := NewServer(adblocker, notifier, config)
226 | port, err := findAvailablePort()
227 | if err != nil {
228 | t.Fatalf("Failed to find available port: %v", err)
229 | }
230 |
231 | go server.Start(fmt.Sprintf(":%d", port))
232 | defer func() {
233 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
234 | defer cancel()
235 | server.Shutdown(ctx)
236 | }()
237 |
238 | // Wait for server to start
239 | time.Sleep(time.Second)
240 |
241 | // Make some test queries
242 | c := &dns.Client{Timeout: 2 * time.Second}
243 | addr := fmt.Sprintf("127.0.0.1:%d", port)
244 |
245 | queries := []struct {
246 | domain string
247 | shouldBlock bool
248 | }{
249 | {"google.com.", false},
250 | {"doubleclick.net.", true},
251 | {"example.com.", false},
252 | }
253 |
254 | for _, q := range queries {
255 | m := new(dns.Msg)
256 | m.SetQuestion(q.domain, dns.TypeA)
257 | _, _, err := c.Exchange(m, addr)
258 | if err != nil {
259 | t.Fatalf("Query failed for %s: %v", q.domain, err)
260 | }
261 | }
262 |
263 | // Give some time for notifications to be processed
264 | time.Sleep(100 * time.Millisecond)
265 |
266 | // Verify notifications
267 | if len(notifier.queries) != len(queries) {
268 | t.Errorf("Expected %d notifications, got %d", len(queries), len(notifier.queries))
269 | }
270 |
271 | for i, q := range queries {
272 | if i >= len(notifier.queries) {
273 | break
274 | }
275 | if notifier.queries[i].domain != q.domain {
276 | t.Errorf("Query %d: expected domain %s, got %s", i, q.domain, notifier.queries[i].domain)
277 | }
278 | if notifier.queries[i].clientIP != "127.0.0.1" {
279 | t.Errorf("Query %d: expected client IP 127.0.0.1, got %s", i, notifier.queries[i].clientIP)
280 | }
281 | if notifier.queries[i].blocked != q.shouldBlock {
282 | t.Errorf("Query %d: expected blocked=%v, got blocked=%v", i, q.shouldBlock, notifier.queries[i].blocked)
283 | }
284 | }
285 | }
286 |
--------------------------------------------------------------------------------
/internal/api/server.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import (
4 | "context"
5 | "embed"
6 | "encoding/json"
7 | "fmt"
8 | "html/template"
9 | "io/fs"
10 | "log"
11 | "net/http"
12 | "sort"
13 | "sync"
14 | "time"
15 |
16 | "github.com/google/uuid"
17 | "github.com/gorilla/mux"
18 | "github.com/vivek-pk/goadblock/internal/dns"
19 | )
20 |
21 | //go:embed templates/*
22 | var templateFS embed.FS
23 |
24 | type Query struct {
25 | ID string `json:"id"`
26 | Domain string `json:"domain"`
27 | Blocked bool `json:"blocked"`
28 | Timestamp time.Time `json:"timestamp"`
29 | }
30 |
31 | type HourlyStats struct {
32 | Requests int
33 | Blocks int
34 | }
35 |
36 | type ClientStats struct {
37 | IP string `json:"ip"`
38 | TotalQueries int64 `json:"totalQueries"`
39 | BlockedQueries int64 `json:"blockedQueries"`
40 | LastSeen time.Time `json:"lastSeen"`
41 | }
42 |
43 | type APIServer struct {
44 | dnsServer *dns.Server
45 | port int
46 | startTime time.Time
47 | recentQueries []Query
48 | queriesLock sync.RWMutex
49 | templates *template.Template
50 | server *http.Server
51 | router *mux.Router
52 | hourlyStats [24]HourlyStats
53 | hourlyStatsMu sync.RWMutex
54 | lastHourIndex int
55 | clientStats map[string]*ClientStats
56 | clientStatsMu sync.RWMutex
57 | }
58 |
59 | func NewAPIServer(dnsServer *dns.Server, port int) (*APIServer, error) {
60 | tmpl := InitTemplates()
61 |
62 | server := &APIServer{
63 | dnsServer: dnsServer,
64 | port: port,
65 | startTime: time.Now(),
66 | recentQueries: make([]Query, 0, 100),
67 | templates: tmpl,
68 | router: mux.NewRouter(), // Initialize the router
69 | clientStats: make(map[string]*ClientStats),
70 | }
71 |
72 | // Call setupRoutes to register all routes
73 | server.setupRoutes()
74 |
75 | // Set up static file serving from embedded files
76 | ServeStaticFiles(server.router)
77 |
78 | return server, nil
79 | }
80 |
81 | // In your API server code
82 | func InitTemplates() *template.Template {
83 | tmpl := template.New("")
84 |
85 | // Parse sidebar template first to make it available to other templates
86 | template.Must(tmpl.ParseFS(embeddedFiles, "templates/sidebar.html"))
87 |
88 | // Then parse all remaining templates
89 | template.Must(tmpl.ParseFS(embeddedFiles, "templates/*.html"))
90 |
91 | return tmpl
92 | }
93 |
94 | // Add method to track queries
95 | func (s *APIServer) AddQuery(domain string, clientIP string, blocked bool) {
96 | s.queriesLock.Lock()
97 | defer s.queriesLock.Unlock()
98 |
99 | query := Query{
100 | ID: uuid.New().String(),
101 | Domain: domain,
102 | Blocked: blocked,
103 | Timestamp: time.Now(),
104 | }
105 |
106 | // Add to front of slice
107 | s.recentQueries = append([]Query{query}, s.recentQueries...)
108 |
109 | // Keep only last 100 queries
110 | if len(s.recentQueries) > 100 {
111 | s.recentQueries = s.recentQueries[:100]
112 | }
113 |
114 | s.trackQuery(blocked)
115 | s.trackClientQuery(clientIP, blocked)
116 | }
117 |
118 | func (s *APIServer) trackQuery(blocked bool) {
119 | s.hourlyStatsMu.Lock()
120 | defer s.hourlyStatsMu.Unlock()
121 |
122 | currentHour := time.Now().Hour()
123 | if currentHour != s.lastHourIndex {
124 | // Roll over to new hour
125 | s.hourlyStats[currentHour] = HourlyStats{}
126 | s.lastHourIndex = currentHour
127 | }
128 |
129 | s.hourlyStats[currentHour].Requests++
130 | if blocked {
131 | s.hourlyStats[currentHour].Blocks++
132 | }
133 | }
134 |
135 | func (s *APIServer) trackClientQuery(ip string, blocked bool) {
136 | s.clientStatsMu.Lock()
137 | defer s.clientStatsMu.Unlock()
138 |
139 | stats, exists := s.clientStats[ip]
140 | if !exists {
141 | stats = &ClientStats{
142 | IP: ip,
143 | }
144 | s.clientStats[ip] = stats
145 | }
146 |
147 | stats.TotalQueries++
148 | if blocked {
149 | stats.BlockedQueries++
150 | }
151 | stats.LastSeen = time.Now()
152 | }
153 |
154 | // Add new handler for queries
155 | func (s *APIServer) handleQueries(w http.ResponseWriter, r *http.Request) {
156 | s.queriesLock.RLock()
157 | defer s.queriesLock.RUnlock()
158 |
159 | w.Header().Set("Content-Type", "application/json")
160 | json.NewEncoder(w).Encode(map[string]interface{}{
161 | "queries": s.recentQueries,
162 | })
163 | }
164 |
165 | func (s *APIServer) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
166 | s.hourlyStatsMu.RLock()
167 | defer s.hourlyStatsMu.RUnlock()
168 |
169 | currentHour := time.Now().Hour()
170 | hours := make([]string, 24)
171 | requests := make([]int, 24)
172 | blocks := make([]int, 24)
173 |
174 | for i := 0; i < 24; i++ {
175 | hour := (currentHour - 23 + i + 24) % 24
176 | hours[i] = fmt.Sprintf("%02d:00", hour)
177 | stats := s.hourlyStats[hour]
178 | requests[i] = stats.Requests
179 | blocks[i] = stats.Blocks
180 | }
181 |
182 | response := map[string]interface{}{
183 | "hours": hours,
184 | "requests": requests,
185 | "blocks": blocks,
186 | }
187 |
188 | w.Header().Set("Content-Type", "application/json")
189 | json.NewEncoder(w).Encode(response)
190 | }
191 |
192 | func (s *APIServer) handleClients(w http.ResponseWriter, r *http.Request) {
193 | s.clientStatsMu.RLock()
194 | defer s.clientStatsMu.RUnlock()
195 |
196 | clients := make([]*ClientStats, 0, len(s.clientStats))
197 | for _, stats := range s.clientStats {
198 | clients = append(clients, stats)
199 | }
200 |
201 | // Sort by last seen, most recent first
202 | sort.Slice(clients, func(i, j int) bool {
203 | return clients[i].LastSeen.After(clients[j].LastSeen)
204 | })
205 |
206 | w.Header().Set("Content-Type", "application/json")
207 | json.NewEncoder(w).Encode(map[string]interface{}{
208 | "clients": clients,
209 | })
210 | }
211 |
212 | func (s *APIServer) Start() error {
213 | s.server = &http.Server{
214 | Addr: fmt.Sprintf(":%d", s.port),
215 | Handler: s.router, // Use the router
216 | ReadTimeout: 10 * time.Second,
217 | WriteTimeout: 10 * time.Second,
218 | }
219 |
220 | return s.server.ListenAndServe()
221 | }
222 |
223 | func (s *APIServer) setupRoutes() {
224 | // Add this debug handler first to log incoming requests
225 | s.router.Use(func(next http.Handler) http.Handler {
226 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
227 | log.Printf("Request: %s %s", r.Method, r.URL.Path)
228 | next.ServeHTTP(w, r)
229 | })
230 | })
231 |
232 | // IMPORTANT: Register static file handler BEFORE other routes
233 | staticFS, err := fs.Sub(embeddedFiles, "static")
234 | if err != nil {
235 | log.Fatalf("Failed to create sub-filesystem for static files: %v", err)
236 | }
237 | fileServer := http.FileServer(http.FS(staticFS))
238 | s.router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer))
239 |
240 | // Then add your API and other routes
241 | s.router.HandleFunc("/", s.handleDashboard).Methods("GET")
242 | s.router.HandleFunc("/blocklists", s.handleBlocklistsPage).Methods("GET")
243 | s.router.HandleFunc("/settings", s.handleSettingsPage).Methods("GET")
244 | s.router.HandleFunc("/about", s.handleAboutPage).Methods("GET")
245 |
246 | // API endpoints
247 | s.router.HandleFunc("/api/v1/metrics", s.handleMetrics).Methods("GET")
248 | s.router.HandleFunc("/api/v1/status", s.handleStatus).Methods("GET")
249 | s.router.HandleFunc("/api/v1/queries", s.handleQueries).Methods("GET")
250 | s.router.HandleFunc("/api/v1/stats/hourly", s.handleHourlyStats).Methods("GET")
251 | s.router.HandleFunc("/api/v1/clients", s.handleClients).Methods("GET")
252 |
253 | // Blocklist management routes
254 | s.router.HandleFunc("/api/v1/blocklists", s.handleGetBlocklists).Methods("GET")
255 | s.router.HandleFunc("/api/v1/blocklist/domain", s.handleAddDomainToBlocklist).Methods("POST")
256 | s.router.HandleFunc("/api/v1/blocklist/domain", s.handleRemoveDomainFromBlocklist).Methods("DELETE")
257 |
258 | // Whitelist management routes
259 | s.router.HandleFunc("/api/v1/whitelist", s.handleGetWhitelist).Methods("GET")
260 | s.router.HandleFunc("/api/v1/whitelist", s.handleAddToWhitelist).Methods("POST")
261 | s.router.HandleFunc("/api/v1/whitelist", s.handleRemoveFromWhitelist).Methods("DELETE")
262 |
263 | // Regex pattern routes
264 | s.router.HandleFunc("/api/v1/regex", s.handleGetRegexPatterns).Methods("GET")
265 | s.router.HandleFunc("/api/v1/regex", s.handleAddRegexPattern).Methods("POST")
266 | s.router.HandleFunc("/api/v1/regex", s.handleRemoveRegexPattern).Methods("DELETE")
267 |
268 | // Add static file serving
269 | fs := http.FileServer(http.Dir("./internal/api/static"))
270 | s.router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fs))
271 | }
272 |
273 | func (s *APIServer) Shutdown(ctx context.Context) error {
274 | return s.server.Shutdown(ctx)
275 | }
276 |
277 | func (s *APIServer) handleDashboard(w http.ResponseWriter, r *http.Request) {
278 | s.templates.ExecuteTemplate(w, "dashboard.html", nil)
279 | }
280 |
281 | func (s *APIServer) handleMetrics(w http.ResponseWriter, r *http.Request) {
282 | metrics := s.dnsServer.GetMetrics()
283 | response := map[string]interface{}{
284 | "totalQueries": metrics.TotalQueries,
285 | "blockedQueries": metrics.BlockedQueries,
286 | "cacheHits": metrics.CacheHits,
287 | "cacheMisses": metrics.CacheMisses,
288 | }
289 |
290 | w.Header().Set("Content-Type", "application/json")
291 | if err := json.NewEncoder(w).Encode(response); err != nil {
292 | http.Error(w, "Failed to encode metrics", http.StatusInternalServerError)
293 | return
294 | }
295 | }
296 |
297 | func (s *APIServer) handleStatus(w http.ResponseWriter, r *http.Request) {
298 | status := map[string]interface{}{
299 | "status": "running",
300 | "uptime": time.Since(s.startTime).String(),
301 | }
302 | w.Header().Set("Content-Type", "application/json")
303 | json.NewEncoder(w).Encode(status)
304 | }
305 |
306 | // Add SetDNSServer method
307 | func (s *APIServer) SetDNSServer(server *dns.Server) {
308 | s.dnsServer = server
309 | }
310 |
311 | // Add handler functions for each page
312 | func (s *APIServer) handleBlocklistsPage(w http.ResponseWriter, r *http.Request) {
313 | s.templates.ExecuteTemplate(w, "blocklists.html", nil)
314 | }
315 |
316 | func (s *APIServer) handleSettingsPage(w http.ResponseWriter, r *http.Request) {
317 | s.templates.ExecuteTemplate(w, "settings.html", nil)
318 | }
319 |
320 | func (s *APIServer) handleAboutPage(w http.ResponseWriter, r *http.Request) {
321 | s.templates.ExecuteTemplate(w, "about.html", nil)
322 | }
323 |
--------------------------------------------------------------------------------
/internal/api/templates/sidebar.html:
--------------------------------------------------------------------------------
1 |
2 |