├── .github
└── workflows
│ ├── ci.yml
│ └── release.yml
├── .golangci.yml
├── .goreleaser.yml
├── COPYING
├── LICENSE-MIT
├── README.md
├── UNLICENSE
├── bin
├── .go-1.23.4.pkg
├── .golangci-lint-1.61.0.pkg
├── .goreleaser-1.26.2.pkg
├── README.hermit.md
├── activate-hermit
├── go
├── gofmt
├── golangci-lint
├── goreleaser
├── hermit
└── hermit.hcl
├── check.go
├── check_test.go
├── cmd
└── go-check-sumtype
│ └── main.go
├── config.go
├── decl.go
├── def.go
├── doc.go
├── go.mod
├── go.sum
├── help_test.go
├── renovate.json5
├── run.go
├── scripts
└── go-check-sumtype
└── testdata
└── sum.go
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - master
5 | pull_request:
6 | name: CI
7 | jobs:
8 | test:
9 | name: Test
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout code
13 | uses: actions/checkout@v3
14 | - name: Init Hermit
15 | run: ./bin/hermit env -r >> $GITHUB_ENV
16 | - name: Test
17 | run: go test ./...
18 | lint:
19 | name: Lint
20 | runs-on: ubuntu-latest
21 | steps:
22 | - name: Checkout code
23 | uses: actions/checkout@v3
24 | - name: Init Hermit
25 | run: ./bin/hermit env -r >> $GITHUB_ENV
26 | - name: golangci-lint
27 | run: golangci-lint run
28 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | push:
4 | tags:
5 | - 'v*'
6 | jobs:
7 | release:
8 | permissions: write-all
9 | name: Release
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 | with:
14 | fetch-depth: 0
15 | - run: ./bin/hermit env --raw >> $GITHUB_ENV
16 | - run: goreleaser release
17 | env:
18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
19 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | run:
2 | tests: true
3 |
4 | output:
5 | print-issued-lines: false
6 |
7 | linters:
8 | enable-all: true
9 | disable:
10 | - cyclop
11 | - depguard
12 | - dupl
13 | - dupword
14 | - err113
15 | - errorlint
16 | - exhaustive
17 | - exhaustruct
18 | - exportloopref
19 | - forcetypeassert
20 | - funlen
21 | - gci
22 | - gochecknoglobals
23 | - gocognit
24 | - goconst
25 | - gocyclo
26 | - godot
27 | - godox
28 | - gofumpt
29 | - govet
30 | - ireturn
31 | - lll
32 | - maintidx
33 | - mnd
34 | - mnd
35 | - musttag
36 | - nestif
37 | - nilnil
38 | - nlreturn
39 | - nolintlint
40 | - nonamedreturns
41 | - paralleltest
42 | - perfsprint
43 | - predeclared
44 | - revive
45 | - stylecheck
46 | - testableexamples
47 | - testpackage
48 | - thelper
49 | - varnamelen
50 | - wrapcheck
51 | - wsl
52 |
53 | linters-settings:
54 | govet:
55 | enable:
56 | - shadow
57 | gocyclo:
58 | min-complexity: 10
59 | dupl:
60 | threshold: 100
61 | goconst:
62 | min-len: 8
63 | min-occurrences: 3
64 | forbidigo:
65 | exclude-godoc-examples: false
66 | #forbid:
67 | # - (Must)?NewLexer$
68 |
69 | issues:
70 | max-issues-per-linter: 0
71 | max-same-issues: 0
72 | exclude-use-default: false
73 | exclude-dirs:
74 | - _examples
75 | exclude:
76 | # Captured by errcheck.
77 | - "^(G104|G204):"
78 | # Very commonly not checked.
79 | - 'Error return value of .(.*\.Help|.*\.MarkFlagRequired|(os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Un)?Setenv). is not checked'
80 | - 'exported method (.*\.MarshalJSON|.*\.UnmarshalJSON|.*\.EntityURN|.*\.GoString|.*\.Pos) should have comment or be unexported'
81 | - "composite literal uses unkeyed fields"
82 | - 'declaration of "err" shadows declaration'
83 | - "should not use dot imports"
84 | - "Potential file inclusion via variable"
85 | - "should have comment or be unexported"
86 | - "comment on exported var .* should be of the form"
87 | - "at least one file in a package should have a package comment"
88 | - "string literal contains the Unicode"
89 | - "methods on the same type should have the same receiver name"
90 | - "_TokenType_name should be _TokenTypeName"
91 | - "`_TokenType_map` should be `_TokenTypeMap`"
92 | - "rewrite if-else to switch statement"
93 |
--------------------------------------------------------------------------------
/.goreleaser.yml:
--------------------------------------------------------------------------------
1 | project_name: go-check-sumtype
2 | release:
3 | github:
4 | owner: alecthomas
5 | name: go-check-sumtype
6 | env:
7 | - CGO_ENABLED=0
8 | builds:
9 | - goos:
10 | - linux
11 | - darwin
12 | - windows
13 | goarch:
14 | - arm64
15 | - amd64
16 | - "386"
17 | goarm:
18 | - "6"
19 | main: ./cmd/go-check-sumtype
20 | binary: go-check-sumtype
21 | archives:
22 | -
23 | format: tar.gz
24 | name_template: '{{ .Binary }}-{{ .Version }}-{{ .Os }}-{{ .Arch }}{{ if .Arm }}v{{
25 | .Arm }}{{ end }}'
26 | files:
27 | - COPYING
28 | - README*
29 | snapshot:
30 | name_template: SNAPSHOT-{{ .Commit }}
31 | checksum:
32 | name_template: '{{ .ProjectName }}-{{ .Version }}-checksums.txt'
33 |
--------------------------------------------------------------------------------
/COPYING:
--------------------------------------------------------------------------------
1 | This project is dual-licensed under the Unlicense and MIT licenses.
2 |
3 | You may use this code under the terms of either license.
4 |
--------------------------------------------------------------------------------
/LICENSE-MIT:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2015 Andrew Gallant
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Note: This is a fork of the great project [go-sumtype](https://github.com/BurntSushi/go-sumtype) by BurntSushi.**
2 | **The original seems largely unmaintained, and the changes in this fork are backwards incompatible.**
3 |
4 | # go-check-sumtype [](https://github.com/alecthomas/go-check-sumtype/actions/workflows/ci.yml)
5 | A simple utility for running exhaustiveness checks on type switch statements.
6 | Exhaustiveness checks are only run on interfaces that are declared to be
7 | "sum types."
8 |
9 | Dual-licensed under MIT or the [UNLICENSE](http://unlicense.org).
10 |
11 | This work was inspired by our code at
12 | [Diffeo](https://diffeo.com).
13 |
14 | ## Installation
15 |
16 | ```go
17 | $ go get github.com/alecthomas/go-check-sumtype
18 | ```
19 |
20 | For usage info, just run the command:
21 |
22 | ```
23 | $ go-check-sumtype
24 | ```
25 |
26 | Typical usage might look like this:
27 |
28 | ```
29 | $ go-check-sumtype $(go list ./... | grep -v vendor)
30 | ```
31 |
32 | ## Usage
33 |
34 | `go-check-sumtype` takes a list of Go package paths or files and looks for sum type
35 | declarations in each package/file provided. Exhaustiveness checks are then
36 | performed for each use of a declared sum type in a type switch statement.
37 | Namely, `go-check-sumtype` will report an error for any type switch statement that
38 | either lacks a `default` clause or does not account for all possible variants.
39 |
40 | Declarations are provided in comments like so:
41 |
42 | ```
43 | //sumtype:decl
44 | type MySumType interface { ... }
45 | ```
46 |
47 | `MySumType` must be *sealed*. That is, part of its interface definition
48 | contains an unexported method.
49 |
50 | `go-check-sumtype` will produce an error if any of the above is not true.
51 |
52 | For valid declarations, `go-check-sumtype` will look for all occurrences in which a
53 | value of type `MySumType` participates in a type switch statement. In those
54 | occurrences, it will attempt to detect whether the type switch is exhaustive
55 | or not. If it's not, `go-check-sumtype` will report an error. For example, running
56 | `go-check-sumtype` on this source file:
57 |
58 | ```go
59 | package main
60 |
61 | //sumtype:decl
62 | type MySumType interface {
63 | sealed()
64 | }
65 |
66 | type VariantA struct{}
67 |
68 | func (*VariantA) sealed() {}
69 |
70 | type VariantB struct{}
71 |
72 | func (*VariantB) sealed() {}
73 |
74 | func main() {
75 | switch MySumType(nil).(type) {
76 | case *VariantA:
77 | }
78 | }
79 | ```
80 |
81 | produces the following:
82 |
83 | ```
84 | $ sumtype mysumtype.go
85 | mysumtype.go:18:2: exhaustiveness check failed for sum type 'MySumType': missing cases for VariantB
86 | ```
87 |
88 | Adding either a `default` clause or a clause to handle `*VariantB` will cause
89 | exhaustive checks to pass. To prevent `default` clauses from automatically
90 | passing checks, set the `-default-signifies-exhasutive=false` flag.
91 |
92 | As a special case, if the type switch statement contains a `default` clause
93 | that always panics, then exhaustiveness checks are still performed.
94 |
95 | By default, `go-check-sumtype` will not include shared interfaces in the exhaustiviness check.
96 | This can be changed by setting the `-include-shared-interfaces=true` flag.
97 | When this flag is set, `go-check-sumtype` will not require that all concrete structs
98 | are listed in the switch statement, as long as the switch statement is exhaustive
99 | with respect to interfaces the structs implement.
100 |
101 | ## Details and motivation
102 |
103 | Sum types are otherwise known as discriminated unions. That is, a sum type is
104 | a finite set of disjoint values. In type systems that support sum types, the
105 | language will guarantee that if one has a sum type `T`, then its value must
106 | be one of its variants.
107 |
108 | Go's type system does not support sum types. A typical proxy for representing
109 | sum types in Go is to use an interface with an unexported method and define
110 | each variant of the sum type in the same package to satisfy said interface.
111 | This guarantees that the set of types that satisfy the interface is closed
112 | at compile time. Performing case analysis on these types is then done with
113 | a type switch statement, e.g., `switch x.(type) { ... }`. Each clause of the
114 | type switch corresponds to a *variant* of the sum type. The downside of this
115 | approach is that Go's type system is not aware of the set of variants, so it
116 | cannot tell you whether case analysis over a sum type is complete or not.
117 |
118 | The `go-check-sumtype` command recognizes this pattern, but it needs a small amount
119 | of help to recognize which interfaces should be treated as sum types, which
120 | is why the `//sumtype:decl` annotation is required. `go-check-sumtype` will
121 | figure out all of the variants of a sum type by finding the set of types
122 | defined in the same package that satisfy the interface specified by the
123 | declaration.
124 |
125 | The `go-check-sumtype` command will prove its worth when you need to add a variant
126 | to an existing sum type. Running `go-check-sumtype` will tell you immediately which
127 | case analyses need to be updated to account for the new variant.
128 |
--------------------------------------------------------------------------------
/UNLICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/bin/.go-1.23.4.pkg:
--------------------------------------------------------------------------------
1 | hermit
--------------------------------------------------------------------------------
/bin/.golangci-lint-1.61.0.pkg:
--------------------------------------------------------------------------------
1 | hermit
--------------------------------------------------------------------------------
/bin/.goreleaser-1.26.2.pkg:
--------------------------------------------------------------------------------
1 | hermit
--------------------------------------------------------------------------------
/bin/README.hermit.md:
--------------------------------------------------------------------------------
1 | # Hermit environment
2 |
3 | This is a [Hermit](https://github.com/cashapp/hermit) bin directory.
4 |
5 | The symlinks in this directory are managed by Hermit and will automatically
6 | download and install Hermit itself as well as packages. These packages are
7 | local to this environment.
8 |
--------------------------------------------------------------------------------
/bin/activate-hermit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This file must be used with "source bin/activate-hermit" from bash or zsh.
3 | # You cannot run it directly
4 | #
5 | # THIS FILE IS GENERATED; DO NOT MODIFY
6 |
7 | if [ "${BASH_SOURCE-}" = "$0" ]; then
8 | echo "You must source this script: \$ source $0" >&2
9 | exit 33
10 | fi
11 |
12 | BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")"
13 | if "${BIN_DIR}/hermit" noop > /dev/null; then
14 | eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")"
15 |
16 | if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then
17 | hash -r 2>/dev/null
18 | fi
19 |
20 | echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated"
21 | fi
22 |
--------------------------------------------------------------------------------
/bin/go:
--------------------------------------------------------------------------------
1 | .go-1.23.4.pkg
--------------------------------------------------------------------------------
/bin/gofmt:
--------------------------------------------------------------------------------
1 | .go-1.23.4.pkg
--------------------------------------------------------------------------------
/bin/golangci-lint:
--------------------------------------------------------------------------------
1 | .golangci-lint-1.61.0.pkg
--------------------------------------------------------------------------------
/bin/goreleaser:
--------------------------------------------------------------------------------
1 | .goreleaser-1.26.2.pkg
--------------------------------------------------------------------------------
/bin/hermit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # THIS FILE IS GENERATED; DO NOT MODIFY
4 |
5 | set -eo pipefail
6 |
7 | export HERMIT_USER_HOME=~
8 |
9 | if [ -z "${HERMIT_STATE_DIR}" ]; then
10 | case "$(uname -s)" in
11 | Darwin)
12 | export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit"
13 | ;;
14 | Linux)
15 | export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit"
16 | ;;
17 | esac
18 | fi
19 |
20 | export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}"
21 | HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")"
22 | export HERMIT_CHANNEL
23 | export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit}
24 |
25 | if [ ! -x "${HERMIT_EXE}" ]; then
26 | echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2
27 | INSTALL_SCRIPT="$(mktemp)"
28 | # This value must match that of the install script
29 | INSTALL_SCRIPT_SHA256="180e997dd837f839a3072a5e2f558619b6d12555cd5452d3ab19d87720704e38"
30 | if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then
31 | curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}"
32 | else
33 | # Install script is versioned by its sha256sum value
34 | curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}"
35 | # Verify install script's sha256sum
36 | openssl dgst -sha256 "${INSTALL_SCRIPT}" | \
37 | awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \
38 | '$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}'
39 | fi
40 | /bin/bash "${INSTALL_SCRIPT}" 1>&2
41 | fi
42 |
43 | exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@"
44 |
--------------------------------------------------------------------------------
/bin/hermit.hcl:
--------------------------------------------------------------------------------
1 | env = {
2 | "PATH": "${HERMIT_ENV}/scripts:${PATH}",
3 | }
4 |
--------------------------------------------------------------------------------
/check.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | import (
4 | "fmt"
5 | "go/ast"
6 | "go/token"
7 | "go/types"
8 | "sort"
9 | "strings"
10 |
11 | "golang.org/x/tools/go/packages"
12 | )
13 |
14 | // inexhaustiveError is returned from check for each occurrence of inexhaustive
15 | // case analysis in a Go type switch statement.
16 | type inexhaustiveError struct {
17 | Position token.Position
18 | Def sumTypeDef
19 | Missing []types.Object
20 | }
21 |
22 | func (e inexhaustiveError) Pos() token.Position { return e.Position }
23 | func (e inexhaustiveError) Error() string {
24 | return fmt.Sprintf(
25 | "%s: exhaustiveness check failed for sum type %q (from %s): missing cases for %s",
26 | e.Pos(), e.Def.Decl.TypeName, e.Def.Decl.Pos, strings.Join(e.Names(), ", "))
27 | }
28 |
29 | // Names returns a sorted list of names corresponding to the missing variant
30 | // cases.
31 | func (e inexhaustiveError) Names() []string {
32 | list := make([]string, 0, len(e.Missing))
33 | for _, o := range e.Missing {
34 | list = append(list, o.Name())
35 | }
36 | sort.Strings(list)
37 | return list
38 | }
39 |
40 | // check does exhaustiveness checking for the given sum type definitions in the
41 | // given package. Every instance of inexhaustive case analysis is returned.
42 | func check(pkg *packages.Package, defs []sumTypeDef, config Config) []error {
43 | var errs []error
44 | for _, astfile := range pkg.Syntax {
45 | ast.Inspect(astfile, func(n ast.Node) bool {
46 | swtch, ok := n.(*ast.TypeSwitchStmt)
47 | if !ok {
48 | return true
49 | }
50 | if err := checkSwitch(pkg, defs, swtch, config); err != nil {
51 | errs = append(errs, err)
52 | }
53 | return true
54 | })
55 | }
56 | return errs
57 | }
58 |
59 | // checkSwitch performs an exhaustiveness check on the given type switch
60 | // statement. If the type switch is used on a sum type and does not cover
61 | // all variants of that sum type, then an error is returned indicating which
62 | // variants were missed.
63 | //
64 | // Note that if the type switch contains a non-panicing default case, then
65 | // exhaustiveness checks are disabled.
66 | func checkSwitch(
67 | pkg *packages.Package,
68 | defs []sumTypeDef,
69 | swtch *ast.TypeSwitchStmt,
70 | config Config,
71 | ) error {
72 | def, missing := missingVariantsInSwitch(pkg, defs, swtch, config)
73 | if len(missing) > 0 {
74 | return inexhaustiveError{
75 | Position: pkg.Fset.Position(swtch.Pos()),
76 | Def: *def,
77 | Missing: missing,
78 | }
79 | }
80 | return nil
81 | }
82 |
83 | // missingVariantsInSwitch returns a list of missing variants corresponding to
84 | // the given switch statement. The corresponding sum type definition is also
85 | // returned. (If no sum type definition could be found, then no exhaustiveness
86 | // checks are performed, and therefore, no missing variants are returned.)
87 | func missingVariantsInSwitch(
88 | pkg *packages.Package,
89 | defs []sumTypeDef,
90 | swtch *ast.TypeSwitchStmt,
91 | config Config,
92 | ) (*sumTypeDef, []types.Object) {
93 | asserted := findTypeAssertExpr(swtch)
94 | ty := pkg.TypesInfo.TypeOf(asserted)
95 | if ty == nil {
96 | panic(fmt.Sprintf("no type found for asserted expression: %v", asserted))
97 | }
98 |
99 | def := findDef(defs, ty)
100 | if def == nil {
101 | // We couldn't find a corresponding sum type, so there's
102 | // nothing we can do to check it.
103 | return nil, nil
104 | }
105 | variantExprs, hasDefault := switchVariants(swtch)
106 | if config.DefaultSignifiesExhaustive && hasDefault && !defaultClauseAlwaysPanics(swtch) {
107 | // A catch-all case defeats all exhaustiveness checks.
108 | return def, nil
109 | }
110 | variantTypes := make([]types.Type, 0, len(variantExprs))
111 | for _, expr := range variantExprs {
112 | variantTypes = append(variantTypes, pkg.TypesInfo.TypeOf(expr))
113 | }
114 | return def, def.missing(variantTypes, config.IncludeSharedInterfaces)
115 | }
116 |
117 | // switchVariants returns all case expressions found in a type switch. This
118 | // includes expressions from cases that have a list of expressions.
119 | func switchVariants(swtch *ast.TypeSwitchStmt) (exprs []ast.Expr, hasDefault bool) {
120 | for _, stmt := range swtch.Body.List {
121 | clause := stmt.(*ast.CaseClause)
122 | if clause.List == nil {
123 | hasDefault = true
124 | } else {
125 | exprs = append(exprs, clause.List...)
126 | }
127 | }
128 | return
129 | }
130 |
131 | // defaultClauseAlwaysPanics returns true if the given switch statement has a
132 | // default clause that always panics. Note that this is done on a best-effort
133 | // basis. While there will never be any false positives, there may be false
134 | // negatives.
135 | //
136 | // If the given switch statement has no default clause, then this function
137 | // panics.
138 | func defaultClauseAlwaysPanics(swtch *ast.TypeSwitchStmt) bool {
139 | var clause *ast.CaseClause
140 | for _, stmt := range swtch.Body.List {
141 | c := stmt.(*ast.CaseClause)
142 | if c.List == nil {
143 | clause = c
144 | break
145 | }
146 | }
147 | if clause == nil {
148 | panic("switch statement has no default clause")
149 | }
150 | if len(clause.Body) != 1 {
151 | return false
152 | }
153 | exprStmt, ok := clause.Body[0].(*ast.ExprStmt)
154 | if !ok {
155 | return false
156 | }
157 | callExpr, ok := exprStmt.X.(*ast.CallExpr)
158 | if !ok {
159 | return false
160 | }
161 | fun, ok := callExpr.Fun.(*ast.Ident)
162 | if !ok {
163 | return false
164 | }
165 | return fun.Name == "panic"
166 | }
167 |
168 | // findTypeAssertExpr extracts the expression that is being type asserted from a
169 | // type swtich statement.
170 | func findTypeAssertExpr(swtch *ast.TypeSwitchStmt) ast.Expr {
171 | var expr ast.Expr
172 | if assign, ok := swtch.Assign.(*ast.AssignStmt); ok {
173 | expr = assign.Rhs[0]
174 | } else {
175 | expr = swtch.Assign.(*ast.ExprStmt).X
176 | }
177 | return expr.(*ast.TypeAssertExpr).X
178 | }
179 |
180 | // findDef returns the sum type definition corresponding to the given type. If
181 | // no such sum type definition exists, then nil is returned.
182 | func findDef(defs []sumTypeDef, needle types.Type) *sumTypeDef {
183 | for i := range defs {
184 | def := &defs[i]
185 | if types.Identical(needle.Underlying(), def.Ty) {
186 | return def
187 | }
188 | }
189 | return nil
190 | }
191 |
--------------------------------------------------------------------------------
/check_test.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/alecthomas/assert/v2"
7 | )
8 |
9 | // TestMissingOne tests that we detect a single missing variant.
10 | func TestMissingOne(t *testing.T) {
11 | code := `
12 | package gochecksumtype
13 |
14 | //sumtype:decl
15 | type T interface { sealed() }
16 |
17 | type A struct {}
18 | func (a *A) sealed() {}
19 |
20 | type B struct {}
21 | func (b *B) sealed() {}
22 |
23 | func main() {
24 | switch T(nil).(type) {
25 | case *A:
26 | }
27 | }
28 | `
29 | pkgs := setupPackages(t, code)
30 |
31 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
32 | assert.Equal(t, 1, len(errs))
33 | assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
34 | }
35 |
36 | // TestMissingTwo tests that we detect two missing variants.
37 | func TestMissingTwo(t *testing.T) {
38 | code := `
39 | package gochecksumtype
40 |
41 | //sumtype:decl
42 | type T interface { sealed() }
43 |
44 | type A struct {}
45 | func (a *A) sealed() {}
46 |
47 | type B struct {}
48 | func (b *B) sealed() {}
49 |
50 | type C struct {}
51 | func (c *C) sealed() {}
52 |
53 | func main() {
54 | switch T(nil).(type) {
55 | case *A:
56 | }
57 | }
58 | `
59 | pkgs := setupPackages(t, code)
60 |
61 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
62 | assert.Equal(t, 1, len(errs))
63 | assert.Equal(t, []string{"B", "C"}, missingNames(t, errs[0]))
64 | }
65 |
66 | // TestMissingOneWithPanic tests that we detect a single missing variant even
67 | // if we have a trivial default case that panics.
68 | func TestMissingOneWithPanic(t *testing.T) {
69 | code := `
70 | package gochecksumtype
71 |
72 | //sumtype:decl
73 | type T interface { sealed() }
74 |
75 | type A struct {}
76 | func (a *A) sealed() {}
77 |
78 | type B struct {}
79 | func (b *B) sealed() {}
80 |
81 | func main() {
82 | switch T(nil).(type) {
83 | case *A:
84 | default:
85 | panic("unreachable")
86 | }
87 | }
88 | `
89 | pkgs := setupPackages(t, code)
90 |
91 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
92 | assert.Equal(t, 1, len(errs))
93 | assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
94 | }
95 |
96 | // TestNoMissing tests that we correctly detect exhaustive case analysis.
97 | func TestNoMissing(t *testing.T) {
98 | code := `
99 | package gochecksumtype
100 |
101 | //sumtype:decl
102 | type T interface { sealed() }
103 |
104 | type A struct {}
105 | func (a *A) sealed() {}
106 |
107 | type B struct {}
108 | func (b *B) sealed() {}
109 |
110 | type C struct {}
111 | func (c *C) sealed() {}
112 |
113 | func main() {
114 | switch T(nil).(type) {
115 | case *A, *B, *C:
116 | }
117 | }
118 | `
119 | pkgs := setupPackages(t, code)
120 |
121 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
122 | assert.Equal(t, 0, len(errs))
123 | }
124 |
125 | // TestNoMissingDefaultWithDefaultSignifiesExhaustive tests that even if we have a missing variant, a default
126 | // case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is true.
127 | func TestNoMissingDefaultWithDefaultSignifiesExhaustive(t *testing.T) {
128 | code := `
129 | package gochecksumtype
130 |
131 | //sumtype:decl
132 | type T interface { sealed() }
133 |
134 | type A struct {}
135 | func (a *A) sealed() {}
136 |
137 | type B struct {}
138 | func (b *B) sealed() {}
139 |
140 | func main() {
141 | switch T(nil).(type) {
142 | case *A:
143 | default:
144 | println("legit catch all goes here")
145 | }
146 | }
147 | `
148 | pkgs := setupPackages(t, code)
149 |
150 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
151 | assert.Equal(t, 0, len(errs))
152 | }
153 |
154 | // TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive tests that even if we have a missing variant, a default
155 | // case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is false.
156 | func TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive(t *testing.T) {
157 | code := `
158 | package gochecksumtype
159 |
160 | //sumtype:decl
161 | type T interface { sealed() }
162 |
163 | type A struct {}
164 | func (a *A) sealed() {}
165 |
166 | type B struct {}
167 | func (b *B) sealed() {}
168 |
169 | func main() {
170 | switch T(nil).(type) {
171 | case *A:
172 | default:
173 | println("legit catch all goes here")
174 | }
175 | }
176 | `
177 | pkgs := setupPackages(t, code)
178 |
179 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: false})
180 | assert.Equal(t, 1, len(errs))
181 | assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
182 | }
183 |
184 | // TestNotSealed tests that we report an error if one tries to declare a sum
185 | // type with an unsealed interface.
186 | func TestNotSealed(t *testing.T) {
187 | code := `
188 | package gochecksumtype
189 |
190 | //sumtype:decl
191 | type T interface {}
192 |
193 | func main() {}
194 | `
195 | pkgs := setupPackages(t, code)
196 |
197 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
198 | assert.Equal(t, 1, len(errs))
199 | assert.Equal(t, "T", errs[0].(unsealedError).Decl.TypeName)
200 | }
201 |
202 | // TestNotInterface tests that we report an error if one tries to declare a sum
203 | // type that doesn't correspond to an interface.
204 | func TestNotInterface(t *testing.T) {
205 | code := `
206 | package gochecksumtype
207 |
208 | //sumtype:decl
209 | type T struct {}
210 |
211 | func main() {}
212 | `
213 | pkgs := setupPackages(t, code)
214 |
215 | errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
216 | assert.Equal(t, 1, len(errs))
217 | assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName)
218 | }
219 |
220 | // TestSubTypeInSwitch tests that if a shared interface is declared in the switch
221 | // statement, we don't report an error if structs that implement the interface are not explicitly
222 | // declared in the switch statement.
223 | func TestSubTypeInSwitch(t *testing.T) {
224 | code := `
225 | package gochecksumtype
226 |
227 | //sumtype:decl
228 | type T1 interface { sealed1() }
229 | type T2 interface {
230 | T1
231 | sealed2()
232 | }
233 |
234 |
235 | type A struct {}
236 | func (a *A) sealed1() {}
237 |
238 | type B struct {}
239 | func (b *B) sealed1() {}
240 | func (b *B) sealed2() {}
241 |
242 | type C struct {}
243 | func (c *C) sealed1() {}
244 | func (c *C) sealed2() {}
245 |
246 | func main() {
247 | switch T1(nil).(type) {
248 | case *A:
249 | case T2:
250 | }
251 | }
252 | `
253 | pkgs := setupPackages(t, code)
254 |
255 | errs := Run(pkgs, Config{IncludeSharedInterfaces: true})
256 | assert.Equal(t, 0, len(errs))
257 | }
258 |
259 | // TestAllLeavesInSwitch tests that we do not report an error if a switch statement
260 | // covers all leaves of the sum type, even if any SubTypes are not explicitly covered
261 | func TestAllLeavesInSwitch(t *testing.T) {
262 | code := `
263 | package gochecksumtype
264 |
265 | //sumtype:decl
266 | type T1 interface { sealed1() }
267 | type T2 interface {
268 | T1
269 | sealed2()
270 | }
271 |
272 |
273 | type A struct {}
274 | func (a *A) sealed1() {}
275 |
276 | type B struct {}
277 | func (b *B) sealed1() {}
278 | func (b *B) sealed2() {}
279 |
280 | type C struct {}
281 | func (c *C) sealed1() {}
282 | func (c *C) sealed2() {}
283 |
284 | func main() {
285 | switch T1(nil).(type) {
286 | case *A:
287 | case *B:
288 | case *C:
289 | }
290 | }
291 | `
292 | pkgs := setupPackages(t, code)
293 |
294 | errs := Run(pkgs, Config{})
295 | assert.Equal(t, 0, len(errs))
296 | }
297 |
298 | func missingNames(t *testing.T, err error) []string {
299 | t.Helper()
300 | ierr, ok := err.(inexhaustiveError)
301 | assert.True(t, ok, "error was not inexhaustiveError: %T", err)
302 | return ierr.Names()
303 | }
304 |
--------------------------------------------------------------------------------
/cmd/go-check-sumtype/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "log"
6 | "os"
7 | "strings"
8 |
9 | gochecksumtype "github.com/alecthomas/go-check-sumtype"
10 | "golang.org/x/tools/go/packages"
11 | )
12 |
13 | func main() {
14 | log.SetFlags(0)
15 |
16 | defaultSignifiesExhaustive := flag.Bool(
17 | "default-signifies-exhaustive",
18 | true,
19 | "Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.",
20 | )
21 |
22 | includeSharedInterfaces := flag.Bool(
23 | "include-shared-interfaces",
24 | false,
25 | "Include shared interfaces in the exhaustiviness check.",
26 | )
27 |
28 | flag.Parse()
29 | if flag.NArg() < 1 {
30 | log.Fatalf("Usage: sumtype \n")
31 | }
32 | args := os.Args[flag.NFlag()+1:]
33 |
34 | config := gochecksumtype.Config{
35 | DefaultSignifiesExhaustive: *defaultSignifiesExhaustive,
36 | IncludeSharedInterfaces: *includeSharedInterfaces,
37 | }
38 |
39 | conf := &packages.Config{
40 | Mode: packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypes | packages.NeedTypesSizes |
41 | packages.NeedImports | packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles,
42 | // Unfortunately, it appears including the test packages in
43 | // this lint makes it difficult to do exhaustiveness checking.
44 | // Namely, it appears that compiling the test version of a
45 | // package introduces distinct types from the normal version
46 | // of the package, which will always result in inexhaustive
47 | // errors whenever a package both defines a sum type and has
48 | // tests. (Specifically, using `package name`. Using `package
49 | // name_test` is OK.)
50 | //
51 | // It's not clear what the best way to fix this is. :-(
52 | Tests: false,
53 | }
54 | pkgs, err := packages.Load(conf, args...)
55 | if err != nil {
56 | log.Fatal(err)
57 | }
58 | if errs := gochecksumtype.Run(pkgs, config); len(errs) > 0 {
59 | var list []string
60 | for _, err := range errs {
61 | list = append(list, err.Error())
62 | }
63 | log.Fatal(strings.Join(list, "\n"))
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/config.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | type Config struct {
4 | DefaultSignifiesExhaustive bool
5 | // IncludeSharedInterfaces in the exhaustiviness check. If true, we do not need to list all concrete structs, as long
6 | // as the switch statement is exhaustive with respect to interfaces the structs implement.
7 | IncludeSharedInterfaces bool
8 | }
9 |
--------------------------------------------------------------------------------
/decl.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | import (
4 | "go/ast"
5 | "go/token"
6 | "strings"
7 |
8 | "golang.org/x/tools/go/packages"
9 | )
10 |
11 | // sumTypeDecl is a declaration of a sum type in a Go source file.
12 | type sumTypeDecl struct {
13 | // The package path that contains this decl.
14 | Package *packages.Package
15 | // The type named by this decl.
16 | TypeName string
17 | // Position where the declaration was found.
18 | Pos token.Position
19 | }
20 |
21 | // Location returns a short string describing where this declaration was found.
22 | func (d sumTypeDecl) Location() string {
23 | return d.Pos.String()
24 | }
25 |
26 | // findSumTypeDecls searches every package given for sum type declarations of
27 | // the form `sumtype:decl`.
28 | func findSumTypeDecls(pkgs []*packages.Package) ([]sumTypeDecl, error) {
29 | var decls []sumTypeDecl
30 | var retErr error
31 | for _, pkg := range pkgs {
32 | for _, file := range pkg.Syntax {
33 | ast.Inspect(file, func(node ast.Node) bool {
34 | if node == nil {
35 | return true
36 | }
37 | decl, ok := node.(*ast.GenDecl)
38 | if !ok || decl.Doc == nil {
39 | return true
40 | }
41 | var tspec *ast.TypeSpec
42 | for _, spec := range decl.Specs {
43 | ts, ok := spec.(*ast.TypeSpec)
44 | if !ok {
45 | continue
46 | }
47 | tspec = ts
48 | }
49 | for _, line := range decl.Doc.List {
50 | if !strings.HasPrefix(line.Text, "//sumtype:decl") {
51 | continue
52 | }
53 | pos := pkg.Fset.Position(decl.Pos())
54 | if tspec == nil {
55 | retErr = notFoundError{Decl: sumTypeDecl{Package: pkg, Pos: pos}}
56 | return false
57 | }
58 | pos = pkg.Fset.Position(tspec.Pos())
59 | decl := sumTypeDecl{Package: pkg, TypeName: tspec.Name.Name, Pos: pos}
60 | debugf("found sum type decl: %s.%s", decl.Package.PkgPath, decl.TypeName)
61 | decls = append(decls, decl)
62 | break
63 | }
64 | return true
65 | })
66 | }
67 | }
68 | return decls, retErr
69 | }
70 |
--------------------------------------------------------------------------------
/def.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "go/token"
7 | "go/types"
8 | "log"
9 | )
10 |
11 | var debug = flag.Bool("debug", false, "enable debug logging")
12 |
13 | func debugf(format string, args ...interface{}) {
14 | if *debug {
15 | log.Printf(format, args...)
16 | }
17 | }
18 |
19 | // Error as returned by Run()
20 | type Error interface {
21 | error
22 | Pos() token.Position
23 | }
24 |
25 | // unsealedError corresponds to a declared sum type whose interface is not
26 | // sealed. A sealed interface requires at least one unexported method.
27 | type unsealedError struct {
28 | Decl sumTypeDecl
29 | }
30 |
31 | func (e unsealedError) Pos() token.Position { return e.Decl.Pos }
32 | func (e unsealedError) Error() string {
33 | return fmt.Sprintf(
34 | "%s: interface '%s' is not sealed "+
35 | "(sealing requires at least one unexported method)",
36 | e.Decl.Location(), e.Decl.TypeName)
37 | }
38 |
39 | // notFoundError corresponds to a declared sum type whose type definition
40 | // could not be found in the same Go package.
41 | type notFoundError struct {
42 | Decl sumTypeDecl
43 | }
44 |
45 | func (e notFoundError) Pos() token.Position { return e.Decl.Pos }
46 | func (e notFoundError) Error() string {
47 | return fmt.Sprintf("%s: type '%s' is not defined", e.Decl.Location(), e.Decl.TypeName)
48 | }
49 |
50 | // notInterfaceError corresponds to a declared sum type that does not
51 | // correspond to an interface.
52 | type notInterfaceError struct {
53 | Decl sumTypeDecl
54 | }
55 |
56 | func (e notInterfaceError) Pos() token.Position { return e.Decl.Pos }
57 | func (e notInterfaceError) Error() string {
58 | return fmt.Sprintf("%s: type '%s' is not an interface", e.Decl.Location(), e.Decl.TypeName)
59 | }
60 |
61 | // sumTypeDef corresponds to the definition of a Go interface that is
62 | // interpreted as a sum type. Its variants are determined by finding all types
63 | // that implement said interface in the same package.
64 | type sumTypeDef struct {
65 | Decl sumTypeDecl
66 | Ty *types.Interface
67 | Variants []types.Object
68 | }
69 |
70 | // findSumTypeDefs attempts to find a Go type definition for each of the given
71 | // sum type declarations. If no such sum type definition could be found for
72 | // any of the given declarations, then an error is returned.
73 | func findSumTypeDefs(decls []sumTypeDecl) ([]sumTypeDef, []error) {
74 | defs := make([]sumTypeDef, 0, len(decls))
75 | var errs []error
76 | for _, decl := range decls {
77 | def, err := newSumTypeDef(decl.Package.Types, decl)
78 | if err != nil {
79 | errs = append(errs, err)
80 | continue
81 | }
82 | if def == nil {
83 | errs = append(errs, notFoundError{decl})
84 | continue
85 | }
86 | defs = append(defs, *def)
87 | }
88 | return defs, errs
89 | }
90 |
91 | // newSumTypeDef attempts to extract a sum type definition from a single
92 | // package. If no such type corresponds to the given decl, then this function
93 | // returns a nil def and a nil error.
94 | //
95 | // If the decl corresponds to a type that isn't an interface containing at
96 | // least one unexported method, then this returns an error.
97 | func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) {
98 | obj := pkg.Scope().Lookup(decl.TypeName)
99 | if obj == nil {
100 | return nil, nil
101 | }
102 | iface, ok := obj.Type().Underlying().(*types.Interface)
103 | if !ok {
104 | return nil, notInterfaceError{decl}
105 | }
106 | hasUnexported := false
107 | for i := range iface.NumMethods() {
108 | if !iface.Method(i).Exported() {
109 | hasUnexported = true
110 | break
111 | }
112 | }
113 | if !hasUnexported {
114 | return nil, unsealedError{decl}
115 | }
116 | def := &sumTypeDef{
117 | Decl: decl,
118 | Ty: iface,
119 | }
120 | debugf("searching for variants of %s.%s\n", pkg.Path(), decl.TypeName)
121 | for _, name := range pkg.Scope().Names() {
122 | obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
123 | if !ok {
124 | continue
125 | }
126 | ty := obj.Type()
127 | if types.Identical(ty.Underlying(), iface) {
128 | continue
129 | }
130 | // Skip generic types.
131 | if named, ok := ty.(*types.Named); ok && named.TypeParams() != nil {
132 | continue
133 | }
134 | if types.Implements(ty, iface) || types.Implements(types.NewPointer(ty), iface) {
135 | debugf(" found variant: %s.%s\n", pkg.Path(), obj.Name())
136 | def.Variants = append(def.Variants, obj)
137 | }
138 | }
139 | return def, nil
140 | }
141 |
142 | func (def *sumTypeDef) String() string {
143 | return def.Decl.TypeName
144 | }
145 |
146 | // missing returns a list of variants in this sum type that are not in the
147 | // given list of types.
148 | func (def *sumTypeDef) missing(tys []types.Type, includeSharedInterfaces bool) []types.Object {
149 | // TODO(ag): This is O(n^2). Fix that. /shrug
150 | var missing []types.Object
151 | for _, v := range def.Variants {
152 | found := false
153 | varty := indirect(v.Type())
154 | for _, ty := range tys {
155 | ty = indirect(ty)
156 | if types.Identical(varty, ty) {
157 | found = true
158 | break
159 | }
160 | if includeSharedInterfaces && implements(varty, ty) {
161 | found = true
162 | break
163 | }
164 | }
165 | if !found && !isInterface(varty) {
166 | // we do not include interfaces extending the sumtype, as the
167 | // all implementations of those interfaces are already covered
168 | // by the sumtype.
169 | missing = append(missing, v)
170 | }
171 | }
172 | return missing
173 | }
174 |
175 | func isInterface(ty types.Type) bool {
176 | underlying := indirect(ty).Underlying()
177 | _, ok := underlying.(*types.Interface)
178 | return ok
179 | }
180 |
181 | // indirect dereferences through an arbitrary number of pointer types.
182 | func indirect(ty types.Type) types.Type {
183 | if ty, ok := ty.(*types.Pointer); ok {
184 | return indirect(ty.Elem())
185 | }
186 | return ty
187 | }
188 |
189 | func implements(varty, interfaceType types.Type) bool {
190 | underlying := interfaceType.Underlying()
191 | if interf, ok := underlying.(*types.Interface); ok {
192 | return types.Implements(varty, interf) || types.Implements(types.NewPointer(varty), interf)
193 | }
194 | return false
195 | }
196 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | sumtype takes a list of Go package paths or files and looks for sum type
3 | declarations in each package/file provided. Exhaustiveness checks are then
4 | performed for each use of a declared sum type in a type switch statement.
5 | Namely, sumtype will report an error for any type switch statement that
6 | either lacks a default clause or does not account for all possible variants.
7 |
8 | Declarations are provided in comments like so:
9 |
10 | //sumtype:decl
11 | type MySumType interface { ... }
12 |
13 | MySumType must be *sealed*. That is, part of its interface definition contains
14 | an unexported method.
15 |
16 | sumtype will produce an error if any of the above is not true.
17 |
18 | For valid declarations, sumtype will look for all occurrences in which a
19 | value of type MySumType participates in a type switch statement. In those
20 | occurrences, it will attempt to detect whether the type switch is exhaustive
21 | or not. If it's not, sumtype will report an error. For example:
22 |
23 | $ cat mysumtype.go
24 | package gochecksumtype
25 |
26 | //sumtype:decl
27 | type MySumType interface {
28 | sealed()
29 | }
30 |
31 | type VariantA struct{}
32 |
33 | func (a *VariantA) sealed() {}
34 |
35 | type VariantB struct{}
36 |
37 | func (b *VariantB) sealed() {}
38 |
39 | func main() {
40 | switch MySumType(nil).(type) {
41 | case *VariantA:
42 | }
43 | }
44 | $ sumtype mysumtype.go
45 | mysumtype.go:18:2: exhaustiveness check failed for sum type 'MySumType': missing cases for VariantB
46 |
47 | Adding either a default clause or a clause to handle *VariantB will cause
48 | exhaustive checks to pass.
49 |
50 | As a special case, if the type switch statement contains a default clause
51 | that always panics, then exhaustiveness checks are still performed.
52 | */
53 | package gochecksumtype
54 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/alecthomas/go-check-sumtype
2 |
3 | go 1.22.0
4 |
5 | require (
6 | github.com/alecthomas/assert/v2 v2.11.0
7 | golang.org/x/tools v0.28.0
8 | )
9 |
10 | require (
11 | github.com/alecthomas/repr v0.4.0 // indirect
12 | github.com/hexops/gotextdiff v1.0.3 // indirect
13 | golang.org/x/mod v0.22.0 // indirect
14 | golang.org/x/sync v0.10.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
2 | github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
3 | github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
4 | github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
5 | github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
6 | github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
7 | golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
8 | golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
9 | golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
10 | golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
11 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
12 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
13 | golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o=
14 | golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q=
15 | golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8=
16 | golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
17 |
--------------------------------------------------------------------------------
/help_test.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "testing"
7 |
8 | "golang.org/x/tools/go/packages"
9 | )
10 |
11 | func setupPackages(t *testing.T, code string) []*packages.Package {
12 | srcPath := filepath.Join(t.TempDir(), "src.go")
13 | if err := os.WriteFile(srcPath, []byte(code), 0600); err != nil {
14 | t.Fatal(err)
15 | }
16 | pkgs, err := tycheckAll([]string{srcPath})
17 | if err != nil {
18 | t.Fatal(err)
19 | }
20 | return pkgs
21 | }
22 |
23 | func tycheckAll(args []string) ([]*packages.Package, error) {
24 | conf := &packages.Config{
25 | Mode: packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypes | packages.NeedTypesSizes |
26 | packages.NeedImports | packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles,
27 | // Unfortunately, it appears including the test packages in
28 | // this lint makes it difficult to do exhaustiveness checking.
29 | // Namely, it appears that compiling the test version of a
30 | // package introduces distinct types from the normal version
31 | // of the package, which will always result in inexhaustive
32 | // errors whenever a package both defines a sum type and has
33 | // tests. (Specifically, using `package name`. Using `package
34 | // name_test` is OK.)
35 | //
36 | // It's not clear what the best way to fix this is. :-(
37 | Tests: false,
38 | }
39 | pkgs, err := packages.Load(conf, args...)
40 | if err != nil {
41 | return nil, err
42 | }
43 | return pkgs, nil
44 | }
45 |
--------------------------------------------------------------------------------
/renovate.json5:
--------------------------------------------------------------------------------
1 | {
2 | $schema: "https://docs.renovatebot.com/renovate-schema.json",
3 | extends: [
4 | "config:recommended",
5 | ":semanticCommits",
6 | ":semanticCommitTypeAll(chore)",
7 | ":semanticCommitScope(deps)",
8 | "group:allNonMajor",
9 | "schedule:earlyMondays", // Run once a week.
10 | ],
11 | packageRules: [
12 | {
13 | matchPackageNames: ["golangci-lint"],
14 | matchManagers: ["hermit"],
15 | enabled: false,
16 | },
17 | ],
18 | }
19 |
--------------------------------------------------------------------------------
/run.go:
--------------------------------------------------------------------------------
1 | package gochecksumtype
2 |
3 | import "golang.org/x/tools/go/packages"
4 |
5 | // Run sumtype checking on the given packages.
6 | func Run(pkgs []*packages.Package, config Config) []error {
7 | var errs []error
8 |
9 | decls, err := findSumTypeDecls(pkgs)
10 | if err != nil {
11 | return []error{err}
12 | }
13 |
14 | defs, defErrs := findSumTypeDefs(decls)
15 | errs = append(errs, defErrs...)
16 | if len(defs) == 0 {
17 | return errs
18 | }
19 |
20 | for _, pkg := range pkgs {
21 | if pkgErrs := check(pkg, defs, config); pkgErrs != nil {
22 | errs = append(errs, pkgErrs...)
23 | }
24 | }
25 | return errs
26 | }
27 |
--------------------------------------------------------------------------------
/scripts/go-check-sumtype:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euo pipefail
3 | basedir="$(dirname "$0")/.."
4 | name="$(basename "$0")"
5 | dest="${basedir}/build/devel"
6 | mkdir -p "$dest"
7 | (cd "${basedir}" && ./bin/go build -ldflags="-s -w -buildid=" -o "$dest/${name}" "./cmd/${name}") && exec "$dest/${name}" "$@"
8 |
--------------------------------------------------------------------------------
/testdata/sum.go:
--------------------------------------------------------------------------------
1 | package testdata
2 |
3 | //sumtype:decl
4 | type Sum interface{ sum() }
5 |
6 | type A struct{}
7 |
8 | func (A) sum() {}
9 |
10 | type B struct{}
11 |
12 | func (B) sum() {}
13 |
14 | type C[T any] struct{}
15 |
16 | func (C[T]) sum() {}
17 |
18 | func SumSwitch(x Sum) {
19 | switch x.(type) {
20 | case A:
21 | case B:
22 | }
23 | }
24 |
--------------------------------------------------------------------------------