├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .golangci.yml ├── .goreleaser.yml ├── COPYING ├── LICENSE-MIT ├── README.md ├── UNLICENSE ├── bin ├── .go-1.23.4.pkg ├── .golangci-lint-1.61.0.pkg ├── .goreleaser-1.26.2.pkg ├── README.hermit.md ├── activate-hermit ├── go ├── gofmt ├── golangci-lint ├── goreleaser ├── hermit └── hermit.hcl ├── check.go ├── check_test.go ├── cmd └── go-check-sumtype │ └── main.go ├── config.go ├── decl.go ├── def.go ├── doc.go ├── go.mod ├── go.sum ├── help_test.go ├── renovate.json5 ├── run.go ├── scripts └── go-check-sumtype └── testdata └── sum.go /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | name: CI 7 | jobs: 8 | test: 9 | name: Test 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v3 14 | - name: Init Hermit 15 | run: ./bin/hermit env -r >> $GITHUB_ENV 16 | - name: Test 17 | run: go test ./... 18 | lint: 19 | name: Lint 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout code 23 | uses: actions/checkout@v3 24 | - name: Init Hermit 25 | run: ./bin/hermit env -r >> $GITHUB_ENV 26 | - name: golangci-lint 27 | run: golangci-lint run 28 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: 5 | - 'v*' 6 | jobs: 7 | release: 8 | permissions: write-all 9 | name: Release 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | with: 14 | fetch-depth: 0 15 | - run: ./bin/hermit env --raw >> $GITHUB_ENV 16 | - run: goreleaser release 17 | env: 18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | tests: true 3 | 4 | output: 5 | print-issued-lines: false 6 | 7 | linters: 8 | enable-all: true 9 | disable: 10 | - cyclop 11 | - depguard 12 | - dupl 13 | - dupword 14 | - err113 15 | - errorlint 16 | - exhaustive 17 | - exhaustruct 18 | - exportloopref 19 | - forcetypeassert 20 | - funlen 21 | - gci 22 | - gochecknoglobals 23 | - gocognit 24 | - goconst 25 | - gocyclo 26 | - godot 27 | - godox 28 | - gofumpt 29 | - govet 30 | - ireturn 31 | - lll 32 | - maintidx 33 | - mnd 34 | - mnd 35 | - musttag 36 | - nestif 37 | - nilnil 38 | - nlreturn 39 | - nolintlint 40 | - nonamedreturns 41 | - paralleltest 42 | - perfsprint 43 | - predeclared 44 | - revive 45 | - stylecheck 46 | - testableexamples 47 | - testpackage 48 | - thelper 49 | - varnamelen 50 | - wrapcheck 51 | - wsl 52 | 53 | linters-settings: 54 | govet: 55 | enable: 56 | - shadow 57 | gocyclo: 58 | min-complexity: 10 59 | dupl: 60 | threshold: 100 61 | goconst: 62 | min-len: 8 63 | min-occurrences: 3 64 | forbidigo: 65 | exclude-godoc-examples: false 66 | #forbid: 67 | # - (Must)?NewLexer$ 68 | 69 | issues: 70 | max-issues-per-linter: 0 71 | max-same-issues: 0 72 | exclude-use-default: false 73 | exclude-dirs: 74 | - _examples 75 | exclude: 76 | # Captured by errcheck. 77 | - "^(G104|G204):" 78 | # Very commonly not checked. 79 | - 'Error return value of .(.*\.Help|.*\.MarkFlagRequired|(os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Un)?Setenv). is not checked' 80 | - 'exported method (.*\.MarshalJSON|.*\.UnmarshalJSON|.*\.EntityURN|.*\.GoString|.*\.Pos) should have comment or be unexported' 81 | - "composite literal uses unkeyed fields" 82 | - 'declaration of "err" shadows declaration' 83 | - "should not use dot imports" 84 | - "Potential file inclusion via variable" 85 | - "should have comment or be unexported" 86 | - "comment on exported var .* should be of the form" 87 | - "at least one file in a package should have a package comment" 88 | - "string literal contains the Unicode" 89 | - "methods on the same type should have the same receiver name" 90 | - "_TokenType_name should be _TokenTypeName" 91 | - "`_TokenType_map` should be `_TokenTypeMap`" 92 | - "rewrite if-else to switch statement" 93 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | project_name: go-check-sumtype 2 | release: 3 | github: 4 | owner: alecthomas 5 | name: go-check-sumtype 6 | env: 7 | - CGO_ENABLED=0 8 | builds: 9 | - goos: 10 | - linux 11 | - darwin 12 | - windows 13 | goarch: 14 | - arm64 15 | - amd64 16 | - "386" 17 | goarm: 18 | - "6" 19 | main: ./cmd/go-check-sumtype 20 | binary: go-check-sumtype 21 | archives: 22 | - 23 | format: tar.gz 24 | name_template: '{{ .Binary }}-{{ .Version }}-{{ .Os }}-{{ .Arch }}{{ if .Arm }}v{{ 25 | .Arm }}{{ end }}' 26 | files: 27 | - COPYING 28 | - README* 29 | snapshot: 30 | name_template: SNAPSHOT-{{ .Commit }} 31 | checksum: 32 | name_template: '{{ .ProjectName }}-{{ .Version }}-checksums.txt' 33 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | This project is dual-licensed under the Unlicense and MIT licenses. 2 | 3 | You may use this code under the terms of either license. 4 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Andrew Gallant 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Note: This is a fork of the great project [go-sumtype](https://github.com/BurntSushi/go-sumtype) by BurntSushi.** 2 | **The original seems largely unmaintained, and the changes in this fork are backwards incompatible.** 3 | 4 | # go-check-sumtype [![CI](https://github.com/alecthomas/go-check-sumtype/actions/workflows/ci.yml/badge.svg)](https://github.com/alecthomas/go-check-sumtype/actions/workflows/ci.yml) 5 | A simple utility for running exhaustiveness checks on type switch statements. 6 | Exhaustiveness checks are only run on interfaces that are declared to be 7 | "sum types." 8 | 9 | Dual-licensed under MIT or the [UNLICENSE](http://unlicense.org). 10 | 11 | This work was inspired by our code at 12 | [Diffeo](https://diffeo.com). 13 | 14 | ## Installation 15 | 16 | ```go 17 | $ go get github.com/alecthomas/go-check-sumtype 18 | ``` 19 | 20 | For usage info, just run the command: 21 | 22 | ``` 23 | $ go-check-sumtype 24 | ``` 25 | 26 | Typical usage might look like this: 27 | 28 | ``` 29 | $ go-check-sumtype $(go list ./... | grep -v vendor) 30 | ``` 31 | 32 | ## Usage 33 | 34 | `go-check-sumtype` takes a list of Go package paths or files and looks for sum type 35 | declarations in each package/file provided. Exhaustiveness checks are then 36 | performed for each use of a declared sum type in a type switch statement. 37 | Namely, `go-check-sumtype` will report an error for any type switch statement that 38 | either lacks a `default` clause or does not account for all possible variants. 39 | 40 | Declarations are provided in comments like so: 41 | 42 | ``` 43 | //sumtype:decl 44 | type MySumType interface { ... } 45 | ``` 46 | 47 | `MySumType` must be *sealed*. That is, part of its interface definition 48 | contains an unexported method. 49 | 50 | `go-check-sumtype` will produce an error if any of the above is not true. 51 | 52 | For valid declarations, `go-check-sumtype` will look for all occurrences in which a 53 | value of type `MySumType` participates in a type switch statement. In those 54 | occurrences, it will attempt to detect whether the type switch is exhaustive 55 | or not. If it's not, `go-check-sumtype` will report an error. For example, running 56 | `go-check-sumtype` on this source file: 57 | 58 | ```go 59 | package main 60 | 61 | //sumtype:decl 62 | type MySumType interface { 63 | sealed() 64 | } 65 | 66 | type VariantA struct{} 67 | 68 | func (*VariantA) sealed() {} 69 | 70 | type VariantB struct{} 71 | 72 | func (*VariantB) sealed() {} 73 | 74 | func main() { 75 | switch MySumType(nil).(type) { 76 | case *VariantA: 77 | } 78 | } 79 | ``` 80 | 81 | produces the following: 82 | 83 | ``` 84 | $ sumtype mysumtype.go 85 | mysumtype.go:18:2: exhaustiveness check failed for sum type 'MySumType': missing cases for VariantB 86 | ``` 87 | 88 | Adding either a `default` clause or a clause to handle `*VariantB` will cause 89 | exhaustive checks to pass. To prevent `default` clauses from automatically 90 | passing checks, set the `-default-signifies-exhasutive=false` flag. 91 | 92 | As a special case, if the type switch statement contains a `default` clause 93 | that always panics, then exhaustiveness checks are still performed. 94 | 95 | By default, `go-check-sumtype` will not include shared interfaces in the exhaustiviness check. 96 | This can be changed by setting the `-include-shared-interfaces=true` flag. 97 | When this flag is set, `go-check-sumtype` will not require that all concrete structs 98 | are listed in the switch statement, as long as the switch statement is exhaustive 99 | with respect to interfaces the structs implement. 100 | 101 | ## Details and motivation 102 | 103 | Sum types are otherwise known as discriminated unions. That is, a sum type is 104 | a finite set of disjoint values. In type systems that support sum types, the 105 | language will guarantee that if one has a sum type `T`, then its value must 106 | be one of its variants. 107 | 108 | Go's type system does not support sum types. A typical proxy for representing 109 | sum types in Go is to use an interface with an unexported method and define 110 | each variant of the sum type in the same package to satisfy said interface. 111 | This guarantees that the set of types that satisfy the interface is closed 112 | at compile time. Performing case analysis on these types is then done with 113 | a type switch statement, e.g., `switch x.(type) { ... }`. Each clause of the 114 | type switch corresponds to a *variant* of the sum type. The downside of this 115 | approach is that Go's type system is not aware of the set of variants, so it 116 | cannot tell you whether case analysis over a sum type is complete or not. 117 | 118 | The `go-check-sumtype` command recognizes this pattern, but it needs a small amount 119 | of help to recognize which interfaces should be treated as sum types, which 120 | is why the `//sumtype:decl` annotation is required. `go-check-sumtype` will 121 | figure out all of the variants of a sum type by finding the set of types 122 | defined in the same package that satisfy the interface specified by the 123 | declaration. 124 | 125 | The `go-check-sumtype` command will prove its worth when you need to add a variant 126 | to an existing sum type. Running `go-check-sumtype` will tell you immediately which 127 | case analyses need to be updated to account for the new variant. 128 | -------------------------------------------------------------------------------- /UNLICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /bin/.go-1.23.4.pkg: -------------------------------------------------------------------------------- 1 | hermit -------------------------------------------------------------------------------- /bin/.golangci-lint-1.61.0.pkg: -------------------------------------------------------------------------------- 1 | hermit -------------------------------------------------------------------------------- /bin/.goreleaser-1.26.2.pkg: -------------------------------------------------------------------------------- 1 | hermit -------------------------------------------------------------------------------- /bin/README.hermit.md: -------------------------------------------------------------------------------- 1 | # Hermit environment 2 | 3 | This is a [Hermit](https://github.com/cashapp/hermit) bin directory. 4 | 5 | The symlinks in this directory are managed by Hermit and will automatically 6 | download and install Hermit itself as well as packages. These packages are 7 | local to this environment. 8 | -------------------------------------------------------------------------------- /bin/activate-hermit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file must be used with "source bin/activate-hermit" from bash or zsh. 3 | # You cannot run it directly 4 | # 5 | # THIS FILE IS GENERATED; DO NOT MODIFY 6 | 7 | if [ "${BASH_SOURCE-}" = "$0" ]; then 8 | echo "You must source this script: \$ source $0" >&2 9 | exit 33 10 | fi 11 | 12 | BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")" 13 | if "${BIN_DIR}/hermit" noop > /dev/null; then 14 | eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")" 15 | 16 | if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then 17 | hash -r 2>/dev/null 18 | fi 19 | 20 | echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated" 21 | fi 22 | -------------------------------------------------------------------------------- /bin/go: -------------------------------------------------------------------------------- 1 | .go-1.23.4.pkg -------------------------------------------------------------------------------- /bin/gofmt: -------------------------------------------------------------------------------- 1 | .go-1.23.4.pkg -------------------------------------------------------------------------------- /bin/golangci-lint: -------------------------------------------------------------------------------- 1 | .golangci-lint-1.61.0.pkg -------------------------------------------------------------------------------- /bin/goreleaser: -------------------------------------------------------------------------------- 1 | .goreleaser-1.26.2.pkg -------------------------------------------------------------------------------- /bin/hermit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # THIS FILE IS GENERATED; DO NOT MODIFY 4 | 5 | set -eo pipefail 6 | 7 | export HERMIT_USER_HOME=~ 8 | 9 | if [ -z "${HERMIT_STATE_DIR}" ]; then 10 | case "$(uname -s)" in 11 | Darwin) 12 | export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit" 13 | ;; 14 | Linux) 15 | export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit" 16 | ;; 17 | esac 18 | fi 19 | 20 | export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}" 21 | HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")" 22 | export HERMIT_CHANNEL 23 | export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit} 24 | 25 | if [ ! -x "${HERMIT_EXE}" ]; then 26 | echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2 27 | INSTALL_SCRIPT="$(mktemp)" 28 | # This value must match that of the install script 29 | INSTALL_SCRIPT_SHA256="180e997dd837f839a3072a5e2f558619b6d12555cd5452d3ab19d87720704e38" 30 | if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then 31 | curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}" 32 | else 33 | # Install script is versioned by its sha256sum value 34 | curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}" 35 | # Verify install script's sha256sum 36 | openssl dgst -sha256 "${INSTALL_SCRIPT}" | \ 37 | awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \ 38 | '$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}' 39 | fi 40 | /bin/bash "${INSTALL_SCRIPT}" 1>&2 41 | fi 42 | 43 | exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@" 44 | -------------------------------------------------------------------------------- /bin/hermit.hcl: -------------------------------------------------------------------------------- 1 | env = { 2 | "PATH": "${HERMIT_ENV}/scripts:${PATH}", 3 | } 4 | -------------------------------------------------------------------------------- /check.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | "go/types" 8 | "sort" 9 | "strings" 10 | 11 | "golang.org/x/tools/go/packages" 12 | ) 13 | 14 | // inexhaustiveError is returned from check for each occurrence of inexhaustive 15 | // case analysis in a Go type switch statement. 16 | type inexhaustiveError struct { 17 | Position token.Position 18 | Def sumTypeDef 19 | Missing []types.Object 20 | } 21 | 22 | func (e inexhaustiveError) Pos() token.Position { return e.Position } 23 | func (e inexhaustiveError) Error() string { 24 | return fmt.Sprintf( 25 | "%s: exhaustiveness check failed for sum type %q (from %s): missing cases for %s", 26 | e.Pos(), e.Def.Decl.TypeName, e.Def.Decl.Pos, strings.Join(e.Names(), ", ")) 27 | } 28 | 29 | // Names returns a sorted list of names corresponding to the missing variant 30 | // cases. 31 | func (e inexhaustiveError) Names() []string { 32 | list := make([]string, 0, len(e.Missing)) 33 | for _, o := range e.Missing { 34 | list = append(list, o.Name()) 35 | } 36 | sort.Strings(list) 37 | return list 38 | } 39 | 40 | // check does exhaustiveness checking for the given sum type definitions in the 41 | // given package. Every instance of inexhaustive case analysis is returned. 42 | func check(pkg *packages.Package, defs []sumTypeDef, config Config) []error { 43 | var errs []error 44 | for _, astfile := range pkg.Syntax { 45 | ast.Inspect(astfile, func(n ast.Node) bool { 46 | swtch, ok := n.(*ast.TypeSwitchStmt) 47 | if !ok { 48 | return true 49 | } 50 | if err := checkSwitch(pkg, defs, swtch, config); err != nil { 51 | errs = append(errs, err) 52 | } 53 | return true 54 | }) 55 | } 56 | return errs 57 | } 58 | 59 | // checkSwitch performs an exhaustiveness check on the given type switch 60 | // statement. If the type switch is used on a sum type and does not cover 61 | // all variants of that sum type, then an error is returned indicating which 62 | // variants were missed. 63 | // 64 | // Note that if the type switch contains a non-panicing default case, then 65 | // exhaustiveness checks are disabled. 66 | func checkSwitch( 67 | pkg *packages.Package, 68 | defs []sumTypeDef, 69 | swtch *ast.TypeSwitchStmt, 70 | config Config, 71 | ) error { 72 | def, missing := missingVariantsInSwitch(pkg, defs, swtch, config) 73 | if len(missing) > 0 { 74 | return inexhaustiveError{ 75 | Position: pkg.Fset.Position(swtch.Pos()), 76 | Def: *def, 77 | Missing: missing, 78 | } 79 | } 80 | return nil 81 | } 82 | 83 | // missingVariantsInSwitch returns a list of missing variants corresponding to 84 | // the given switch statement. The corresponding sum type definition is also 85 | // returned. (If no sum type definition could be found, then no exhaustiveness 86 | // checks are performed, and therefore, no missing variants are returned.) 87 | func missingVariantsInSwitch( 88 | pkg *packages.Package, 89 | defs []sumTypeDef, 90 | swtch *ast.TypeSwitchStmt, 91 | config Config, 92 | ) (*sumTypeDef, []types.Object) { 93 | asserted := findTypeAssertExpr(swtch) 94 | ty := pkg.TypesInfo.TypeOf(asserted) 95 | if ty == nil { 96 | panic(fmt.Sprintf("no type found for asserted expression: %v", asserted)) 97 | } 98 | 99 | def := findDef(defs, ty) 100 | if def == nil { 101 | // We couldn't find a corresponding sum type, so there's 102 | // nothing we can do to check it. 103 | return nil, nil 104 | } 105 | variantExprs, hasDefault := switchVariants(swtch) 106 | if config.DefaultSignifiesExhaustive && hasDefault && !defaultClauseAlwaysPanics(swtch) { 107 | // A catch-all case defeats all exhaustiveness checks. 108 | return def, nil 109 | } 110 | variantTypes := make([]types.Type, 0, len(variantExprs)) 111 | for _, expr := range variantExprs { 112 | variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr)) 113 | } 114 | return def, def.missing(variantTypes, config.IncludeSharedInterfaces) 115 | } 116 | 117 | // switchVariants returns all case expressions found in a type switch. This 118 | // includes expressions from cases that have a list of expressions. 119 | func switchVariants(swtch *ast.TypeSwitchStmt) (exprs []ast.Expr, hasDefault bool) { 120 | for _, stmt := range swtch.Body.List { 121 | clause := stmt.(*ast.CaseClause) 122 | if clause.List == nil { 123 | hasDefault = true 124 | } else { 125 | exprs = append(exprs, clause.List...) 126 | } 127 | } 128 | return 129 | } 130 | 131 | // defaultClauseAlwaysPanics returns true if the given switch statement has a 132 | // default clause that always panics. Note that this is done on a best-effort 133 | // basis. While there will never be any false positives, there may be false 134 | // negatives. 135 | // 136 | // If the given switch statement has no default clause, then this function 137 | // panics. 138 | func defaultClauseAlwaysPanics(swtch *ast.TypeSwitchStmt) bool { 139 | var clause *ast.CaseClause 140 | for _, stmt := range swtch.Body.List { 141 | c := stmt.(*ast.CaseClause) 142 | if c.List == nil { 143 | clause = c 144 | break 145 | } 146 | } 147 | if clause == nil { 148 | panic("switch statement has no default clause") 149 | } 150 | if len(clause.Body) != 1 { 151 | return false 152 | } 153 | exprStmt, ok := clause.Body[0].(*ast.ExprStmt) 154 | if !ok { 155 | return false 156 | } 157 | callExpr, ok := exprStmt.X.(*ast.CallExpr) 158 | if !ok { 159 | return false 160 | } 161 | fun, ok := callExpr.Fun.(*ast.Ident) 162 | if !ok { 163 | return false 164 | } 165 | return fun.Name == "panic" 166 | } 167 | 168 | // findTypeAssertExpr extracts the expression that is being type asserted from a 169 | // type swtich statement. 170 | func findTypeAssertExpr(swtch *ast.TypeSwitchStmt) ast.Expr { 171 | var expr ast.Expr 172 | if assign, ok := swtch.Assign.(*ast.AssignStmt); ok { 173 | expr = assign.Rhs[0] 174 | } else { 175 | expr = swtch.Assign.(*ast.ExprStmt).X 176 | } 177 | return expr.(*ast.TypeAssertExpr).X 178 | } 179 | 180 | // findDef returns the sum type definition corresponding to the given type. If 181 | // no such sum type definition exists, then nil is returned. 182 | func findDef(defs []sumTypeDef, needle types.Type) *sumTypeDef { 183 | for i := range defs { 184 | def := &defs[i] 185 | if types.Identical(needle.Underlying(), def.Ty) { 186 | return def 187 | } 188 | } 189 | return nil 190 | } 191 | -------------------------------------------------------------------------------- /check_test.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/alecthomas/assert/v2" 7 | ) 8 | 9 | // TestMissingOne tests that we detect a single missing variant. 10 | func TestMissingOne(t *testing.T) { 11 | code := ` 12 | package gochecksumtype 13 | 14 | //sumtype:decl 15 | type T interface { sealed() } 16 | 17 | type A struct {} 18 | func (a *A) sealed() {} 19 | 20 | type B struct {} 21 | func (b *B) sealed() {} 22 | 23 | func main() { 24 | switch T(nil).(type) { 25 | case *A: 26 | } 27 | } 28 | ` 29 | pkgs := setupPackages(t, code) 30 | 31 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 32 | assert.Equal(t, 1, len(errs)) 33 | assert.Equal(t, []string{"B"}, missingNames(t, errs[0])) 34 | } 35 | 36 | // TestMissingTwo tests that we detect two missing variants. 37 | func TestMissingTwo(t *testing.T) { 38 | code := ` 39 | package gochecksumtype 40 | 41 | //sumtype:decl 42 | type T interface { sealed() } 43 | 44 | type A struct {} 45 | func (a *A) sealed() {} 46 | 47 | type B struct {} 48 | func (b *B) sealed() {} 49 | 50 | type C struct {} 51 | func (c *C) sealed() {} 52 | 53 | func main() { 54 | switch T(nil).(type) { 55 | case *A: 56 | } 57 | } 58 | ` 59 | pkgs := setupPackages(t, code) 60 | 61 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 62 | assert.Equal(t, 1, len(errs)) 63 | assert.Equal(t, []string{"B", "C"}, missingNames(t, errs[0])) 64 | } 65 | 66 | // TestMissingOneWithPanic tests that we detect a single missing variant even 67 | // if we have a trivial default case that panics. 68 | func TestMissingOneWithPanic(t *testing.T) { 69 | code := ` 70 | package gochecksumtype 71 | 72 | //sumtype:decl 73 | type T interface { sealed() } 74 | 75 | type A struct {} 76 | func (a *A) sealed() {} 77 | 78 | type B struct {} 79 | func (b *B) sealed() {} 80 | 81 | func main() { 82 | switch T(nil).(type) { 83 | case *A: 84 | default: 85 | panic("unreachable") 86 | } 87 | } 88 | ` 89 | pkgs := setupPackages(t, code) 90 | 91 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 92 | assert.Equal(t, 1, len(errs)) 93 | assert.Equal(t, []string{"B"}, missingNames(t, errs[0])) 94 | } 95 | 96 | // TestNoMissing tests that we correctly detect exhaustive case analysis. 97 | func TestNoMissing(t *testing.T) { 98 | code := ` 99 | package gochecksumtype 100 | 101 | //sumtype:decl 102 | type T interface { sealed() } 103 | 104 | type A struct {} 105 | func (a *A) sealed() {} 106 | 107 | type B struct {} 108 | func (b *B) sealed() {} 109 | 110 | type C struct {} 111 | func (c *C) sealed() {} 112 | 113 | func main() { 114 | switch T(nil).(type) { 115 | case *A, *B, *C: 116 | } 117 | } 118 | ` 119 | pkgs := setupPackages(t, code) 120 | 121 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 122 | assert.Equal(t, 0, len(errs)) 123 | } 124 | 125 | // TestNoMissingDefaultWithDefaultSignifiesExhaustive tests that even if we have a missing variant, a default 126 | // case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is true. 127 | func TestNoMissingDefaultWithDefaultSignifiesExhaustive(t *testing.T) { 128 | code := ` 129 | package gochecksumtype 130 | 131 | //sumtype:decl 132 | type T interface { sealed() } 133 | 134 | type A struct {} 135 | func (a *A) sealed() {} 136 | 137 | type B struct {} 138 | func (b *B) sealed() {} 139 | 140 | func main() { 141 | switch T(nil).(type) { 142 | case *A: 143 | default: 144 | println("legit catch all goes here") 145 | } 146 | } 147 | ` 148 | pkgs := setupPackages(t, code) 149 | 150 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 151 | assert.Equal(t, 0, len(errs)) 152 | } 153 | 154 | // TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive tests that even if we have a missing variant, a default 155 | // case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is false. 156 | func TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive(t *testing.T) { 157 | code := ` 158 | package gochecksumtype 159 | 160 | //sumtype:decl 161 | type T interface { sealed() } 162 | 163 | type A struct {} 164 | func (a *A) sealed() {} 165 | 166 | type B struct {} 167 | func (b *B) sealed() {} 168 | 169 | func main() { 170 | switch T(nil).(type) { 171 | case *A: 172 | default: 173 | println("legit catch all goes here") 174 | } 175 | } 176 | ` 177 | pkgs := setupPackages(t, code) 178 | 179 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: false}) 180 | assert.Equal(t, 1, len(errs)) 181 | assert.Equal(t, []string{"B"}, missingNames(t, errs[0])) 182 | } 183 | 184 | // TestNotSealed tests that we report an error if one tries to declare a sum 185 | // type with an unsealed interface. 186 | func TestNotSealed(t *testing.T) { 187 | code := ` 188 | package gochecksumtype 189 | 190 | //sumtype:decl 191 | type T interface {} 192 | 193 | func main() {} 194 | ` 195 | pkgs := setupPackages(t, code) 196 | 197 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 198 | assert.Equal(t, 1, len(errs)) 199 | assert.Equal(t, "T", errs[0].(unsealedError).Decl.TypeName) 200 | } 201 | 202 | // TestNotInterface tests that we report an error if one tries to declare a sum 203 | // type that doesn't correspond to an interface. 204 | func TestNotInterface(t *testing.T) { 205 | code := ` 206 | package gochecksumtype 207 | 208 | //sumtype:decl 209 | type T struct {} 210 | 211 | func main() {} 212 | ` 213 | pkgs := setupPackages(t, code) 214 | 215 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true}) 216 | assert.Equal(t, 1, len(errs)) 217 | assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName) 218 | } 219 | 220 | // TestSubTypeInSwitch tests that if a shared interface is declared in the switch 221 | // statement, we don't report an error if structs that implement the interface are not explicitly 222 | // declared in the switch statement. 223 | func TestSubTypeInSwitch(t *testing.T) { 224 | code := ` 225 | package gochecksumtype 226 | 227 | //sumtype:decl 228 | type T1 interface { sealed1() } 229 | type T2 interface { 230 | T1 231 | sealed2() 232 | } 233 | 234 | 235 | type A struct {} 236 | func (a *A) sealed1() {} 237 | 238 | type B struct {} 239 | func (b *B) sealed1() {} 240 | func (b *B) sealed2() {} 241 | 242 | type C struct {} 243 | func (c *C) sealed1() {} 244 | func (c *C) sealed2() {} 245 | 246 | func main() { 247 | switch T1(nil).(type) { 248 | case *A: 249 | case T2: 250 | } 251 | } 252 | ` 253 | pkgs := setupPackages(t, code) 254 | 255 | errs := Run(pkgs, Config{IncludeSharedInterfaces: true}) 256 | assert.Equal(t, 0, len(errs)) 257 | } 258 | 259 | // TestAllLeavesInSwitch tests that we do not report an error if a switch statement 260 | // covers all leaves of the sum type, even if any SubTypes are not explicitly covered 261 | func TestAllLeavesInSwitch(t *testing.T) { 262 | code := ` 263 | package gochecksumtype 264 | 265 | //sumtype:decl 266 | type T1 interface { sealed1() } 267 | type T2 interface { 268 | T1 269 | sealed2() 270 | } 271 | 272 | 273 | type A struct {} 274 | func (a *A) sealed1() {} 275 | 276 | type B struct {} 277 | func (b *B) sealed1() {} 278 | func (b *B) sealed2() {} 279 | 280 | type C struct {} 281 | func (c *C) sealed1() {} 282 | func (c *C) sealed2() {} 283 | 284 | func main() { 285 | switch T1(nil).(type) { 286 | case *A: 287 | case *B: 288 | case *C: 289 | } 290 | } 291 | ` 292 | pkgs := setupPackages(t, code) 293 | 294 | errs := Run(pkgs, Config{}) 295 | assert.Equal(t, 0, len(errs)) 296 | } 297 | 298 | func missingNames(t *testing.T, err error) []string { 299 | t.Helper() 300 | ierr, ok := err.(inexhaustiveError) 301 | assert.True(t, ok, "error was not inexhaustiveError: %T", err) 302 | return ierr.Names() 303 | } 304 | -------------------------------------------------------------------------------- /cmd/go-check-sumtype/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | "strings" 8 | 9 | gochecksumtype "github.com/alecthomas/go-check-sumtype" 10 | "golang.org/x/tools/go/packages" 11 | ) 12 | 13 | func main() { 14 | log.SetFlags(0) 15 | 16 | defaultSignifiesExhaustive := flag.Bool( 17 | "default-signifies-exhaustive", 18 | true, 19 | "Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.", 20 | ) 21 | 22 | includeSharedInterfaces := flag.Bool( 23 | "include-shared-interfaces", 24 | false, 25 | "Include shared interfaces in the exhaustiviness check.", 26 | ) 27 | 28 | flag.Parse() 29 | if flag.NArg() < 1 { 30 | log.Fatalf("Usage: sumtype \n") 31 | } 32 | args := os.Args[flag.NFlag()+1:] 33 | 34 | config := gochecksumtype.Config{ 35 | DefaultSignifiesExhaustive: *defaultSignifiesExhaustive, 36 | IncludeSharedInterfaces: *includeSharedInterfaces, 37 | } 38 | 39 | conf := &packages.Config{ 40 | Mode: packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypes | packages.NeedTypesSizes | 41 | packages.NeedImports | packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles, 42 | // Unfortunately, it appears including the test packages in 43 | // this lint makes it difficult to do exhaustiveness checking. 44 | // Namely, it appears that compiling the test version of a 45 | // package introduces distinct types from the normal version 46 | // of the package, which will always result in inexhaustive 47 | // errors whenever a package both defines a sum type and has 48 | // tests. (Specifically, using `package name`. Using `package 49 | // name_test` is OK.) 50 | // 51 | // It's not clear what the best way to fix this is. :-( 52 | Tests: false, 53 | } 54 | pkgs, err := packages.Load(conf, args...) 55 | if err != nil { 56 | log.Fatal(err) 57 | } 58 | if errs := gochecksumtype.Run(pkgs, config); len(errs) > 0 { 59 | var list []string 60 | for _, err := range errs { 61 | list = append(list, err.Error()) 62 | } 63 | log.Fatal(strings.Join(list, "\n")) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | type Config struct { 4 | DefaultSignifiesExhaustive bool 5 | // IncludeSharedInterfaces in the exhaustiviness check. If true, we do not need to list all concrete structs, as long 6 | // as the switch statement is exhaustive with respect to interfaces the structs implement. 7 | IncludeSharedInterfaces bool 8 | } 9 | -------------------------------------------------------------------------------- /decl.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | import ( 4 | "go/ast" 5 | "go/token" 6 | "strings" 7 | 8 | "golang.org/x/tools/go/packages" 9 | ) 10 | 11 | // sumTypeDecl is a declaration of a sum type in a Go source file. 12 | type sumTypeDecl struct { 13 | // The package path that contains this decl. 14 | Package *packages.Package 15 | // The type named by this decl. 16 | TypeName string 17 | // Position where the declaration was found. 18 | Pos token.Position 19 | } 20 | 21 | // Location returns a short string describing where this declaration was found. 22 | func (d sumTypeDecl) Location() string { 23 | return d.Pos.String() 24 | } 25 | 26 | // findSumTypeDecls searches every package given for sum type declarations of 27 | // the form `sumtype:decl`. 28 | func findSumTypeDecls(pkgs []*packages.Package) ([]sumTypeDecl, error) { 29 | var decls []sumTypeDecl 30 | var retErr error 31 | for _, pkg := range pkgs { 32 | for _, file := range pkg.Syntax { 33 | ast.Inspect(file, func(node ast.Node) bool { 34 | if node == nil { 35 | return true 36 | } 37 | decl, ok := node.(*ast.GenDecl) 38 | if !ok || decl.Doc == nil { 39 | return true 40 | } 41 | var tspec *ast.TypeSpec 42 | for _, spec := range decl.Specs { 43 | ts, ok := spec.(*ast.TypeSpec) 44 | if !ok { 45 | continue 46 | } 47 | tspec = ts 48 | } 49 | for _, line := range decl.Doc.List { 50 | if !strings.HasPrefix(line.Text, "//sumtype:decl") { 51 | continue 52 | } 53 | pos := pkg.Fset.Position(decl.Pos()) 54 | if tspec == nil { 55 | retErr = notFoundError{Decl: sumTypeDecl{Package: pkg, Pos: pos}} 56 | return false 57 | } 58 | pos = pkg.Fset.Position(tspec.Pos()) 59 | decl := sumTypeDecl{Package: pkg, TypeName: tspec.Name.Name, Pos: pos} 60 | debugf("found sum type decl: %s.%s", decl.Package.PkgPath, decl.TypeName) 61 | decls = append(decls, decl) 62 | break 63 | } 64 | return true 65 | }) 66 | } 67 | } 68 | return decls, retErr 69 | } 70 | -------------------------------------------------------------------------------- /def.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "go/token" 7 | "go/types" 8 | "log" 9 | ) 10 | 11 | var debug = flag.Bool("debug", false, "enable debug logging") 12 | 13 | func debugf(format string, args ...interface{}) { 14 | if *debug { 15 | log.Printf(format, args...) 16 | } 17 | } 18 | 19 | // Error as returned by Run() 20 | type Error interface { 21 | error 22 | Pos() token.Position 23 | } 24 | 25 | // unsealedError corresponds to a declared sum type whose interface is not 26 | // sealed. A sealed interface requires at least one unexported method. 27 | type unsealedError struct { 28 | Decl sumTypeDecl 29 | } 30 | 31 | func (e unsealedError) Pos() token.Position { return e.Decl.Pos } 32 | func (e unsealedError) Error() string { 33 | return fmt.Sprintf( 34 | "%s: interface '%s' is not sealed "+ 35 | "(sealing requires at least one unexported method)", 36 | e.Decl.Location(), e.Decl.TypeName) 37 | } 38 | 39 | // notFoundError corresponds to a declared sum type whose type definition 40 | // could not be found in the same Go package. 41 | type notFoundError struct { 42 | Decl sumTypeDecl 43 | } 44 | 45 | func (e notFoundError) Pos() token.Position { return e.Decl.Pos } 46 | func (e notFoundError) Error() string { 47 | return fmt.Sprintf("%s: type '%s' is not defined", e.Decl.Location(), e.Decl.TypeName) 48 | } 49 | 50 | // notInterfaceError corresponds to a declared sum type that does not 51 | // correspond to an interface. 52 | type notInterfaceError struct { 53 | Decl sumTypeDecl 54 | } 55 | 56 | func (e notInterfaceError) Pos() token.Position { return e.Decl.Pos } 57 | func (e notInterfaceError) Error() string { 58 | return fmt.Sprintf("%s: type '%s' is not an interface", e.Decl.Location(), e.Decl.TypeName) 59 | } 60 | 61 | // sumTypeDef corresponds to the definition of a Go interface that is 62 | // interpreted as a sum type. Its variants are determined by finding all types 63 | // that implement said interface in the same package. 64 | type sumTypeDef struct { 65 | Decl sumTypeDecl 66 | Ty *types.Interface 67 | Variants []types.Object 68 | } 69 | 70 | // findSumTypeDefs attempts to find a Go type definition for each of the given 71 | // sum type declarations. If no such sum type definition could be found for 72 | // any of the given declarations, then an error is returned. 73 | func findSumTypeDefs(decls []sumTypeDecl) ([]sumTypeDef, []error) { 74 | defs := make([]sumTypeDef, 0, len(decls)) 75 | var errs []error 76 | for _, decl := range decls { 77 | def, err := newSumTypeDef(decl.Package.Types, decl) 78 | if err != nil { 79 | errs = append(errs, err) 80 | continue 81 | } 82 | if def == nil { 83 | errs = append(errs, notFoundError{decl}) 84 | continue 85 | } 86 | defs = append(defs, *def) 87 | } 88 | return defs, errs 89 | } 90 | 91 | // newSumTypeDef attempts to extract a sum type definition from a single 92 | // package. If no such type corresponds to the given decl, then this function 93 | // returns a nil def and a nil error. 94 | // 95 | // If the decl corresponds to a type that isn't an interface containing at 96 | // least one unexported method, then this returns an error. 97 | func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) { 98 | obj := pkg.Scope().Lookup(decl.TypeName) 99 | if obj == nil { 100 | return nil, nil 101 | } 102 | iface, ok := obj.Type().Underlying().(*types.Interface) 103 | if !ok { 104 | return nil, notInterfaceError{decl} 105 | } 106 | hasUnexported := false 107 | for i := range iface.NumMethods() { 108 | if !iface.Method(i).Exported() { 109 | hasUnexported = true 110 | break 111 | } 112 | } 113 | if !hasUnexported { 114 | return nil, unsealedError{decl} 115 | } 116 | def := &sumTypeDef{ 117 | Decl: decl, 118 | Ty: iface, 119 | } 120 | debugf("searching for variants of %s.%s\n", pkg.Path(), decl.TypeName) 121 | for _, name := range pkg.Scope().Names() { 122 | obj, ok := pkg.Scope().Lookup(name).(*types.TypeName) 123 | if !ok { 124 | continue 125 | } 126 | ty := obj.Type() 127 | if types.Identical(ty.Underlying(), iface) { 128 | continue 129 | } 130 | // Skip generic types. 131 | if named, ok := ty.(*types.Named); ok && named.TypeParams() != nil { 132 | continue 133 | } 134 | if types.Implements(ty, iface) || types.Implements(types.NewPointer(ty), iface) { 135 | debugf(" found variant: %s.%s\n", pkg.Path(), obj.Name()) 136 | def.Variants = append(def.Variants, obj) 137 | } 138 | } 139 | return def, nil 140 | } 141 | 142 | func (def *sumTypeDef) String() string { 143 | return def.Decl.TypeName 144 | } 145 | 146 | // missing returns a list of variants in this sum type that are not in the 147 | // given list of types. 148 | func (def *sumTypeDef) missing(tys []types.Type, includeSharedInterfaces bool) []types.Object { 149 | // TODO(ag): This is O(n^2). Fix that. /shrug 150 | var missing []types.Object 151 | for _, v := range def.Variants { 152 | found := false 153 | varty := indirect(v.Type()) 154 | for _, ty := range tys { 155 | ty = indirect(ty) 156 | if types.Identical(varty, ty) { 157 | found = true 158 | break 159 | } 160 | if includeSharedInterfaces && implements(varty, ty) { 161 | found = true 162 | break 163 | } 164 | } 165 | if !found && !isInterface(varty) { 166 | // we do not include interfaces extending the sumtype, as the 167 | // all implementations of those interfaces are already covered 168 | // by the sumtype. 169 | missing = append(missing, v) 170 | } 171 | } 172 | return missing 173 | } 174 | 175 | func isInterface(ty types.Type) bool { 176 | underlying := indirect(ty).Underlying() 177 | _, ok := underlying.(*types.Interface) 178 | return ok 179 | } 180 | 181 | // indirect dereferences through an arbitrary number of pointer types. 182 | func indirect(ty types.Type) types.Type { 183 | if ty, ok := ty.(*types.Pointer); ok { 184 | return indirect(ty.Elem()) 185 | } 186 | return ty 187 | } 188 | 189 | func implements(varty, interfaceType types.Type) bool { 190 | underlying := interfaceType.Underlying() 191 | if interf, ok := underlying.(*types.Interface); ok { 192 | return types.Implements(varty, interf) || types.Implements(types.NewPointer(varty), interf) 193 | } 194 | return false 195 | } 196 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | sumtype takes a list of Go package paths or files and looks for sum type 3 | declarations in each package/file provided. Exhaustiveness checks are then 4 | performed for each use of a declared sum type in a type switch statement. 5 | Namely, sumtype will report an error for any type switch statement that 6 | either lacks a default clause or does not account for all possible variants. 7 | 8 | Declarations are provided in comments like so: 9 | 10 | //sumtype:decl 11 | type MySumType interface { ... } 12 | 13 | MySumType must be *sealed*. That is, part of its interface definition contains 14 | an unexported method. 15 | 16 | sumtype will produce an error if any of the above is not true. 17 | 18 | For valid declarations, sumtype will look for all occurrences in which a 19 | value of type MySumType participates in a type switch statement. In those 20 | occurrences, it will attempt to detect whether the type switch is exhaustive 21 | or not. If it's not, sumtype will report an error. For example: 22 | 23 | $ cat mysumtype.go 24 | package gochecksumtype 25 | 26 | //sumtype:decl 27 | type MySumType interface { 28 | sealed() 29 | } 30 | 31 | type VariantA struct{} 32 | 33 | func (a *VariantA) sealed() {} 34 | 35 | type VariantB struct{} 36 | 37 | func (b *VariantB) sealed() {} 38 | 39 | func main() { 40 | switch MySumType(nil).(type) { 41 | case *VariantA: 42 | } 43 | } 44 | $ sumtype mysumtype.go 45 | mysumtype.go:18:2: exhaustiveness check failed for sum type 'MySumType': missing cases for VariantB 46 | 47 | Adding either a default clause or a clause to handle *VariantB will cause 48 | exhaustive checks to pass. 49 | 50 | As a special case, if the type switch statement contains a default clause 51 | that always panics, then exhaustiveness checks are still performed. 52 | */ 53 | package gochecksumtype 54 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/alecthomas/go-check-sumtype 2 | 3 | go 1.22.0 4 | 5 | require ( 6 | github.com/alecthomas/assert/v2 v2.11.0 7 | golang.org/x/tools v0.28.0 8 | ) 9 | 10 | require ( 11 | github.com/alecthomas/repr v0.4.0 // indirect 12 | github.com/hexops/gotextdiff v1.0.3 // indirect 13 | golang.org/x/mod v0.22.0 // indirect 14 | golang.org/x/sync v0.10.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= 2 | github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= 3 | github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= 4 | github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= 5 | github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= 6 | github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= 7 | golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= 8 | golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 9 | golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= 10 | golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 11 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 12 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 13 | golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o= 14 | golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= 15 | golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= 16 | golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= 17 | -------------------------------------------------------------------------------- /help_test.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "golang.org/x/tools/go/packages" 9 | ) 10 | 11 | func setupPackages(t *testing.T, code string) []*packages.Package { 12 | srcPath := filepath.Join(t.TempDir(), "src.go") 13 | if err := os.WriteFile(srcPath, []byte(code), 0600); err != nil { 14 | t.Fatal(err) 15 | } 16 | pkgs, err := tycheckAll([]string{srcPath}) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | return pkgs 21 | } 22 | 23 | func tycheckAll(args []string) ([]*packages.Package, error) { 24 | conf := &packages.Config{ 25 | Mode: packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypes | packages.NeedTypesSizes | 26 | packages.NeedImports | packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles, 27 | // Unfortunately, it appears including the test packages in 28 | // this lint makes it difficult to do exhaustiveness checking. 29 | // Namely, it appears that compiling the test version of a 30 | // package introduces distinct types from the normal version 31 | // of the package, which will always result in inexhaustive 32 | // errors whenever a package both defines a sum type and has 33 | // tests. (Specifically, using `package name`. Using `package 34 | // name_test` is OK.) 35 | // 36 | // It's not clear what the best way to fix this is. :-( 37 | Tests: false, 38 | } 39 | pkgs, err := packages.Load(conf, args...) 40 | if err != nil { 41 | return nil, err 42 | } 43 | return pkgs, nil 44 | } 45 | -------------------------------------------------------------------------------- /renovate.json5: -------------------------------------------------------------------------------- 1 | { 2 | $schema: "https://docs.renovatebot.com/renovate-schema.json", 3 | extends: [ 4 | "config:recommended", 5 | ":semanticCommits", 6 | ":semanticCommitTypeAll(chore)", 7 | ":semanticCommitScope(deps)", 8 | "group:allNonMajor", 9 | "schedule:earlyMondays", // Run once a week. 10 | ], 11 | packageRules: [ 12 | { 13 | matchPackageNames: ["golangci-lint"], 14 | matchManagers: ["hermit"], 15 | enabled: false, 16 | }, 17 | ], 18 | } 19 | -------------------------------------------------------------------------------- /run.go: -------------------------------------------------------------------------------- 1 | package gochecksumtype 2 | 3 | import "golang.org/x/tools/go/packages" 4 | 5 | // Run sumtype checking on the given packages. 6 | func Run(pkgs []*packages.Package, config Config) []error { 7 | var errs []error 8 | 9 | decls, err := findSumTypeDecls(pkgs) 10 | if err != nil { 11 | return []error{err} 12 | } 13 | 14 | defs, defErrs := findSumTypeDefs(decls) 15 | errs = append(errs, defErrs...) 16 | if len(defs) == 0 { 17 | return errs 18 | } 19 | 20 | for _, pkg := range pkgs { 21 | if pkgErrs := check(pkg, defs, config); pkgErrs != nil { 22 | errs = append(errs, pkgErrs...) 23 | } 24 | } 25 | return errs 26 | } 27 | -------------------------------------------------------------------------------- /scripts/go-check-sumtype: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | basedir="$(dirname "$0")/.." 4 | name="$(basename "$0")" 5 | dest="${basedir}/build/devel" 6 | mkdir -p "$dest" 7 | (cd "${basedir}" && ./bin/go build -ldflags="-s -w -buildid=" -o "$dest/${name}" "./cmd/${name}") && exec "$dest/${name}" "$@" 8 | -------------------------------------------------------------------------------- /testdata/sum.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | //sumtype:decl 4 | type Sum interface{ sum() } 5 | 6 | type A struct{} 7 | 8 | func (A) sum() {} 9 | 10 | type B struct{} 11 | 12 | func (B) sum() {} 13 | 14 | type C[T any] struct{} 15 | 16 | func (C[T]) sum() {} 17 | 18 | func SumSwitch(x Sum) { 19 | switch x.(type) { 20 | case A: 21 | case B: 22 | } 23 | } 24 | --------------------------------------------------------------------------------