├── .dockerignore ├── .env.example ├── .env.test.example ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE.txt ├── README.md ├── api ├── api.go ├── api_test.go ├── middleware.go ├── middleware_test.go ├── queue_test.go ├── scan_test.go ├── stats.go └── stats_test.go ├── checker ├── Dockerfile ├── README.md ├── cache.go ├── cache_test.go ├── checker.go ├── cmd │ └── starttls-check │ │ ├── cmd.go │ │ └── cmd_test.go ├── domain.go ├── domain_test.go ├── hostname.go ├── hostname_test.go ├── mta_sts.go ├── mta_sts_test.go ├── result.go ├── result_test.go ├── totals.go └── totals_test.go ├── db ├── Dockerfile ├── db.go ├── scripts │ └── init_tables.sql ├── sqldb.go └── sqldb_test.go ├── docker-compose.yml ├── email ├── email.go ├── email_test.go └── template.go ├── entrypoint.sh ├── go.mod ├── go.sum ├── main.go ├── models ├── domain.go ├── domain_test.go ├── scan.go ├── token.go └── token_test.go ├── policy ├── policy.go └── policy_test.go ├── stats ├── stats.go └── stats_test.go ├── util ├── util.go └── util_test.go ├── validator ├── validator.go └── validator_test.go └── views ├── default.html.tmpl └── scan.html.tmpl /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore configuration 2 | .dockerignore 3 | docker-compose.yml 4 | Dockerfile 5 | .env.example 6 | .env 7 | .travis.yml 8 | 9 | # Ignore git 10 | .git 11 | .gitignore 12 | 13 | # Ignore compiled scanner 14 | starttls-backend 15 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Port to listen for requests on 2 | PORT=8080 3 | # Permitted domains for cross-origin requests, e.g. http://localhost:1313, separated by commas 4 | ALLOWED_ORIGINS= 5 | # Filepath to domain blacklist, eg domain_blacklist.txt 6 | DOMAIN_BLACKLIST= 7 | # Filepath to IP blacklist 8 | IP_BLACKLIST= 9 | 10 | # The name of the database, e.g. `starttls` or `starttls_dev` 11 | # (this should be created in advance) 12 | DB_NAME=starttls 13 | # Username and password for database access 14 | DB_USERNAME=postgres 15 | DB_PASSWORD=password 16 | # The database hostname, e.g. `localhost` for local development or `postgres` for Docker 17 | DB_HOST=postgres 18 | # Whether to migrate DB on startup 19 | DB_MIGRATE=false 20 | 21 | # Email sending information 22 | SMTP_USERNAME= 23 | SMTP_PASSWORD= 24 | SMTP_ENDPOINT= 25 | SMTP_PORT= 26 | SMTP_FROM_ADDRESS= 27 | 28 | # Authorize key for AWS SNS email notifications (eg. bounces) 29 | AMAZON_AUTHORIZE_KEY= 30 | 31 | # Error reporting 32 | SENTRY_URL= 33 | 34 | FRONTEND_WEBSITE_LINK= 35 | # Url aggregated scan results, for importing results of our scans of top domains 36 | REMOTE_STATS_URL= 37 | -------------------------------------------------------------------------------- /.env.test.example: -------------------------------------------------------------------------------- 1 | # Test database credentials 2 | TEST_DB_NAME=starttls_test 3 | DB_USERNAME=postgres 4 | DB_PASSWORD=password 5 | DB_HOST=postgres_test 6 | 7 | AMAZON_AUTHORIZE_KEY=test 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .env.test 3 | starttls-backend 4 | domain_blacklist.txt 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.11" 5 | 6 | addons: 7 | postgresql: "9.6" 8 | 9 | env: 10 | - TEST_DB_NAME=starttls_test GO111MODULE=on 11 | 12 | install: 13 | - go get -u golang.org/x/lint/golint 14 | - go get github.com/mattn/goveralls 15 | 16 | before_script: 17 | - psql -c 'CREATE DATABASE starttls_test;' -U postgres 18 | - psql -c "ALTER USER postgres WITH PASSWORD 'postgres';" -U postgres 19 | - psql starttls_test < db/scripts/init_tables.sql 20 | # Repeat the previous command to test idempotence of init_tables script 21 | - psql starttls_test < db/scripts/init_tables.sql 22 | 23 | script: 24 | - golint -set_exit_status ./... 25 | - go test -race -coverprofile=profile.cov -covermode=atomic -v ./... 26 | - $GOPATH/bin/goveralls -coverprofile=profile.cov -service=travis-ci 27 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.11 2 | 3 | WORKDIR /go/src/github.com/EFForg/starttls-backend 4 | 5 | RUN apt-get update && apt-get -y install postgresql-client 6 | 7 | # Download vendorized dependencies 8 | ENV GO111MODULE=on 9 | COPY go.mod . 10 | COPY go.sum . 11 | RUN go mod download 12 | 13 | # Build the binary 14 | COPY . . 15 | RUN go install . 16 | 17 | ENTRYPOINT ["/go/src/github.com/EFForg/starttls-backend/entrypoint.sh"] 18 | CMD ["/go/bin/starttls-backend"] 19 | 20 | EXPOSE 8080 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Note:** The STARTTLS Everywhere project is not currently being maintained. The information and resources on this repository may be outdated. See [this post](https://www.eff.org/deeplinks/2020/04/winding-down-starttls-everywhere-project-and-future-secure-email) for more information. If you impacted by this news, or rely on the STARTTLS Policy List, you can [read this post](https://www.eff.org/deeplinks/2020/04/technical-deep-dive-winding-down-starttls-policy-list) for a deeper dive. 2 | 3 | # STARTTLS Everywhere Backend API 4 | 5 | [![Build Status](https://travis-ci.com/EFForg/starttls-backend.svg?branch=master)](https://travis-ci.org/EFForg/starttls-backend) 6 | [![Coverage Status](https://coveralls.io/repos/github/EFForg/starttls-backend/badge.svg?branch=master)](https://coveralls.io/github/EFForg/starttls-backend?branch=master) 7 | 8 | starttls-backend is the JSON backend for starttls-everywhere.org. It provides endpoints to run security checks against email domains and manage the status of those domain's on EFF's [STARTTLS Everywhere policy list](https://github.com/EFForg/starttls-everywhere). 9 | 10 | ## Setup 11 | 1. Install `go` and `postgres`. 12 | 2. Download the project and copy the configuration file: 13 | ``` 14 | go get github.com/EFForg/starttls-backend 15 | cd $GOPATH/github.com/EFForg/starttls-backend 16 | cp .env.example .env 17 | cp .env.test.example .env.test 18 | ``` 19 | 3. Edit `.env` and `.env.test` with your postgres credentials and any other changes. 20 | 4. Ensure `postgres` is running, then run `db/scripts/init_tables.sql` in the appropriate postgres DBs in order to initialize your development and test databases. 21 | 5. Build the scanner and start serving requests: 22 | ``` 23 | go build 24 | ./starttls-backend 25 | ``` 26 | 27 | ### Via Docker 28 | ``` 29 | cp .env.example .env 30 | cp .env.test.example .env.test 31 | docker-compose build 32 | docker-compose up 33 | ``` 34 | 35 | To automatically on container start, set `DB_MIGRATE=true` in the `.env` file. 36 | 37 | ## Testing 38 | 39 | Test all packages in this repo with 40 | ``` 41 | go test -v ./... 42 | ``` 43 | 44 | The `main` and `db` packages contain integration tests that require a successful connection to the Postgres database. The remaining packages do not require the database to pass tests. 45 | 46 | ## Configuration 47 | 48 | ### No-scan domains 49 | In case of complaints or abuse, we may not want to continually scan some domains. You can set the environment variable `DOMAIN_BLACKLIST` to point to a file with a list of newline-separated domains. Attempting to scan those domains from the public-facing website will result in error codes. 50 | 51 | ## Scan API 52 | 53 | Our API objects can look a bit complicated! There's lots of information contained in a TLS scan. 54 | To request a scan: 55 | ``` 56 | POST /api/scan 57 | { "domain": "example.com" } 58 | ``` 59 | 60 | Let's break down exactly what each part of this giant nested response means. All API responses, not just scans, are wrapped in a JSON object, like: 61 | ``` 62 | { 63 | status_code: 200, 64 | message: "", 65 | response: 66 | } 67 | ``` 68 | Or even: 69 | ``` 70 | { 71 | status_code: 400, 72 | message: "query parameter domain not specified", 73 | response: {} 74 | } 75 | ``` 76 | The status codes always correspond with the HTTP status that is given for the response. `message` provides more context into why your request failed. 77 | 78 | ### Scan responses 79 | 80 | Here's an abbreviated scan response. There's extra information on these objects that help 81 | describe the errors we encountered. 82 | ``` 83 | { 84 | domain: "example.com", 85 | scandata: { 86 | status: 0, 87 | results: { // Individual hostname check results 88 | "mx.example.com": { 89 | "status": 0, 90 | "checks": { 91 | "connectivity": { "status": 0 }, 92 | "certificate": { "status": 0 }, 93 | "starttls": { "status": 0 }, 94 | "version": { "status": 0 }, 95 | } 96 | } 97 | "dummy.example.com": { 98 | "status": 3, 99 | "checks": { 100 | "connectivity": { 101 | "status": 3, 102 | "messages": [ "Error: Could not establish connection" ] 103 | }, 104 | } 105 | }, 106 | }, 107 | preferred_hostnames: ["mx.example.com"], // Hostnames we were able to connect to 108 | extra_results: {"policylist": { "status": 0 }}, 109 | }, 110 | timestamp: 0, 111 | version: 1, 112 | } 113 | ``` 114 | 115 | The meat of the response is in `scandata`, which is a JSON-ification of the `DomainResult` structure returned from the `checker` package. 116 | 117 | ### Domain results 118 | 119 | Here's a quick synopsis of the fields you see in a domain response: 120 | 121 | - `domain`: the domain name that the scan was performed on. 122 | - `status`: Whether the check succeeded overall, and some more specific common failure types. Types 4-6 are types of test failures that are particularly common. 123 | - 0: Success, all TLS tests passed. 124 | - 1: Warning, at least one TLS test produced a warning. 125 | - 2: Failure, at least one TLS test failed. 126 | - 3: Error, something went wrong during the test. 127 | - 4: NoSTARTTLS, at least one of your mailboxes did not advertise STARTTLS. 128 | - 5: CouldNotConnect, could not connect to any mailbox. 129 | - 6: BadHostnameFailure, one of your mailbox's provided certificates didn't match its hostname. 130 | - `message`: A more detailed description of the failure type. 131 | - `preferred_hostnames`: A misnomer, but refers to mailboxes that passed the connectivity test. 132 | - `mta_sts`: result for MTA STS check. 133 | - `extra_results`: A map of other security checks for this domain. 134 | - `results`: A map of mailbox hostnames to their individual results. 135 | - `timestamp`: Timestamp of when the scan was performed. 136 | - `version`: The scan API's version when it was performed. 137 | 138 | ### Hostname results 139 | 140 | Here's a sample, 141 | ``` 142 | { 143 | "status": 0, 144 | "checks": { 145 | "connectivity": { "status": 0 }, 146 | "certificate": { 147 | "status": 2, 148 | "messages": ["Hostname doesn't match any name in certificate", 149 | "Certificate root is not trusted"] 150 | }, 151 | "starttls": { "status": 0 }, 152 | "version": { "status": 0 }, 153 | } 154 | } 155 | ``` 156 | 157 | - `checks`: A result can have a suite of checks. `checks` is a map from a particular check name to its result. 158 | - `status`: The status of a particular check, or the overall suite. Can be 0 through 3, which are `Success`, `Warning`, `Failure`, `Error`. The overall suite status takes the max status of all the sub-checks. 159 | - `messages`: If status of a check isn't success, messages is where all warnings and failure messages go. 160 | 161 | ### What do we scan for? 162 | 163 | Right now, these are the checks we perform. 164 | 165 | ##### Hostname-level scans 166 | These scans are performed for every hostname-- that is, we try these things for every MX we find for the given domain. 167 | 168 | * *Connectivity*: This one is performed first. It's common for mailservers to use dummy MX records as a spam-prevention tactic, so a hostname that fails to connect doesn't automatically fail the entire TLS scan, unless *no* hostnames succeed in connectivity. 169 | * *STARTTLS*: The checker first connects to the mailbox and looks for a STARTTLS support banner. Then, we actively try to initiate a STARTTLS session. 170 | * *Certificate*: The checker checks for certificate validity, which includes (1) chaining to a valid root in Mozilla's CA store, (2) the hostname matching the certificate, and (3) the certificate being not expired. 171 | * *Version*: The checker checks your mailserver doesn't support obsolete and insecure protocols prior to TLS 1.0. 172 | 173 | ##### Domain-level scans 174 | These scans are performed for the domain itself. 175 | 176 | * *MTA-STS* We check to see whether your email domain follows the MTA-STS specification, and that the MTA-STS policy we find is valid. 177 | * *Policy List* We check to see whether your email domain is on our policy list, or queued to be added. 178 | 179 | ### Rate-limiting, caching, and no-scan lists 180 | 181 | We rate-limit several endpoints to prevent abuse and reduce load on our servers. By default, scan requests are cached-- if you're consistently updating your servers and want to check to see if it's passing, we recommend waiting a few minutes and re-scanning. 182 | 183 | In case of complaints of abuse, we may not want to continually scan some domains, who can elect to prevent automated scans from this service. 184 | -------------------------------------------------------------------------------- /api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "html/template" 7 | "io/ioutil" 8 | "log" 9 | "net/http" 10 | "os" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "golang.org/x/net/idna" 16 | 17 | "github.com/EFForg/starttls-backend/checker" 18 | "github.com/EFForg/starttls-backend/db" 19 | "github.com/EFForg/starttls-backend/email" 20 | "github.com/EFForg/starttls-backend/models" 21 | "github.com/EFForg/starttls-backend/policy" 22 | "github.com/EFForg/starttls-backend/util" 23 | raven "github.com/getsentry/raven-go" 24 | ) 25 | 26 | //////////////////////////////// 27 | // ***** REST API ***** // 28 | //////////////////////////////// 29 | 30 | // Minimum time to cache each domain scan 31 | const cacheScanTime = time.Minute 32 | 33 | // Type for performing checks against an input domain. Returns 34 | // a DomainResult object from the checker. 35 | type checkPerformer func(API, string) (checker.DomainResult, error) 36 | 37 | // API is the HTTP API that this service provides. 38 | // All requests respond with an response JSON, with fields: 39 | // { 40 | // status_code // HTTP status code of request 41 | // message // Any error message accompanying the status_code. If 200, empty. 42 | // response // Response data (as JSON) from this request. 43 | // } 44 | // Any POST request accepts either URL query parameters or data value parameters, 45 | // and prefers the latter if both are present. 46 | type API struct { 47 | Database db.Database 48 | checkDomainOverride checkPerformer 49 | List PolicyList 50 | DontScan map[string]bool 51 | Emailer EmailSender 52 | Templates map[string]*template.Template 53 | } 54 | 55 | // PolicyList interface wraps a policy-list like structure. 56 | // The most important query you can perform is to fetch the policy 57 | // for a particular domain. 58 | type PolicyList interface { 59 | HasDomain(string) bool 60 | Raw() policy.List 61 | } 62 | 63 | // EmailSender interface wraps a back-end that can send e-mails. 64 | type EmailSender interface { 65 | // SendValidation sends a validation e-mail for a particular domain, 66 | // with a particular validation token. 67 | SendValidation(*models.Domain, string) error 68 | } 69 | 70 | type response struct { 71 | StatusCode int `json:"status_code"` 72 | Message string `json:"message"` 73 | Response interface{} `json:"response"` 74 | templateName string `json:"-"` 75 | } 76 | 77 | type apiHandler func(r *http.Request) response 78 | 79 | func (api *API) checkDomain(domain string) (checker.DomainResult, error) { 80 | if api.checkDomainOverride == nil { 81 | return defaultCheck(*api, domain) 82 | } 83 | return api.checkDomainOverride(*api, domain) 84 | } 85 | 86 | func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.Request) { 87 | return func(w http.ResponseWriter, r *http.Request) { 88 | response := handler(r) 89 | if response.StatusCode == http.StatusInternalServerError { 90 | packet := raven.NewPacket(response.Message, raven.NewHttp(r)) 91 | raven.Capture(packet, nil) 92 | } 93 | if strings.Contains(r.Header.Get("accept"), "text/html") { 94 | api.writeHTML(w, response) 95 | } else { 96 | api.writeJSON(w, response) 97 | } 98 | } 99 | } 100 | 101 | func pingHandler(w http.ResponseWriter, r *http.Request) { 102 | w.WriteHeader(http.StatusOK) 103 | w.Header().Set("Content-Type", "application/json") 104 | } 105 | 106 | // RegisterHandlers binds API functions to the given http server, 107 | // and returns the resulting handler. 108 | func (api *API) RegisterHandlers(mux *http.ServeMux) http.Handler { 109 | mux.HandleFunc("/sns", HandleSESNotification(api.Database)) 110 | mux.HandleFunc("/api/scan", api.wrapper(api.scan)) 111 | // ===================================================================== 112 | // No longer exposing these endpoints due to STARTTLS Everywhere sunset. 113 | // ===================================================================== 114 | // mux.Handle("/api/queue", 115 | // throttleHandler(time.Hour, 20, http.HandlerFunc(api.wrapper(api.queue)))) 116 | // mux.HandleFunc("/api/validate", api.wrapper(api.validate)) 117 | mux.HandleFunc("/api/stats", api.wrapper(api.stats)) 118 | mux.HandleFunc("/api/ping", pingHandler) 119 | return middleware(mux) 120 | } 121 | 122 | func defaultCheck(api API, domain string) (checker.DomainResult, error) { 123 | policyChan := models.Domain{Name: domain}.AsyncPolicyListCheck(api.Database, api.List) 124 | c := checker.Checker{ 125 | Cache: &checker.ScanCache{ 126 | ScanStore: api.Database, 127 | ExpireTime: 5 * time.Minute, 128 | }, 129 | Timeout: 3 * time.Second, 130 | } 131 | result := c.CheckDomain(domain, nil) 132 | policyResult := <-policyChan 133 | result.ExtraResults["policylist"] = &policyResult 134 | return result, nil 135 | } 136 | 137 | // Scan is the handler for /api/scan. 138 | // POST /api/scan 139 | // domain: Mail domain to scan. 140 | // Scans domain and returns data from it. 141 | // GET /api/scan?domain= 142 | // Retrieves most recent scan for domain. 143 | // Both set a models.Scan JSON as the response. 144 | func (api API) scan(r *http.Request) response { 145 | domain, err := getASCIIDomain(r) 146 | if err != nil { 147 | return response{StatusCode: http.StatusBadRequest, Message: err.Error()} 148 | } 149 | // Check if we shouldn't scan this domain 150 | if api.DontScan != nil { 151 | if _, ok := api.DontScan[domain]; ok { 152 | return response{StatusCode: http.StatusTooManyRequests} 153 | } 154 | } 155 | // POST: Force scan to be conducted 156 | if r.Method == http.MethodPost { 157 | // 0. If last scan was recent and on same scan version, return cached scan. 158 | scan, err := api.Database.GetLatestScan(domain) 159 | if err == nil && scan.Version == models.ScanVersion && 160 | time.Now().Before(scan.Timestamp.Add(cacheScanTime)) { 161 | return response{ 162 | StatusCode: http.StatusOK, 163 | Response: scan, 164 | templateName: "scan", 165 | } 166 | } 167 | // 1. Conduct scan via starttls-checker 168 | scanData, err := api.checkDomain(domain) 169 | if err != nil { 170 | return response{StatusCode: http.StatusInternalServerError, Message: err.Error()} 171 | } 172 | scan = models.Scan{ 173 | Domain: domain, 174 | Data: scanData, 175 | Timestamp: time.Now(), 176 | Version: models.ScanVersion, 177 | } 178 | // 2. Put scan into DB 179 | err = api.Database.PutScan(scan) 180 | if err != nil { 181 | return response{StatusCode: http.StatusInternalServerError, Message: err.Error()} 182 | } 183 | return response{ 184 | StatusCode: http.StatusOK, 185 | Response: scan, 186 | templateName: "scan", 187 | } 188 | // GET: Just fetch the most recent scan 189 | } else if r.Method == http.MethodGet { 190 | scan, err := api.Database.GetLatestScan(domain) 191 | if err != nil { 192 | return response{StatusCode: http.StatusNotFound, Message: err.Error()} 193 | } 194 | return response{StatusCode: http.StatusOK, Response: scan} 195 | } else { 196 | return response{StatusCode: http.StatusMethodNotAllowed, 197 | Message: "/api/scan only accepts POST and GET requests"} 198 | } 199 | } 200 | 201 | // MaxHostnames is the maximum number of hostnames that can be specified for a single domain's TLS policy. 202 | const MaxHostnames = 8 203 | 204 | // Extracts relevant parameters from http.Request for a POST to /api/queue 205 | // TODO: also validate hostnames as FQDNs. 206 | func getDomainParams(r *http.Request) (models.Domain, error) { 207 | name, err := getASCIIDomain(r) 208 | if err != nil { 209 | return models.Domain{}, err 210 | } 211 | mtasts := r.FormValue("mta-sts") 212 | domain := models.Domain{ 213 | Name: name, 214 | MTASTS: mtasts == "on", 215 | State: models.StateUnconfirmed, 216 | } 217 | givenEmail, err := getParam("email", r) 218 | if err == nil { 219 | domain.Email = givenEmail 220 | } else { 221 | domain.Email = email.ValidationAddress(&domain) 222 | } 223 | queueWeeks, err := getInt("weeks", r, 4, 52, 4) 224 | if err != nil { 225 | return domain, err 226 | } 227 | domain.QueueWeeks = queueWeeks 228 | 229 | if mtasts != "on" { 230 | for _, hostname := range r.PostForm["hostnames"] { 231 | if len(hostname) == 0 { 232 | continue 233 | } 234 | if !util.ValidDomainName(strings.TrimPrefix(hostname, ".")) { 235 | return domain, fmt.Errorf("Hostname %s is invalid", hostname) 236 | } 237 | domain.MXs = append(domain.MXs, hostname) 238 | } 239 | if len(domain.MXs) == 0 { 240 | return domain, fmt.Errorf("No MX hostnames supplied for domain %s", domain.Name) 241 | } 242 | if len(domain.MXs) > MaxHostnames { 243 | return domain, fmt.Errorf("No more than 8 MX hostnames are permitted") 244 | } 245 | } 246 | return domain, nil 247 | } 248 | 249 | // Queue is the handler for /api/queue 250 | // POST /api/queue?domain= 251 | // domain: Mail domain to queue a TLS policy for. 252 | // mta_sts: "on" if domain supports MTA-STS, else "". 253 | // hostnames: List of MX hostnames to put into this domain's TLS policy. Up to 8. 254 | // Sets models.Domain object as response. 255 | // weeks (optional, default 4): How many weeks is this domain queued for. 256 | // email (optional): Contact email associated with domain. 257 | // GET /api/queue?domain= 258 | // Sets models.Domain object as response. 259 | func (api API) queue(r *http.Request) response { 260 | // POST: Insert this domain into the queue 261 | if r.Method == http.MethodPost { 262 | domain, err := getDomainParams(r) 263 | if err != nil { 264 | return badRequest(err.Error()) 265 | } 266 | ok, msg, scan := domain.IsQueueable(api.Database, api.Database, api.List) 267 | if !ok { 268 | return badRequest(msg) 269 | } 270 | domain.PopulateFromScan(scan) 271 | token, err := domain.InitializeWithToken(api.Database, api.Database) 272 | if err != nil { 273 | return serverError(err.Error()) 274 | } 275 | if err = api.Emailer.SendValidation(&domain, token); err != nil { 276 | log.Print(err) 277 | return serverError("Unable to send validation e-mail") 278 | } 279 | return response{ 280 | StatusCode: http.StatusOK, 281 | Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain.Name), 282 | } 283 | } 284 | // GET: Retrieve domain status from queue 285 | if r.Method == http.MethodGet { 286 | domainName, err := getASCIIDomain(r) 287 | if err != nil { 288 | return badRequest(err.Error()) 289 | } 290 | domainObj, err := models.GetDomain(api.Database, domainName) 291 | if err != nil { 292 | return response{StatusCode: http.StatusNotFound, Message: err.Error()} 293 | } 294 | return response{ 295 | StatusCode: http.StatusOK, 296 | Response: domainObj, 297 | } 298 | } 299 | return response{StatusCode: http.StatusMethodNotAllowed, 300 | Message: "/api/queue only accepts POST and GET requests"} 301 | } 302 | 303 | // Validate handles requests to /api/validate 304 | // POST /api/validate 305 | // token: token to validate/redeem 306 | // Sets the queued domain name as response. 307 | func (api API) validate(r *http.Request) response { 308 | token, err := getParam("token", r) 309 | if err != nil { 310 | return response{StatusCode: http.StatusBadRequest, Message: err.Error()} 311 | } 312 | if r.Method != http.MethodPost { 313 | return response{StatusCode: http.StatusMethodNotAllowed, 314 | Message: "/api/validate only accepts POST requests"} 315 | } 316 | tokenData := models.Token{Token: token} 317 | domain, userErr, dbErr := tokenData.Redeem(api.Database, api.Database) 318 | if userErr != nil { 319 | return badRequest(userErr.Error()) 320 | } 321 | if dbErr != nil { 322 | return serverError(dbErr.Error()) 323 | } 324 | return response{StatusCode: http.StatusOK, Response: domain} 325 | } 326 | 327 | // Retrieve "domain" parameter from request as ASCII 328 | // If fails, returns an error. 329 | func getASCIIDomain(r *http.Request) (string, error) { 330 | domain, err := getParam("domain", r) 331 | if err != nil { 332 | return domain, err 333 | } 334 | ascii, err := idna.ToASCII(domain) 335 | if err != nil { 336 | return "", fmt.Errorf("could not convert domain %s to ASCII (%s)", domain, err) 337 | } 338 | return ascii, nil 339 | } 340 | 341 | // Retrieves and lowercases `param` as a query parameter from `http.Request` r. 342 | // If fails, then returns an error. 343 | func getParam(param string, r *http.Request) (string, error) { 344 | unicode := r.FormValue(param) 345 | if unicode == "" { 346 | return "", fmt.Errorf("query parameter %s not specified", param) 347 | } 348 | return strings.ToLower(unicode), nil 349 | } 350 | 351 | // Retrieves `param` as a query parameter from `http.Request` r, and tries to cast it as 352 | // a number between [lowInc, highExc). If fails, then returns an error. 353 | // If `param` isn't specified, return defaultNum. 354 | func getInt(param string, r *http.Request, lowInc int, highExc int, defaultNum int) (int, error) { 355 | unicode := r.FormValue(param) 356 | if unicode == "" { 357 | return defaultNum, nil 358 | } 359 | n, err := strconv.Atoi(unicode) 360 | if err != nil { 361 | return -1, err 362 | } 363 | if n < lowInc { 364 | return n, fmt.Errorf("expected query parameter %s to be more than or equal to %d, was %d", param, lowInc, n) 365 | } 366 | if n >= highExc { 367 | return n, fmt.Errorf("expected query parameter %s to be less than %d, was %d", param, highExc, n) 368 | } 369 | return n, nil 370 | } 371 | 372 | // Writes `v` as a JSON object to http.ResponseWriter `w`. If an error 373 | // occurs, writes `http.StatusInternalServerError` to `w`. 374 | func (api *API) writeJSON(w http.ResponseWriter, apiResponse response) { 375 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 376 | w.WriteHeader(apiResponse.StatusCode) 377 | b, err := json.MarshalIndent(apiResponse, "", " ") 378 | if err != nil { 379 | msg := fmt.Sprintf("Internal error: could not format JSON. (%s)\n", err) 380 | http.Error(w, msg, http.StatusInternalServerError) 381 | return 382 | } 383 | fmt.Fprintf(w, "%s\n", b) 384 | } 385 | 386 | // ParseTemplates initializes our HTML template data 387 | func (api *API) ParseTemplates(dir string) { 388 | names := []string{"default", "scan"} 389 | api.Templates = make(map[string]*template.Template) 390 | for _, name := range names { 391 | path := fmt.Sprintf("%s/%s.html.tmpl", dir, name) 392 | tmpl, err := template.ParseFiles(path) 393 | if err != nil { 394 | raven.CaptureError(err, nil) 395 | log.Fatal(err) 396 | } 397 | api.Templates[name] = tmpl 398 | } 399 | } 400 | 401 | func (api *API) writeHTML(w http.ResponseWriter, apiResponse response) { 402 | // Add some additional useful fields for use in templates. 403 | data := struct { 404 | response 405 | BaseURL string 406 | StatusText string 407 | }{ 408 | response: apiResponse, 409 | BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"), 410 | StatusText: http.StatusText(apiResponse.StatusCode), 411 | } 412 | if apiResponse.templateName == "" { 413 | apiResponse.templateName = "default" 414 | } 415 | tmpl, ok := api.Templates[apiResponse.templateName] 416 | if !ok { 417 | err := fmt.Errorf("Template not found: %s", apiResponse.templateName) 418 | raven.CaptureError(err, nil) 419 | http.Error(w, err.Error(), http.StatusInternalServerError) 420 | return 421 | } 422 | w.WriteHeader(apiResponse.StatusCode) 423 | err := tmpl.Execute(w, data) 424 | if err != nil { 425 | log.Println(err) 426 | raven.CaptureError(err, nil) 427 | } 428 | } 429 | 430 | func badRequest(format string, a ...interface{}) response { 431 | return response{ 432 | StatusCode: http.StatusBadRequest, 433 | Message: fmt.Sprintf(format, a...), 434 | } 435 | } 436 | 437 | func serverError(format string, a ...interface{}) response { 438 | return response{ 439 | StatusCode: http.StatusInternalServerError, 440 | Message: fmt.Sprintf(format, a...), 441 | } 442 | } 443 | 444 | type ravenExtraContent string 445 | 446 | // Class satisfies raven's Interface interface so we can send this as extra context. 447 | // https://github.com/getsentry/raven-go/issues/125 448 | func (r ravenExtraContent) Class() string { 449 | return "extra" 450 | } 451 | 452 | func (r ravenExtraContent) MarshalJSON() ([]byte, error) { 453 | return []byte(r), nil 454 | } 455 | 456 | // HandleSESNotification handles AWS SES bounces and complaints submitted to a webhook 457 | // via AWS SNS (Simple Notification Service). 458 | // The SNS webhook is configured to include a secret API key stored in the environment. 459 | func HandleSESNotification(database db.Database) func(http.ResponseWriter, *http.Request) { 460 | return func(w http.ResponseWriter, r *http.Request) { 461 | keyParam := r.URL.Query()["amazon_authorize_key"] 462 | if len(keyParam) == 0 || keyParam[0] != os.Getenv("AMAZON_AUTHORIZE_KEY") { 463 | w.WriteHeader(http.StatusUnauthorized) 464 | return 465 | } 466 | 467 | body, err := ioutil.ReadAll(r.Body) 468 | if err != nil { 469 | raven.CaptureError(err, nil) 470 | return 471 | } 472 | 473 | data := &email.BlacklistRequest{} 474 | err = json.Unmarshal(body, data) 475 | if err != nil { 476 | w.WriteHeader(http.StatusInternalServerError) 477 | raven.CaptureError(err, nil, ravenExtraContent(body)) 478 | return 479 | } 480 | 481 | tags := map[string]string{"notification_type": data.Reason} 482 | raven.CaptureMessage("Received SES notification", tags, ravenExtraContent(data.Raw)) 483 | 484 | for _, recipient := range data.Recipients { 485 | err = database.PutBlacklistedEmail(recipient.EmailAddress, data.Reason, data.Timestamp) 486 | if err != nil { 487 | raven.CaptureError(err, nil) 488 | } 489 | } 490 | 491 | w.WriteHeader(http.StatusOK) 492 | } 493 | } 494 | -------------------------------------------------------------------------------- /api/api_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "os" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/EFForg/starttls-backend/checker" 15 | "github.com/EFForg/starttls-backend/db" 16 | "github.com/EFForg/starttls-backend/models" 17 | "github.com/EFForg/starttls-backend/policy" 18 | "github.com/joho/godotenv" 19 | ) 20 | 21 | var api *API 22 | var server *httptest.Server 23 | 24 | func mockCheckPerform(message string) func(API, string) (checker.DomainResult, error) { 25 | return func(api API, domain string) (checker.DomainResult, error) { 26 | return checker.NewSampleDomainResult(domain), nil 27 | } 28 | } 29 | 30 | // Mock PolicyList 31 | type mockList struct { 32 | domains map[string]bool 33 | } 34 | 35 | func (l mockList) Raw() policy.List { 36 | list := policy.List{ 37 | Timestamp: time.Now(), 38 | Expires: time.Now().Add(time.Minute), 39 | Version: "", 40 | Author: "", 41 | PolicyAliases: make(map[string]policy.TLSPolicy), 42 | Policies: make(map[string]policy.TLSPolicy), 43 | } 44 | for domain := range l.domains { 45 | list.Policies[domain] = 46 | policy.TLSPolicy{Mode: "enforce", MXs: []string{"mx.fake.com"}} 47 | } 48 | return list 49 | } 50 | 51 | func (l mockList) HasDomain(domain string) bool { 52 | _, ok := l.domains[domain] 53 | return ok 54 | } 55 | 56 | // Mock emailer 57 | type mockEmailer struct{} 58 | 59 | func (e mockEmailer) SendValidation(domain *models.Domain, token string) error { return nil } 60 | 61 | func testHTMLPost(path string, data url.Values, t *testing.T) ([]byte, int) { 62 | req, err := http.NewRequest("POST", server.URL+path, strings.NewReader(data.Encode())) 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 67 | req.Header.Set("accept", "text/html") 68 | resp, err := http.DefaultClient.Do(req) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | body, _ := ioutil.ReadAll(resp.Body) 73 | if !strings.Contains(strings.ToLower(string(body)), " 46 | ``` 47 | 48 | For instance, running `./starttls-check -domain gmail.com` will 49 | check for the TLS configurations (over SMTP) on port 25 for all the MX domains for `gmail.com`. 50 | 51 | 52 | ## Results 53 | From a preliminary STARTTLS scan on the top 1000 alexa domains, performed 3/8/2018, we found: 54 | - 20.19% of 421 unique MX hostnames don't support STARTTLS. 55 | - 36.01% of the servers which support STARTTLS didn't present valid certificates. 56 | - We're not sure how to define valid certificates. On manual inspection, although many certificates are self-signed, it seems that many of these certs are issued for other subdomains owned by the same entity. 57 | 58 | Seems like an improvement from results in [2014](https://research.google.com/pubs/pub43962.html), but we can do better! 59 | 60 | 61 | ## TODO 62 | - [ ] Check DANE 63 | - [ ] Present recommendations for issues 64 | - [ ] Tests 65 | -------------------------------------------------------------------------------- /checker/cache.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // ScanStore is an interface for using and retrieving scan results. 10 | type ScanStore interface { 11 | GetHostnameScan(string) (HostnameResult, error) 12 | PutHostnameScan(string, HostnameResult) error 13 | } 14 | 15 | // ScanCache wraps a scan storage object. When calling GetScan, only returns a scan 16 | // if there was made in the last ExpireTime window 17 | type ScanCache struct { 18 | ScanStore 19 | ExpireTime time.Duration 20 | } 21 | 22 | // GetHostnameScan retrieves the scan from underlying storage if there is one 23 | // present within the cached time window. 24 | func (c *ScanCache) GetHostnameScan(hostname string) (HostnameResult, error) { 25 | result, err := c.ScanStore.GetHostnameScan(hostname) 26 | if err != nil { 27 | return result, err 28 | } 29 | if time.Now().Sub(result.Timestamp) > c.ExpireTime { 30 | return result, fmt.Errorf("most recent scan for %s expired", hostname) 31 | } 32 | return result, nil 33 | } 34 | 35 | // PutHostnameScan puts in a scan. 36 | func (c *ScanCache) PutHostnameScan(hostname string, result HostnameResult) error { 37 | return c.ScanStore.PutHostnameScan(hostname, result) 38 | } 39 | 40 | // SimpleStore is simple HostnameResult storage backed by map. 41 | type SimpleStore struct { 42 | m map[string]HostnameResult 43 | mu sync.RWMutex 44 | } 45 | 46 | // GetHostnameScan wraps a map get. Returns error if not present in map. 47 | func (s *SimpleStore) GetHostnameScan(hostname string) (HostnameResult, error) { 48 | s.mu.RLock() 49 | defer s.mu.RUnlock() 50 | result, ok := s.m[hostname] 51 | if !ok { 52 | return result, fmt.Errorf("Couldn't find scan for hostname %s", hostname) 53 | } 54 | return result, nil 55 | } 56 | 57 | // PutHostnameScan wraps a map set. Can never return error. 58 | func (s *SimpleStore) PutHostnameScan(hostname string, result HostnameResult) error { 59 | s.mu.Lock() 60 | defer s.mu.Unlock() 61 | s.m[hostname] = result 62 | return nil 63 | } 64 | 65 | // MakeSimpleCache creates a cache with a SimpleStore backing it. 66 | func MakeSimpleCache(expiryTime time.Duration) *ScanCache { 67 | store := SimpleStore{m: make(map[string]HostnameResult)} 68 | return &ScanCache{ScanStore: &store, ExpireTime: expiryTime} 69 | } 70 | -------------------------------------------------------------------------------- /checker/cache_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestSimpleCacheMap(t *testing.T) { 9 | cache := MakeSimpleCache(time.Hour) 10 | err := cache.PutHostnameScan("anything", HostnameResult{ 11 | Result: &Result{Status: 3}, 12 | Timestamp: time.Now(), 13 | }) 14 | if err != nil { 15 | t.Errorf("Expected scan put to succeed: %v", err) 16 | } 17 | result, err := cache.GetHostnameScan("anything") 18 | if err != nil { 19 | t.Errorf("Expected scan get to succeed: %v", err) 20 | } 21 | if result.Status != 3 { 22 | t.Errorf("Expected scan to have status 3, had status %d", result.Status) 23 | } 24 | } 25 | 26 | func TestSimpleCacheExpires(t *testing.T) { 27 | cache := MakeSimpleCache(0) 28 | cache.PutHostnameScan("anything", HostnameResult{ 29 | Result: &Result{Status: 3}, 30 | Timestamp: time.Now(), 31 | }) 32 | _, err := cache.GetHostnameScan("anything") 33 | if err == nil { 34 | t.Errorf("Expected cache to expire and scan get to fail: %v", err) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /checker/checker.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "net" 5 | "time" 6 | ) 7 | 8 | // A Checker is used to run checks against SMTP domains and hostnames. 9 | type Checker struct { 10 | // Timeout specifies the maximum timeout for network requests made during 11 | // checks. 12 | // If nil, a default timeout of 10 seconds is used. 13 | Timeout time.Duration 14 | 15 | // Cache specifies the hostname scan cache store and expire time. 16 | // If `nil`, then scans are not cached. 17 | Cache *ScanCache 18 | 19 | // lookupMXOverride specifies an alternate function to retrieve hostnames for a given 20 | // domain. It is used to mock DNS lookups during testing. 21 | lookupMXOverride func(string) ([]*net.MX, error) 22 | 23 | // CheckHostname defines the function that should be used to check each hostname. 24 | // If nil, FullCheckHostname (all hostname checks) will be used. 25 | CheckHostname func(string, string, time.Duration) HostnameResult 26 | 27 | // checkMTASTSOverride is used to mock MTA-STS checks. 28 | checkMTASTSOverride func(string, map[string]HostnameResult) *MTASTSResult 29 | } 30 | 31 | func (c *Checker) timeout() time.Duration { 32 | if c.Timeout != 0 { 33 | return c.Timeout 34 | } 35 | return 10 * time.Second 36 | } 37 | -------------------------------------------------------------------------------- /checker/cmd/starttls-check/cmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/csv" 6 | "encoding/json" 7 | "flag" 8 | "fmt" 9 | "io" 10 | "log" 11 | "net/http" 12 | "os" 13 | "time" 14 | 15 | "github.com/EFForg/starttls-backend/checker" 16 | ) 17 | 18 | var out io.Writer = os.Stdout 19 | 20 | func setFlags() (domain, filePath, url *string, column *int, aggregate *bool) { 21 | flag.Usage = func() { 22 | fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) 23 | flag.PrintDefaults() 24 | } 25 | domain = flag.String("domain", "", "Domain to check") 26 | filePath = flag.String("file", "", "File path to a CSV of domains to check") 27 | url = flag.String("url", "", "URL of a CSV of domains to check") 28 | column = flag.Int("column", 0, "Zero indexed column of domains") 29 | aggregate = flag.Bool("aggregate", false, "Write aggregated MTA-STS statistics to database, specified by ENV") 30 | 31 | flag.Parse() 32 | if *domain == "" && *filePath == "" && *url == "" { 33 | flag.PrintDefaults() 34 | os.Exit(1) 35 | } 36 | if *domain != "" && (*column != 0 || *aggregate == true) { 37 | log.Println("column and aggregate are not supported for single domain checks") 38 | flag.PrintDefaults() 39 | os.Exit(1) 40 | } 41 | return 42 | } 43 | 44 | // Run a series of security checks on an MTA domain. 45 | // ================================================= 46 | // Validating (START)TLS configurations for all MX domains. 47 | func main() { 48 | domain, filePath, url, column, aggregate := setFlags() 49 | 50 | c := checker.Checker{ 51 | Cache: checker.MakeSimpleCache(10 * time.Minute), 52 | } 53 | var resultHandler checker.ResultHandler 54 | resultHandler = &domainWriter{} 55 | 56 | if *domain != "" { 57 | // Handle single domain and return 58 | result := c.CheckDomain(*domain, nil) 59 | resultHandler.HandleDomain(result) 60 | os.Exit(0) 61 | } 62 | 63 | var instream io.Reader 64 | var label string 65 | if *filePath != "" { 66 | csvFile, err := os.Open(*filePath) 67 | defer csvFile.Close() 68 | if err != nil { 69 | log.Println(err) 70 | os.Exit(1) 71 | } 72 | instream = bufio.NewReader(csvFile) 73 | label = csvFile.Name() 74 | } else { 75 | resp, err := http.Get(*url) 76 | if err != nil { 77 | log.Println(err) 78 | os.Exit(1) 79 | } 80 | instream = resp.Body 81 | label = *url 82 | } 83 | 84 | domainReader := csv.NewReader(instream) 85 | if *aggregate { 86 | c = checker.Checker{ 87 | CheckHostname: checker.NoopCheckHostname, 88 | } 89 | resultHandler = &checker.AggregatedScan{ 90 | Time: time.Now(), 91 | Source: label, 92 | } 93 | } 94 | c.CheckCSV(domainReader, resultHandler, *column) 95 | json.NewEncoder(out).Encode(resultHandler) 96 | } 97 | 98 | type domainWriter struct{} 99 | 100 | func (w domainWriter) HandleDomain(r checker.DomainResult) { 101 | b, err := json.Marshal(r) 102 | if err != nil { 103 | log.Println(err) 104 | os.Exit(1) 105 | } 106 | fmt.Fprintln(out, string(b)) 107 | } 108 | -------------------------------------------------------------------------------- /checker/cmd/starttls-check/cmd_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "net/http/httptest" 9 | "os" 10 | "regexp" 11 | "strings" 12 | "testing" 13 | "time" 14 | 15 | "github.com/EFForg/starttls-backend/checker" 16 | ) 17 | 18 | func TestUpdateStats(t *testing.T) { 19 | out = new(bytes.Buffer) 20 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | fmt.Fprintln(w, `1,foo,localhost 22 | 2,bar,localhost 23 | 3,baz,localhost`) 24 | })) 25 | defer ts.Close() 26 | 27 | oldArgs := os.Args 28 | defer func() { os.Args = oldArgs }() 29 | os.Args = []string{"starttls-checker", "--url", ts.URL, "--aggregate=true", "--column=2"} 30 | 31 | // @TODO make this faster 32 | main() 33 | got := out.(*bytes.Buffer).String() 34 | expected, err := json.Marshal(checker.AggregatedScan{ 35 | Time: time.Time{}, 36 | Source: ts.URL, 37 | Attempted: 3, 38 | }) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | timeJSON, err := json.Marshal(time.Time{}) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | re := regexp.MustCompile( 47 | strings.Replace(string(expected), string(timeJSON), ".*", 1), 48 | ) 49 | 50 | if !re.MatchString(got) { 51 | t.Errorf("Expected:\n%s\nGot:\n%s", expected, got) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /checker/domain.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "strings" 8 | "time" 9 | 10 | "golang.org/x/net/idna" 11 | ) 12 | 13 | // Reports an error during the domain checks. 14 | func (d DomainResult) reportError(err error) DomainResult { 15 | d.Status = DomainError 16 | d.Message = err.Error() 17 | return d 18 | } 19 | 20 | // DomainStatus indicates the overall status of a single domain. 21 | type DomainStatus int32 22 | 23 | // NOTE: if you change the below structures, remember to fix the documentation in `README.md`. 24 | 25 | // In order of precedence. 26 | const ( 27 | DomainSuccess DomainStatus = 0 28 | DomainWarning DomainStatus = 1 29 | DomainFailure DomainStatus = 2 30 | DomainError DomainStatus = 3 31 | DomainNoSTARTTLSFailure DomainStatus = 4 32 | DomainCouldNotConnect DomainStatus = 5 33 | DomainBadHostnameFailure DomainStatus = 6 34 | ) 35 | 36 | // DomainResult wraps all the results for a particular mail domain. 37 | type DomainResult struct { 38 | // Domain being checked against. 39 | Domain string `json:"domain"` 40 | // Message if a failure or error occurs on the domain lookup level. 41 | Message string `json:"message,omitempty"` 42 | // Status of this check, inherited from the results of preferred hostnames. 43 | Status DomainStatus `json:"status"` 44 | // Results of this check, on each hostname. 45 | HostnameResults map[string]HostnameResult `json:"results"` 46 | // The list of hostnames which will impact the Status of this result. 47 | // It discards mailboxes that we can't connect to. 48 | PreferredHostnames []string `json:"preferred_hostnames"` 49 | // Expected MX hostnames supplied by the caller of CheckDomain. 50 | MxHostnames []string `json:"mx_hostnames,omitempty"` 51 | // Result of MTA-STS checks 52 | MTASTSResult *MTASTSResult `json:"mta_sts"` 53 | // Extra global results 54 | ExtraResults map[string]*Result `json:"extra_results,omitempty"` 55 | } 56 | 57 | // Class satisfies raven's Interface interface. 58 | // https://github.com/getsentry/raven-go/issues/125 59 | func (d DomainResult) Class() string { 60 | return "extra" 61 | } 62 | 63 | func (d DomainResult) setStatus(status DomainStatus) DomainResult { 64 | d.Status = DomainStatus(SetStatus(Status(d.Status), Status(status))) 65 | return d 66 | } 67 | 68 | func lookupMXWithTimeout(domain string, timeout time.Duration) ([]*net.MX, error) { 69 | ctx, cancel := context.WithTimeout(context.TODO(), timeout) 70 | defer cancel() 71 | var r net.Resolver 72 | return r.LookupMX(ctx, domain) 73 | } 74 | 75 | // lookupHostnames retrieves the MX hostnames associated with a domain. 76 | func (c *Checker) lookupHostnames(domain string) ([]string, error) { 77 | domainASCII, err := idna.ToASCII(domain) 78 | if err != nil { 79 | return nil, fmt.Errorf("domain name %s couldn't be converted to ASCII", domain) 80 | } 81 | // Allow the Checker to mock DNS lookup. 82 | var mxs []*net.MX 83 | if c.lookupMXOverride != nil { 84 | mxs, err = c.lookupMXOverride(domain) 85 | } else { 86 | mxs, err = lookupMXWithTimeout(domainASCII, c.timeout()) 87 | } 88 | if err != nil || len(mxs) == 0 { 89 | return nil, fmt.Errorf("No MX records found") 90 | } 91 | hostnames := make([]string, 0) 92 | for _, mx := range mxs { 93 | hostnames = append(hostnames, strings.ToLower(mx.Host)) 94 | } 95 | return hostnames, nil 96 | } 97 | 98 | // CheckDomain performs all associated checks for a particular domain. 99 | // First performs an MX lookup, then performs subchecks on each of the 100 | // resulting hostnames. 101 | // 102 | // The status of DomainResult is inherited from the check status of the MX 103 | // records with highest priority. This check succeeds only if the hostname 104 | // checks on the highest priority mailservers succeed. 105 | // 106 | // `domain` is the mail domain to perform the lookup on. 107 | // `expectedHostnames` is the list of expected hostnames. 108 | // If `expectedHostnames` is nil, we don't validate the DNS lookup. 109 | func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainResult { 110 | result := DomainResult{ 111 | Domain: domain, 112 | MxHostnames: expectedHostnames, 113 | HostnameResults: make(map[string]HostnameResult), 114 | ExtraResults: make(map[string]*Result), 115 | } 116 | // 1. Look up hostnames 117 | // 2. Perform and aggregate checks from those hostnames. 118 | // 3. Set a summary message. 119 | hostnames, err := c.lookupHostnames(domain) 120 | if err != nil { 121 | return result.setStatus(DomainCouldNotConnect) 122 | } 123 | checkedHostnames := make([]string, 0) 124 | for _, hostname := range hostnames { 125 | hostnameResult := c.checkHostname(domain, hostname) 126 | result.HostnameResults[hostname] = hostnameResult 127 | if hostnameResult.couldConnect() { 128 | checkedHostnames = append(checkedHostnames, hostname) 129 | } 130 | } 131 | result.PreferredHostnames = checkedHostnames 132 | result.MTASTSResult = c.checkMTASTS(domain, result.HostnameResults) 133 | 134 | // Derive Domain code from Hostname results. 135 | if len(checkedHostnames) == 0 { 136 | // We couldn't connect to any of those hostnames. 137 | return result.setStatus(DomainCouldNotConnect) 138 | } 139 | for _, hostname := range checkedHostnames { 140 | hostnameResult := result.HostnameResults[hostname] 141 | // Any of the connected hostnames don't support STARTTLS. 142 | if !hostnameResult.couldSTARTTLS() { 143 | return result.setStatus(DomainNoSTARTTLSFailure) 144 | } 145 | // Any of the connected hostnames don't have a match? 146 | if expectedHostnames != nil && !PolicyMatches(hostname, expectedHostnames) { 147 | return result.setStatus(DomainBadHostnameFailure) 148 | } 149 | result = result.setStatus(DomainStatus(hostnameResult.Status)) 150 | } 151 | // result.setStatus(DomainStatus(result.ExtraResults["mta-sts"].Status)) 152 | return result 153 | } 154 | 155 | // NewSampleDomainResult returns a sample successful domain result for testing. 156 | // This is exported so other packages can use it in their integration tests. 157 | func NewSampleDomainResult(domain string) DomainResult { 158 | hostname := "mx." + domain 159 | return DomainResult{ 160 | Domain: domain, 161 | Status: DomainSuccess, 162 | HostnameResults: map[string]HostnameResult{ 163 | hostname: HostnameResult{ 164 | Domain: domain, 165 | Hostname: hostname, 166 | Result: &Result{ 167 | Checks: map[string]*Result{ 168 | Connectivity: MakeResult(Connectivity), 169 | STARTTLS: MakeResult(STARTTLS), 170 | Certificate: MakeResult(Certificate), 171 | Version: MakeResult(Version), 172 | }, 173 | }, 174 | }, 175 | }, 176 | PreferredHostnames: []string{hostname}, 177 | MTASTSResult: &MTASTSResult{ 178 | Result: &Result{ 179 | Status: Success, 180 | Checks: map[string]*Result{ 181 | MTASTSText: MakeResult(MTASTSText), 182 | MTASTSPolicyFile: MakeResult(MTASTSPolicyFile), 183 | }, 184 | }, 185 | Mode: "enforce", 186 | MXs: []string{"." + domain}, 187 | }, 188 | ExtraResults: map[string]*Result{ 189 | PolicyList: MakeResult(PolicyList), 190 | }, 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /checker/domain_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | // fake DNS map for "resolving" MX lookups 11 | var mxLookup = map[string][]string{ 12 | "empty": []string{}, 13 | "changes": []string{"changes"}, 14 | "domain": []string{"hostname1", "hostname2"}, 15 | "domain.tld": []string{"mail2.domain.tld", "mail1.domain.tld"}, 16 | "noconnection": []string{"noconnection", "noconnection"}, 17 | "noconnection2": []string{"noconnection", "nostarttlsconnect"}, 18 | "nostarttls": []string{"nostarttls", "noconnection"}, 19 | } 20 | 21 | // Fake hostname checks :) 22 | var hostnameResults = map[string]Result{ 23 | "noconnection": Result{ 24 | Status: 3, 25 | Checks: map[string]*Result{ 26 | Connectivity: {Connectivity, 3, nil, nil}, 27 | }, 28 | }, 29 | "nostarttls": Result{ 30 | Status: 2, 31 | Checks: map[string]*Result{ 32 | Connectivity: {Connectivity, 0, nil, nil}, 33 | STARTTLS: {STARTTLS, 2, nil, nil}, 34 | }, 35 | }, 36 | "nostarttlsconnect": Result{ 37 | Status: 3, 38 | Checks: map[string]*Result{ 39 | Connectivity: {Connectivity, 0, nil, nil}, 40 | STARTTLS: {STARTTLS, 3, nil, nil}, 41 | }, 42 | }, 43 | } 44 | 45 | func mockCheckMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult { 46 | r := MakeMTASTSResult() 47 | r.Mode = "testing" 48 | return r 49 | } 50 | 51 | func mockLookupMX(domain string) ([]*net.MX, error) { 52 | if domain == "error" { 53 | return nil, fmt.Errorf("No MX records found") 54 | } 55 | result := []*net.MX{} 56 | for _, host := range mxLookup[domain] { 57 | result = append(result, &net.MX{Host: host}) 58 | } 59 | return result, nil 60 | } 61 | 62 | func mockCheckHostname(domain string, hostname string, _ time.Duration) HostnameResult { 63 | if result, ok := hostnameResults[hostname]; ok { 64 | return HostnameResult{ 65 | Result: &result, 66 | Timestamp: time.Now(), 67 | } 68 | } 69 | // For caching test: "changes" result changes after first scan 70 | if hostname == "changes" { 71 | hostnameResults["changes"] = hostnameResults["nostarttls"] 72 | } 73 | // by default return successful check 74 | return HostnameResult{ 75 | Result: &Result{ 76 | Status: 0, 77 | Checks: map[string]*Result{ 78 | Connectivity: {Connectivity, 0, nil, nil}, 79 | STARTTLS: {STARTTLS, 0, nil, nil}, 80 | Certificate: {Certificate, 0, nil, nil}, 81 | Version: {Version, 0, nil, nil}, 82 | }, 83 | }, 84 | Timestamp: time.Now(), 85 | } 86 | } 87 | 88 | // Test helpers. 89 | 90 | // If expectedHostnames is nil, we just assume that whatever lookup occurs is correct. 91 | type domainTestCase struct { 92 | // Test case parameters 93 | domain string 94 | expectedHostnames []string 95 | // Expected result of test case. 96 | expect DomainStatus 97 | } 98 | 99 | // Perform a single test check 100 | func (test domainTestCase) check(t *testing.T, got DomainStatus) { 101 | if got != test.expect { 102 | t.Errorf("Testing %s with hostnames %s: Expected status code %d, got code %d", 103 | test.domain, test.expectedHostnames, test.expect, got) 104 | } 105 | } 106 | 107 | func performTests(t *testing.T, tests []domainTestCase) { 108 | performTestsWithCacheTimeout(t, tests, time.Hour) 109 | } 110 | 111 | func performTestsWithCacheTimeout(t *testing.T, tests []domainTestCase, cacheExpiry time.Duration) { 112 | c := Checker{ 113 | Timeout: time.Second, 114 | Cache: MakeSimpleCache(cacheExpiry), 115 | lookupMXOverride: mockLookupMX, 116 | CheckHostname: mockCheckHostname, 117 | checkMTASTSOverride: mockCheckMTASTS, 118 | } 119 | for _, test := range tests { 120 | if test.expectedHostnames == nil { 121 | test.expectedHostnames = mxLookup[test.domain] 122 | } 123 | got := c.CheckDomain(test.domain, test.expectedHostnames).Status 124 | test.check(t, got) 125 | } 126 | } 127 | 128 | // Test cases. 129 | 130 | func TestBadMXLookup(t *testing.T) { 131 | tests := []domainTestCase{ 132 | {"empty", []string{}, DomainCouldNotConnect}, 133 | } 134 | performTests(t, tests) 135 | } 136 | 137 | func TestNoExpectedHostnames(t *testing.T) { 138 | tests := []domainTestCase{ 139 | {"domain", []string{}, DomainBadHostnameFailure}, 140 | {"domain", []string{"hostname"}, DomainBadHostnameFailure}, 141 | {"domain", []string{"hostname1"}, DomainBadHostnameFailure}, 142 | {"domain", []string{"hostname1", "hostname2"}, DomainSuccess}, 143 | {"domain", nil, DomainSuccess}, 144 | } 145 | performTests(t, tests) 146 | } 147 | 148 | func TestWildcardHostnames(t *testing.T) { 149 | tests := []domainTestCase{ 150 | {"domain.tld", []string{".tld"}, DomainBadHostnameFailure}, 151 | {"domain.tld", []string{".domain.tld"}, DomainSuccess}, 152 | } 153 | performTests(t, tests) 154 | } 155 | 156 | func TestHostnamesNoConnection(t *testing.T) { 157 | tests := []domainTestCase{ 158 | {domain: "noconnection", expect: DomainCouldNotConnect}, 159 | } 160 | performTests(t, tests) 161 | } 162 | 163 | func TestHostnamesNoSTARTTLS(t *testing.T) { 164 | tests := []domainTestCase{ 165 | {domain: "nostarttls", expect: DomainNoSTARTTLSFailure}, 166 | {domain: "noconnection2", expect: DomainNoSTARTTLSFailure}, 167 | } 168 | performTests(t, tests) 169 | } 170 | 171 | func TestHostnameScanCached(t *testing.T) { 172 | // "Changes" result status should change from 0 => 5 after first scan, 173 | // but since it's cached, we should always get 0 (the result from the 174 | // first scan) 175 | delete(hostnameResults, "changes") 176 | tests := []domainTestCase{ 177 | {domain: "changes", expect: 0}, 178 | {domain: "changes", expect: 0}, 179 | {domain: "changes", expect: 0}} 180 | performTests(t, tests) 181 | } 182 | 183 | func TestHostnameScanExpires(t *testing.T) { 184 | delete(hostnameResults, "changes") 185 | tests := []domainTestCase{ 186 | {domain: "changes", expect: 0}, 187 | {domain: "changes", expect: 4}} 188 | performTestsWithCacheTimeout(t, tests, 0) 189 | } 190 | 191 | func TestNewSampleDomainResult(t *testing.T) { 192 | NewSampleDomainResult("example.com") 193 | } 194 | -------------------------------------------------------------------------------- /checker/hostname.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "net" 7 | "net/smtp" 8 | "os" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | // HostnameResult wraps the results of a security check against a particular hostname. 14 | type HostnameResult struct { 15 | *Result 16 | Domain string `json:"domain"` 17 | Hostname string `json:"hostname"` 18 | Timestamp time.Time `json:"-"` 19 | } 20 | 21 | func (h HostnameResult) couldConnect() bool { 22 | return h.subcheckSucceeded(Connectivity) 23 | } 24 | 25 | func (h HostnameResult) couldSTARTTLS() bool { 26 | return h.subcheckSucceeded(STARTTLS) 27 | } 28 | 29 | // PolicyMatches return true iff a given mx matches an array of patterns. 30 | // It is modelled after PolicyMatches in Appendix B of the MTA-STS RFC 8641. 31 | // Also used to validate hostnames on the STARTTLS Everywhere policy list. 32 | func PolicyMatches(mx string, patterns []string) bool { 33 | mx = strings.TrimSuffix(mx, ".") // If FQDN, might end with . 34 | mx = withoutPort(mx) // If URL, might include port 35 | mx = strings.ToLower(mx) // Lowercase for comparison 36 | for _, pattern := range patterns { 37 | pattern = strings.ToLower(pattern) 38 | 39 | // Literal match 40 | if pattern == mx { 41 | return true 42 | } 43 | // Wildcard match 44 | if strings.HasPrefix(pattern, "*.") || strings.HasPrefix(pattern, ".") { 45 | pattern = strings.TrimPrefix(pattern, "*") 46 | mxParts := strings.SplitN(mx, ".", 2) 47 | if len(mxParts) > 1 && mxParts[1] == pattern[1:] { 48 | return true 49 | } 50 | } 51 | } 52 | return false 53 | } 54 | 55 | func withoutPort(url string) string { 56 | if strings.Contains(url, ":") { 57 | return url[0:strings.LastIndex(url, ":")] 58 | } 59 | return url 60 | } 61 | 62 | // Retrieves this machine's hostname, if specified. 63 | func getThisHostname() string { 64 | hostname := os.Getenv("HOSTNAME") 65 | if len(hostname) == 0 { 66 | return "localhost" 67 | } 68 | return hostname 69 | } 70 | 71 | // Performs an SMTP dial with a short timeout. 72 | // https://github.com/golang/go/issues/16436 73 | func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client, error) { 74 | if _, _, err := net.SplitHostPort(hostname); err != nil { 75 | hostname += ":25" 76 | } 77 | conn, err := net.DialTimeout("tcp", hostname, timeout) 78 | if err != nil { 79 | return nil, err 80 | } 81 | client, err := smtp.NewClient(conn, hostname) 82 | if err != nil { 83 | return client, err 84 | } 85 | return client, client.Hello(getThisHostname()) 86 | } 87 | 88 | // Simply tries to StartTLS with the server. 89 | func checkStartTLS(client *smtp.Client) *Result { 90 | result := MakeResult(STARTTLS) 91 | ok, _ := client.Extension("StartTLS") 92 | if !ok { 93 | return result.Failure("Server does not advertise support for STARTTLS.") 94 | } 95 | config := tls.Config{InsecureSkipVerify: true} 96 | if err := client.StartTLS(&config); err != nil { 97 | return result.Failure("Could not complete a TLS handshake.") 98 | } 99 | return result.Success() 100 | } 101 | 102 | // If no MX matching policy was provided, then we'll default to accepting matches 103 | // based on the mail domain and the MX hostname. 104 | // 105 | // Returns a list containing the domain and hostname. 106 | func defaultValidMX(domain, hostname string) []string { 107 | if strings.HasSuffix(hostname, ".") { 108 | hostname = hostname[0 : len(hostname)-1] 109 | } 110 | return []string{domain, hostname} 111 | } 112 | 113 | // Validates that a certificate chain is valid for this system roots. 114 | func verifyCertChain(state tls.ConnectionState) error { 115 | pool := x509.NewCertPool() 116 | for _, peerCert := range state.PeerCertificates[1:] { 117 | pool.AddCert(peerCert) 118 | } 119 | _, err := state.PeerCertificates[0].Verify(x509.VerifyOptions{ 120 | Roots: certRoots, 121 | Intermediates: pool, 122 | }) 123 | return err 124 | } 125 | 126 | // certRoots is the certificate roots to use for verifying 127 | // a TLS certificate. It is nil by default so that the system 128 | // root certs are used. 129 | // 130 | // It is a global variable because it is used as a test hook. 131 | var certRoots *x509.CertPool 132 | 133 | // Checks that the certificate presented is valid for a particular hostname, unexpired, 134 | // and chains to a trusted root. 135 | func checkCert(client *smtp.Client, domain, hostname string) *Result { 136 | result := MakeResult(Certificate) 137 | state, ok := client.TLSConnectionState() 138 | if !ok { 139 | return result.Error("TLS not initiated properly.") 140 | } 141 | cert := state.PeerCertificates[0] 142 | // If hostname is an FQDN, it might end with '.' 143 | hostname = strings.TrimSuffix(hostname, ".") 144 | err := cert.VerifyHostname(withoutPort(hostname)) 145 | if err != nil { 146 | result.Failure("Name in cert doesn't match hostname: %v", err) 147 | } 148 | err = verifyCertChain(state) 149 | if err != nil { 150 | return result.Failure("Certificate root is not trusted: %v", err) 151 | } 152 | return result.Success() 153 | } 154 | 155 | func tlsConfigForCipher(ciphers []uint16) tls.Config { 156 | return tls.Config{ 157 | InsecureSkipVerify: true, 158 | CipherSuites: ciphers, 159 | } 160 | } 161 | 162 | // Checks to see that insecure ciphers are disabled. 163 | func checkTLSCipher(hostname string, timeout time.Duration) *Result { 164 | result := MakeResult("cipher") 165 | badCiphers := []uint16{ 166 | tls.TLS_RSA_WITH_RC4_128_SHA, 167 | tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 168 | tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA} 169 | client, err := smtpDialWithTimeout(hostname, timeout) 170 | if err != nil { 171 | return result.Error("Could not establish connection with hostname %s", hostname) 172 | } 173 | defer client.Close() 174 | config := tlsConfigForCipher(badCiphers) 175 | err = client.StartTLS(&config) 176 | if err == nil { 177 | return result.Failure("Server should NOT be able to negotiate any ciphers with RC4.") 178 | } 179 | return result.Success() 180 | } 181 | 182 | func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result { 183 | result := MakeResult(Version) 184 | 185 | // Check the TLS version of the existing connection. 186 | tlsConnectionState, ok := client.TLSConnectionState() 187 | if !ok { 188 | // We shouldn't end up here because we already checked that STARTTLS succeeded. 189 | return result.Error("Could not check TLS connection version.") 190 | } 191 | if tlsConnectionState.Version < tls.VersionTLS12 { 192 | result = result.Warning("Server should support TLSv1.2, but doesn't.") 193 | } 194 | 195 | // Attempt to connect with an old SSL version. 196 | client, err := smtpDialWithTimeout(hostname, timeout) 197 | if err != nil { 198 | return result.Error("Could not establish connection: %v", err) 199 | } 200 | defer client.Close() 201 | config := tls.Config{ 202 | InsecureSkipVerify: true, 203 | MinVersion: tls.VersionSSL30, 204 | MaxVersion: tls.VersionSSL30, 205 | } 206 | err = client.StartTLS(&config) 207 | if err == nil { 208 | return result.Failure("Server should NOT support SSLv2/3, but does.") 209 | } 210 | return result.Success() 211 | } 212 | 213 | // checkHostname returns the result of c.CheckHostname or FullCheckHostname, 214 | // using or updating the Checker's cache. 215 | func (c *Checker) checkHostname(domain string, hostname string) HostnameResult { 216 | check := c.CheckHostname 217 | if check == nil { 218 | // If CheckHostname hasn't been set, default to the full set of checks. 219 | check = FullCheckHostname 220 | } 221 | 222 | if c.Cache == nil { 223 | return check(domain, hostname, c.timeout()) 224 | } 225 | hostnameResult, err := c.Cache.GetHostnameScan(hostname) 226 | if err != nil { 227 | hostnameResult = check(domain, hostname, c.timeout()) 228 | c.Cache.PutHostnameScan(hostname, hostnameResult) 229 | } 230 | return hostnameResult 231 | } 232 | 233 | // NoopCheckHostname returns a fake error result containing `domain` and `hostname`. 234 | func NoopCheckHostname(domain string, hostname string, _ time.Duration) HostnameResult { 235 | r := HostnameResult{ 236 | Domain: domain, 237 | Hostname: hostname, 238 | Result: MakeResult("hostnames"), 239 | } 240 | r.addCheck(MakeResult(Connectivity).Error("Skipping hostname checks")) 241 | return r 242 | } 243 | 244 | // FullCheckHostname performs a series of checks against a hostname for an email domain. 245 | // `domain` is the mail domain that this server serves email for. 246 | // `hostname` is the hostname for this server. 247 | func FullCheckHostname(domain string, hostname string, timeout time.Duration) HostnameResult { 248 | result := HostnameResult{ 249 | Domain: domain, 250 | Hostname: hostname, 251 | Result: MakeResult("hostnames"), 252 | Timestamp: time.Now(), 253 | } 254 | 255 | // Connect to the SMTP server and use that connection to perform as many checks as possible. 256 | connectivityResult := MakeResult(Connectivity) 257 | client, err := smtpDialWithTimeout(hostname, timeout) 258 | if err != nil { 259 | result.addCheck(connectivityResult.Error("Could not establish connection: %v", err)) 260 | return result 261 | } 262 | defer client.Close() 263 | result.addCheck(connectivityResult.Success()) 264 | 265 | result.addCheck(checkStartTLS(client)) 266 | if result.Status != Success { 267 | return result 268 | } 269 | result.addCheck(checkCert(client, domain, hostname)) 270 | // result.addCheck(checkTLSCipher(hostname)) 271 | 272 | // Creates a new connection to check for SSLv2/3 support because we can't call starttls twice. 273 | result.addCheck(checkTLSVersion(client, hostname, timeout)) 274 | 275 | return result 276 | } 277 | -------------------------------------------------------------------------------- /checker/hostname_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "bufio" 5 | "crypto/rand" 6 | "crypto/tls" 7 | "crypto/x509" 8 | "encoding/pem" 9 | "math/big" 10 | "net" 11 | "os" 12 | "strings" 13 | "testing" 14 | "time" 15 | 16 | "github.com/mhale/smtpd" 17 | ) 18 | 19 | func TestMain(m *testing.M) { 20 | certString = createCert(key, "localhost") 21 | certStringHostnameMismatch = createCert(key, "you_give_love_a_bad_name") 22 | code := m.Run() 23 | os.Exit(code) 24 | } 25 | 26 | const testTimeout = 250 * time.Millisecond 27 | 28 | // Code follows pattern from crypto/tls/generate_cert.go 29 | // to generate a cert from a PEM-encoded RSA private key. 30 | func createCert(keyData string, commonName string) string { 31 | // 1. Convert privkey from PEM to DER. 32 | block, _ := pem.Decode([]byte(key)) 33 | privKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes) 34 | // 2. Generate cert with private key. 35 | template := x509.Certificate{ 36 | SerialNumber: big.NewInt(0), 37 | NotBefore: time.Now(), 38 | NotAfter: time.Now().Add(time.Minute), 39 | IsCA: true, 40 | DNSNames: []string{commonName}, 41 | } 42 | certDER, _ := x509.CreateCertificate(rand.Reader, &template, &template, &(privKey.PublicKey), privKey) 43 | // 3. Convert cert to PEM format (for consumption by crypto/tls) 44 | b := pem.Block{Type: "CERTIFICATE", Bytes: certDER} 45 | certPEM := pem.EncodeToMemory(&b) 46 | return string(certPEM) 47 | } 48 | 49 | func TestPolicyMatch(t *testing.T) { 50 | var tests = []struct { 51 | hostname string 52 | policyMX string 53 | want bool 54 | }{ 55 | // Equal matches 56 | {"example.com", "example.com", true}, 57 | {"mx.example.com", "mx.example.com", true}, 58 | 59 | // Not equal matches 60 | {"different.org", "example.com", false}, 61 | {"not.example.com", "example.com", false}, 62 | 63 | // base domain shouldn't match wildcard 64 | {"example.com", ".example.com", false}, 65 | {"example.com", "*.example.com", false}, 66 | 67 | // Invalid wildcard shouldn't match. 68 | {"mx.example.com", "*mx.example.com", false}, 69 | 70 | // Single-level subdomain match 71 | {"mx.example.com", "*.example.com", true}, 72 | {"mx.example.com", ".example.com", true}, 73 | {"mx.mx.example.com", "*.mx.example.com", true}, 74 | {"mx.mx.example.com", ".mx.example.com", true}, 75 | 76 | // Wildcard may match left-most label only 77 | {"mx.example.com", "mx.*.com", false}, 78 | 79 | // No multi-level subdomain matching. 80 | {"mx.mx.example.com", "*.example.com", false}, 81 | 82 | // No partial subdomain matches 83 | {"mx.example.com", "mx.*ple.com", false}, 84 | 85 | // Hostname should not use wildcards. 86 | {"*.example.com", "mx.example.com", false}, 87 | {"*.example.com", "mx.mx.example.com", false}, 88 | {"*.example.com", ".mx.example.com", false}, 89 | 90 | // Some more edge cases 91 | {"mx.example.com", "..example.com", false}, 92 | } 93 | 94 | for _, test := range tests { 95 | policy := []string{test.policyMX} 96 | if got := PolicyMatches(test.hostname, policy); got != test.want { 97 | t.Errorf("policyMatch(%q, %q) = %v", test.hostname, policy, got) 98 | } 99 | } 100 | } 101 | 102 | func TestNoConnection(t *testing.T) { 103 | result := FullCheckHostname("", "example.com", testTimeout) 104 | 105 | expected := Result{ 106 | Status: 3, 107 | Checks: map[string]*Result{ 108 | "connectivity": {Connectivity, 3, nil, nil}, 109 | }, 110 | } 111 | compareStatuses(t, expected, result) 112 | } 113 | 114 | func TestNoTLS(t *testing.T) { 115 | ln := smtpListenAndServe(t, &tls.Config{}) 116 | defer ln.Close() 117 | 118 | result := FullCheckHostname("", ln.Addr().String(), testTimeout) 119 | 120 | expected := Result{ 121 | Status: 2, 122 | Checks: map[string]*Result{ 123 | Connectivity: {Connectivity, 0, nil, nil}, 124 | STARTTLS: {STARTTLS, 2, nil, nil}, 125 | }, 126 | } 127 | compareStatuses(t, expected, result) 128 | } 129 | 130 | func TestSelfSigned(t *testing.T) { 131 | cert, err := tls.X509KeyPair([]byte(certString), []byte(key)) 132 | if err != nil { 133 | t.Fatal(err) 134 | } 135 | ln := smtpListenAndServe(t, &tls.Config{Certificates: []tls.Certificate{cert}}) 136 | defer ln.Close() 137 | 138 | result := FullCheckHostname("", ln.Addr().String(), testTimeout) 139 | 140 | expected := Result{ 141 | Status: 2, 142 | Checks: map[string]*Result{ 143 | Connectivity: {Connectivity, 0, nil, nil}, 144 | STARTTLS: {STARTTLS, 0, nil, nil}, 145 | Certificate: {Certificate, 2, nil, nil}, 146 | Version: {Version, 0, nil, nil}, 147 | }, 148 | } 149 | compareStatuses(t, expected, result) 150 | } 151 | 152 | func TestNoTLS12(t *testing.T) { 153 | cert, err := tls.X509KeyPair([]byte(certString), []byte(key)) 154 | if err != nil { 155 | t.Fatal(err) 156 | } 157 | ln := smtpListenAndServe(t, &tls.Config{ 158 | MinVersion: tls.VersionTLS11, 159 | MaxVersion: tls.VersionTLS11, 160 | Certificates: []tls.Certificate{cert}, 161 | }) 162 | defer ln.Close() 163 | 164 | result := FullCheckHostname("", ln.Addr().String(), testTimeout) 165 | 166 | expected := Result{ 167 | Status: 2, 168 | Checks: map[string]*Result{ 169 | Connectivity: {Connectivity, 0, nil, nil}, 170 | STARTTLS: {STARTTLS, 0, nil, nil}, 171 | Certificate: {Certificate, 2, nil, nil}, 172 | Version: {Version, 1, nil, nil}, 173 | }, 174 | } 175 | compareStatuses(t, expected, result) 176 | } 177 | 178 | func TestSuccessWithFakeCA(t *testing.T) { 179 | cert, err := tls.X509KeyPair([]byte(certString), []byte(key)) 180 | if err != nil { 181 | t.Fatal(err) 182 | } 183 | ln := smtpListenAndServe(t, &tls.Config{Certificates: []tls.Certificate{cert}}) 184 | defer ln.Close() 185 | 186 | certRoots, _ = x509.SystemCertPool() 187 | certRoots.AppendCertsFromPEM([]byte(certString)) 188 | defer func() { 189 | certRoots = nil 190 | }() 191 | 192 | // Our test cert happens to be valid for hostname "localhost", 193 | // so here we replace the loopback address with "localhost" while 194 | // conserving the port number. 195 | addrParts := strings.Split(ln.Addr().String(), ":") 196 | port := addrParts[len(addrParts)-1] 197 | result := FullCheckHostname("", "localhost:"+port, testTimeout) 198 | expected := Result{ 199 | Status: 0, 200 | Checks: map[string]*Result{ 201 | Connectivity: {Connectivity, 0, nil, nil}, 202 | STARTTLS: {STARTTLS, 0, nil, nil}, 203 | Certificate: {Certificate, 0, nil, nil}, 204 | Version: {Version, 0, nil, nil}, 205 | }, 206 | } 207 | compareStatuses(t, expected, result) 208 | } 209 | 210 | // Tests that the checker successfully initiates an SMTP connection with mail 211 | // servers that use a greet delay. 212 | func TestSuccessWithDelayedGreeting(t *testing.T) { 213 | ln, err := net.Listen("tcp", "localhost:0") 214 | if err != nil { 215 | t.Fatal(err) 216 | } 217 | defer ln.Close() 218 | go ServeDelayedGreeting(ln, t) 219 | 220 | client, err := smtpDialWithTimeout(ln.Addr().String(), testTimeout) 221 | if err != nil { 222 | t.Fatal(err) 223 | } 224 | client.Close() 225 | } 226 | 227 | func ServeDelayedGreeting(ln net.Listener, t *testing.T) { 228 | conn, err := ln.Accept() 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | defer conn.Close() 233 | 234 | time.Sleep(testTimeout + 100*time.Millisecond) 235 | _, err = conn.Write([]byte("220 localhost ESMTP\n")) 236 | if err != nil { 237 | t.Fatal(err) 238 | } 239 | line, err := bufio.NewReader(conn).ReadString('\n') 240 | if err != nil { 241 | t.Fatal(err) 242 | } 243 | if !strings.Contains(line, "EHLO localhost") { 244 | t.Fatalf("unexpected response from checker: %s", line) 245 | } 246 | 247 | _, err = conn.Write([]byte("250 HELO\n")) 248 | if err != nil { 249 | t.Fatal(err) 250 | } 251 | } 252 | 253 | func TestFailureWithBadHostname(t *testing.T) { 254 | cert, err := tls.X509KeyPair([]byte(certString), []byte(key)) 255 | if err != nil { 256 | t.Fatal(err) 257 | } 258 | ln := smtpListenAndServe(t, &tls.Config{Certificates: []tls.Certificate{cert}}) 259 | defer ln.Close() 260 | 261 | certRoots, _ = x509.SystemCertPool() 262 | certRoots.AppendCertsFromPEM([]byte(certStringHostnameMismatch)) 263 | defer func() { 264 | certRoots = nil 265 | }() 266 | 267 | // Our test cert happens to be valid for hostname "localhost", 268 | // so here we replace the loopback address with "localhost" while 269 | // conserving the port number. 270 | addrParts := strings.Split(ln.Addr().String(), ":") 271 | port := addrParts[len(addrParts)-1] 272 | result := FullCheckHostname("", "localhost:"+port, testTimeout) 273 | expected := Result{ 274 | Status: 2, 275 | Checks: map[string]*Result{ 276 | Connectivity: {Connectivity, 0, nil, nil}, 277 | STARTTLS: {STARTTLS, 0, nil, nil}, 278 | Certificate: {Certificate, 2, nil, nil}, 279 | Version: {Version, 0, nil, nil}, 280 | }, 281 | } 282 | compareStatuses(t, expected, result) 283 | } 284 | 285 | func TestAdvertisedCiphers(t *testing.T) { 286 | cert, err := tls.X509KeyPair([]byte(certString), []byte(key)) 287 | if err != nil { 288 | t.Fatal(err) 289 | } 290 | 291 | var cipherSuites []uint16 292 | // GetConfigForClient is a callback that lets us alter the TLSConfig 293 | // based on the client hello. Here we just use it to check which ciphers 294 | // are advertised by the client. 295 | // 296 | // Alternatively, we could use the CipherSuites attribute to attempt a 297 | // separate connection with each cipher. 298 | tlsConfig := &tls.Config{ 299 | Certificates: []tls.Certificate{cert}, 300 | GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 301 | if len(cipherSuites) == 0 { 302 | // Throw out the second connection to the mailserver 303 | // where we intentionally advertised insecure ciphers. 304 | cipherSuites = info.CipherSuites 305 | } 306 | return &tls.Config{Certificates: []tls.Certificate{cert}}, nil 307 | }, 308 | } 309 | 310 | ln := smtpListenAndServe(t, tlsConfig) 311 | defer ln.Close() 312 | FullCheckHostname("", ln.Addr().String(), testTimeout) 313 | 314 | // Partial list of ciphers we want to support 315 | expectedCipherSuites := []struct { 316 | val uint16 317 | desc string 318 | }{ 319 | {tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"}, 320 | {tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305"}, 321 | } 322 | for _, expected := range expectedCipherSuites { 323 | if !containsCipherSuite(cipherSuites, expected.val) { 324 | t.Errorf("expected check to advertise ciphersuite %s", expected.desc) 325 | } 326 | } 327 | } 328 | 329 | func containsCipherSuite(result []uint16, want uint16) bool { 330 | for _, candidate := range result { 331 | if want == candidate { 332 | return true 333 | } 334 | } 335 | return false 336 | } 337 | 338 | // compareStatuses compares the status for the HostnameResult and each Check with a desired value 339 | func compareStatuses(t *testing.T, expected Result, result HostnameResult) { 340 | if result.Status != expected.Status { 341 | t.Errorf("hostname status = %d, want %d", result.Status, expected.Status) 342 | } 343 | 344 | if len(result.Checks) > len(expected.Checks) { 345 | t.Errorf("result contains too many checks\n expected %v\n want %v", result.Checks, expected.Checks) 346 | } 347 | 348 | for _, c := range expected.Checks { 349 | if got := result.Checks[c.Name].Status; got != c.Status { 350 | t.Errorf("%s status = %d, want %d", c.Name, got, c.Status) 351 | } 352 | } 353 | } 354 | 355 | // smtpListenAndServe creates a test smtp server to run checks on. 356 | // We use this rather than smtpd.ListenAndServe so that we can use net.Listen 357 | // to assign a random available port. 358 | func smtpListenAndServe(t *testing.T, tlsConfig *tls.Config) net.Listener { 359 | srv := &smtpd.Server{ 360 | Handler: noopHandler, 361 | Hostname: "example.com", 362 | } 363 | srv.TLSConfig = tlsConfig 364 | 365 | ln, err := net.Listen("tcp", "localhost:0") 366 | if err != nil { 367 | t.Fatal(err) 368 | } 369 | 370 | go func() { 371 | if err := srv.Serve(ln); err != nil { 372 | if strings.Contains(err.Error(), "closed") { 373 | return 374 | } 375 | t.Fatal(err) 376 | } 377 | }() 378 | 379 | return ln 380 | } 381 | 382 | func noopHandler(_ net.Addr, _ string, _ []string, _ []byte) {} 383 | 384 | var certString string 385 | var certStringHostnameMismatch string 386 | 387 | const key = `-----BEGIN RSA PRIVATE KEY----- 388 | MIICXQIBAAKBgQC7BhTtrZkgD7Q0fGHHBl4TRrEFO2KmN93MVZdTob2S3nwWsFUo 389 | aP9Jx4WsQ0F+MwP2nKTS52LvTCqPyD9VFp9XS52Mtq6cylK+UTkKAQnSVu14g5dS 390 | 0gAbM914zxO1NFp/9C4iCi0qaKWzPCGLCIEoqkb7+HlYQekBkJHR3Tzq3QIDAQAB 391 | AoGBALL2RuCI1ZYQcOgofYftV+gqJQpUoTldDCiTXpLwmm8H5sXvRg29K0x2WDtW 392 | wDz6pDg//Ji0Qb+qqq+bdr79PsquUon6G+t9LWFQ6F1qD7JRssBr5FPAfWFij2pm 393 | zH61dX/j/kas67W+23H4k0Rc3oExaPF4gecc/EJaQ4Wc5EohAkEA6GaMhlwsONhv 394 | TbW3FIOm54obvLhS0XDrdig8CIl7+x6KSBsHBmLv+MDh/DRywwv5sOR6Sg6HGMAc 395 | 4pNsk6UOXwJBAM4D7HHfqMyuiKDIiAwdjPn/Ux2nlQe05d7iai0nSEVEfneaGX/g 396 | r4C1Gg8VDA6U94XE/S9d60IpUg4DwH9W2EMCQCufxFUcTDjHd+0wZRN2uwfPhvFf 397 | 8DvcZHajitFXbWxwCSkL2b+7JqydGE6NUdWHE/G+ka4BGB7vQPzPC5yTaSUCQAn3 398 | Ap7XdLDB2HX+fSYo38LP6NNMYdcHlv7a8MvSVJqVH5DlcUpQMe0F1YbZO8YQypA7 399 | 4QtDfberi/6Fi/Ac4UUCQQDHf89gtZYZKfeTBMRwaer7yG/UovX2AJSkCB34BGxn 400 | gIxzlen/RRmXtBGCR5G24n08/2AJaMeI/8sJWM8or9cs 401 | -----END RSA PRIVATE KEY-----` 402 | -------------------------------------------------------------------------------- /checker/mta_sts.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | "regexp" 12 | "strconv" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | // MTASTSResult represents the result of a check for inbound MTA-STS support. 18 | type MTASTSResult struct { 19 | *Result 20 | Policy string // Text of MTA-STS policy file 21 | Mode string 22 | MXs []string 23 | } 24 | 25 | // MakeMTASTSResult constructs a base result object and returns its pointer. 26 | func MakeMTASTSResult() *MTASTSResult { 27 | return &MTASTSResult{ 28 | Result: MakeResult(MTASTS), 29 | } 30 | } 31 | 32 | // MarshalJSON prevents MTASTSResult from inheriting the version of MarshalJSON 33 | // implemented by Result. 34 | func (m MTASTSResult) MarshalJSON() ([]byte, error) { 35 | // type FakeMTASTSResult MTASTSResult 36 | type FakeResult Result 37 | return json.Marshal(struct { 38 | FakeResult 39 | Policy string `json:"policy"` 40 | Mode string `json:"mode"` 41 | MXs []string `json:"mxs"` 42 | }{ 43 | FakeResult: FakeResult(*m.Result), 44 | Policy: m.Policy, 45 | Mode: m.Mode, 46 | MXs: m.MXs, 47 | }) 48 | } 49 | 50 | func filterByPrefix(records []string, prefix string) []string { 51 | filtered := []string{} 52 | for _, elem := range records { 53 | if strings.HasPrefix(elem, prefix) { 54 | filtered = append(filtered, elem) 55 | } 56 | } 57 | return filtered 58 | } 59 | 60 | func getKeyValuePairs(record string, lineDelimiter string, 61 | pairDelimiter string) map[string]string { 62 | parsed := make(map[string]string) 63 | for _, line := range strings.Split(record, lineDelimiter) { 64 | split := strings.Split(strings.TrimSpace(line), pairDelimiter) 65 | if len(split) != 2 { 66 | continue 67 | } 68 | key := strings.TrimSpace(split[0]) 69 | value := strings.TrimSpace(split[1]) 70 | if parsed[key] == "" { 71 | parsed[key] = value 72 | } else { 73 | parsed[key] = parsed[key] + " " + value 74 | } 75 | } 76 | return parsed 77 | } 78 | 79 | func checkMTASTSRecord(domain string, timeout time.Duration) *Result { 80 | result := MakeResult(MTASTSText) 81 | ctx, cancel := context.WithTimeout(context.Background(), timeout) 82 | defer cancel() 83 | var r net.Resolver 84 | records, err := r.LookupTXT(ctx, fmt.Sprintf("_mta-sts.%s", domain)) 85 | if err != nil { 86 | return result.Failure("Couldn't find an MTA-STS TXT record: %v.", err) 87 | } 88 | return validateMTASTSRecord(records, result) 89 | } 90 | 91 | func validateMTASTSRecord(records []string, result *Result) *Result { 92 | records = filterByPrefix(records, "v=STSv1") 93 | if len(records) != 1 { 94 | return result.Failure("Exactly 1 MTA-STS TXT record required, found %d.", len(records)) 95 | } 96 | record := getKeyValuePairs(records[0], ";", "=") 97 | 98 | idPattern := regexp.MustCompile("^[a-zA-Z0-9]+$") 99 | if !idPattern.MatchString(record["id"]) { 100 | return result.Failure("Invalid MTA-STS TXT record id %s.", record["id"]) 101 | } 102 | return result.Success() 103 | } 104 | 105 | func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameResult, timeout time.Duration) (*Result, string, map[string]string) { 106 | result := MakeResult(MTASTSPolicyFile) 107 | client := &http.Client{ 108 | Timeout: timeout, 109 | // Don't follow redirects. 110 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 111 | return http.ErrUseLastResponse 112 | }, 113 | } 114 | policyURL := fmt.Sprintf("https://mta-sts.%s/.well-known/mta-sts.txt", domain) 115 | resp, err := client.Get(policyURL) 116 | if err != nil { 117 | return result.Failure("Couldn't find policy file at %s.", policyURL), "", map[string]string{} 118 | } 119 | if resp.StatusCode != 200 { 120 | return result.Failure("Couldn't get policy file: %s returned %s.", policyURL, resp.Status), "", map[string]string{} 121 | } 122 | // Media type should be text/plain, ignoring other Content-Type parms. 123 | // Format: Content-Type := type "/" subtype *[";" parameter] 124 | for _, contentType := range resp.Header["Content-Type"] { 125 | contentType := strings.ToLower(contentType) 126 | if !strings.HasPrefix(contentType, "text/plain") { 127 | result.Warning("The media type specified by your policy file's Content-Type header should be text/plain.") 128 | } 129 | } 130 | defer resp.Body.Close() 131 | // Read up to 64,000 bytes of response body. 132 | body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 64000)) 133 | if err != nil { 134 | return result.Error("Couldn't read policy file: %v.", err), "", map[string]string{} 135 | } 136 | 137 | policy := validateMTASTSPolicyFile(string(body), result) 138 | validateMTASTSMXs(strings.Split(policy["mx"], " "), hostnameResults, result) 139 | return result, string(body), policy 140 | } 141 | 142 | func validateMTASTSPolicyFile(body string, result *Result) map[string]string { 143 | policy := getKeyValuePairs(body, "\n", ":") 144 | 145 | if policy["version"] != "STSv1" { 146 | result.Failure("Your MTA-STS policy file version must be STSv1.") 147 | } 148 | 149 | if policy["mode"] == "" { 150 | result.Failure("Your MTA-STS policy file must specify mode.") 151 | } 152 | if m := policy["mode"]; m == "testing" { 153 | result.Warning("You're still in \"testing\" mode; senders won't enforce TLS when connecting to your mailservers. We recommend switching from \"testing\" to \"enforce\" to get the full security benefits of MTA-STS, as long as it hasn't been affecting your deliverability.") 154 | } else if m == "none" { 155 | result.Failure("MTA-STS policy is in \"none\" mode; senders won't enforce TLS when connecting to your mailservers.") 156 | } else if m != "enforce" { 157 | result.Failure("Mode must be one of \"enforce\", \"testing\", or \"none\", got %s", m) 158 | } 159 | 160 | if policy["max_age"] == "" { 161 | result.Failure("Your MTA-STS policy file must specify max_age.") 162 | } 163 | if i, err := strconv.Atoi(policy["max_age"]); err != nil || i <= 0 || i > 31557600 { 164 | result.Failure("MTA-STS max_age must be a positive integer <= 31557600.") 165 | } 166 | 167 | return policy 168 | } 169 | 170 | func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult, 171 | result *Result) { 172 | for dnsMX, dnsMXResult := range dnsMXs { 173 | if !dnsMXResult.couldConnect() { 174 | // Ignore hostnames we couldn't connect to, they may be spam traps. 175 | continue 176 | } 177 | if !PolicyMatches(dnsMX, policyFileMXs) { 178 | result.Failure("%s appears in the DNS record but not the MTA-STS policy file", 179 | dnsMX) 180 | } else if !dnsMXResult.couldSTARTTLS() { 181 | result.Failure("%s appears in the DNS record and MTA-STS policy file, but doesn't support STARTTLS", 182 | dnsMX) 183 | } 184 | } 185 | } 186 | 187 | func (c Checker) checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult { 188 | if c.checkMTASTSOverride != nil { 189 | // Allow the Checker to mock this function. 190 | return c.checkMTASTSOverride(domain, hostnameResults) 191 | } 192 | result := MakeMTASTSResult() 193 | result.addCheck(checkMTASTSRecord(domain, c.timeout())) 194 | policyResult, policy, policyMap := checkMTASTSPolicyFile(domain, hostnameResults, c.timeout()) 195 | result.addCheck(policyResult) 196 | result.Policy = policy 197 | result.Mode = policyMap["mode"] 198 | result.MXs = strings.Split(policyMap["mx"], " ") 199 | return result 200 | } 201 | -------------------------------------------------------------------------------- /checker/mta_sts_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestMarshalMTASTSJSON(t *testing.T) { 11 | r := MakeMTASTSResult() 12 | m, err := json.Marshal(r) 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | if !bytes.Contains(m, []byte("\"policy\":\"")) { 17 | t.Errorf("Marshalled result should contain policy, got %s", string(m)) 18 | } 19 | } 20 | 21 | func TestGetKeyValuePairs(t *testing.T) { 22 | tests := []struct { 23 | txt string 24 | ld string 25 | pd string 26 | want map[string]string 27 | }{ 28 | {"", ";", "=", map[string]string{}}, 29 | {"v=STSv1; foo;", ";", "=", map[string]string{ 30 | "v": "STSv1", 31 | }}, 32 | {"v=STSv1; id=20171114T070707;", ";", "=", map[string]string{ 33 | "v": "STSv1", 34 | "id": "20171114T070707", 35 | }}, 36 | {"version: STSv1\nmode: enforce\nmx: foo.example.com\nmx: bar.example.com\n\n", "\n", ":", map[string]string{ 37 | "version": "STSv1", 38 | "mode": "enforce", 39 | "mx": "foo.example.com bar.example.com", 40 | }}, 41 | } 42 | for _, test := range tests { 43 | got := getKeyValuePairs(test.txt, test.ld, test.pd) 44 | if !reflect.DeepEqual(got, test.want) { 45 | t.Errorf("getKeyValuePairs(%s, %s, %s) = %v, want %v", 46 | test.txt, test.ld, test.pd, got, test.want) 47 | } 48 | } 49 | } 50 | 51 | func TestValidateMTASTSRecord(t *testing.T) { 52 | tests := []struct { 53 | txt []string 54 | status Status 55 | }{ 56 | {[]string{"v=STSv1; id=1234", "v=STSv1; id=5678"}, Failure}, 57 | {[]string{"v=STSv1; id=20171114T070707;"}, Success}, 58 | {[]string{"v=STSv1; id=;"}, Failure}, 59 | {[]string{"v=STSv1; id=###;"}, Failure}, 60 | {[]string{"v=spf1 a -all"}, Failure}, 61 | } 62 | for _, test := range tests { 63 | result := validateMTASTSRecord(test.txt, &Result{}) 64 | if result.Status != test.status { 65 | t.Errorf("validateMTASTSRecord(%v) = %v", test.txt, result) 66 | } 67 | } 68 | } 69 | 70 | func TestValidateMTASTSPolicyFile(t *testing.T) { 71 | tests := []struct { 72 | txt string 73 | status Status 74 | }{ 75 | {"version: STSv1\nmode: enforce\nmax_age:100000\nmx: foo.example.com\nmx: bar.example.com\n", Success}, 76 | // Support UTF-8 77 | {"version: STSv1\nmode: enforce\nmax_age:100000\nmx: 🌟.🐢.com\n", Success}, 78 | {"\nmx: foo.example.com\nmx: bar.example.com\n", Failure}, 79 | {"version: STSv1\nmode: enforce\nmax_age:0\nmx: foo.example.com\nmx: bar.example.com\n", Failure}, 80 | {"version: STSv1\nmode: start_turtles\nmax_age:100000\nmx: foo.example.com\nmx: bar.example.com\n", Failure}, 81 | } 82 | for _, test := range tests { 83 | result := &Result{} 84 | validateMTASTSPolicyFile(test.txt, result) 85 | if result.Status != test.status { 86 | t.Errorf("validateMTASTSPolicyFile(%v) = %v", test.txt, result) 87 | } 88 | } 89 | } 90 | 91 | func TestValidateMTASTSMXs(t *testing.T) { 92 | goodHostnameResult := HostnameResult{ 93 | Result: &Result{ 94 | Status: 3, 95 | Checks: map[string]*Result{ 96 | "connectivity": {Connectivity, 0, nil, nil}, 97 | "starttls": {STARTTLS, 0, nil, nil}, 98 | }, 99 | }, 100 | } 101 | noSTARTTLSHostnameResult := HostnameResult{ 102 | Result: &Result{ 103 | Status: 3, 104 | Checks: map[string]*Result{ 105 | "connectivity": {Connectivity, 0, nil, nil}, 106 | "starttls": {STARTTLS, 3, nil, nil}, 107 | }, 108 | }, 109 | } 110 | tests := []struct { 111 | policyFileMXs []string 112 | dnsMXs map[string]HostnameResult 113 | status Status 114 | }{ 115 | { 116 | []string{"mail.example.com"}, 117 | map[string]HostnameResult{"mail.example.com": goodHostnameResult}, 118 | Success, 119 | }, 120 | { 121 | []string{"mail.example.com", "extra-entries.are-okay.com"}, 122 | map[string]HostnameResult{"mail.example.com": goodHostnameResult}, 123 | Success, 124 | }, 125 | { 126 | []string{"*.example.com"}, 127 | map[string]HostnameResult{"mail.example.com": goodHostnameResult}, 128 | Success, 129 | }, 130 | { 131 | []string{}, 132 | map[string]HostnameResult{"mail.example.com": goodHostnameResult}, 133 | Failure, 134 | }, 135 | { 136 | []string{"nostarttls.example.com"}, 137 | map[string]HostnameResult{"nostarttls.example.com": noSTARTTLSHostnameResult}, 138 | Failure, 139 | }, 140 | } 141 | for _, test := range tests { 142 | result := &Result{} 143 | validateMTASTSMXs(test.policyFileMXs, test.dnsMXs, result) 144 | if result.Status != test.status { 145 | t.Errorf("validateMTASTSMXs(%v, %v, %v) = %v", test.policyFileMXs, test.dnsMXs, Result{}, result) 146 | } 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /checker/result.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | // Status is an enum encoding the status of the overall check. 9 | type Status int32 10 | 11 | // Values for Result Status 12 | const ( 13 | Success Status = 0 14 | Warning Status = 1 15 | Failure Status = 2 16 | Error Status = 3 17 | ) 18 | 19 | var statusText = map[Status]string{ 20 | Success: "Success", 21 | Warning: "Warning", 22 | Failure: "Failure", 23 | Error: "Error", 24 | } 25 | 26 | // StatusText returns the text version of the Result Status 27 | func (r Result) StatusText() string { 28 | return statusText[r.Status] 29 | } 30 | 31 | // SetStatus the resulting status of combining old & new. The order of priority 32 | // for CheckStatus goes: Error > Failure > Warning > Success 33 | func SetStatus(oldStatus Status, newStatus Status) Status { 34 | if newStatus > oldStatus { 35 | return newStatus 36 | } 37 | return oldStatus 38 | } 39 | 40 | // Result is the result of a singular check. It's agnostic to the nature 41 | // of the check performed, and simply stores a reference to the check's name, 42 | // a summary of what the check should do, as well as any error, failure, or 43 | // warning messages associated. 44 | type Result struct { 45 | Name string `json:"name"` 46 | Status Status `json:"status"` 47 | Messages []string `json:"messages,omitempty"` 48 | Checks map[string]*Result `json:"checks,omitempty"` 49 | } 50 | 51 | // MakeResult constructs a base result object and returns its pointer. 52 | func MakeResult(name string) *Result { 53 | return &Result{ 54 | Name: name, 55 | Status: Success, 56 | Messages: make([]string, 0), 57 | Checks: make(map[string]*Result), 58 | } 59 | } 60 | 61 | // Error adds an error message to this check result. 62 | // The Error status will override any other existing status for this check. 63 | // Typically, when a check encounters an error, it stops executing. 64 | func (r *Result) Error(format string, a ...interface{}) *Result { 65 | r.Status = SetStatus(r.Status, Error) 66 | r.Messages = append(r.Messages, fmt.Sprintf("Error: "+format, a...)) 67 | return r 68 | } 69 | 70 | // Failure adds a failure message to this check result. 71 | // The Failure status will override any Status other than Error. 72 | // Whenever Failure is called, the entire check is failed. 73 | func (r *Result) Failure(format string, a ...interface{}) *Result { 74 | r.Status = SetStatus(r.Status, Failure) 75 | r.Messages = append(r.Messages, fmt.Sprintf("Failure: "+format, a...)) 76 | return r 77 | } 78 | 79 | // Warning adds a warning message to this check result. 80 | // The Warning status only supercedes the Success status. 81 | func (r *Result) Warning(format string, a ...interface{}) *Result { 82 | r.Status = SetStatus(r.Status, Warning) 83 | r.Messages = append(r.Messages, fmt.Sprintf("Warning: "+format, a...)) 84 | return r 85 | } 86 | 87 | // Success simply sets the status of Result to a Success. 88 | // Status is set if no other status has been declared on this check. 89 | func (r *Result) Success() *Result { 90 | r.Status = SetStatus(r.Status, Success) 91 | return r 92 | } 93 | 94 | // Returns result of specified check. 95 | // If called before that check occurs, returns false. 96 | func (r *Result) subcheckSucceeded(checkName string) bool { 97 | if result, ok := r.Checks[checkName]; ok { 98 | return result.Status == Success 99 | } 100 | return false 101 | } 102 | 103 | // Wrapping helper function to set the status of this hostname. 104 | func (r *Result) addCheck(checkResult *Result) { 105 | r.Checks[checkResult.Name] = checkResult 106 | // SetStatus sets Result's status to the most severe of any individual check 107 | r.Status = SetStatus(r.Status, checkResult.Status) 108 | } 109 | 110 | // IDs for checks that can be run 111 | const ( 112 | Connectivity = "connectivity" 113 | STARTTLS = "starttls" 114 | Version = "version" 115 | Certificate = "certificate" 116 | MTASTS = "mta-sts" 117 | MTASTSText = "mta-sts-text" 118 | MTASTSPolicyFile = "mta-sts-policy-file" 119 | PolicyList = "policylist" 120 | ) 121 | 122 | // Text descriptions of checks that can be run 123 | var checkNames = map[string]string{ 124 | Connectivity: "Server connectivity", 125 | STARTTLS: "Support for inbound STARTTLS", 126 | Version: "Secure version of TLS", 127 | Certificate: "Valid certificate", 128 | MTASTS: "Inbound MTA-STS support", 129 | MTASTSText: "Correct MTA-STS DNS record", 130 | MTASTSPolicyFile: "Correct MTA-STS policy file", 131 | PolicyList: "Status on EFF's STARTTLS Everywhere policy list", 132 | } 133 | 134 | // Description returns the full-text name of a check. 135 | func (r Result) Description() string { 136 | return checkNames[r.Name] 137 | } 138 | 139 | // MarshalJSON writes Result to JSON. It adds status_text and description to 140 | // the output. 141 | func (r Result) MarshalJSON() ([]byte, error) { 142 | // FakeResult lets us access the default json.Marshall result for Result. 143 | type FakeResult Result 144 | return json.Marshal(struct { 145 | FakeResult 146 | StatusText string `json:"status_text,omitempty"` 147 | Description string `json:"description,omitempty"` 148 | }{ 149 | Description: r.Description(), 150 | FakeResult: FakeResult(r), 151 | StatusText: r.StatusText(), 152 | }) 153 | } 154 | -------------------------------------------------------------------------------- /checker/result_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "testing" 7 | ) 8 | 9 | func TestMarshalResultJSON(t *testing.T) { 10 | // Should set description and status_text for CheckResult w/ recognized keys 11 | result := Result{ 12 | Name: "starttls", 13 | Status: Success, 14 | } 15 | marshalled, err := json.Marshal(result) 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | if !bytes.Contains(marshalled, []byte("\"status_text\":\"Success\"")) { 20 | t.Errorf("Marshalled result should contain status_text, got %s", string(marshalled)) 21 | } 22 | if !bytes.Contains(marshalled, []byte("\"description\":\"")) { 23 | t.Errorf("Marshalled result should contain description, got %s", string(marshalled)) 24 | } 25 | 26 | // Should survive unrecognized keys 27 | result = Result{ 28 | Name: "foo", 29 | Status: 100, 30 | } 31 | marshalled, _ = json.Marshal(result) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | if bytes.Contains(marshalled, []byte("\"status_text\":\"")) { 36 | t.Errorf("Result with unrecognized keys shouldn't output status_text, got %s", string(marshalled)) 37 | } 38 | if bytes.Contains(marshalled, []byte("\"description\":\"")) { 39 | t.Errorf("Result with unrecognized keys shouldn't output status_text, got %s", string(marshalled)) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /checker/totals.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "encoding/csv" 5 | "io" 6 | "log" 7 | "os" 8 | "strconv" 9 | "time" 10 | ) 11 | 12 | // AggregatedScan compiles aggregated stats across domains. 13 | // Implements ResultHandler. 14 | type AggregatedScan struct { 15 | Time time.Time 16 | Source string 17 | Attempted int 18 | WithMXs int 19 | MTASTSTesting int 20 | MTASTSTestingList []string 21 | MTASTSEnforce int 22 | MTASTSEnforceList []string 23 | } 24 | 25 | const ( 26 | // TopDomainsSource labels aggregated scans of the top million domains. 27 | TopDomainsSource = "TOP_DOMAINS" 28 | // LocalSource labels aggregated scan data for users of the web frontend. 29 | LocalSource = "LOCAL" 30 | ) 31 | 32 | // TotalMTASTS returns the number of domains supporting test or enforce mode. 33 | func (a AggregatedScan) TotalMTASTS() int { 34 | return a.MTASTSTesting + a.MTASTSEnforce 35 | } 36 | 37 | // PercentMTASTS returns the fraction of domains with MXs that support 38 | // MTA-STS, represented as a float between 0 and 1. 39 | func (a AggregatedScan) PercentMTASTS() float64 { 40 | if a.WithMXs == 0 { 41 | return 0 42 | } 43 | return 100 * float64(a.TotalMTASTS()) / float64(a.WithMXs) 44 | } 45 | 46 | // HandleDomain adds the result of a single domain scan to aggregated stats. 47 | func (a *AggregatedScan) HandleDomain(r DomainResult) { 48 | a.Attempted++ 49 | // Show progress. 50 | if a.Attempted%1000 == 0 { 51 | log.Printf("\n%v\n", a) 52 | log.Println(a.MTASTSTestingList) 53 | log.Println(a.MTASTSEnforceList) 54 | } 55 | 56 | if len(r.HostnameResults) == 0 { 57 | // No MX records - assume this isn't an email domain. 58 | return 59 | } 60 | a.WithMXs++ 61 | if r.MTASTSResult != nil { 62 | switch r.MTASTSResult.Mode { 63 | case "enforce": 64 | a.MTASTSEnforce++ 65 | a.MTASTSEnforceList = append(a.MTASTSEnforceList, r.Domain) 66 | case "testing": 67 | a.MTASTSTesting++ 68 | a.MTASTSTestingList = append(a.MTASTSTestingList, r.Domain) 69 | } 70 | } 71 | } 72 | 73 | // ResultHandler processes domain results. 74 | // It could print them, aggregate them, write the to the db, etc. 75 | type ResultHandler interface { 76 | HandleDomain(DomainResult) 77 | } 78 | 79 | const defaultPoolSize = 16 80 | 81 | // CheckCSV runs the checker on a csv of domains, processing the results according 82 | // to resultHandler. 83 | func (c *Checker) CheckCSV(domains *csv.Reader, resultHandler ResultHandler, domainColumn int) { 84 | poolSize, err := strconv.Atoi(os.Getenv("CONNECTION_POOL_SIZE")) 85 | if err != nil || poolSize <= 0 { 86 | poolSize = defaultPoolSize 87 | } 88 | work := make(chan string) 89 | results := make(chan DomainResult) 90 | 91 | go func() { 92 | for { 93 | data, err := domains.Read() 94 | if err != nil { 95 | if err != io.EOF { 96 | log.Println("Error reading CSV") 97 | log.Fatal(err) 98 | } 99 | break 100 | } 101 | if len(data) > 0 { 102 | work <- data[domainColumn] 103 | } 104 | } 105 | close(work) 106 | }() 107 | 108 | done := make(chan struct{}) 109 | for i := 0; i < poolSize; i++ { 110 | go func() { 111 | for domain := range work { 112 | results <- c.CheckDomain(domain, nil) 113 | } 114 | done <- struct{}{} 115 | }() 116 | } 117 | 118 | go func() { 119 | // Close the results channel when all the worker goroutines have finished. 120 | for i := 0; i < poolSize; i++ { 121 | <-done 122 | } 123 | close(results) 124 | }() 125 | 126 | for r := range results { 127 | resultHandler.HandleDomain(r) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /checker/totals_test.go: -------------------------------------------------------------------------------- 1 | package checker 2 | 3 | import ( 4 | "encoding/csv" 5 | "strings" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestCheckCSV(t *testing.T) { 11 | in := "empty\ndomain\ndomain.tld\nnoconnection\nnoconnection2\nnostarttls\n" 12 | reader := csv.NewReader(strings.NewReader(in)) 13 | 14 | c := Checker{ 15 | Cache: MakeSimpleCache(10 * time.Minute), 16 | lookupMXOverride: mockLookupMX, 17 | CheckHostname: mockCheckHostname, 18 | checkMTASTSOverride: mockCheckMTASTS, 19 | } 20 | totals := AggregatedScan{} 21 | c.CheckCSV(reader, &totals, 0) 22 | 23 | if totals.Attempted != 6 { 24 | t.Errorf("Expected 6 attempted connections, got %d", totals.Attempted) 25 | } 26 | if totals.WithMXs != 5 { 27 | t.Errorf("Expected 5 domains with MXs, got %d", totals.WithMXs) 28 | } 29 | if len(totals.MTASTSTestingList) != 5 { 30 | t.Errorf("Expected 5 domains in MTA-STS testing mode, got %d", len(totals.MTASTSTestingList)) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /db/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM postgres:10 2 | 3 | # Initialize starttls tables 4 | ADD scripts/init_tables.sql /docker-entrypoint-initdb.d/ 5 | -------------------------------------------------------------------------------- /db/db.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "flag" 5 | "os" 6 | "time" 7 | 8 | "github.com/EFForg/starttls-backend/checker" 9 | "github.com/EFForg/starttls-backend/models" 10 | "github.com/EFForg/starttls-backend/stats" 11 | ) 12 | 13 | // Database interface: These are the things that the Database should be able to do. 14 | // Slightly more limited than CRUD for all the schemas. 15 | type Database interface { 16 | // Puts new scandata for domain 17 | PutScan(models.Scan) error 18 | // Retrieves most recent scandata for domain 19 | GetLatestScan(string) (models.Scan, error) 20 | // Retrieves all scandata for domain 21 | GetAllScans(string) ([]models.Scan, error) 22 | // Gets the token for a domain 23 | GetTokenByDomain(string) (string, error) 24 | // Creates a token in the db 25 | PutToken(string) (models.Token, error) 26 | // Uses a token in the db 27 | UseToken(string) (string, error) 28 | // Adds a bounce or complaint notification to the email blacklist. 29 | PutBlacklistedEmail(email string, reason string, timestamp string) error 30 | // Returns true if we've blacklisted an email. 31 | IsBlacklistedEmail(string) (bool, error) 32 | // Retrieves a hostname scan for a particular hostname 33 | GetHostnameScan(string) (checker.HostnameResult, error) 34 | // Enters a hostname scan. 35 | PutHostnameScan(string, checker.HostnameResult) error 36 | // Writes an aggregated scan to the database 37 | PutAggregatedScan(checker.AggregatedScan) error 38 | // Caches stats for the 14 days preceding time.Time 39 | PutLocalStats(time.Time) (checker.AggregatedScan, error) 40 | // Gets counts per day of hosts supporting MTA-STS for a given source. 41 | GetStats(string) (stats.Series, error) 42 | // Upserts domain state. 43 | PutDomain(models.Domain) error 44 | // Retrieves state of a domain 45 | GetDomain(string, models.DomainState) (models.Domain, error) 46 | // Retrieves all domains in a particular state. 47 | GetDomains(models.DomainState) ([]models.Domain, error) 48 | SetStatus(string, models.DomainState) error 49 | RemoveDomain(string, models.DomainState) (models.Domain, error) 50 | ClearTables() error 51 | } 52 | 53 | // Config is a configuration struct for a Database. 54 | type Config struct { 55 | Port string 56 | DbHost string 57 | DbName string 58 | DbUsername string 59 | DbPass string 60 | DbTokenTable string 61 | DbScanTable string 62 | DbDomainTable string 63 | } 64 | 65 | // Default configuration values. Can be overwritten by env vars of the same name. 66 | var configDefaults = map[string]string{ 67 | "PORT": "8080", 68 | "DB_HOST": "localhost", 69 | "DB_NAME": "starttls", 70 | "DB_USERNAME": "postgres", 71 | "DB_PASSWORD": "postgres", 72 | "TEST_DB_NAME": "starttls_test", 73 | "DB_TOKEN_TABLE": "tokens", 74 | "DB_DOMAIN_TABLE": "domains", 75 | "DB_SCAN_TABLE": "scans", 76 | } 77 | 78 | func getEnvOrDefault(varName string) string { 79 | envVar := os.Getenv(varName) 80 | if len(envVar) == 0 { 81 | envVar = configDefaults[varName] 82 | } 83 | return envVar 84 | } 85 | 86 | // LoadEnvironmentVariables loads relevant environment variables into a 87 | // Config object. 88 | func LoadEnvironmentVariables() (Config, error) { 89 | config := Config{ 90 | Port: getEnvOrDefault("PORT"), 91 | DbTokenTable: getEnvOrDefault("DB_TOKEN_TABLE"), 92 | DbDomainTable: getEnvOrDefault("DB_DOMAIN_TABLE"), 93 | DbScanTable: getEnvOrDefault("DB_SCAN_TABLE"), 94 | DbHost: getEnvOrDefault("DB_HOST"), 95 | DbName: getEnvOrDefault("DB_NAME"), 96 | DbUsername: getEnvOrDefault("DB_USERNAME"), 97 | DbPass: getEnvOrDefault("DB_PASSWORD"), 98 | } 99 | if flag.Lookup("test.v") != nil { 100 | // Avoid accidentally wiping the default db during tests. 101 | config.DbName = getEnvOrDefault("TEST_DB_NAME") 102 | } 103 | return config, nil 104 | } 105 | -------------------------------------------------------------------------------- /db/scripts/init_tables.sql: -------------------------------------------------------------------------------- 1 | -- Create all tables. 2 | 3 | CREATE TABLE IF NOT EXISTS tokens 4 | ( 5 | domain TEXT NOT NULL PRIMARY KEY, 6 | token VARCHAR(255) NOT NULL, 7 | expires TIMESTAMP NOT NULL, 8 | used BOOLEAN DEFAULT FALSE 9 | ); 10 | 11 | 12 | CREATE TABLE IF NOT EXISTS scans 13 | ( 14 | id SERIAL PRIMARY KEY, 15 | domain TEXT NOT NULL, 16 | scandata TEXT NOT NULL, 17 | timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 18 | version INTEGER DEFAULT 0 19 | ); 20 | 21 | CREATE TABLE IF NOT EXISTS hostname_scans 22 | ( 23 | id SERIAL PRIMARY KEY, 24 | hostname TEXT NOT NULL, 25 | timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 26 | status SMALLINT, 27 | scandata TEXT NOT NULL 28 | ); 29 | 30 | CREATE TABLE IF NOT EXISTS domains 31 | ( 32 | domain TEXT NOT NULL, 33 | email TEXT NOT NULL, 34 | data TEXT NOT NULL, 35 | last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 36 | status VARCHAR(255) NOT NULL, 37 | queue_weeks INTEGER DEFAULT 4, 38 | testing_start TIMESTAMP, 39 | mta_sts BOOLEAN DEFAULT FALSE, 40 | PRIMARY KEY (domain, status) 41 | ); 42 | 43 | CREATE TABLE IF NOT EXISTS blacklisted_emails 44 | ( 45 | id SERIAL PRIMARY KEY, 46 | email TEXT NOT NULL, 47 | reason TEXT NOT NULL, 48 | timestamp TIMESTAMP 49 | ); 50 | 51 | -- Schema change: add "last_updated" timestamp column if it doesn't exist. 52 | 53 | ALTER TABLE domains ADD COLUMN IF NOT EXISTS last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP; 54 | 55 | -- Create trigger to ensure last_updated is updated every time 56 | -- the corresponding row changes. 57 | 58 | CREATE OR REPLACE FUNCTION update_changetimestamp_column() 59 | RETURNS TRIGGER AS $$ 60 | BEGIN 61 | IF row(NEW.*) IS DISTINCT FROM row(OLD.*) THEN 62 | NEW.last_updated = now(); 63 | RETURN NEW; 64 | ELSE 65 | RETURN OLD; 66 | END IF; 67 | END; 68 | $$ language 'plpgsql'; 69 | 70 | DROP TRIGGER IF EXISTS update_change_timestamp ON domains; 71 | 72 | CREATE TRIGGER update_change_timestamp BEFORE UPDATE 73 | ON domains FOR EACH ROW EXECUTE PROCEDURE 74 | update_changetimestamp_column(); 75 | 76 | ALTER TABLE scans ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 0; 77 | 78 | ALTER TABLE scans ADD COLUMN IF NOT EXISTS mta_sts_mode TEXT DEFAULT ''; 79 | 80 | ALTER TABLE IF EXISTS domain_totals RENAME TO aggregated_scans; 81 | 82 | CREATE TABLE IF NOT EXISTS aggregated_scans 83 | ( 84 | id SERIAL PRIMARY KEY, 85 | time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 86 | source TEXT NOT NULL, 87 | attempted INTEGER DEFAULT 0, 88 | with_mxs INTEGER DEFAULT 0, 89 | mta_sts_testing INTEGER DEFAULT 0, 90 | mta_sts_enforce INTEGER DEFAULT 0, 91 | UNIQUE (time, source) 92 | ); 93 | 94 | ALTER TABLE domains ADD COLUMN IF NOT EXISTS queue_weeks INTEGER DEFAULT 4; 95 | 96 | ALTER TABLE domains ADD COLUMN IF NOT EXISTS testing_start TIMESTAMP; 97 | 98 | -- Drop & re-add constraint 99 | BEGIN; 100 | ALTER TABLE domains DROP CONSTRAINT domains_pkey; 101 | ALTER TABLE domains ADD PRIMARY KEY (domain, status); 102 | COMMIT; 103 | 104 | ALTER TABLE IF EXISTS aggregated_scans DROP COLUMN IF EXISTS connected; 105 | ALTER TABLE IF EXISTS aggregated_scans ADD COLUMN IF NOT EXISTS with_mxs INTEGER DEFAULT 0; 106 | 107 | ALTER TABLE domains ADD COLUMN IF NOT EXISTS mta_sts BOOLEAN DEFAULT FALSE; 108 | 109 | BEGIN; 110 | ALTER TABLE aggregated_scans DROP CONSTRAINT aggregated_scans_time_source_key; 111 | ALTER TABLE aggregated_scans ADD UNIQUE (time, source); 112 | COMMIT; 113 | -------------------------------------------------------------------------------- /db/sqldb.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "database/sql" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "math/rand" 9 | "net/url" 10 | "strings" 11 | "time" 12 | 13 | "github.com/EFForg/starttls-backend/checker" 14 | "github.com/EFForg/starttls-backend/models" 15 | "github.com/EFForg/starttls-backend/stats" 16 | 17 | // Imports postgresql driver for database/sql 18 | _ "github.com/lib/pq" 19 | ) 20 | 21 | // Format string for Sql timestamps. 22 | const sqlTimeFormat = "2006-01-02 15:04:05" 23 | 24 | // SQLDatabase is a Database interface backed by postgresql. 25 | type SQLDatabase struct { 26 | cfg Config // Configuration to define the DB connection. 27 | conn *sql.DB // The database connection. 28 | } 29 | 30 | func getConnectionString(cfg Config) string { 31 | connectionString := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", 32 | url.PathEscape(cfg.DbUsername), 33 | url.PathEscape(cfg.DbPass), 34 | url.PathEscape(cfg.DbHost), 35 | url.PathEscape(cfg.DbName)) 36 | return connectionString 37 | } 38 | 39 | // InitSQLDatabase creates a DB connection based on information in a Config, and 40 | // returns a pointer the resulting SQLDatabase object. If connection fails, 41 | // returns an error. 42 | func InitSQLDatabase(cfg Config) (*SQLDatabase, error) { 43 | connectionString := getConnectionString(cfg) 44 | log.Printf("Connecting to Postgres DB ... \n") 45 | conn, err := sql.Open("postgres", connectionString) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return &SQLDatabase{cfg: cfg, conn: conn}, nil 50 | } 51 | 52 | // TOKEN DB FUNCTIONS 53 | 54 | // randToken generates a random token. 55 | func randToken() string { 56 | b := make([]byte, 8) 57 | rand.Read(b) 58 | return fmt.Sprintf("%x", b) 59 | } 60 | 61 | // UseToken sets the `used` flag on a particular email validation token to 62 | // true, and returns the domain that was associated with the token. 63 | func (db *SQLDatabase) UseToken(tokenStr string) (string, error) { 64 | var domain string 65 | err := db.conn.QueryRow("UPDATE tokens SET used=TRUE WHERE token=$1 AND used=FALSE RETURNING domain", 66 | tokenStr).Scan(&domain) 67 | return domain, err 68 | } 69 | 70 | // GetTokenByDomain gets the token for a domain name. 71 | func (db *SQLDatabase) GetTokenByDomain(domain string) (string, error) { 72 | var token string 73 | err := db.conn.QueryRow("SELECT token FROM tokens WHERE domain=$1", domain).Scan(&token) 74 | if err != nil { 75 | return "", err 76 | } 77 | return token, nil 78 | } 79 | 80 | // PutToken generates and inserts a token into the database for a particular 81 | // domain, and returns the resulting token row. 82 | func (db *SQLDatabase) PutToken(domain string) (models.Token, error) { 83 | token := models.Token{ 84 | Domain: domain, 85 | Token: randToken(), 86 | Expires: time.Now().Add(time.Duration(time.Hour * 72)), 87 | Used: false, 88 | } 89 | _, err := db.conn.Exec("INSERT INTO tokens(domain, token, expires) VALUES($1, $2, $3) "+ 90 | "ON CONFLICT (domain) DO UPDATE SET token=$2, expires=$3, used=FALSE", 91 | domain, token.Token, token.Expires.UTC().Format(sqlTimeFormat)) 92 | if err != nil { 93 | return models.Token{}, err 94 | } 95 | return token, nil 96 | } 97 | 98 | // SCAN DB FUNCTIONS 99 | 100 | // PutScan inserts a new scan for a particular domain into the database. 101 | func (db *SQLDatabase) PutScan(scan models.Scan) error { 102 | // Serialize scanData.Data for insertion into SQLdb! 103 | // @TODO marshall scan adds extra fields - need a custom obj for this 104 | byteArray, err := json.Marshal(scan.Data) 105 | if err != nil { 106 | return err 107 | } 108 | // Extract MTA-STS Mode to column for querying by mode, eg. adoption stats. 109 | // Note, this will include MTA-STS configurations that serve a parse-able 110 | // policy file and define a mode but don't pass full validation. 111 | mtastsMode := "" 112 | if scan.Data.MTASTSResult != nil { 113 | mtastsMode = scan.Data.MTASTSResult.Mode 114 | } 115 | _, err = db.conn.Exec("INSERT INTO scans(domain, scandata, timestamp, version, mta_sts_mode) VALUES($1, $2, $3, $4, $5)", 116 | scan.Domain, string(byteArray), scan.Timestamp.UTC().Format(sqlTimeFormat), scan.Version, mtastsMode) 117 | return err 118 | } 119 | 120 | // GetStats returns statistics about a MTA-STS adoption from a single 121 | // source domains to check. 122 | func (db *SQLDatabase) GetStats(source string) (stats.Series, error) { 123 | series := stats.Series{} 124 | rows, err := db.conn.Query( 125 | `SELECT time, source, with_mxs, mta_sts_testing, mta_sts_enforce 126 | FROM aggregated_scans 127 | WHERE source=$1 128 | ORDER BY time`, source) 129 | if err != nil { 130 | return series, err 131 | } 132 | defer rows.Close() 133 | for rows.Next() { 134 | var a checker.AggregatedScan 135 | if err := rows.Scan(&a.Time, &a.Source, &a.WithMXs, &a.MTASTSTesting, &a.MTASTSEnforce); err != nil { 136 | return series, err 137 | } 138 | series = append(series, a) 139 | } 140 | return series, nil 141 | } 142 | 143 | // PutLocalStats writes aggregated stats for the 14 days preceding `date` to 144 | // the aggregated_stats table. 145 | func (db *SQLDatabase) PutLocalStats(date time.Time) (checker.AggregatedScan, error) { 146 | query := ` 147 | SELECT 148 | COUNT(domain) AS total, 149 | COALESCE ( SUM ( 150 | CASE WHEN mta_sts_mode = 'testing' THEN 1 ELSE 0 END 151 | ), 0 ) AS testing, 152 | COALESCE ( SUM ( 153 | CASE WHEN mta_sts_mode = 'enforce' THEN 1 ELSE 0 END 154 | ), 0 ) AS enforce 155 | FROM ( 156 | SELECT DISTINCT ON (domain) domain, timestamp, mta_sts_mode 157 | FROM scans 158 | WHERE timestamp BETWEEN $1 AND $2 159 | ORDER BY domain, timestamp DESC 160 | ) AS latest_domains; 161 | ` 162 | start := date.Add(-14 * 24 * time.Hour) 163 | end := date 164 | a := checker.AggregatedScan{ 165 | Source: checker.LocalSource, 166 | Time: date, 167 | } 168 | err := db.conn.QueryRow(query, start.UTC(), end.UTC()).Scan(&a.WithMXs, &a.MTASTSTesting, &a.MTASTSEnforce) 169 | if err != nil { 170 | return a, err 171 | } 172 | err = db.PutAggregatedScan(a) 173 | return a, err 174 | } 175 | 176 | const mostRecentQuery = ` 177 | SELECT domain, scandata, timestamp, version FROM scans 178 | WHERE timestamp = (SELECT MAX(timestamp) FROM scans WHERE domain=$1) 179 | ` 180 | 181 | // GetLatestScan retrieves the most recent scan performed on a particular email 182 | // domain. 183 | func (db SQLDatabase) GetLatestScan(domain string) (models.Scan, error) { 184 | var rawScanData []byte 185 | result := models.Scan{} 186 | err := db.conn.QueryRow(mostRecentQuery, domain).Scan( 187 | &result.Domain, &rawScanData, &result.Timestamp, &result.Version) 188 | if err != nil { 189 | return result, err 190 | } 191 | err = json.Unmarshal(rawScanData, &result.Data) 192 | return result, err 193 | } 194 | 195 | // GetAllScans retrieves all the scans performed for a particular domain. 196 | func (db SQLDatabase) GetAllScans(domain string) ([]models.Scan, error) { 197 | rows, err := db.conn.Query( 198 | "SELECT domain, scandata, timestamp, version FROM scans WHERE domain=$1", domain) 199 | if err != nil { 200 | return nil, err 201 | } 202 | defer rows.Close() 203 | scans := []models.Scan{} 204 | for rows.Next() { 205 | var scan models.Scan 206 | var rawScanData []byte 207 | if err := rows.Scan(&scan.Domain, &rawScanData, &scan.Timestamp, &scan.Version); err != nil { 208 | return nil, err 209 | } 210 | err = json.Unmarshal(rawScanData, &scan.Data) 211 | scans = append(scans, scan) 212 | } 213 | return scans, nil 214 | } 215 | 216 | // =============== models.DomainStore impl =============== 217 | 218 | // PutDomain inserts a particular domain into the database. If the domain does 219 | // not yet exist in the database, we initialize it with StateUnconfirmed 220 | // If there is already a domain in the database with StateUnconfirmed, performs 221 | // an update of the fields. 222 | func (db *SQLDatabase) PutDomain(domain models.Domain) error { 223 | _, err := db.conn.Exec("INSERT INTO domains(domain, email, data, status, queue_weeks, mta_sts) "+ 224 | "VALUES($1, $2, $3, $4, $5, $6) "+ 225 | "ON CONFLICT ON CONSTRAINT domains_pkey DO UPDATE SET email=$2, data=$3, queue_weeks=$5", 226 | domain.Name, domain.Email, strings.Join(domain.MXs[:], ","), 227 | models.StateUnconfirmed, domain.QueueWeeks, domain.MTASTS) 228 | return err 229 | } 230 | 231 | // GetDomain retrieves the status and information associated with a particular 232 | // mailserver domain. 233 | func (db SQLDatabase) GetDomain(domain string, state models.DomainState) (models.Domain, error) { 234 | return db.queryDomain("SELECT %s FROM domains WHERE domain=$1 AND status=$2", domain, state) 235 | } 236 | 237 | // GetDomains retrieves all the domains which match a particular state, 238 | // that are not in MTA_STS mode 239 | func (db SQLDatabase) GetDomains(state models.DomainState) ([]models.Domain, error) { 240 | return db.queryDomainsWhere("status=$1", state) 241 | } 242 | 243 | // GetMTASTSDomains retrieves domains which wish their policy to be queued with their MTASTS. 244 | func (db SQLDatabase) GetMTASTSDomains() ([]models.Domain, error) { 245 | return db.queryDomainsWhere("mta_sts=TRUE") 246 | } 247 | 248 | // SetStatus sets the status of a particular domain object to |state|. 249 | func (db SQLDatabase) SetStatus(domain string, state models.DomainState) error { 250 | var testingStart time.Time 251 | if state == models.StateTesting { 252 | testingStart = time.Now() 253 | } 254 | _, err := db.conn.Exec("UPDATE domains SET status = $1, testing_start = $2 WHERE domain=$3", 255 | state, testingStart, domain) 256 | return err 257 | } 258 | 259 | // RemoveDomain removes a particular domain and returns it. 260 | func (db SQLDatabase) RemoveDomain(domain string, state models.DomainState) (models.Domain, error) { 261 | return db.queryDomain("DELETE FROM domains WHERE domain=$1 AND status=$2 RETURNING %s") 262 | } 263 | 264 | // EMAIL BLACKLIST DB FUNCTIONS 265 | 266 | // PutBlacklistedEmail adds a bounce or complaint notification to the email blacklist. 267 | func (db SQLDatabase) PutBlacklistedEmail(email string, reason string, timestamp string) error { 268 | _, err := db.conn.Exec("INSERT INTO blacklisted_emails(email, reason, timestamp) VALUES($1, $2, $3)", 269 | email, reason, timestamp) 270 | return err 271 | } 272 | 273 | // IsBlacklistedEmail returns true iff we've blacklisted the passed email address for sending. 274 | func (db SQLDatabase) IsBlacklistedEmail(email string) (bool, error) { 275 | var count int 276 | row := db.conn.QueryRow("SELECT COUNT(*) FROM blacklisted_emails WHERE email=$1", email) 277 | err := row.Scan(&count) 278 | if err != nil { 279 | return false, err 280 | } 281 | return count > 0, nil 282 | } 283 | 284 | func tryExec(database SQLDatabase, commands []string) error { 285 | for _, command := range commands { 286 | if _, err := database.conn.Exec(command); err != nil { 287 | return fmt.Errorf("command failed: %s\nwith error: %v", 288 | command, err.Error()) 289 | } 290 | } 291 | return nil 292 | } 293 | 294 | // ClearTables nukes all the tables. ** Should only be used during testing ** 295 | func (db SQLDatabase) ClearTables() error { 296 | return tryExec(db, []string{ 297 | fmt.Sprintf("DELETE FROM %s", db.cfg.DbDomainTable), 298 | fmt.Sprintf("DELETE FROM %s", db.cfg.DbScanTable), 299 | fmt.Sprintf("DELETE FROM %s", db.cfg.DbTokenTable), 300 | fmt.Sprintf("DELETE FROM %s", "hostname_scans"), 301 | fmt.Sprintf("DELETE FROM %s", "blacklisted_emails"), 302 | fmt.Sprintf("DELETE FROM %s", "aggregated_scans"), 303 | fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", db.cfg.DbScanTable), 304 | }) 305 | } 306 | 307 | func (db SQLDatabase) queryDomain(sqlQuery string, args ...interface{}) (models.Domain, error) { 308 | query := fmt.Sprintf(sqlQuery, "domain, email, data, status, last_updated, queue_weeks") 309 | data := models.Domain{} 310 | var rawMXs string 311 | err := db.conn.QueryRow(query, args...).Scan( 312 | &data.Name, &data.Email, &rawMXs, &data.State, &data.LastUpdated, &data.QueueWeeks) 313 | data.MXs = strings.Split(rawMXs, ",") 314 | if len(rawMXs) == 0 { 315 | data.MXs = []string{} 316 | } 317 | return data, err 318 | } 319 | 320 | func (db SQLDatabase) queryDomainsWhere(condition string, args ...interface{}) ([]models.Domain, error) { 321 | query := fmt.Sprintf("SELECT domain, email, data, status, last_updated, queue_weeks FROM domains WHERE %s", condition) 322 | rows, err := db.conn.Query(query, args...) 323 | if err != nil { 324 | return nil, err 325 | } 326 | defer rows.Close() 327 | domains := []models.Domain{} 328 | for rows.Next() { 329 | var domain models.Domain 330 | var rawMXs string 331 | if err := rows.Scan(&domain.Name, &domain.Email, &rawMXs, &domain.State, &domain.LastUpdated, &domain.QueueWeeks); err != nil { 332 | return nil, err 333 | } 334 | domain.MXs = strings.Split(rawMXs, ",") 335 | domains = append(domains, domain) 336 | } 337 | return domains, nil 338 | } 339 | 340 | // DomainsToValidate [interface Validator] retrieves domains from the 341 | // DB whose policies should be validated. 342 | func (db SQLDatabase) DomainsToValidate() ([]string, error) { 343 | domains := []string{} 344 | data, err := db.GetDomains(models.StateTesting) 345 | if err != nil { 346 | return domains, err 347 | } 348 | for _, domainInfo := range data { 349 | domains = append(domains, domainInfo.Name) 350 | } 351 | return domains, nil 352 | } 353 | 354 | // HostnamesForDomain [interface Validator] retrieves the hostname policy for 355 | // a particular domain. 356 | func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) { 357 | data, err := db.GetDomain(domain, models.StateEnforce) 358 | if err != nil { 359 | data, err = db.GetDomain(domain, models.StateTesting) 360 | } 361 | if err != nil { 362 | return []string{}, err 363 | } 364 | return data.MXs, nil 365 | } 366 | 367 | // GetHostnameScan retrives most recent scan from database. 368 | func (db *SQLDatabase) GetHostnameScan(hostname string) (checker.HostnameResult, error) { 369 | result := checker.HostnameResult{ 370 | Hostname: hostname, 371 | Result: &checker.Result{}, 372 | } 373 | var rawScanData []byte 374 | err := db.conn.QueryRow(`SELECT timestamp, status, scandata FROM hostname_scans 375 | WHERE hostname=$1 AND 376 | timestamp=(SELECT MAX(timestamp) FROM hostname_scans WHERE hostname=$1)`, 377 | hostname).Scan(&result.Timestamp, &result.Status, &rawScanData) 378 | if err != nil { 379 | return result, err 380 | } 381 | err = json.Unmarshal(rawScanData, &result.Checks) 382 | return result, err 383 | } 384 | 385 | // PutHostnameScan puts this scan into the database. 386 | func (db *SQLDatabase) PutHostnameScan(hostname string, result checker.HostnameResult) error { 387 | data, err := json.Marshal(result.Checks) 388 | if err != nil { 389 | return err 390 | } 391 | _, err = db.conn.Exec(`INSERT INTO hostname_scans(hostname, status, scandata) 392 | VALUES($1, $2, $3)`, hostname, result.Status, string(data)) 393 | return err 394 | } 395 | 396 | // PutAggregatedScan writes and AggregatedScan to the db. 397 | func (db *SQLDatabase) PutAggregatedScan(a checker.AggregatedScan) error { 398 | _, err := db.conn.Exec(`INSERT INTO 399 | aggregated_scans(time, source, attempted, with_mxs, mta_sts_testing, mta_sts_enforce) 400 | VALUES ($1, $2, $3, $4, $5, $6) 401 | ON CONFLICT (time,source) DO NOTHING`, 402 | a.Time, a.Source, a.Attempted, a.WithMXs, a.MTASTSTesting, a.MTASTSEnforce) 403 | return err 404 | } 405 | -------------------------------------------------------------------------------- /db/sqldb_test.go: -------------------------------------------------------------------------------- 1 | package db_test 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/EFForg/starttls-backend/checker" 11 | "github.com/EFForg/starttls-backend/db" 12 | "github.com/EFForg/starttls-backend/models" 13 | "github.com/joho/godotenv" 14 | ) 15 | 16 | // Global database object for tests. 17 | var database *db.SQLDatabase 18 | 19 | // Connects to local test db. 20 | func initTestDb() *db.SQLDatabase { 21 | os.Setenv("PRIV_KEY", "./certs/key.pem") 22 | os.Setenv("PUBLIC_KEY", "./certs/cert.pem") 23 | cfg, err := db.LoadEnvironmentVariables() 24 | if err != nil { 25 | log.Fatal(err) 26 | } 27 | database, err := db.InitSQLDatabase(cfg) 28 | if err != nil { 29 | log.Fatal(err) 30 | } 31 | return database 32 | } 33 | 34 | func TestMain(m *testing.M) { 35 | godotenv.Overload("../.env.test") 36 | database = initTestDb() 37 | code := m.Run() 38 | err := database.ClearTables() 39 | if err != nil { 40 | log.Fatal(err) 41 | } 42 | os.Exit(code) 43 | } 44 | 45 | //////////////////////////////// 46 | // ***** Database tests ***** // 47 | //////////////////////////////// 48 | 49 | func TestPutScan(t *testing.T) { 50 | database.ClearTables() 51 | dummyScan := models.Scan{ 52 | Domain: "dummy.com", 53 | Data: checker.DomainResult{Domain: "dummy.com"}, 54 | Timestamp: time.Now(), 55 | Version: 2, 56 | } 57 | err := database.PutScan(dummyScan) 58 | if err != nil { 59 | t.Fatalf("PutScan failed: %v\n", err) 60 | } 61 | scan, err := database.GetLatestScan("dummy.com") 62 | if err != nil { 63 | t.Fatalf("GetLatestScan failed: %v\n", err) 64 | } 65 | if dummyScan.Domain != scan.Domain || dummyScan.Data.Domain != scan.Data.Domain || 66 | dummyScan.Version != scan.Version || 67 | dummyScan.Timestamp.Unix() != dummyScan.Timestamp.Unix() { 68 | t.Errorf("Expected %v and %v to be the same\n", dummyScan, scan) 69 | } 70 | } 71 | 72 | func TestGetLatestScan(t *testing.T) { 73 | database.ClearTables() 74 | // Add two dummy objects 75 | earlyScan := models.Scan{ 76 | Domain: "dummy.com", 77 | Data: checker.DomainResult{Domain: "dummy.com", Message: "test_before"}, 78 | Timestamp: time.Now(), 79 | } 80 | laterScan := models.Scan{ 81 | Domain: "dummy.com", 82 | Data: checker.DomainResult{Domain: "dummy.com", Message: "test_after"}, 83 | Timestamp: time.Now().Add(time.Duration(time.Hour)), 84 | } 85 | err := database.PutScan(laterScan) 86 | if err != nil { 87 | t.Errorf("PutScan failed: %v\n", err) 88 | } 89 | err = database.PutScan(earlyScan) 90 | if err != nil { 91 | t.Errorf("PutScan failed: %v\n", err) 92 | } 93 | scan, err := database.GetLatestScan("dummy.com") 94 | if err != nil { 95 | t.Errorf("GetLatestScan failed: %v\n", err) 96 | } 97 | if scan.Data.Message != "test_after" { 98 | t.Errorf("Expected GetLatestScan to retrieve most recent scanData: %v", scan) 99 | } 100 | } 101 | 102 | func TestGetAllScans(t *testing.T) { 103 | database.ClearTables() 104 | data, err := database.GetAllScans("dummy.com") 105 | if err != nil { 106 | t.Errorf("GetAllScans failed: %v\n", err) 107 | } 108 | // Retrieving scans for domain that's never been scanned before 109 | if len(data) != 0 { 110 | t.Errorf("Expected GetAllScans to return []") 111 | } 112 | // Add two dummy objects 113 | dummyScan := models.Scan{ 114 | Domain: "dummy.com", 115 | Data: checker.DomainResult{Domain: "dummy.com", Message: "test1"}, 116 | Timestamp: time.Now(), 117 | } 118 | err = database.PutScan(dummyScan) 119 | if err != nil { 120 | t.Errorf("PutScan failed: %v\n", err) 121 | } 122 | dummyScan.Data.Message = "test2" 123 | err = database.PutScan(dummyScan) 124 | if err != nil { 125 | t.Errorf("PutScan failed: %v\n", err) 126 | } 127 | data, err = database.GetAllScans("dummy.com") 128 | // Retrieving scans for domain that's been scanned once 129 | if err != nil { 130 | t.Errorf("GetAllScans failed: %v\n", err) 131 | } 132 | if len(data) != 2 { 133 | t.Errorf("Expected GetAllScans to return two items, returned %d\n", len(data)) 134 | } 135 | if data[0].Data.Message != "test1" || data[1].Data.Message != "test2" { 136 | t.Errorf("Expected Data of scan objects to include both test1 and test2") 137 | } 138 | } 139 | 140 | func TestPutGetDomain(t *testing.T) { 141 | database.ClearTables() 142 | data := models.Domain{ 143 | Name: "testing.com", 144 | Email: "admin@testing.com", 145 | } 146 | err := database.PutDomain(data) 147 | if err != nil { 148 | t.Errorf("PutDomain failed: %v\n", err) 149 | } 150 | retrievedData, err := database.GetDomain(data.Name, models.StateUnconfirmed) 151 | if err != nil { 152 | t.Errorf("GetDomain(%s) failed: %v\n", data.Name, err) 153 | } 154 | if retrievedData.Name != data.Name { 155 | t.Errorf("Somehow, GetDomain retrieved the wrong object?") 156 | } 157 | if retrievedData.State != models.StateUnconfirmed { 158 | t.Errorf("Default state should be 'Unconfirmed'") 159 | } 160 | } 161 | 162 | func TestUpsertDomain(t *testing.T) { 163 | database.ClearTables() 164 | data := models.Domain{ 165 | Name: "testing.com", 166 | MXs: []string{"hello1"}, 167 | Email: "admin@testing.com", 168 | } 169 | database.PutDomain(data) 170 | err := database.PutDomain(models.Domain{Name: "testing.com", MXs: []string{"hello_darkness_my_old_friend"}, Email: "actual_admin@testing.com"}) 171 | if err != nil { 172 | t.Errorf("PutDomain(%s) failed: %v\n", data.Name, err) 173 | } 174 | retrievedData, err := database.GetDomain(data.Name, models.StateUnconfirmed) 175 | if retrievedData.MXs[0] != "hello_darkness_my_old_friend" || retrievedData.Email != "actual_admin@testing.com" { 176 | t.Errorf("Email and MXs should have been rewritten: %v\n", retrievedData) 177 | } 178 | } 179 | 180 | func TestDomainSetStatus(t *testing.T) { 181 | // TODO 182 | } 183 | 184 | func TestPutUseToken(t *testing.T) { 185 | database.ClearTables() 186 | data, err := database.PutToken("testing.com") 187 | if err != nil { 188 | t.Errorf("PutToken failed: %v\n", err) 189 | } 190 | domain, err := database.UseToken(data.Token) 191 | if err != nil { 192 | t.Errorf("UseToken failed: %v\n", err) 193 | } 194 | if domain != data.Domain { 195 | t.Errorf("UseToken used token for %s instead of %s\n", domain, data.Domain) 196 | } 197 | } 198 | 199 | func TestPutTokenTwice(t *testing.T) { 200 | database.ClearTables() 201 | data, err := database.PutToken("testing.com") 202 | if err != nil { 203 | t.Errorf("PutToken failed: %v\n", err) 204 | } 205 | _, err = database.PutToken("testing.com") 206 | if err != nil { 207 | t.Errorf("PutToken failed: %v\n", err) 208 | } 209 | domain, err := database.UseToken(data.Token) 210 | if domain == data.Domain { 211 | t.Errorf("UseToken should not have succeeded with old token!\n") 212 | } 213 | } 214 | 215 | func TestLastUpdatedFieldUpdates(t *testing.T) { 216 | database.ClearTables() 217 | data := models.Domain{ 218 | Name: "testing.com", 219 | Email: "admin@testing.com", 220 | State: models.StateUnconfirmed, 221 | } 222 | database.PutDomain(data) 223 | retrievedData, _ := database.GetDomain(data.Name, models.StateUnconfirmed) 224 | lastUpdated := retrievedData.LastUpdated 225 | data.State = models.StateTesting 226 | database.PutDomain(models.Domain{Name: data.Name, Email: "new fone who dis"}) 227 | retrievedData, _ = database.GetDomain(data.Name, models.StateUnconfirmed) 228 | if lastUpdated.Equal(retrievedData.LastUpdated) { 229 | t.Errorf("Expected last_updated to be updated on change: %v", lastUpdated) 230 | } 231 | } 232 | 233 | func TestLastUpdatedFieldDoesntUpdate(t *testing.T) { 234 | database.ClearTables() 235 | data := models.Domain{ 236 | Name: "testing.com", 237 | Email: "admin@testing.com", 238 | State: models.StateUnconfirmed, 239 | } 240 | database.PutDomain(data) 241 | retrievedData, _ := database.GetDomain(data.Name, models.StateUnconfirmed) 242 | lastUpdated := retrievedData.LastUpdated 243 | database.PutDomain(data) 244 | retrievedData, _ = database.GetDomain(data.Name, models.StateUnconfirmed) 245 | if !lastUpdated.Equal(retrievedData.LastUpdated) { 246 | t.Errorf("Expected last_updated to stay the same if no changes were made") 247 | } 248 | } 249 | 250 | func TestDomainsToValidate(t *testing.T) { 251 | database.ClearTables() 252 | queuedMap := map[string]bool{ 253 | "a": false, "b": true, "c": false, "d": true, 254 | } 255 | for domain, queued := range queuedMap { 256 | if queued { 257 | database.PutDomain(models.Domain{Name: domain, State: models.StateTesting}) 258 | } else { 259 | database.PutDomain(models.Domain{Name: domain}) 260 | } 261 | } 262 | result, err := database.DomainsToValidate() 263 | if err != nil { 264 | t.Fatalf("DomainsToValidate failed: %v\n", err) 265 | } 266 | for _, domain := range result { 267 | if !queuedMap[domain] { 268 | t.Errorf("Did not expect %s to be returned", domain) 269 | } 270 | } 271 | } 272 | 273 | func TestHostnamesForDomain(t *testing.T) { 274 | database.ClearTables() 275 | database.PutDomain(models.Domain{Name: "x", MXs: []string{"x.com", "y.org"}}) 276 | database.PutDomain(models.Domain{Name: "y"}) 277 | database.SetStatus("x", models.StateTesting) 278 | database.SetStatus("y", models.StateTesting) 279 | result, err := database.HostnamesForDomain("x") 280 | if err != nil { 281 | t.Fatalf("HostnamesForDomain failed: %v\n", err) 282 | } 283 | if len(result) != 2 || result[0] != "x.com" || result[1] != "y.org" { 284 | t.Errorf("Expected two hostnames, x.com and y.org\n") 285 | } 286 | result, err = database.HostnamesForDomain("y") 287 | if err != nil { 288 | t.Fatalf("HostnamesForDomain failed: %v\n", err) 289 | } 290 | if len(result) > 0 { 291 | t.Errorf("Expected no hostnames to be returned, got %s\n", result[0]) 292 | } 293 | } 294 | 295 | func TestPutAndIsBlacklistedEmail(t *testing.T) { 296 | database.ClearTables() 297 | 298 | // Add an e-mail address to the blacklist. 299 | err := database.PutBlacklistedEmail("fail@example.com", "bounce", "2017-07-21T18:47:13.498Z") 300 | if err != nil { 301 | t.Errorf("PutBlacklistedEmail failed: %v\n", err) 302 | } 303 | 304 | // Check that the email address was blacklisted. 305 | blacklisted, err := database.IsBlacklistedEmail("fail@example.com") 306 | if err != nil { 307 | t.Errorf("IsBlacklistedEmail failed: %v\n", err) 308 | } 309 | if !blacklisted { 310 | t.Errorf("fail@example.com should be blacklisted, but wasn't") 311 | } 312 | 313 | // Check that an un-added email address is not blacklisted. 314 | blacklisted, err = database.IsBlacklistedEmail("good@example.com") 315 | if err != nil { 316 | t.Errorf("IsBlacklistedEmail failed: %v\n", err) 317 | } 318 | if blacklisted { 319 | t.Errorf("good@example.com should not be blacklisted, but was") 320 | } 321 | } 322 | 323 | func TestGetHostnameScan(t *testing.T) { 324 | database.ClearTables() 325 | checksMap := make(map[string]*checker.Result) 326 | checksMap["test"] = &checker.Result{} 327 | now := time.Now() 328 | database.PutHostnameScan("hello", 329 | checker.HostnameResult{ 330 | Timestamp: now, 331 | Hostname: "hello", 332 | Result: &checker.Result{Status: 1, Checks: checksMap}, 333 | }, 334 | ) 335 | result, err := database.GetHostnameScan("hello") 336 | if err != nil { 337 | t.Errorf("Expected hostname scan to return without errors") 338 | } 339 | if now == result.Timestamp { 340 | t.Errorf("unexpected gap between written timestamp %s and read timestamp %s", now, result.Timestamp) 341 | } 342 | if result.Status != 1 || checksMap["test"].Name != result.Checks["test"].Name { 343 | t.Errorf("Expected hostname scan to return correct data") 344 | } 345 | } 346 | 347 | func dateMustParse(date string, t *testing.T) time.Time { 348 | const shortForm = "2006-Jan-02" 349 | parsed, err := time.Parse(shortForm, date) 350 | if err != nil { 351 | t.Fatal(err) 352 | } 353 | return parsed 354 | } 355 | 356 | func TestGetStats(t *testing.T) { 357 | database.ClearTables() 358 | may1 := dateMustParse("2019-May-01", t) 359 | may2 := dateMustParse("2019-May-02", t) 360 | data := []checker.AggregatedScan{ 361 | checker.AggregatedScan{ 362 | Time: may1, 363 | Source: checker.TopDomainsSource, 364 | Attempted: 5, 365 | WithMXs: 4, 366 | MTASTSTesting: 2, 367 | MTASTSEnforce: 1, 368 | }, 369 | checker.AggregatedScan{ 370 | Time: may2, 371 | Source: checker.TopDomainsSource, 372 | Attempted: 10, 373 | WithMXs: 8, 374 | MTASTSTesting: 1, 375 | MTASTSEnforce: 3, 376 | }, 377 | } 378 | for _, a := range data { 379 | err := database.PutAggregatedScan(a) 380 | if err != nil { 381 | t.Fatal(err) 382 | } 383 | } 384 | result, err := database.GetStats(checker.TopDomainsSource) 385 | if err != nil { 386 | t.Fatal(err) 387 | } 388 | if result[0].TotalMTASTS() != 3 || result[1].TotalMTASTS() != 4 { 389 | t.Errorf("Incorrect MTA-STS stats, got %v", result) 390 | } 391 | } 392 | 393 | func TestPutLocalStats(t *testing.T) { 394 | database.ClearTables() 395 | a, err := database.PutLocalStats(time.Now()) 396 | if err != nil { 397 | t.Fatal(err) 398 | } 399 | if a.PercentMTASTS() != 0 { 400 | t.Errorf("Expected PercentMTASTS with no recent scans to be 0, got %v", 401 | a.PercentMTASTS()) 402 | } 403 | day := time.Hour * 24 404 | today := time.Now() 405 | lastWeek := today.Add(-6 * day) 406 | s := models.Scan{ 407 | Domain: "example1.com", 408 | Data: checker.NewSampleDomainResult("example1.com"), 409 | Timestamp: lastWeek, 410 | } 411 | database.PutScan(s) 412 | a, err = database.PutLocalStats(time.Now()) 413 | if err != nil { 414 | t.Fatal(err) 415 | } 416 | if a.PercentMTASTS() != 100 { 417 | t.Errorf("Expected PercentMTASTS with one recent scan to be 100, got %v", 418 | a.PercentMTASTS()) 419 | } 420 | } 421 | 422 | func TestGetLocalStats(t *testing.T) { 423 | database.ClearTables() 424 | day := time.Hour * 24 425 | today := time.Now() 426 | lastWeek := today.Add(-6 * day) 427 | 428 | // Two recent scans from example1.com 429 | // The most recent scan shows no MTA-STS support. 430 | s := models.Scan{ 431 | Domain: "example1.com", 432 | Data: checker.NewSampleDomainResult("example1.com"), 433 | Timestamp: lastWeek.Add(1 * day), 434 | } 435 | database.PutScan(s) 436 | s.Timestamp = lastWeek.Add(3 * day) 437 | s.Data.MTASTSResult.Mode = "" 438 | database.PutScan(s) 439 | 440 | // Add another recent scan, from a second domain. 441 | s = models.Scan{ 442 | Domain: "example2.com", 443 | Data: checker.NewSampleDomainResult("example2.com"), 444 | Timestamp: lastWeek.Add(2 * day), 445 | } 446 | database.PutScan(s) 447 | 448 | // Add a third scan to check that floats are outputted correctly. 449 | s = models.Scan{ 450 | Domain: "example3.com", 451 | Data: checker.NewSampleDomainResult("example2.com"), 452 | Timestamp: lastWeek.Add(6 * day), 453 | } 454 | database.PutScan(s) 455 | 456 | // Write stats to the database for all the windows we want to check. 457 | for i := 0; i < 7; i++ { 458 | database.PutLocalStats(lastWeek.Add(day * time.Duration(i))) 459 | } 460 | 461 | stats, err := database.GetStats(checker.LocalSource) 462 | if err != nil { 463 | t.Fatal(err) 464 | } 465 | 466 | // Validate result 467 | expPcts := []float64{0, 100, 100, 50, 50, 50, 100 * 2 / float64(3)} 468 | if len(expPcts) != 7 { 469 | t.Errorf("Expected 7 stats, got\n %v\n", stats) 470 | } 471 | for i, got := range stats { 472 | if got.PercentMTASTS() != expPcts[i] { 473 | t.Errorf("\nExpected %v%%\nGot %v\n (%v%%)", expPcts[i], got, got.PercentMTASTS()) 474 | } 475 | } 476 | } 477 | 478 | func TestGetMTASTSDomains(t *testing.T) { 479 | database.ClearTables() 480 | database.PutDomain(models.Domain{Name: "unicorns"}) 481 | database.PutDomain(models.Domain{Name: "mta-sts-x", MTASTS: true}) 482 | database.PutDomain(models.Domain{Name: "mta-sts-y", MTASTS: true}) 483 | database.PutDomain(models.Domain{Name: "regular"}) 484 | domains, err := database.GetMTASTSDomains() 485 | if err != nil { 486 | t.Fatalf("GetMTASTSDomains() failed: %v", err) 487 | } 488 | if len(domains) != 2 { 489 | t.Errorf("Expected GetMTASTSDomains() to return 2 elements") 490 | } 491 | for _, domain := range domains { 492 | if !strings.HasPrefix(domain.Name, "mta-sts") { 493 | t.Errorf("GetMTASTSDomains returned %s when it wasn't supposed to", domain.Name) 494 | } 495 | } 496 | } 497 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.1' 2 | services: 3 | postgres: 4 | build: db/ 5 | env_file: 6 | - .env 7 | healthcheck: 8 | test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres}"] 9 | environment: 10 | POSTGRES_DB: $DB_NAME 11 | POSTGRES_USER: $DB_USERNAME 12 | POSTGRES_PASSWORD: $DB_PASSWORD 13 | postgres_test: 14 | build: db/ 15 | healthcheck: 16 | test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres}"] 17 | env_file: 18 | - .env.test 19 | environment: 20 | POSTGRES_DB: starttls_test 21 | POSTGRES_USER: postgres 22 | POSTGRES_PASSWORD: password 23 | app: 24 | build: . 25 | volumes: 26 | - .:/go/src/github.com/EFForg/starttls-backend 27 | ports: 28 | - 8080:8080 29 | depends_on: 30 | postgres: 31 | condition: service_healthy 32 | postgres_test: 33 | condition: service_healthy 34 | env_file: 35 | - .env 36 | -------------------------------------------------------------------------------- /email/email.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "crypto/tls" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "net/smtp" 9 | "strings" 10 | 11 | "github.com/EFForg/starttls-backend/db" 12 | "github.com/EFForg/starttls-backend/models" 13 | "github.com/EFForg/starttls-backend/util" 14 | ) 15 | 16 | type blacklistStore interface { 17 | PutBlacklistedEmail(email string, reason string, timestamp string) error 18 | IsBlacklistedEmail(string) (bool, error) 19 | } 20 | 21 | // Config stores variables needed to submit emails for sending, as well as 22 | // to generate the templates. 23 | type Config struct { 24 | auth smtp.Auth 25 | username string 26 | password string 27 | submissionHostname string 28 | port string 29 | sender string 30 | website string // Needed to generate email template text. 31 | database blacklistStore 32 | } 33 | 34 | // MakeConfigFromEnv initializes our email config object with 35 | // environment variables. 36 | func MakeConfigFromEnv(database db.Database) (Config, error) { 37 | // create config 38 | varErrs := util.Errors{} 39 | c := Config{ 40 | username: util.RequireEnv("SMTP_USERNAME", &varErrs), 41 | password: util.RequireEnv("SMTP_PASSWORD", &varErrs), 42 | submissionHostname: util.RequireEnv("SMTP_ENDPOINT", &varErrs), 43 | port: util.RequireEnv("SMTP_PORT", &varErrs), 44 | sender: util.RequireEnv("SMTP_FROM_ADDRESS", &varErrs), 45 | website: util.RequireEnv("FRONTEND_WEBSITE_LINK", &varErrs), 46 | database: database, 47 | } 48 | if len(varErrs) > 0 { 49 | return c, varErrs 50 | } 51 | log.Printf("Establishing auth connection with SMTP server %s", c.submissionHostname) 52 | // create auth 53 | client, err := smtp.Dial(fmt.Sprintf("%s:%s", c.submissionHostname, c.port)) 54 | if err != nil { 55 | return c, err 56 | } 57 | defer client.Close() 58 | err = client.StartTLS(&tls.Config{ServerName: c.submissionHostname}) 59 | if err != nil { 60 | return c, fmt.Errorf("SMTP server doesn't support STARTTLS") 61 | } 62 | ok, auths := client.Extension("AUTH") 63 | if !ok { 64 | return c, fmt.Errorf("remote SMTP server doesn't support any authentication mechanisms") 65 | } 66 | if strings.Contains(auths, "PLAIN") { 67 | c.auth = smtp.PlainAuth("", c.username, c.password, c.submissionHostname) 68 | } else if strings.Contains(auths, "CRAM-MD5") { 69 | c.auth = smtp.CRAMMD5Auth(c.username, c.password) 70 | } else { 71 | return c, fmt.Errorf("SMTP server doesn't support PLAIN or CRAM-MD5 authentication") 72 | } 73 | return c, nil 74 | } 75 | 76 | // ValidationAddress Returns default validation address for this domain submission. 77 | func ValidationAddress(domain *models.Domain) string { 78 | return fmt.Sprintf("postmaster@%s", domain.Name) 79 | } 80 | 81 | func validationEmailText(domain string, contactEmail string, hostnames []string, token string, website string) string { 82 | return fmt.Sprintf(validationEmailTemplate, 83 | domain, strings.Join(hostnames[:], ", "), website, token, contactEmail) 84 | } 85 | 86 | // SendValidation sends a validation e-mail for the domain outlined by domainInfo. 87 | // The validation link is generated using a token. 88 | func (c Config) SendValidation(domain *models.Domain, token string) error { 89 | emailContent := validationEmailText(domain.Name, domain.Email, domain.MXs, token, 90 | c.website) 91 | return c.sendEmail(validationEmailSubject, emailContent, ValidationAddress(domain)) 92 | } 93 | 94 | func (c Config) sendEmail(subject string, body string, address string) error { 95 | blacklisted, err := c.database.IsBlacklistedEmail(address) 96 | if err != nil { 97 | return err 98 | } 99 | if blacklisted { 100 | return fmt.Errorf("address %s is blacklisted", address) 101 | } 102 | message := fmt.Sprintf("From: %s\nTo: %s\nSubject: %s\n\n%s", 103 | c.sender, address, subject, body) 104 | if c.submissionHostname == "" { 105 | log.Println("Warning: email host not configured, not sending email") 106 | log.Println(message) 107 | return nil 108 | } 109 | return smtp.SendMail(fmt.Sprintf("%s:%s", c.submissionHostname, c.port), 110 | c.auth, 111 | c.sender, []string{address}, []byte(message)) 112 | } 113 | 114 | // Recipients lists the email addresses that have triggered a bounce or complaint. 115 | type Recipients []struct { 116 | EmailAddress string `json:"emailAddress"` 117 | } 118 | 119 | // BlacklistRequest represents a submission for a particular email address to be blacklisted. 120 | type BlacklistRequest struct { 121 | Reason string 122 | Timestamp string 123 | Recipients Recipients 124 | Raw string 125 | } 126 | 127 | // UnmarshalJSON wrangles the JSON posted by AWS SNS into something easier to access 128 | // and generalized across notification types. 129 | func (r *BlacklistRequest) UnmarshalJSON(b []byte) error { 130 | // We need to start by unmarshalling Message into a string because the field contains stringified JSON. 131 | // See email_test.go for examples. 132 | var wrapper struct { 133 | Message string 134 | Timestamp string 135 | } 136 | if err := json.Unmarshal(b, &wrapper); err != nil { 137 | return fmt.Errorf("failed to load notification wrapper: %v", err) 138 | } 139 | 140 | type Complaint struct { 141 | *Recipients `json:"complainedRecipients"` 142 | } 143 | 144 | type Bounce struct { 145 | *Recipients `json:"bouncedRecipients"` 146 | } 147 | 148 | // We'll unmarshall the list of bounced or complained emails into 149 | // &recipients. Only one of Complaint or Bounce will contain data, so we can 150 | // reuse &recipients to capture whichever field holds the list. 151 | var recipients Recipients 152 | msg := struct { 153 | NotificationType string `json:"notificationType"` 154 | Complaint `json:"complaint"` 155 | Bounce `json:"bounce"` 156 | }{ 157 | Complaint: Complaint{Recipients: &recipients}, 158 | Bounce: Bounce{Recipients: &recipients}, 159 | } 160 | 161 | if err := json.Unmarshal([]byte(wrapper.Message), &msg); err != nil { 162 | return fmt.Errorf("failed to load notification message: %v", err) 163 | } 164 | 165 | *r = BlacklistRequest{ 166 | Raw: wrapper.Message, 167 | Timestamp: wrapper.Timestamp, 168 | Reason: msg.NotificationType, 169 | Recipients: recipients, 170 | } 171 | return nil 172 | } 173 | -------------------------------------------------------------------------------- /email/email_test.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/EFForg/starttls-backend/util" 9 | ) 10 | 11 | type mockBlacklistStore struct { 12 | blacklist map[string]bool 13 | } 14 | 15 | func (b *mockBlacklistStore) PutBlacklistedEmail(email string, reason string, timestamp string) error { 16 | b.blacklist[email] = true 17 | return nil 18 | } 19 | 20 | func (b *mockBlacklistStore) IsBlacklistedEmail(email string) (bool, error) { 21 | return b.blacklist[email], nil 22 | } 23 | 24 | func newMockStore() *mockBlacklistStore { 25 | return &mockBlacklistStore{ 26 | blacklist: make(map[string]bool), 27 | } 28 | } 29 | 30 | func TestValidationEmailText(t *testing.T) { 31 | content := validationEmailText("example.com", "contact@example.com", []string{"mx.example.com, .mx.example.com"}, "abcd", "https://fake.starttls-everywhere.website") 32 | if !strings.Contains(content, "https://fake.starttls-everywhere.website/validate?abcd") { 33 | t.Errorf("E-mail formatted incorrectly.") 34 | } 35 | } 36 | 37 | func shouldPanic(t *testing.T, message string) { 38 | if r := recover(); r == nil { 39 | t.Errorf(message) 40 | } 41 | } 42 | 43 | func TestRequireMissingEnvPanics(t *testing.T) { 44 | varErrs := util.Errors{} 45 | util.RequireEnv("FAKE_ENV_VAR", &varErrs) 46 | if len(varErrs) == 0 { 47 | t.Errorf("should have received an error") 48 | } 49 | } 50 | 51 | func TestRequireEnvConfig(t *testing.T) { 52 | requiredVars := map[string]string{ 53 | "SMTP_USERNAME": "", 54 | "SMTP_PASSWORD": "", 55 | "SMTP_ENDPOINT": "", 56 | "SMTP_PORT": "", 57 | "SMTP_FROM_ADDRESS": "", 58 | "FRONTEND_WEBSITE_LINK": ""} 59 | for varName := range requiredVars { 60 | requiredVars[varName] = os.Getenv(varName) 61 | os.Setenv(varName, "") 62 | } 63 | _, err := MakeConfigFromEnv(nil) 64 | if err == nil { 65 | t.Errorf("should have received multiple error from unset env vars") 66 | } 67 | for varName, varValue := range requiredVars { 68 | os.Setenv(varName, varValue) 69 | } 70 | } 71 | 72 | func TestSendEmailToBlacklistedAddressFails(t *testing.T) { 73 | mockStore := newMockStore() 74 | err := mockStore.PutBlacklistedEmail("fail@example.com", "bounce", "2017-07-21T18:47:13.498Z") 75 | if err != nil { 76 | t.Errorf("PutBlacklistedEmail failed: %v\n", err) 77 | } 78 | c := &Config{database: mockStore} 79 | err = c.sendEmail("Subject", "Body", "fail@example.com") 80 | if err == nil || !strings.Contains(err.Error(), "blacklisted") { 81 | t.Error("attempting to send mail to blacklisted address should fail") 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /email/template.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | const validationEmailSubject = "Email validation for STARTTLS Policy List submission" 4 | const validationEmailTemplate = ` 5 | Hey there! 6 | 7 | It looks like you requested *%[1]s* to be added to the STARTTLS Policy List, with hostnames %[2]s and contact email %[5]s. If this was you, visit 8 | 9 | %[3]s/validate?%[4]s 10 | 11 | to confirm! If this wasn't you, please let us know at starttls-policy@eff.org. 12 | 13 | Once you confirm your email address, your domain will be queued for addition some time in the next couple of weeks. We will continue to run validation checks (%[3]s/policy-list#add) against your email server until then. *%[1]s* will be added to the STARTTLS Policy List as long as it has continued to pass our tests! 14 | 15 | Remember to read our guidelines (%[3]s/policy-list) about the requirements your mailserver must meet, and continue to meet, in order to stay on the list. If your mailserver ceases to meet these requirements at any point and is at risk of facing deliverability issues, we will notify you through this email address. 16 | 17 | We also recommend signing up for the STARTTLS Everywhere mailing list at https://lists.eff.org/mailman/listinfo/starttls-everywhere in order to stay up to date on new features, changes to policies, and updates to the project. (This is a low-volume mailing list.) 18 | 19 | Thanks for helping us secure email for everyone :) 20 | ` 21 | -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$DB_MIGRATE" = "true" ]; then 4 | # Perform ouststanding DB migrations 5 | PGPASSWORD=$DB_PASSWORD psql -h $DB_HOST -U $DB_USERNAME $DB_NAME -f ./db/scripts/init_tables.sql 6 | fi 7 | 8 | exec "$@" 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/EFForg/starttls-backend 2 | 3 | go 1.11 4 | 5 | require ( 6 | github.com/certifi/gocertifi v0.0.0-20190506164543-d2eda7129713 // indirect 7 | github.com/davecgh/go-spew v1.1.1 // indirect 8 | github.com/getsentry/raven-go v0.2.0 9 | github.com/gorilla/handlers v1.4.0 10 | github.com/joho/godotenv v1.3.0 11 | github.com/lib/pq v1.1.1 12 | github.com/mhale/smtpd v0.0.0-20181125220505-3c4c908952b8 13 | github.com/pkg/errors v0.8.1 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | github.com/stretchr/testify v1.2.2 // indirect 16 | github.com/ulule/limiter v2.2.2+incompatible 17 | golang.org/x/net v0.0.0-20190611141213-3f473d35a33a 18 | ) 19 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/certifi/gocertifi v0.0.0-20190506164543-d2eda7129713 h1:UNOqI3EKhvbqV8f1Vm3NIwkrhq388sGCeAH2Op7w0rc= 2 | github.com/certifi/gocertifi v0.0.0-20190506164543-d2eda7129713/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= 6 | github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= 7 | github.com/gorilla/handlers v1.4.0 h1:XulKRWSQK5uChr4pEgSE4Tc/OcmnU9GJuSwdog/tZsA= 8 | github.com/gorilla/handlers v1.4.0/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= 9 | github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= 10 | github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= 11 | github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4= 12 | github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 13 | github.com/mhale/smtpd v0.0.0-20181125220505-3c4c908952b8 h1:DuLRJOD3tr0rbrwDXXw5mw8YRPl70y8RbFpUtCjzOkU= 14 | github.com/mhale/smtpd v0.0.0-20181125220505-3c4c908952b8/go.mod h1:qqKwvL5sfYgFxcMy96Kjx3TCorMfDaQBvmEL2nvdidc= 15 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 16 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 17 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 18 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 19 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 20 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 21 | github.com/ulule/limiter v2.2.2+incompatible h1:1lk9jesmps1ziYHHb4doL7l5hFkYYYA3T8dkNyw7ffY= 22 | github.com/ulule/limiter v2.2.2+incompatible/go.mod h1:VJx/ZNGmClQDS5F6EmsGqK8j3jz1qJYZ6D9+MdAD+kw= 23 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 24 | golang.org/x/net v0.0.0-20190611141213-3f473d35a33a h1:+KkCgOMgnKSgenxTBoiwkMqTiouMIy/3o8RLdmSbGoY= 25 | golang.org/x/net v0.0.0-20190611141213-3f473d35a33a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 26 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 27 | golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= 28 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 29 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "log" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "strings" 11 | "time" 12 | 13 | "github.com/EFForg/starttls-backend/api" 14 | "github.com/EFForg/starttls-backend/db" 15 | "github.com/EFForg/starttls-backend/email" 16 | "github.com/EFForg/starttls-backend/policy" 17 | "github.com/EFForg/starttls-backend/stats" 18 | "github.com/EFForg/starttls-backend/util" 19 | "github.com/EFForg/starttls-backend/validator" 20 | 21 | "github.com/getsentry/raven-go" 22 | _ "github.com/joho/godotenv/autoload" 23 | ) 24 | 25 | // ServePublicEndpoints serves all public HTTP endpoints. 26 | func ServePublicEndpoints(a *api.API, cfg *db.Config) { 27 | mux := http.NewServeMux() 28 | mainHandler := a.RegisterHandlers(mux) 29 | 30 | portString, err := util.ValidPort(cfg.Port) 31 | if err != nil { 32 | log.Fatal(err) 33 | } 34 | 35 | server := http.Server{ 36 | Addr: portString, 37 | Handler: mainHandler, 38 | } 39 | 40 | exited := make(chan struct{}) 41 | go func() { 42 | sigint := make(chan os.Signal, 1) 43 | signal.Notify(sigint, os.Interrupt) 44 | <-sigint 45 | 46 | if err := server.Shutdown(context.Background()); err != nil { 47 | log.Printf("HTTP server Shutdown: %v", err) 48 | } 49 | close(exited) 50 | }() 51 | 52 | log.Fatal(server.ListenAndServe()) 53 | <-exited 54 | } 55 | 56 | // Loads a map of domains (effectively a set for fast lookup) to blacklist. 57 | // if `DOMAIN_BLACKLIST` is not set, returns an empty map. 58 | func loadDontScan() map[string]bool { 59 | filepath := os.Getenv("DOMAIN_BLACKLIST") 60 | if len(filepath) == 0 { 61 | return make(map[string]bool) 62 | } 63 | data, err := ioutil.ReadFile(filepath) 64 | if err != nil { 65 | log.Fatal(err) 66 | } 67 | domainlist := strings.Split(string(data), "\n") 68 | domainset := make(map[string]bool) 69 | for _, domain := range domainlist { 70 | if len(domain) > 0 { 71 | domainset[domain] = true 72 | } 73 | } 74 | return domainset 75 | } 76 | 77 | func main() { 78 | raven.SetDSN(os.Getenv("SENTRY_URL")) 79 | 80 | cfg, err := db.LoadEnvironmentVariables() 81 | if err != nil { 82 | log.Fatal(err) 83 | } 84 | db, err := db.InitSQLDatabase(cfg) 85 | if err != nil { 86 | log.Fatal(err) 87 | } 88 | emailConfig, err := email.MakeConfigFromEnv(db) 89 | if err != nil { 90 | log.Printf("couldn't connect to mailserver: %v", err) 91 | log.Println("======NOT SENDING EMAIL======") 92 | } 93 | list := policy.MakeUpdatedList() 94 | a := api.API{ 95 | Database: db, 96 | List: list, 97 | DontScan: loadDontScan(), 98 | Emailer: emailConfig, 99 | } 100 | a.ParseTemplates("views") 101 | if os.Getenv("VALIDATE_LIST") == "1" { 102 | log.Println("[Starting list validator]") 103 | go validator.ValidateRegularly("Live policy list", list, 24*time.Hour) 104 | } 105 | if os.Getenv("VALIDATE_QUEUED") == "1" { 106 | log.Println("[Starting queued validator]") 107 | go validator.ValidateRegularly("Testing domains", db, 24*time.Hour) 108 | } 109 | go stats.UpdateRegularly(db, time.Hour) 110 | ServePublicEndpoints(&a, &cfg) 111 | } 112 | -------------------------------------------------------------------------------- /models/domain.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "time" 7 | 8 | "github.com/EFForg/starttls-backend/checker" 9 | ) 10 | 11 | /* Domain represents an email domain's TLS policy. 12 | * 13 | * If there's a Domain object for a particular email domain in "Enforce" mode, 14 | * that email domain's policy is fixed and cannot be changed. 15 | */ 16 | 17 | // Domain stores the preload state of a single domain. 18 | type Domain struct { 19 | Name string `json:"domain"` // Domain that is preloaded 20 | Email string `json:"-"` // Contact e-mail for Domain 21 | MXs []string `json:"mxs"` // MXs that are valid for this domain 22 | MTASTS bool `json:"mta_sts"` 23 | State DomainState `json:"state"` 24 | LastUpdated time.Time `json:"last_updated"` 25 | TestingStart time.Time `json:"-"` 26 | QueueWeeks int `json:"queue_weeks"` 27 | } 28 | 29 | // domainStore is a simple interface for fetching and adding domain objects. 30 | type domainStore interface { 31 | PutDomain(Domain) error 32 | GetDomain(string, DomainState) (Domain, error) 33 | GetDomains(DomainState) ([]Domain, error) 34 | SetStatus(string, DomainState) error 35 | RemoveDomain(string, DomainState) (Domain, error) 36 | } 37 | 38 | // DomainState represents the state of a single domain. 39 | type DomainState string 40 | 41 | // Possible values for DomainState 42 | const ( 43 | StateUnknown = "unknown" // Domain was never submitted, so we don't know. 44 | StateUnconfirmed = "unvalidated" // Administrator has not yet confirmed their intention to add the domain. 45 | StateTesting = "queued" // Queued for addition at next addition date pending continued validation 46 | StateFailed = "failed" // Requested to be queued, but failed verification. 47 | StateEnforce = "added" // On the list. 48 | ) 49 | 50 | type policyList interface { 51 | HasDomain(string) bool 52 | } 53 | 54 | // IsQueueable returns true if a domain can be submitted for validation and 55 | // queueing to the STARTTLS Everywhere Policy List. 56 | // A successful scan should already have been submitted for this domain, 57 | // and it should not already be on the policy list. 58 | // Returns (queuability, error message, and most recent scan) 59 | func (d *Domain) IsQueueable(domains domainStore, scans scanStore, list policyList) (bool, string, Scan) { 60 | scan, err := scans.GetLatestScan(d.Name) 61 | if err != nil { 62 | return false, "We haven't scanned this domain yet. " + 63 | "Please use the STARTTLS checker to scan your domain's " + 64 | "STARTTLS configuration so we can validate your submission", scan 65 | } 66 | if scan.Data.Status != 0 { 67 | return false, "Domain hasn't passed our STARTTLS security checks", scan 68 | } 69 | if list.HasDomain(d.Name) { 70 | return false, "Domain is already on the policy list!", scan 71 | } 72 | if _, err := domains.GetDomain(d.Name, StateEnforce); err == nil { 73 | return false, "Domain is already on the policy list!", scan 74 | } 75 | // Domains without submitted MTA-STS support must match provided mx patterns. 76 | if !d.MTASTS { 77 | for _, hostname := range scan.Data.PreferredHostnames { 78 | if !checker.PolicyMatches(hostname, d.MXs) { 79 | return false, fmt.Sprintf("Hostnames %v do not match policy %v", scan.Data.PreferredHostnames, d.MXs), scan 80 | } 81 | } 82 | } else if !scan.SupportsMTASTS() { 83 | return false, "Domain does not correctly implement MTA-STS.", scan 84 | } 85 | return true, "", scan 86 | } 87 | 88 | // PopulateFromScan updates a Domain's fields based on a scan of that domain. 89 | func (d *Domain) PopulateFromScan(scan Scan) { 90 | // We should only trust MTA-STS info from a successful MTA-STS check. 91 | if d.MTASTS && scan.SupportsMTASTS() { 92 | // If the domain's MXs are missing, we can take them from the scan's 93 | // PreferredHostnames, which must be a subset of those listed in the 94 | // MTA-STS policy file. 95 | if len(d.MXs) == 0 { 96 | d.MXs = scan.Data.MTASTSResult.MXs 97 | } 98 | } 99 | } 100 | 101 | // InitializeWithToken adds this domain to the given DomainStore and initializes a validation token 102 | // for the addition. The newly generated Token is returned. 103 | func (d *Domain) InitializeWithToken(store domainStore, tokens tokenStore) (string, error) { 104 | if err := store.PutDomain(*d); err != nil { 105 | return "", err 106 | } 107 | token, err := tokens.PutToken(d.Name) 108 | if err != nil { 109 | return "", err 110 | } 111 | return token.Token, nil 112 | } 113 | 114 | // PolicyListCheck checks the policy list status of this particular domain. 115 | func (d *Domain) PolicyListCheck(store domainStore, list policyList) *checker.Result { 116 | result := checker.Result{Name: checker.PolicyList} 117 | if list.HasDomain(d.Name) { 118 | return result.Success() 119 | } 120 | domain, err := GetDomain(store, d.Name) 121 | if err != nil { 122 | return result.Failure("Domain %s is not on the policy list.", d.Name) 123 | } 124 | if domain.State == StateEnforce { 125 | log.Println("Warning: Domain was StateEnforce in DB but was not found on the policy list.") 126 | return result.Success() 127 | } 128 | if domain.State == StateTesting { 129 | return result.Warning("Domain %s is queued to be added to the policy list.", d.Name) 130 | } 131 | if domain.State == StateUnconfirmed { 132 | return result.Failure("The policy addition request for %s is waiting on email validation", d.Name) 133 | } 134 | return result.Failure("Domain %s is not on the policy list.", d.Name) 135 | } 136 | 137 | // AsyncPolicyListCheck performs PolicyListCheck asynchronously. 138 | // domainStore and policyList should be safe for concurrent use. 139 | func (d Domain) AsyncPolicyListCheck(store domainStore, list policyList) <-chan checker.Result { 140 | result := make(chan checker.Result) 141 | go func() { result <- *d.PolicyListCheck(store, list) }() 142 | return result 143 | } 144 | 145 | // GetDomain retrieves Domain with the most "important" state. 146 | // At any given time, there can only be one domain that's either StateEnforce 147 | // or StateTesting. If that domain exists in the store, return that one. 148 | // Otherwise, look for a Domain policy in the unconfirmed state. 149 | func GetDomain(store domainStore, name string) (Domain, error) { 150 | domain, err := store.GetDomain(name, StateEnforce) 151 | if err == nil { 152 | return domain, nil 153 | } 154 | domain, err = store.GetDomain(name, StateTesting) 155 | if err == nil { 156 | return domain, nil 157 | } 158 | domain, err = store.GetDomain(name, StateUnconfirmed) 159 | if err == nil { 160 | return domain, nil 161 | } 162 | return store.GetDomain(name, StateFailed) 163 | } 164 | -------------------------------------------------------------------------------- /models/domain_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/EFForg/starttls-backend/checker" 9 | ) 10 | 11 | type mockDomainStore struct { 12 | domain Domain 13 | domains []Domain 14 | err error 15 | } 16 | 17 | func (m *mockDomainStore) PutDomain(d Domain) error { 18 | m.domain = d 19 | return m.err 20 | } 21 | 22 | func (m *mockDomainStore) SetStatus(d string, status DomainState) error { 23 | m.domain.State = status 24 | return m.err 25 | } 26 | 27 | func (m *mockDomainStore) GetDomain(d string, state DomainState) (Domain, error) { 28 | domain := m.domain 29 | if state != domain.State { 30 | return m.domain, errors.New("") 31 | } 32 | return m.domain, nil 33 | } 34 | 35 | func (m *mockDomainStore) GetDomains(_ DomainState) ([]Domain, error) { 36 | return m.domains, m.err 37 | } 38 | 39 | func (m *mockDomainStore) RemoveDomain(d string, state DomainState) (Domain, error) { 40 | domain := m.domain 41 | if state != domain.State { 42 | return m.domain, errors.New("") 43 | } 44 | return m.domain, nil 45 | } 46 | 47 | type mockList struct { 48 | hasDomain bool 49 | } 50 | 51 | func (m mockList) HasDomain(string) bool { return m.hasDomain } 52 | 53 | type mockScanStore struct { 54 | scan Scan 55 | err error 56 | } 57 | 58 | func (m mockScanStore) GetLatestScan(string) (Scan, error) { return m.scan, m.err } 59 | 60 | func TestIsQueueable(t *testing.T) { 61 | // With supplied hostnames 62 | d := Domain{ 63 | Name: "example.com", 64 | Email: "me@example.com", 65 | MXs: []string{".example.com"}, 66 | } 67 | goodScan := Scan{ 68 | Data: checker.DomainResult{ 69 | PreferredHostnames: []string{"mx1.example.com", "mx2.example.com"}, 70 | MTASTSResult: checker.MakeMTASTSResult(), 71 | }, 72 | } 73 | failedScan := Scan{ 74 | Data: checker.DomainResult{Status: checker.DomainFailure}, 75 | } 76 | wrongMXsScan := Scan{ 77 | Data: checker.DomainResult{ 78 | PreferredHostnames: []string{"mx1.nomatch.example.com"}, 79 | }, 80 | } 81 | var testCases = []struct { 82 | name string 83 | scan Scan 84 | scanErr error 85 | state DomainState 86 | onList bool 87 | ok bool 88 | msg string 89 | }{ 90 | {name: "Unadded domain with passing scan should be queueable", 91 | scan: goodScan, scanErr: nil, onList: false, 92 | ok: true, msg: ""}, 93 | {name: "Domain on policy list should not be queueable", 94 | scan: goodScan, scanErr: nil, onList: true, 95 | ok: false, msg: "already on the policy list"}, 96 | {name: "Enforced domain should not be queueable", 97 | scan: goodScan, scanErr: nil, onList: false, state: StateEnforce, 98 | ok: false, msg: "already on the policy list"}, 99 | {name: "Domain with failing scan should not be queueable", 100 | scan: failedScan, scanErr: nil, onList: false, 101 | ok: false, msg: "hasn't passed"}, 102 | {name: "Domain without scan should not be queueable", 103 | scan: goodScan, scanErr: errors.New(""), onList: false, 104 | ok: false, msg: "haven't scanned"}, 105 | {name: "Domain with mismatched hostnames should not be queueable", 106 | scan: wrongMXsScan, scanErr: nil, onList: false, 107 | ok: false, msg: "do not match policy"}, 108 | } 109 | for _, tc := range testCases { 110 | domainStore := mockDomainStore{domain: Domain{State: tc.state}} 111 | ok, msg, _ := d.IsQueueable(&domainStore, mockScanStore{tc.scan, tc.scanErr}, mockList{tc.onList}) 112 | if ok != tc.ok { 113 | t.Error(tc.name) 114 | } 115 | if !strings.Contains(msg, tc.msg) { 116 | t.Errorf("IsQueueable message should contain %s, got %s", tc.msg, msg) 117 | } 118 | } 119 | // With MTA-STS 120 | d = Domain{ 121 | Name: "example.com", 122 | Email: "me@example.com", 123 | MTASTS: true, 124 | } 125 | domainStore := mockDomainStore{err: errors.New("")} 126 | ok, msg, _ := d.IsQueueable(&domainStore, mockScanStore{goodScan, nil}, mockList{false}) 127 | if !ok { 128 | t.Error("Unadded domain with passing scan should be queueable, got " + msg) 129 | } 130 | noMTASTSScan := Scan{ 131 | Data: checker.DomainResult{ 132 | MTASTSResult: &checker.MTASTSResult{ 133 | Result: &checker.Result{ 134 | Status: checker.Failure, 135 | }, 136 | }, 137 | }, 138 | } 139 | ok, msg, _ = d.IsQueueable(&domainStore, mockScanStore{noMTASTSScan, nil}, mockList{false}) 140 | if ok || !strings.Contains(msg, "MTA-STS") { 141 | t.Error("Domain without MTA-STS or hostnames should not be queueable, got " + msg) 142 | } 143 | } 144 | 145 | func TestPopulateFromScan(t *testing.T) { 146 | d := Domain{ 147 | Name: "example.com", 148 | Email: "me@example.com", 149 | MTASTS: true, 150 | } 151 | s := Scan{ 152 | Data: checker.DomainResult{ 153 | MTASTSResult: checker.MakeMTASTSResult(), 154 | }, 155 | } 156 | s.Data.MTASTSResult.MXs = []string{"mx1.example.com", "mx2.example.com"} 157 | d.PopulateFromScan(s) 158 | for i, mx := range s.Data.MTASTSResult.MXs { 159 | if mx != d.MXs[i] { 160 | t.Errorf("Expected MXs to match scan, got %s", d.MXs) 161 | } 162 | } 163 | } 164 | 165 | func TestPolicyCheck(t *testing.T) { 166 | var testCases = []struct { 167 | name string 168 | onList bool 169 | state DomainState 170 | inDB bool 171 | expected checker.Status 172 | }{ 173 | {"Domain on the list should return success", true, StateEnforce, false, checker.Success}, 174 | {"Domain in DB as enforce should return success", false, StateEnforce, true, checker.Success}, 175 | {"Domain queued should return a warning", false, StateTesting, true, checker.Warning}, 176 | {"Unconfirmed domain should return a failure", false, StateUnconfirmed, true, checker.Failure}, 177 | {"Domain not currently in the DB or on the list should return a failure", false, StateUnconfirmed, false, checker.Failure}, 178 | } 179 | for _, tc := range testCases { 180 | domainObj := Domain{Name: "example.com", State: tc.state} 181 | var dbErr error 182 | if !tc.inDB { 183 | dbErr = errors.New("") 184 | } 185 | result := domainObj.PolicyListCheck(&mockDomainStore{domain: domainObj, err: dbErr}, mockList{tc.onList}) 186 | if result.Status != tc.expected { 187 | t.Error(tc.name) 188 | } 189 | } 190 | } 191 | 192 | func TestInitializeWithToken(t *testing.T) { 193 | mockToken := mockTokenStore{domain: "domain", err: nil} 194 | domainObj := Domain{Name: "example.com"} 195 | // domainStore returns error 196 | _, err := domainObj.InitializeWithToken(&mockDomainStore{domain: domainObj, err: errors.New("")}, &mockToken) 197 | if err == nil { 198 | t.Error("Expected InitializeWithToken to forward error message from DB") 199 | } 200 | if mockToken.token != nil { 201 | t.Error("Token should not have been set if domain not found") 202 | } 203 | _, err = domainObj.InitializeWithToken(&mockDomainStore{domain: domainObj}, &mockTokenStore{err: errors.New("")}) 204 | if err == nil { 205 | t.Error("Expected InitializeWithToken to forward error message from DB") 206 | } 207 | domainObj.InitializeWithToken(&mockDomainStore{domain: domainObj, err: nil}, &mockToken) 208 | if mockToken.token == nil { 209 | t.Error("Token should have been set for domain") 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /models/scan.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/EFForg/starttls-backend/checker" 7 | ) 8 | 9 | // ScanVersion is the version of the Scan API that the binary is currently using. 10 | const ScanVersion = 1 11 | 12 | // Scan stores the result of a scan of a domain 13 | type Scan struct { 14 | Domain string `json:"domain"` // Input domain 15 | Data checker.DomainResult `json:"scandata"` // Scan results from starttls-checker 16 | Timestamp time.Time `json:"timestamp"` // Time at which this scan was conducted 17 | Version uint32 `json:"version"` // Version counter 18 | } 19 | 20 | type scanStore interface { 21 | GetLatestScan(string) (Scan, error) 22 | } 23 | 24 | // CanAddToPolicyList returns true if the domain owner should be prompted to 25 | // add their domain to the STARTTLS Everywhere Policy List. 26 | func (s Scan) CanAddToPolicyList() bool { 27 | if policyResult, ok := s.Data.ExtraResults[checker.PolicyList]; ok { 28 | return s.Data.Status == checker.DomainSuccess && 29 | policyResult.Status == checker.Failure 30 | } 31 | return false 32 | } 33 | 34 | // SupportsMTASTS returns true if the Scan's MTA-STS check passed. 35 | func (s Scan) SupportsMTASTS() bool { 36 | if s.Data.MTASTSResult == nil { 37 | return false 38 | } 39 | return s.Data.MTASTSResult.Status == checker.Success || s.Data.MTASTSResult.Status == checker.Warning 40 | } 41 | -------------------------------------------------------------------------------- /models/token.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "time" 4 | 5 | // Token stores the state of an email verification token. 6 | type Token struct { 7 | Domain string `json:"domain"` // Domain for which we're verifying the e-mail. 8 | Token string `json:"token"` // Token that we're expecting. 9 | Expires time.Time `json:"expires"` // When this token expires. 10 | Used bool `json:"used"` // Whether this token was used. 11 | } 12 | 13 | // tokenStore is the interface for performing actions with tokens. 14 | type tokenStore interface { 15 | PutToken(string) (Token, error) 16 | UseToken(string) (string, error) 17 | } 18 | 19 | // Redeem redeems this Token, and updates its entry in the associated domain and token 20 | // database stores. Returns the domain name that this token was generated for. 21 | func (t *Token) Redeem(store domainStore, tokens tokenStore) (ret string, userErr error, dbErr error) { 22 | domain, err := tokens.UseToken(t.Token) 23 | if err != nil { 24 | return domain, err, nil 25 | } 26 | domainData, err := store.GetDomain(domain, StateUnconfirmed) 27 | if err != nil { 28 | return domain, nil, err 29 | } 30 | domainOnList, err := GetDomain(store, domainData.Name) 31 | if err != nil { 32 | return domain, nil, err 33 | } 34 | if domainOnList.State != StateUnconfirmed { 35 | store.RemoveDomain(domainData.Name, domainOnList.State) 36 | } 37 | err = store.SetStatus(domainData.Name, StateTesting) 38 | return domain, nil, err 39 | } 40 | -------------------------------------------------------------------------------- /models/token_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | type mockTokenStore struct { 9 | token *Token 10 | domain string 11 | err error 12 | } 13 | 14 | func (m *mockTokenStore) PutToken(domain string) (Token, error) { 15 | m.token = &Token{Domain: domain, Token: "token"} 16 | return *m.token, m.err 17 | } 18 | 19 | func (m *mockTokenStore) UseToken(token string) (string, error) { 20 | return m.domain, m.err 21 | } 22 | 23 | func TestRedeemToken(t *testing.T) { 24 | domains := mockDomainStore{domain: Domain{Name: "anything", State: StateUnconfirmed}, err: nil} 25 | token := Token{Token: "token"} 26 | domain, userErr, dbErr := token.Redeem(&domains, &mockTokenStore{domain: "anything", err: nil}) 27 | if domain != "anything" || userErr != nil || dbErr != nil { 28 | t.Error("Expected token redeem to succeed") 29 | } 30 | if domains.domain.State != StateTesting { 31 | t.Error("Expected PutDomain to have upgraded domain State") 32 | } 33 | } 34 | 35 | func TestRedeemTokenFailures(t *testing.T) { 36 | token := Token{Token: "token"} 37 | _, userErr, _ := token.Redeem(&mockDomainStore{err: nil}, &mockTokenStore{err: errors.New("")}) 38 | if userErr == nil { 39 | t.Error("Errors reported from the token store should be interpreted as usage error (token already used, or doesn't exist)") 40 | } 41 | _, _, dbErr := token.Redeem(&mockDomainStore{err: errors.New("")}, &mockTokenStore{err: nil}) 42 | if dbErr == nil { 43 | t.Error("Errors reported from the domain store should be interpreted as a hard failure") 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /policy/policy.go: -------------------------------------------------------------------------------- 1 | package policy 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // policyURL is the default URL from which to fetch the policy JSON. 14 | const policyURL = "https://dl.eff.org/starttls-everywhere/policy.json" 15 | 16 | // TLSPolicy dictates the policy for a particular email domain. 17 | type TLSPolicy struct { 18 | PolicyAlias string `json:"policy-alias,omitempty"` 19 | Mode string `json:"mode,omitempty"` 20 | MXs []string `json:"mxs,omitempty"` 21 | } 22 | 23 | // List is a raw representation of the policy list. 24 | type List struct { 25 | Timestamp time.Time `json:"timestamp"` 26 | Expires time.Time `json:"expires"` 27 | Version string `json:"version"` 28 | Author string `json:"author"` 29 | PolicyAliases map[string]TLSPolicy `json:"policy-aliases"` 30 | Policies map[string]TLSPolicy `json:"policies"` 31 | } 32 | 33 | // Add adds a particular domain's policy to the list. 34 | func (l *List) Add(domain string, policy TLSPolicy) { 35 | l.Policies[domain] = policy 36 | } 37 | 38 | // get retrieves the TLSPolicy for a domain, and resolves 39 | // aliases if they exist. 40 | func (l *List) get(domain string) (TLSPolicy, error) { 41 | policy, ok := l.Policies[domain] 42 | if !ok { 43 | return TLSPolicy{}, fmt.Errorf("policy for domain %s doesn't exist", domain) 44 | } 45 | if len(policy.PolicyAlias) > 0 { 46 | policy, ok = l.PolicyAliases[policy.PolicyAlias] 47 | if !ok { 48 | return TLSPolicy{}, fmt.Errorf("policy alias for domain %s doesn't exist", domain) 49 | } 50 | } 51 | return policy, nil 52 | } 53 | 54 | // UpdatedList wraps a list that is updated from a remote 55 | // policyURL every hour. Safe for concurrent calls to `Get`. 56 | type UpdatedList struct { 57 | mu sync.RWMutex 58 | *List 59 | } 60 | 61 | // DomainsToValidate [interface Validator] retrieves domains from the 62 | // DB whose policies should be validated. 63 | func (l *UpdatedList) DomainsToValidate() ([]string, error) { 64 | l.mu.RLock() 65 | defer l.mu.RUnlock() 66 | domains := []string{} 67 | for domain := range l.Policies { 68 | domains = append(domains, domain) 69 | } 70 | return domains, nil 71 | } 72 | 73 | // HostnamesForDomain [interface Validator] retrieves the hostname policy for 74 | // a particular domain. 75 | func (l *UpdatedList) HostnamesForDomain(domain string) ([]string, error) { 76 | policy, err := l.Get(domain) 77 | if err != nil { 78 | return []string{}, err 79 | } 80 | return policy.MXs, nil 81 | } 82 | 83 | // Get safely reads from the underlying policy list and returns a TLSPolicy for a domain 84 | func (l *UpdatedList) Get(domain string) (TLSPolicy, error) { 85 | l.mu.RLock() 86 | defer l.mu.RUnlock() 87 | return l.get(domain) 88 | } 89 | 90 | // HasDomain returns true if a domain is present on the policy list. 91 | func (l *UpdatedList) HasDomain(domain string) bool { 92 | _, err := l.Get(domain) 93 | return err == nil 94 | } 95 | 96 | // Raw returns a raw List struct, copied from the underlying one 97 | func (l *UpdatedList) Raw() List { 98 | l.mu.RLock() 99 | defer l.mu.RUnlock() 100 | list := *l.List 101 | list.Timestamp = l.Timestamp 102 | list.Expires = l.Expires 103 | list.PolicyAliases = make(map[string]TLSPolicy) 104 | for alias, policy := range l.PolicyAliases { 105 | list.PolicyAliases[alias] = policy.clone() 106 | } 107 | list.Policies = make(map[string]TLSPolicy) 108 | for domain, policy := range l.Policies { 109 | list.Policies[domain] = policy.clone() 110 | } 111 | return list 112 | } 113 | 114 | func (p TLSPolicy) clone() TLSPolicy { 115 | policy := p 116 | policy.MXs = make([]string, 0) 117 | for _, mx := range p.MXs { 118 | policy.MXs = append(policy.MXs, mx) 119 | } 120 | return policy 121 | } 122 | 123 | // fetchListFn returns a new policy list. It can be used to update UpdatedList 124 | type fetchListFn func() (List, error) 125 | 126 | // Retrieve and parse List from policyURL 127 | func fetchListHTTP() (List, error) { 128 | resp, err := http.Get(policyURL) 129 | if err != nil { 130 | return List{}, err 131 | } 132 | defer resp.Body.Close() 133 | body, err := ioutil.ReadAll(resp.Body) 134 | var policyList List 135 | err = json.Unmarshal(body, &policyList) 136 | if err != nil { 137 | return List{}, err 138 | } 139 | return policyList, nil 140 | } 141 | 142 | // Get a new policy list and safely assign it the UpdatedList 143 | func (l *UpdatedList) update(fetch fetchListFn) { 144 | newList, err := fetch() 145 | if err != nil { 146 | log.Printf("Error updating policy list: %s\n", err) 147 | } else { 148 | l.mu.Lock() 149 | l.List = &newList 150 | l.mu.Unlock() 151 | } 152 | } 153 | 154 | // makeUpdatedList constructs an UpdatedList object and launches a 155 | // thread to continually update it. Accepts a fetchListFn to allow 156 | // stubbing http request to remote policy list. 157 | func makeUpdatedList(fetch fetchListFn, updateFrequency time.Duration) *UpdatedList { 158 | l := UpdatedList{List: &List{}} 159 | l.update(fetch) 160 | 161 | go func() { 162 | for { 163 | l.update(fetch) 164 | time.Sleep(updateFrequency) 165 | } 166 | }() 167 | return &l 168 | } 169 | 170 | // MakeUpdatedList wraps makeUpdatedList to use FetchListHTTP by default to update policy list 171 | func MakeUpdatedList() *UpdatedList { 172 | return makeUpdatedList(fetchListHTTP, time.Hour) 173 | } 174 | -------------------------------------------------------------------------------- /policy/policy_test.go: -------------------------------------------------------------------------------- 1 | package policy 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | var mockList = List{ 11 | Policies: map[string]TLSPolicy{ 12 | "eff.org": TLSPolicy{Mode: "testing"}, 13 | }, 14 | } 15 | 16 | func mockFetchHTTP() (List, error) { 17 | return mockList, nil 18 | } 19 | 20 | func mockErroringFetchHTTP() (List, error) { 21 | return List{}, fmt.Errorf("something went wrong") 22 | } 23 | 24 | func TestGetPolicy(t *testing.T) { 25 | list := makeUpdatedList(mockFetchHTTP, time.Hour) 26 | 27 | policy, err := list.Get("not-on-the-List.com") 28 | if err == nil { 29 | t.Error("Getting the policy for an unListed domain should return an error") 30 | } 31 | 32 | policy, err = list.Get("eff.org") 33 | if err != nil { 34 | t.Errorf("Unexpected error while getting policy: %s", err) 35 | } 36 | if !reflect.DeepEqual(policy, mockList.Policies["eff.org"]) { 37 | t.Errorf("Expected policy for eff.org to be %v, got %v", mockList.Policies["eff.org"], policy) 38 | } 39 | } 40 | 41 | func TestHasDomain(t *testing.T) { 42 | list := makeUpdatedList(mockFetchHTTP, time.Hour) 43 | 44 | if list.HasDomain("not-on-the-List.com") { 45 | t.Error("Calling HasDomain for an unListed domain should return false") 46 | } 47 | 48 | if !list.HasDomain("eff.org") { 49 | t.Error("Calling HasDomain for a Listed domain should return true") 50 | } 51 | } 52 | 53 | func TestFailedListUpdate(t *testing.T) { 54 | list := makeUpdatedList(mockErroringFetchHTTP, time.Hour) 55 | _, err := list.Get("eff.org") 56 | if err == nil { 57 | t.Errorf("Get should return an error if fetching the List fails") 58 | } 59 | } 60 | 61 | func TestListUpdate(t *testing.T) { 62 | var updatedList = List{Policies: map[string]TLSPolicy{}} 63 | list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second) 64 | _, err := list.Get("example.com") 65 | if err == nil { 66 | t.Error("Getting the policy for an unListed domain should return an error") 67 | } 68 | // Update the List! 69 | updatedList.Policies["example.com"] = TLSPolicy{Mode: "testing"} 70 | time.Sleep(time.Second * 2) 71 | policy, err := list.Get("example.com") 72 | if err != nil { 73 | t.Errorf("Unexpected error while getting policy: %s", err) 74 | } 75 | if !reflect.DeepEqual(policy, updatedList.Policies["example.com"]) { 76 | t.Errorf("Expected policy for example.com to be %v, got %v", mockList.Policies["eff.org"], policy) 77 | } 78 | } 79 | 80 | func TestDomainsToValidate(t *testing.T) { 81 | var updatedList = List{Policies: map[string]TLSPolicy{ 82 | "eff.org": TLSPolicy{}, 83 | "example.com": TLSPolicy{}, 84 | }} 85 | list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second) 86 | domains, err := list.DomainsToValidate() 87 | if err != nil { 88 | t.Fatalf("Encoutnered %v", err) 89 | } 90 | if len(updatedList.Policies) != len(domains) { 91 | t.Fatalf("Expected domains to validate to match policy list, got %s", domains) 92 | } 93 | for _, domain := range domains { 94 | if _, exists := updatedList.Policies[domain]; !exists { 95 | t.Fatalf("Expected domains to validate to match policy list, got %s", domains) 96 | } 97 | } 98 | } 99 | 100 | func TestHostnamesForDomain(t *testing.T) { 101 | hostnames := []string{"a", "b", "c"} 102 | var updatedList = List{Policies: map[string]TLSPolicy{ 103 | "eff.org": TLSPolicy{MXs: hostnames}}} 104 | list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second) 105 | returned, err := list.HostnamesForDomain("eff.org") 106 | if err != nil { 107 | t.Fatalf("Encountered %v", err) 108 | } 109 | if !reflect.DeepEqual(returned, hostnames) { 110 | t.Errorf("Expected %s, got %s", hostnames, returned) 111 | } 112 | } 113 | 114 | func TestCloneDoesntChangeOriginal(t *testing.T) { 115 | var updatedList = List{ 116 | Version: "3", 117 | Policies: map[string]TLSPolicy{ 118 | "eff.org": TLSPolicy{MXs: []string{"a"}}}} 119 | list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Hour) 120 | newList := list.Raw() 121 | // Change new list 122 | newList.Version = "5" 123 | effPolicy := newList.Policies["eff.org"] 124 | effPolicy.MXs = []string{"a", "b"} 125 | list.mu.RLock() 126 | defer list.mu.RUnlock() 127 | if list.Version == "5" || len(list.Policies["eff.org"].MXs) > 1 { 128 | t.Errorf("Expected original to remain unchanged after changing copy") 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /stats/stats.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "os" 10 | "time" 11 | 12 | "github.com/EFForg/starttls-backend/checker" 13 | raven "github.com/getsentry/raven-go" 14 | ) 15 | 16 | // Store wraps storage for MTA-STS adoption statistics. 17 | type Store interface { 18 | PutAggregatedScan(checker.AggregatedScan) error 19 | PutLocalStats(time.Time) (checker.AggregatedScan, error) 20 | GetStats(string) (Series, error) 21 | } 22 | 23 | // Import imports aggregated scans from a remote server to the datastore. 24 | // Expected format is JSONL (newline-separated JSON objects). 25 | func Import(store Store) error { 26 | statsURL := os.Getenv("REMOTE_STATS_URL") 27 | resp, err := http.Get(statsURL) 28 | if err != nil { 29 | return err 30 | } 31 | defer resp.Body.Close() 32 | 33 | s := bufio.NewScanner(resp.Body) 34 | for s.Scan() { 35 | var a checker.AggregatedScan 36 | err := json.Unmarshal(s.Bytes(), &a) 37 | if err != nil { 38 | return err 39 | } 40 | a.Source = checker.TopDomainsSource 41 | err = store.PutAggregatedScan(a) 42 | if err != nil { 43 | return err 44 | } 45 | } 46 | if err := s.Err(); err != nil { 47 | return err 48 | } 49 | return nil 50 | } 51 | 52 | // Update imports aggregated scans and updates our cache table of local scans. 53 | // Log any errors. 54 | func Update(store Store) { 55 | err := Import(store) 56 | if err != nil { 57 | err = fmt.Errorf("Failed to import top domains stats: %v", err) 58 | log.Println(err) 59 | raven.CaptureError(err, nil) 60 | } 61 | // Cache stats for the previous day at midnight. This ensures that we capture 62 | // full days and maintain regularly intervals. 63 | _, err = store.PutLocalStats(time.Now().UTC().Truncate(24 * time.Hour)) 64 | if err != nil { 65 | err = fmt.Errorf("Failed to update local stats: %v", err) 66 | log.Println(err) 67 | raven.CaptureError(err, nil) 68 | } 69 | } 70 | 71 | // UpdateRegularly runs Import to import aggregated stats from a remote server at regular intervals. 72 | func UpdateRegularly(store Store, interval time.Duration) { 73 | for { 74 | Update(store) 75 | <-time.After(interval) 76 | } 77 | } 78 | 79 | // Series represents some statistic as it changes over time. 80 | // This will likely be updated when we know what format our frontend charting 81 | // library prefers. 82 | type Series []checker.AggregatedScan 83 | 84 | // MarshalJSON marshals a Series to the format expected by chart.js. 85 | // See https://www.chartjs.org/docs/latest/ 86 | func (s Series) MarshalJSON() ([]byte, error) { 87 | type xyPt struct { 88 | X time.Time `json:"x"` 89 | Y float64 `json:"y"` 90 | } 91 | xySeries := make([]xyPt, 0) 92 | for _, a := range s { 93 | var y float64 94 | if a.Source != checker.TopDomainsSource { 95 | y = a.PercentMTASTS() 96 | } else { 97 | // Top million scans have too few MTA-STS domains to use a percent, 98 | // display a raw total instead. 99 | y = float64(a.TotalMTASTS()) 100 | } 101 | xySeries = append(xySeries, xyPt{X: a.Time, Y: y}) 102 | } 103 | return json.Marshal(xySeries) 104 | } 105 | 106 | // Get retrieves MTA-STS adoption statistics for user-initiated scans and scans 107 | // of the top million domains over time. 108 | func Get(store Store) (result map[string]Series, err error) { 109 | result = make(map[string]Series) 110 | sources := []string{checker.TopDomainsSource, checker.LocalSource} 111 | for _, source := range sources { 112 | series, err := store.GetStats(source) 113 | if err != nil { 114 | return result, err 115 | } 116 | result[source] = series 117 | } 118 | return result, err 119 | } 120 | -------------------------------------------------------------------------------- /stats/stats_test.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | "testing" 9 | "time" 10 | 11 | "github.com/EFForg/starttls-backend/checker" 12 | ) 13 | 14 | type mockAgScanStore []checker.AggregatedScan 15 | 16 | func (m *mockAgScanStore) PutAggregatedScan(agScan checker.AggregatedScan) error { 17 | *m = append(*m, agScan) 18 | return nil 19 | } 20 | 21 | func (m *mockAgScanStore) PutLocalStats(date time.Time) (checker.AggregatedScan, error) { 22 | a := checker.AggregatedScan{ 23 | Source: checker.LocalSource, 24 | Time: date, 25 | } 26 | *m = append(*m, a) 27 | return a, nil 28 | } 29 | 30 | func (m *mockAgScanStore) GetStats(source string) (Series, error) { 31 | return Series{}, nil 32 | } 33 | 34 | func TestImport(t *testing.T) { 35 | agScans := []checker.AggregatedScan{ 36 | checker.AggregatedScan{ 37 | Time: time.Now().Add(-24 * time.Hour), 38 | Attempted: 4, 39 | WithMXs: 3, 40 | MTASTSTesting: 2, 41 | MTASTSEnforce: 1, 42 | }, 43 | checker.AggregatedScan{ 44 | Time: time.Now(), 45 | Attempted: 8, 46 | WithMXs: 7, 47 | MTASTSTesting: 6, 48 | MTASTSEnforce: 5, 49 | }, 50 | } 51 | ts := httptest.NewServer( 52 | http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 | enc := json.NewEncoder(w) 54 | enc.Encode(agScans[0]) 55 | enc.Encode(agScans[1]) 56 | }), 57 | ) 58 | defer ts.Close() 59 | os.Setenv("REMOTE_STATS_URL", ts.URL) 60 | store := mockAgScanStore{} 61 | err := Import(&store) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | for i, want := range agScans { 66 | got := store[i] 67 | // Times must be compared with Time.Equal, so we can't reflect.DeepEqual. 68 | if !want.Time.Equal(got.Time) { 69 | t.Errorf("\nExpected\n %v\nGot\n %v", agScans, store) 70 | } 71 | if want.PercentMTASTS() != got.PercentMTASTS() { 72 | t.Errorf("\nExpected\n %v\nGot\n %v", agScans, store) 73 | } 74 | if got.Source != checker.TopDomainsSource { 75 | t.Errorf("Expected source for imported domains to be %s", checker.TopDomainsSource) 76 | } 77 | } 78 | } 79 | 80 | func TestUpdate(t *testing.T) { 81 | store := mockAgScanStore{} 82 | Update(&store) 83 | a := store[0] 84 | // Confirm that date is trucated correctly 85 | if a.Time.Hour() != 0 || a.Time.Minute() != 0 { 86 | t.Errorf("Expected date to be truncated, got %v", a.Time) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // Match domain names according to RFC 1035 13 | // * Neither suffix nor prefix; should not end or start with `.` 14 | const matchDNS = `^([a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62}){1}(\.[a-zA-Z0-9_]{1}[a-zA-Z0-9_-]{0,62})*$` 15 | 16 | // ValidDomainName returns true if given name is a valid FQDN. 17 | func ValidDomainName(s string) bool { 18 | if len(s) < 1 || !strings.Contains(s, ".") { 19 | return false 20 | } 21 | ok, err := regexp.MatchString(matchDNS, s) 22 | if err != nil { 23 | log.Printf("Regex for DNS matching failed with error %v", err) 24 | return false 25 | } 26 | return ok 27 | } 28 | 29 | // ValidPort normalizes a portstring like "80" to ":80". 30 | func ValidPort(port string) (string, error) { 31 | if _, err := strconv.Atoi(port); err != nil { 32 | return "", fmt.Errorf("Given portstring %s is invalid", port) 33 | } 34 | return fmt.Sprintf(":%s", port), nil 35 | } 36 | 37 | // Errors composites multiple errors. 38 | type Errors []error 39 | 40 | // Error composites the messages from all contained errors. 41 | func (e Errors) Error() string { 42 | if len(e) == 1 { 43 | return e[0].Error() 44 | } 45 | msg := "multiple errors:" 46 | for _, err := range e { 47 | msg += "\n" + err.Error() 48 | } 49 | return msg 50 | } 51 | 52 | // Add adds another error to this composite. 53 | func (e Errors) Add(err error) Errors { 54 | if err != nil { 55 | return append(e, err) 56 | } 57 | return e 58 | } 59 | 60 | // RequireEnv retrieves environment variable varName. If not set as env 61 | // variable, panic and exit. 62 | // varName is the OS environment variable name. 63 | // errors is a composite errors object to add to if a variable is not set. 64 | func RequireEnv(varName string, errors *Errors) string { 65 | envVar := os.Getenv(varName) 66 | if len(envVar) == 0 { 67 | *errors = errors.Add(fmt.Errorf("expected environment variable %s to be set", varName)) 68 | } 69 | return envVar 70 | } 71 | -------------------------------------------------------------------------------- /util/util_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "testing" 4 | 5 | func TestInvalidPort(t *testing.T) { 6 | portString, err := ValidPort("8000") 7 | if err != nil { 8 | t.Fatalf("Should not have errored on valid string: %v", err) 9 | } 10 | if portString != ":8000" { 11 | t.Fatalf("Expected portstring be :8000 instead of %s", portString) 12 | } 13 | portString, err = ValidPort("80a") 14 | if err == nil { 15 | t.Fatalf("Expected error on invalid port") 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /validator/validator.go: -------------------------------------------------------------------------------- 1 | package validator 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "time" 7 | 8 | "github.com/EFForg/starttls-backend/checker" 9 | "github.com/getsentry/raven-go" 10 | ) 11 | 12 | // DomainPolicyStore is an interface for any back-end that 13 | // stores a map of domains to its "policy" (in this case, just the 14 | // expected hostnames). 15 | type DomainPolicyStore interface { 16 | DomainsToValidate() ([]string, error) 17 | HostnamesForDomain(string) ([]string, error) 18 | } 19 | 20 | // Called with failure by defaault. 21 | func reportToSentry(name string, domain string, result checker.DomainResult) { 22 | raven.CaptureMessageAndWait("Validation failed for previously validated domain", 23 | map[string]string{ 24 | "validatorName": name, 25 | "domain": result.Domain, 26 | "status": fmt.Sprintf("%d", result.Status), 27 | }, 28 | result) 29 | } 30 | 31 | type checkPerformer func(string, []string) checker.DomainResult 32 | type resultCallback func(string, string, checker.DomainResult) 33 | 34 | // Validator runs checks regularly against domain policies. This structure 35 | // defines the configurations. 36 | type Validator struct { 37 | // Name: Required with which to refer to this validator. Appears in log files and 38 | // error reports. 39 | Name string 40 | // Store: Required-- store from which the validator fetches policies to validate. 41 | Store DomainPolicyStore 42 | // Interval: optional; time at which validator should re-run. 43 | // If not set, default interval is 1 day. 44 | Interval time.Duration 45 | // OnFailure: optional. Called when a particular policy validation fails. Defaults to 46 | // a sentry report. 47 | OnFailure resultCallback 48 | // OnSuccess: optional. Called when a particular policy validation succeeds. 49 | OnSuccess resultCallback 50 | // checkPerformer: performs the check. 51 | checkPerformer checkPerformer 52 | } 53 | 54 | func (v *Validator) checkPolicy(domain string, hostnames []string) checker.DomainResult { 55 | if v.checkPerformer == nil { 56 | c := checker.Checker{ 57 | Cache: checker.MakeSimpleCache(time.Hour), 58 | } 59 | v.checkPerformer = c.CheckDomain 60 | } 61 | return v.checkPerformer(domain, hostnames) 62 | } 63 | 64 | func (v *Validator) interval() time.Duration { 65 | if v.Interval != 0 { 66 | return v.Interval 67 | } 68 | return time.Hour * 24 69 | } 70 | 71 | func (v *Validator) policyFailed(name string, domain string, result checker.DomainResult) { 72 | if v.OnFailure != nil { 73 | v.OnFailure(name, domain, result) 74 | } 75 | reportToSentry(name, domain, result) 76 | } 77 | 78 | func (v *Validator) policyPassed(name string, domain string, result checker.DomainResult) { 79 | if v.OnSuccess != nil { 80 | v.OnSuccess(name, domain, result) 81 | } 82 | } 83 | 84 | // Run starts the endless loop of validations. The first validation happens after the given 85 | // Interval. Validation failures induce `policyFailed`, and successes cause `policyPassed`. 86 | func (v *Validator) Run() { 87 | for { 88 | <-time.After(v.interval()) 89 | log.Printf("[%s validator] starting regular validation", v.Name) 90 | domains, err := v.Store.DomainsToValidate() 91 | if err != nil { 92 | log.Printf("[%s validator] Could not retrieve domains: %v", v.Name, err) 93 | continue 94 | } 95 | for _, domain := range domains { 96 | hostnames, err := v.Store.HostnamesForDomain(domain) 97 | if err != nil { 98 | log.Printf("[%s validator] Could not retrieve policy for domain %s: %v", v.Name, domain, err) 99 | continue 100 | } 101 | result := v.checkPolicy(domain, hostnames) 102 | if result.Status != 0 { 103 | log.Printf("[%s validator] %s failed; sending report", v.Name, domain) 104 | v.policyFailed(v.Name, domain, result) 105 | } else { 106 | v.policyPassed(v.Name, domain, result) 107 | } 108 | } 109 | } 110 | } 111 | 112 | // ValidateRegularly regularly runs checker.CheckDomain against a Domain- 113 | // Hostname map. Interval specifies the interval to wait between each run. 114 | // Failures are reported to Sentry. 115 | func ValidateRegularly(name string, store DomainPolicyStore, interval time.Duration) { 116 | v := Validator{ 117 | Name: name, 118 | Store: store, 119 | Interval: interval, 120 | } 121 | v.Run() 122 | } 123 | -------------------------------------------------------------------------------- /validator/validator_test.go: -------------------------------------------------------------------------------- 1 | package validator 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/EFForg/starttls-backend/checker" 8 | ) 9 | 10 | type mockDomainPolicyStore struct { 11 | hostnames map[string][]string 12 | } 13 | 14 | func (m mockDomainPolicyStore) DomainsToValidate() ([]string, error) { 15 | domains := []string{} 16 | for domain := range m.hostnames { 17 | domains = append(domains, domain) 18 | } 19 | return domains, nil 20 | } 21 | 22 | func (m mockDomainPolicyStore) HostnamesForDomain(domain string) ([]string, error) { 23 | return m.hostnames[domain], nil 24 | } 25 | 26 | func noop(_ string, _ string, _ checker.DomainResult) {} 27 | 28 | func TestRegularValidationValidates(t *testing.T) { 29 | called := make(chan bool) 30 | fakeChecker := func(domain string, hostnames []string) checker.DomainResult { 31 | called <- true 32 | return checker.DomainResult{} 33 | } 34 | mock := mockDomainPolicyStore{ 35 | hostnames: map[string][]string{"a": []string{"hostname"}}} 36 | v := Validator{Store: mock, Interval: 100 * time.Millisecond, checkPerformer: fakeChecker, OnFailure: noop} 37 | go v.Run() 38 | 39 | select { 40 | case <-called: 41 | return 42 | case <-time.After(time.Second): 43 | t.Errorf("Checker wasn't called on hostname!") 44 | } 45 | } 46 | 47 | func TestRegularValidationReportsErrors(t *testing.T) { 48 | reports := make(chan string) 49 | fakeChecker := func(domain string, hostnames []string) checker.DomainResult { 50 | if domain == "fail" || domain == "error" { 51 | return checker.DomainResult{Status: 5} 52 | } 53 | return checker.DomainResult{Status: 0} 54 | } 55 | fakeReporter := func(name string, domain string, result checker.DomainResult) { 56 | reports <- domain 57 | } 58 | successReports := make(chan string) 59 | fakeSuccessReporter := func(name string, domain string, result checker.DomainResult) { 60 | successReports <- domain 61 | } 62 | mock := mockDomainPolicyStore{ 63 | hostnames: map[string][]string{ 64 | "fail": []string{"hostname"}, 65 | "error": []string{"hostname"}, 66 | "normal": []string{"hostname"}}} 67 | v := Validator{Store: mock, Interval: 100 * time.Millisecond, checkPerformer: fakeChecker, 68 | OnFailure: fakeReporter, OnSuccess: fakeSuccessReporter, 69 | } 70 | go v.Run() 71 | recvd := make(map[string]bool) 72 | numRecvd := 0 73 | for numRecvd < 4 { 74 | select { 75 | case report := <-successReports: 76 | if report != "normal" { 77 | t.Errorf("Didn't expect %s to succeed", report) 78 | } 79 | case report := <-reports: 80 | recvd[report] = true 81 | numRecvd++ 82 | case <-time.After(time.Second): 83 | t.Errorf("Timed out waiting for reports") 84 | } 85 | } 86 | if _, ok := recvd["fail"]; !ok { 87 | t.Errorf("Expected fail to be reported") 88 | } 89 | if _, ok := recvd["error"]; !ok { 90 | t.Errorf("Expected error to be reported") 91 | } 92 | if _, ok := recvd["normal"]; ok { 93 | t.Errorf("Didn't expect normal to be reported as failure") 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /views/default.html.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 | STARTTLS Everywhere 4 | 5 | 6 | {{ if ne .StatusCode 200 }} 7 |

{{ .StatusText }}

8 | {{ end }} 9 | 10 | {{ if ne .Message "" }} 11 |

{{ .Message }}

12 | {{ end }} 13 | 14 |

{{ .Response }}

15 | 16 | 17 | -------------------------------------------------------------------------------- /views/scan.html.tmpl: -------------------------------------------------------------------------------- 1 | 2 | 3 |

Scan results for {{ .Response.Domain }}

4 | You're viewing unstyled results. You can enable Javascript to view styled content. 5 | 6 |

Summary

7 | {{ if eq .Response.Data.Status 0 }} 8 |

Congratulations, your domain passed all checks.

9 | {{ else if eq .Response.Data.Status 1 }} 10 |

Your domain passed all checks with some warnings. See below for details.

11 | {{ else }} 12 |

There were some problems with your domain. See below for details.

13 | {{ end }} 14 | 15 |

{{ .Response.Data.Message }}

16 | 17 |

STARTTLS Everywhere Policy List

18 | {{ with index .Response.Data.ExtraResults "policylist" }} 19 | {{ .Description }}: {{ .StatusText }} 20 |
    21 | {{ range $_, $message := .Messages }} 22 |
  • {{ $message }}
  • 23 | {{ end }} 24 |
25 | {{ end }} 26 | {{ if .Response.CanAddToPolicyList }} 27 | Add your email domain the STARTTLS Everywhere Policy List 28 | {{ end }} 29 | 30 |

Mailboxes

31 | {{ range $hostname, $hostnameResult := .Response.Data.HostnameResults }} 32 |

{{ $hostname }}

33 |
    34 | {{ range $_, $r := $hostnameResult.Checks }} 35 |
  • 36 | {{ $r.Description }}: {{ $r.StatusText }} 37 |
      38 | {{ range $_, $message := $r.Messages }} 39 |
    • {{ $message }}
    • 40 | {{ end }} 41 |
    42 |
  • 43 | {{ end }} 44 |
45 | {{ end }} 46 | 47 | 48 | --------------------------------------------------------------------------------