├── .gitignore ├── go.mod ├── go.sum ├── test ├── example.go ├── generate_test.go ├── parse_test.go └── example_enums.go ├── utils.go ├── README.md ├── cmd └── stringenum │ └── main.go ├── generator.go └── parser.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/therne/stringenum 2 | 3 | go 1.14 4 | 5 | require github.com/spf13/pflag v1.0.5 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 2 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 3 | -------------------------------------------------------------------------------- /test/example.go: -------------------------------------------------------------------------------- 1 | //go:generate go run ../cmd/stringenum ExampleEnum 2 | package test 3 | 4 | type ExampleEnum string 5 | 6 | const ( 7 | Enum1 = ExampleEnum("hello") 8 | Enum2 = ExampleEnum("world") 9 | Enum3 = ExampleEnum("goos") 10 | ) 11 | 12 | const Enum4 ExampleEnum = "geese" 13 | 14 | const irrelevantConstDecl = 1234 15 | 16 | var irrelevantVarDecl = 5678 17 | -------------------------------------------------------------------------------- /test/generate_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/therne/stringenum" 7 | ) 8 | 9 | func TestGenerateCode(t *testing.T) { 10 | def := stringenum.ParsedFile{ 11 | PackageName: "example", 12 | Enums: []stringenum.EnumDesc{ 13 | { 14 | Type: "Helvetica", 15 | Values: map[string]string{ 16 | "Neue": "neue", 17 | "Bold": "bold", 18 | }, 19 | }, 20 | { 21 | Type: "NotoSans", 22 | Values: map[string]string{ 23 | "Nerf": "nerf", 24 | "CJK": "chJpKr", 25 | }, 26 | }, 27 | }, 28 | } 29 | if _, err := stringenum.GenerateCode(def); err != nil { 30 | t.Fatalf("Error occurred on stringenum.GenerateCode: %v", err) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package stringenum 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "text/template" 7 | ) 8 | 9 | var templateFunctions = template.FuncMap{ 10 | "StringJoin": strings.Join, 11 | } 12 | 13 | func defineTemplate(code string) *template.Template { 14 | code = strings.TrimLeft(code, "\n ") 15 | 16 | tmpl, err := template.New("").Funcs(templateFunctions).Parse(code) 17 | if err != nil { 18 | panic(err) 19 | } 20 | return tmpl 21 | } 22 | 23 | func DumpWithLine(code string) string { 24 | srcWithLine := "" 25 | for i, line := range strings.Split(code, "\n") { 26 | line = strings.ReplaceAll(line, "\t", " ") 27 | srcWithLine += fmt.Sprintf(" %3d | %s\n", i+1, line) 28 | } 29 | return srcWithLine 30 | } 31 | -------------------------------------------------------------------------------- /test/parse_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "os" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/therne/stringenum" 9 | ) 10 | 11 | func TestParser(t *testing.T) { 12 | pwd, _ := os.Getwd() 13 | parsed, err := stringenum.Parse(pwd, stringenum.ParsingOptions{Types: []string{"ExampleEnum"}}) 14 | if err != nil { 15 | t.Fatalf("Expected stringenum.Parse to return any error, but returned:\n\t%v", err) 16 | } 17 | result, ok := parsed["example.go"] 18 | if len(parsed) != 1 || !ok { 19 | t.Fatalf("Expected example.go to be parsed only, but found %v", parsed) 20 | } 21 | if result.PackageName != "test" { 22 | t.Fatalf("Expected parsed package name to be 'test', but was '%v'", result.PackageName) 23 | } 24 | if len(result.Enums) != 1 { 25 | t.Fatalf("Expected parsed enum count to be 1, but was %v", len(result.Enums)) 26 | } 27 | expectedValues := map[string]string{ 28 | "Enum1": "hello", 29 | "Enum2": "world", 30 | "Enum3": "goos", 31 | "Enum4": "geese", 32 | } 33 | if !reflect.DeepEqual(result.Enums[0].Values, expectedValues) { 34 | t.Fatalf("Wrong parsed enums: found %v", result.Enums[0].Values) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | stringenum 2 | ========= 3 | 4 | A go tool to auto-generate serialization / validation methods for enum types aliasing `string`. 5 | 6 | #### Features 7 | 8 | * Supports JSON, YAML serialization 9 | * Implements `Validator` interface 10 | * **Does not support default values yet.** But it has good fit with 11 | third-party default value modules like [creasty/defaults](https://github.com/creasty/defaults) 12 | because your types are basically a string! 13 | 14 | 15 | ## Installation 16 | 17 | You need to install `stringenum` to generate enum stub codes. 18 | 19 | ``` 20 | $ go get github.com/therne/stringenum/... 21 | ``` 22 | 23 | ## Usage 24 | 25 | On the top of your type definition sources, add `go generate` clause to generate stub codes with `stringenum`. 26 | 27 | ```diff 28 | + //go:generate stringenum Kind 29 | package mytype 30 | 31 | type Kind string 32 | 33 | const ( 34 | Apple = Kind("apple") 35 | Google = Kind("google") 36 | ) 37 | ``` 38 | 39 | Then, run go generate to generate stub codes: 40 | 41 | ``` 42 | $ go generate ./... 43 | ``` 44 | 45 | ## Generated Values and Methods 46 | 47 | 48 | * `Values`: A list of all available values in the enum. 49 | * `FromString(string)`: Casts string into the enum. An error is returned if given string is not defined on the enum. 50 | * `IsValid()`: Returns false if the value is not defined on the enum. 51 | * `Validate()`: Returns an error if the value is not defined on the enum. 52 | * `MarshalText` / `UnmarshalText`: Implements `encoding.TextMarshaler` / `TextUnmarshaler` interface for JSON / YAML serialization. 53 | * `String() `: Casts the enum into a `string`. Implements `fmt.Stringer` interface. 54 | 55 | 56 | ## License: MIT -------------------------------------------------------------------------------- /test/example_enums.go: -------------------------------------------------------------------------------- 1 | // Code generated by stringenum - DO NOT EDIT. 2 | package test 3 | 4 | import "fmt" 5 | 6 | // ExampleEnumValues contains all possible values of ExampleEnum. 7 | var ExampleEnumValues = []ExampleEnum{ 8 | Enum1, 9 | Enum2, 10 | Enum3, 11 | Enum4, 12 | } 13 | 14 | // ExampleEnumFromValue returns a ExampleEnum for given value. 15 | func ExampleEnumFromValue(s string) (v ExampleEnum, err error) { 16 | v = (ExampleEnum)(s) 17 | if !v.IsValid() { 18 | err = fmt.Errorf("%s is not a valid ExampleEnum", s) 19 | return 20 | } 21 | return v, nil 22 | } 23 | 24 | // IsExampleEnum returns true if given value is a valid ExampleEnum. 25 | func (v ExampleEnum) IsValid() bool { 26 | for _, val := range ExampleEnumValues { 27 | if val == v { 28 | return true 29 | } 30 | } 31 | return false 32 | } 33 | 34 | // ExampleEnum returns an error if the value is not valid. 35 | func (v ExampleEnum) Validate() error { 36 | if _, err := ExampleEnumFromValue(string(v)); err != nil { 37 | return fmt.Errorf("%w. possible values are: \"hello\", \"world\", \"goos\", \"geese\"") 38 | } 39 | return nil 40 | } 41 | 42 | // String returns a string value of ExampleEnum. 43 | func (v ExampleEnum) String() string { 44 | return string(v) 45 | } 46 | 47 | // MarshalText implements encoding.TextMarshaler interface which is compatible with JSON, YAML. 48 | func (v ExampleEnum) MarshalText() ([]byte, error) { 49 | return []byte(v), nil 50 | } 51 | 52 | // UnmarshalText implements encoding.TextUnmarshaler interface which is compatible with JSON, YAML. 53 | func (v *ExampleEnum) UnmarshalText(d []byte) error { 54 | vv, err := ExampleEnumFromValue(string(d)) 55 | if err != nil { 56 | return err 57 | } 58 | *v = vv 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /cmd/stringenum/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/spf13/pflag" 11 | "github.com/therne/stringenum" 12 | ) 13 | 14 | func main() { 15 | var ( 16 | parsingOpts stringenum.ParsingOptions 17 | outputSuffix string 18 | ) 19 | pflag.StringVarP(&outputSuffix, "output-file-suffix", "o", "enums", "postfix appended in generated output sources (e.g. srcName_enums.go) ") 20 | pflag.Parse() 21 | 22 | if pflag.NArg() < 1 { 23 | log.Fatalln("usage: stringenum ") 24 | } 25 | parsingOpts.Types = pflag.Args() 26 | 27 | pwd, err := os.Getwd() 28 | if err != nil { 29 | log.Fatalln("unable to locate pwd") 30 | } 31 | parsedFiles, err := stringenum.Parse(pwd, parsingOpts) 32 | if err != nil { 33 | log.Fatalln(err.Error()) 34 | } 35 | found := make(map[string]bool) 36 | for fileName, parsedFile := range parsedFiles { 37 | code, err := stringenum.GenerateCode(parsedFile) 38 | if err != nil { 39 | fmt.Printf("error generating code for %s: %v\n", fileName, err) 40 | return 41 | } 42 | 43 | outputFileName := fmt.Sprintf("%s_%s.go", fileName[:len(fileName)-3], outputSuffix) 44 | if err := ioutil.WriteFile(outputFileName, []byte(code), 0644); err != nil { 45 | fmt.Printf("error writing generated source %s: %v\n", outputFileName, err) 46 | return 47 | } 48 | fmt.Println("generated:", filepath.Join(pwd, outputFileName)) 49 | for _, enum := range parsedFile.Enums { 50 | found[enum.Type] = true 51 | } 52 | } 53 | if len(parsingOpts.Types) != len(found) { 54 | for _, name := range parsingOpts.Types { 55 | if !found[name] { 56 | log.Fatalln("enum", name, "not found in source files.") 57 | } 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /generator.go: -------------------------------------------------------------------------------- 1 | package stringenum 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | gofmt "go/format" 7 | ) 8 | 9 | var generatedSrcTmpl = defineTemplate(` 10 | // Code generated by stringenum - DO NOT EDIT. 11 | package {{.PackageName}} 12 | 13 | import "fmt" 14 | 15 | {{range .Enums}} 16 | 17 | // {{.Type}}Values contains all possible values of {{.Type}}. 18 | var {{.Type}}Values = []{{.Type}}{ 19 | {{range $name, $value := .Values -}} 20 | {{ $name }}, 21 | {{end}} 22 | } 23 | 24 | // {{.Type}}FromValue returns a {{.Type}} for given value. 25 | func {{.Type}}FromValue(s string) (v {{.Type}}, err error) { 26 | v = ({{.Type}})(s) 27 | if !v.IsValid() { 28 | err = fmt.Errorf("%s is not a valid {{.Type}}", s) 29 | return 30 | } 31 | return v, nil 32 | } 33 | 34 | // Is{{.Type}} returns true if given value is a valid {{.Type}}. 35 | func (v {{.Type}}) IsValid() bool { 36 | for _, val := range {{.Type}}Values { 37 | if val == v { 38 | return true 39 | } 40 | } 41 | return false 42 | } 43 | 44 | // {{.Type}} returns an error if the value is not valid. 45 | func (v {{.Type}}) Validate() error { 46 | if _, err := {{.Type}}FromValue(string(v)); err != nil { 47 | {{ $possibleValues := StringJoin .EnumValues ", " -}} 48 | return fmt.Errorf("%w. possible values are: {{ js $possibleValues }}") 49 | } 50 | return nil 51 | } 52 | 53 | // String returns a string value of {{.Type}}. 54 | func (v {{.Type}}) String() string { 55 | return string(v) 56 | } 57 | 58 | // MarshalText implements encoding.TextMarshaler interface which is compatible with JSON, YAML. 59 | func (v {{.Type}}) MarshalText() ([]byte, error) { 60 | return []byte(v), nil 61 | } 62 | 63 | // UnmarshalText implements encoding.TextUnmarshaler interface which is compatible with JSON, YAML. 64 | func (v *{{.Type}}) UnmarshalText(d []byte) error { 65 | vv, err := {{.Type}}FromValue(string(d)) 66 | if err != nil { 67 | return err 68 | } 69 | *v = vv 70 | return nil 71 | } 72 | 73 | {{end}} 74 | `) 75 | 76 | func GenerateCode(p ParsedFile) (string, error) { 77 | buf := bytes.NewBufferString("") 78 | if err := generatedSrcTmpl.Execute(buf, p); err != nil { 79 | return "", err 80 | } 81 | code, err := gofmt.Source(buf.Bytes()) 82 | if err != nil { 83 | return "", fmt.Errorf("go fmt: %v. source code was:\n%s", err, DumpWithLine(buf.String())) 84 | } 85 | return string(code), nil 86 | } 87 | -------------------------------------------------------------------------------- /parser.go: -------------------------------------------------------------------------------- 1 | package stringenum 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/parser" 7 | "go/token" 8 | "io/ioutil" 9 | "strings" 10 | ) 11 | 12 | type ParsingOptions struct { 13 | Types []string 14 | } 15 | 16 | type ParsedFile struct { 17 | PackageName string 18 | Enums []EnumDesc 19 | } 20 | 21 | type EnumDesc struct { 22 | Type string 23 | Values map[string]string 24 | } 25 | 26 | func (e EnumDesc) EnumValues() (vv []string) { 27 | for _, v := range e.Values { 28 | vv = append(vv, v) 29 | } 30 | return vv 31 | } 32 | 33 | func Parse(srcDir string, opt ParsingOptions) (parsedFiles map[string]ParsedFile, err error) { 34 | parsedFiles = make(map[string]ParsedFile) 35 | 36 | files, err := ioutil.ReadDir(srcDir) 37 | if err != nil { 38 | return 39 | } 40 | toks := token.NewFileSet() 41 | for _, f := range files { 42 | if f.IsDir() || !strings.HasSuffix(f.Name(), ".go") || strings.HasSuffix(f.Name(), "_test.go") { 43 | continue 44 | } 45 | src, err := parser.ParseFile(toks, f.Name(), nil, 0) 46 | if err != nil { 47 | return nil, err 48 | } 49 | v := newAstVisitor(opt) 50 | ast.Walk(v, src) 51 | if v.Error != nil { 52 | return nil, v.Error 53 | } 54 | result := v.Result() 55 | if len(result.Enums) > 0 { 56 | parsedFiles[f.Name()] = result 57 | } 58 | } 59 | return 60 | } 61 | 62 | type astVisitor struct { 63 | Error error 64 | 65 | packageName string 66 | intermediates map[string]*EnumDesc 67 | options ParsingOptions 68 | } 69 | 70 | func newAstVisitor(opt ParsingOptions) *astVisitor { 71 | return &astVisitor{ 72 | intermediates: make(map[string]*EnumDesc), 73 | options: opt, 74 | } 75 | } 76 | 77 | func (a *astVisitor) Visit(node ast.Node) (w ast.Visitor) { 78 | switch n := node.(type) { 79 | case *ast.File: 80 | a.packageName = fmt.Sprint(n.Name) 81 | return a 82 | case *ast.GenDecl: 83 | if n.Tok.String() == "type" { 84 | // check that type is string 85 | typeSpec := n.Specs[0].(*ast.TypeSpec) 86 | typ, originalType := fmt.Sprint(typeSpec.Name), fmt.Sprint(typeSpec.Type) 87 | 88 | isTarget := false 89 | for _, target := range a.options.Types { 90 | if typ == target { 91 | isTarget = true 92 | break 93 | } 94 | } 95 | if isTarget && originalType != "string" { 96 | a.Error = fmt.Errorf("expected type %s to be string, but was %s", typ, originalType) 97 | return nil 98 | } 99 | return a 100 | } else if n.Tok.String() == "var" { 101 | return nil 102 | } 103 | for _, entry := range n.Specs { 104 | decl, ok := entry.(*ast.ValueSpec) 105 | if !ok || len(decl.Values) != 1 { 106 | continue 107 | } 108 | name := fmt.Sprint(decl.Names[0]) 109 | 110 | if typ, ok := decl.Type.(*ast.Ident); ok { 111 | if literal, ok := decl.Values[0].(*ast.BasicLit); ok && literal.Kind.String() == "STRING" { 112 | // case 1) const Type EnumType = "value" 113 | a.addResult(typ.Name, name, literal.Value) 114 | } 115 | 116 | } else if cast, ok := decl.Values[0].(*ast.CallExpr); ok && len(cast.Args) == 1 { 117 | if literal, ok := cast.Args[0].(*ast.BasicLit); ok && literal.Kind.String() == "STRING" { 118 | // case 2) const Type = EnumType("value") 119 | a.addResult(fmt.Sprint(cast.Fun), name, literal.Value) 120 | } 121 | } 122 | } 123 | return nil 124 | } 125 | return nil 126 | } 127 | 128 | func (a *astVisitor) addResult(typ, name, value string) { 129 | enum := a.intermediates[typ] 130 | if enum == nil { 131 | enum = &EnumDesc{ 132 | Type: typ, 133 | Values: make(map[string]string), 134 | } 135 | a.intermediates[typ] = enum 136 | } 137 | enum.Values[name] = value 138 | } 139 | 140 | func (a *astVisitor) Result() (res ParsedFile) { 141 | res.PackageName = a.packageName 142 | for _, enumName := range a.options.Types { 143 | if enum, ok := a.intermediates[enumName]; ok { 144 | res.Enums = append(res.Enums, *enum) 145 | } 146 | } 147 | return res 148 | } 149 | --------------------------------------------------------------------------------