├── .gitignore ├── types └── const.go ├── go.mod ├── test ├── add.go ├── sum │ └── sum.go ├── max │ └── max.go ├── min │ └── min.go ├── delete │ └── delete.go ├── find │ └── find.go ├── map │ └── map.go └── reduce │ └── main.go ├── enum ├── add.go ├── reduce_test.go ├── map.go ├── util.go ├── find.go ├── sum.go ├── delete.go ├── max.go ├── min.go └── reduce.go ├── Makefile ├── utils └── utils.go ├── .github └── workflows │ └── go.yml ├── .golangci.yml ├── translator ├── parse_fun.go ├── expr.go ├── var.go └── gen_fun_call.go ├── README_CN.md ├── go.sum ├── README.md ├── fileoperations ├── genFunc.go └── replace.go └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | betterGo 2 | vendor* 3 | utils/enum -------------------------------------------------------------------------------- /types/const.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | const BasicLitStr = "BasicLit" 4 | const CallExprStr = "CallExpr" 5 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/PioneerIncubator/betterGo 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect 7 | github.com/sirupsen/logrus v1.8.1 8 | github.com/urfave/cli/v2 v2.2.0 9 | golang.org/x/tools v0.0.0-20200108203644-89082a384178 10 | ) 11 | -------------------------------------------------------------------------------- /test/add.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func useAdd() { 10 | a, b := 1, 2 11 | out := enum.Add(a, b).(int) 12 | 13 | expect := a + b 14 | if expect != out { 15 | fmt.Printf("expected:%d, out:%d", expect, out) 16 | } 17 | fmt.Println("success, expect:", expect) 18 | } 19 | -------------------------------------------------------------------------------- /enum/add.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | log "github.com/sirupsen/logrus" 5 | ) 6 | 7 | func Add(a, b interface{}) interface{} { 8 | 9 | switch typeAB := a.(type) { 10 | default: 11 | log.WithFields(log.Fields{ 12 | "type": typeAB, 13 | }).Fatal("Unexpected type!") 14 | return nil 15 | case int: 16 | return a.(int) + b.(int) 17 | case float64: 18 | return a.(float64) + b.(float64) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /enum/reduce_test.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func mul(a, b int) int { 8 | return a * b 9 | } 10 | 11 | func TestReduce(t *testing.T) { 12 | a := make([]int, 10) 13 | for i := range a { 14 | a[i] = i + 1 15 | } 16 | // Compute 10! 17 | out := Reduce(a, mul, 1).(int) 18 | expect := 1 19 | for i := range a { 20 | expect *= a[i] 21 | } 22 | if expect != out { 23 | t.Fatalf("expected %d got %d", expect, out) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps build test mock cloc unit-test testonly 2 | 3 | 4 | deps: 5 | env GO111MODULE=on go mod download 6 | env GO111MODULE=on go mod vendor 7 | 8 | 9 | build: deps 10 | 11 | lint: 12 | golangci-lint run 13 | 14 | test: lint deps unit-test 15 | 16 | testonly: deps unit-test 17 | 18 | cloc: 19 | cloc --exclude-dir=vendor,3rdmocks,mocks,tools --not-match-f=test . 20 | 21 | unit-test: 22 | go vet `go list ./... | grep -v '/vendor/' | grep -v '/tools'` 23 | go test -count=1 -cover ./... 24 | -------------------------------------------------------------------------------- /test/sum/sum.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func TestSumInts() { 10 | origin := []int{1, 2, 3, 4} 11 | flag := enum.Sum(origin) 12 | if flag == 10 { 13 | fmt.Println("success") 14 | } else { 15 | fmt.Printf("failed") 16 | } 17 | } 18 | 19 | func TestSumFloats() { 20 | origin := []int{1.0, 2.0, 3.0, 4.0} 21 | flag := enum.Sum(origin) 22 | if flag == 10.0 { 23 | fmt.Println("success") 24 | } else { 25 | fmt.Printf("failed") 26 | } 27 | } 28 | 29 | func main() { 30 | TestSumInts() 31 | TestSumFloats() 32 | } -------------------------------------------------------------------------------- /test/max/max.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | func TestMaxInts() { 9 | origin := []int{2, 3, 4, 5} 10 | output := enum.Max(origin) 11 | if output == 5 { 12 | fmt.Printf("success, output is %d \n", output) 13 | } else { 14 | fmt.Printf("failed \n") 15 | } 16 | } 17 | 18 | func TestMaxFloats() { 19 | origin := []float64{2.0, 3.0, 4.0, 5.0} 20 | output := enum.Max(origin) 21 | if output == 5.0 { 22 | fmt.Printf("success, output is %d \n", output) 23 | } else { 24 | fmt.Printf("failed \n") 25 | } 26 | } 27 | func main() { 28 | TestMaxInts() 29 | TestMaxFloats() 30 | } -------------------------------------------------------------------------------- /test/min/min.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func TestMinInts() { 10 | origin := []int{2, 3, 4, 5} 11 | output := enum.Min(origin) 12 | if output == 2 { 13 | fmt.Printf("success, output is %d \n", output) 14 | } else { 15 | fmt.Printf("failed \n") 16 | } 17 | } 18 | 19 | func TestMinFloats() { 20 | origin := []float64{2.0, 3.0, 4.0, 5.0} 21 | output := enum.Min(origin) 22 | if output == 2.0 { 23 | fmt.Printf("success, output is %d \n", output) 24 | } else { 25 | fmt.Printf("failed \n") 26 | } 27 | } 28 | 29 | func main() { 30 | TestMinInts() 31 | TestMinFloats() 32 | } -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | func IncrementString(str string, separator string, first int) string { 11 | if separator == "" { 12 | separator = "_" 13 | } 14 | 15 | if first == 0 || first < 0 { 16 | first = 1 17 | } 18 | 19 | test := strings.SplitN(str, separator, 2) 20 | expect := 2 21 | if len(test) >= expect { 22 | i, err := strconv.Atoi(test[1]) 23 | 24 | if err != nil { 25 | log.Fatal(err) 26 | } 27 | increased := i + first 28 | return test[0] + separator + strconv.Itoa(increased) 29 | } 30 | 31 | return str + separator + strconv.Itoa(first) 32 | } 33 | -------------------------------------------------------------------------------- /test/delete/delete.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func main() { 10 | TestDeleteInts() 11 | TestDeleteStrings() 12 | } 13 | 14 | func TestDeleteStrings() { 15 | origin := []string{"ab", "cdefg", "hijk"} 16 | flag := enum.Delete(origin, func(x string) bool { return len(x) > 2 }) 17 | if flag { 18 | fmt.Println("success") 19 | } else { 20 | fmt.Printf("failed") 21 | } 22 | } 23 | 24 | func TestDeleteInts() { 25 | origin := []int{2, 3, 4, 5} 26 | flag := enum.Delete(origin, func(x int) bool { return x%2 == 0 }) 27 | if flag { 28 | fmt.Println("success") 29 | } else { 30 | fmt.Printf("failed") 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /enum/map.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | func Map(slice, anonymousFunc interface{}) { 10 | in := reflect.ValueOf(slice) 11 | if in.Kind() != reflect.Slice { 12 | log.Fatal("Input is not slice") 13 | } 14 | n := in.Len() 15 | if n == 0 { 16 | return 17 | } 18 | 19 | elemType := in.Type().Elem() 20 | fn := reflect.ValueOf(anonymousFunc) 21 | if fn.Kind() != reflect.Func { 22 | str := elemType.String() 23 | log.Fatal("Function must be of type func(" + str + ", " + str + ") " + str) 24 | } 25 | var ins [1]reflect.Value 26 | for i := 0; i < in.Len(); i++ { 27 | ins[0] = in.Index(i) 28 | tmp := fn.Call(ins[:])[0] 29 | in.Index(i).Set(tmp) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /test/find/find.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func main() { 10 | TestFindInts() 11 | TestFindStrings() 12 | } 13 | 14 | func TestFindStrings() { 15 | origin := []string{"ab", "cdefg", "hijk"} 16 | output := enum.Find(origin, func(x string) bool { return len(x) > 2 }) 17 | if output != nil { 18 | fmt.Printf("success, output is %s \n", output) 19 | } else { 20 | fmt.Printf("failed \n") 21 | } 22 | } 23 | 24 | func TestFindInts() { 25 | origin := []int{2, 3, 4, 5} 26 | output := enum.Find(origin, func(x int) bool { return x%2 == 0 }) 27 | if output != nil { 28 | fmt.Printf("success, output is %d \n", output) 29 | } else { 30 | fmt.Printf("failed \n") 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /enum/util.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import "reflect" 4 | 5 | // goodFunc verifies that the function satisfies the signature, represented as a slice of types. 6 | // The last type is the single result type; the others are the input types. 7 | // A final type of nil means any result type is accepted. 8 | func goodFunc(fn reflect.Value, types ...reflect.Type) bool { 9 | if fn.Kind() != reflect.Func { 10 | return false 11 | } 12 | // Last type is return, the rest are ins. 13 | if fn.Type().NumIn() != len(types)-1 || fn.Type().NumOut() != 1 { 14 | return false 15 | } 16 | for i := 0; i < len(types)-1; i++ { 17 | if fn.Type().In(i) != types[i] { 18 | return false 19 | } 20 | } 21 | outType := types[len(types)-1] 22 | if outType != nil && fn.Type().Out(0) != outType { 23 | return false 24 | } 25 | return true 26 | } 27 | -------------------------------------------------------------------------------- /enum/find.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | func Find(slice, anonymousFunc interface{}) interface{} { 10 | in := reflect.ValueOf(slice) 11 | if in.Kind() != reflect.Slice { 12 | log.Fatal("Input is not slice") 13 | } 14 | n := in.Len() 15 | if n == 0 { 16 | return nil 17 | } 18 | 19 | // Get slice element's type 20 | elemType := in.Type().Elem() 21 | fn := reflect.ValueOf(anonymousFunc) 22 | if fn.Kind() != reflect.Func { 23 | str := elemType.String() 24 | log.Fatal("Function must be of type func(" + str + ", " + str + ")" + str) 25 | } 26 | 27 | var ins [1]reflect.Value 28 | 29 | for i := 0; i < n; i++ { 30 | ins[0] = in.Index(i) 31 | if fn.Call(ins[:])[0].Bool() { 32 | return ins[0].Interface() 33 | } 34 | } 35 | return nil 36 | } 37 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.13 20 | id: go 21 | 22 | - name: Check out code into the Go module directory 23 | uses: actions/checkout@v2 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get -v -t -d ./... 28 | if [ -f Gopkg.toml ]; then 29 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 30 | dep ensure 31 | fi 32 | 33 | - name: Build 34 | run: go build -v . 35 | 36 | - name: Test 37 | run: go test -v . 38 | -------------------------------------------------------------------------------- /enum/sum.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | func Sum(slice interface{}) interface{} { 10 | in := reflect.ValueOf(slice) 11 | if in.Kind() != reflect.Slice { 12 | log.Fatal("Input is not slice") 13 | } 14 | n := in.Len() 15 | if n == 0 { 16 | return nil 17 | } 18 | 19 | switch sliceType := slice.(type) { 20 | default: 21 | log.WithFields(log.Fields{ 22 | "type": sliceType, 23 | }).Fatal("Unexpected type!") 24 | return nil 25 | 26 | //reflect.value only return int64 and float64 27 | case []int: 28 | var sum int64 29 | sum = 0 30 | for i := 0; i < n; i++ { 31 | sum += in.Index(i).Int() 32 | } 33 | return sum 34 | case []float64: 35 | var sum float64 36 | sum = 0.0 37 | for i := 0; i < n; i++ { 38 | sum += in.Index(i).Float() 39 | } 40 | return sum 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /enum/delete.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | func Delete(slice, anonymousFunc interface{}) bool { 10 | in := reflect.ValueOf(slice) 11 | if in.Kind() != reflect.Slice { 12 | log.Fatal("Input is not slice") 13 | } 14 | n := in.Len() 15 | if n == 0 { 16 | return false 17 | } 18 | 19 | // Get slice element's type 20 | elemType := in.Type().Elem() 21 | fn := reflect.ValueOf(anonymousFunc) 22 | if fn.Kind() != reflect.Func { 23 | str := elemType.String() 24 | log.Fatal("Function must be of type func(" + str + ", " + str + ")" + str) 25 | } 26 | 27 | var ins [1]reflect.Value 28 | 29 | count := 0 30 | for i := 0; i < n; i++ { 31 | ins[0] = in.Index(i) 32 | if fn.Call(ins[:])[0].Bool() { 33 | in.Index(count).Set(ins[0]) 34 | count++ 35 | } 36 | } 37 | in = in.Slice(0, count) 38 | return true 39 | } 40 | -------------------------------------------------------------------------------- /test/map/map.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func randomFn(origin int) int { 10 | origin += 1 11 | var b = 10 12 | return origin * b 13 | } 14 | 15 | func varTestFn(origin int) int { 16 | a := 10 17 | return a * origin 18 | } 19 | 20 | func testMap(origin, expect []int, fn func(int) int) { 21 | enum.Map(origin, fn) 22 | flag := true 23 | for i := range origin { 24 | if !(expect[i] == origin[i]) { 25 | flag = false 26 | break 27 | } 28 | } 29 | if flag { 30 | fmt.Println("success, expect:", expect) 31 | } else { 32 | fmt.Printf("expected:%d, out:%d", expect, origin) 33 | } 34 | 35 | } 36 | 37 | func main() { 38 | origin := []int{2, 4, 6} 39 | expect := []int{30, 50, 70} 40 | testMap(origin, expect, randomFn) 41 | 42 | origin = []int{2} 43 | expect = []int{20} 44 | testMap(origin, expect, varTestFn) 45 | } 46 | -------------------------------------------------------------------------------- /enum/max.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | func Max(slice interface{}) interface{} { 10 | in := reflect.ValueOf(slice) 11 | if in.Kind() != reflect.Slice { 12 | log.Fatal("Input is not slice") 13 | } 14 | n := in.Len() 15 | if n == 0 { 16 | return nil 17 | } 18 | 19 | switch sliceType := slice.(type) { 20 | default: 21 | log.WithFields(log.Fields{ 22 | "type": sliceType, 23 | }).Fatal("Unexpected type!") 24 | return nil 25 | 26 | //reflect.value only return int64 and float64 27 | case []int: 28 | var maxVal int64 29 | maxVal = in.Index(0).Int() 30 | for i := 1; i < n; i++ { 31 | if in.Index(i).Int() > maxVal { 32 | maxVal = in.Index(i).Int() 33 | } 34 | } 35 | return maxVal 36 | 37 | case []float64: 38 | var maxVal float64 39 | maxVal = in.Index(0).Float() 40 | for i := 1; i < n; i++ { 41 | if in.Index(i).Float() > maxVal { 42 | maxVal = in.Index(i).Float() 43 | } 44 | } 45 | return maxVal 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /enum/min.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | func Min(slice interface{}) interface{} { 10 | in := reflect.ValueOf(slice) 11 | if in.Kind() != reflect.Slice { 12 | log.Fatal("Input is not slice") 13 | } 14 | n := in.Len() 15 | if n == 0 { 16 | return nil 17 | } 18 | 19 | switch sliceType := slice.(type) { 20 | default: 21 | log.WithFields(log.Fields{ 22 | "type": sliceType, 23 | }).Fatal("Unexpected type!") 24 | return nil 25 | 26 | //reflect.value only return int64 and float64 27 | case []int: 28 | var minVal int64 29 | minVal = in.Index(0).Int() 30 | for i := 1; i < n; i++ { 31 | if in.Index(i).Int() < minVal { 32 | minVal = in.Index(i).Int() 33 | } 34 | } 35 | return minVal 36 | 37 | case []float64: 38 | var minVal float64 39 | minVal = in.Index(0).Float() 40 | for i := 1; i < n; i++ { 41 | if in.Index(i).Float() < minVal { 42 | minVal = in.Index(i).Float() 43 | } 44 | } 45 | return minVal 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /test/reduce/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/PioneerIncubator/betterGo/enum" 7 | ) 8 | 9 | func mul(a, b int) (c int) { 10 | c = a * b 11 | return 12 | } 13 | 14 | func testINT() { 15 | a, b := make([]int, 10), 12 16 | for i := range a { 17 | a[i] = i + 1 18 | } 19 | // Compute 10! 20 | out := enum.Reduce(a, mul, 1).(int) 21 | expect := 1 22 | for i := range a { 23 | expect *= a[i] 24 | } 25 | if expect != out { 26 | fmt.Printf("expected %d got %d , b %d", expect, out, b) 27 | } 28 | fmt.Println("success, ", expect) 29 | } 30 | 31 | func main() { 32 | testINT() 33 | 34 | c, d := make([]float32, 10), 12.3 35 | for i := range c { 36 | c[i] = float32(i) + 1.1 37 | } 38 | // Compute 10! 39 | floatOut := enum.Reduce(c, func(x, y float32) (z float32) { 40 | z = x * y 41 | return 42 | }, 1.0).(float32) 43 | var floatExpect float32 = 1.0 44 | for i := range c { 45 | floatExpect *= c[i] 46 | } 47 | if floatExpect != floatOut { 48 | fmt.Printf("expected %f got %f , b %f", floatExpect, floatOut, d) 49 | } 50 | fmt.Println("success, ", floatExpect) 51 | // var arrayInt = []int{1, 2, 3} 52 | // lambda := 53 | // enum.Map(arrayInt, func(a int) int { 54 | // return a + 1 55 | // }) 56 | } 57 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | skip-dirs: 3 | - test 4 | - vendor 5 | 6 | linters-settings: 7 | maligned: 8 | suggest-new: true 9 | gocritic: 10 | disabled-checks: 11 | - singleCaseSwitch 12 | 13 | linters: 14 | disable-all: true 15 | enable: 16 | - bodyclose 17 | - deadcode 18 | - depguard 19 | - dogsled 20 | - gochecknoinits 21 | - goconst 22 | - gocritic 23 | - gocyclo 24 | - gofmt 25 | - goimports 26 | - golint 27 | - gomnd 28 | - goprintffuncname 29 | - gosec 30 | - gosimple 31 | - govet 32 | - ineffassign 33 | - interfacer 34 | - misspell 35 | - nakedret 36 | - nolintlint 37 | - rowserrcheck 38 | - scopelint 39 | - staticcheck 40 | - structcheck 41 | - typecheck 42 | - unconvert 43 | - unparam 44 | - unused 45 | - varcheck 46 | - gocognit 47 | - asciicheck 48 | - nestif 49 | - errcheck 50 | - dupl 51 | #Consider this 52 | # - godox 53 | # - funlen 54 | # - lll 55 | # - gochecknoglobals 56 | # don't enable: 57 | # - whitespace 58 | # - goerr113 59 | # - godot 60 | # - maligned 61 | # - prealloc 62 | # - testpackage 63 | # - wsl 64 | # - stylecheck 65 | -------------------------------------------------------------------------------- /enum/reduce.go: -------------------------------------------------------------------------------- 1 | package enum 2 | 3 | import ( 4 | "reflect" 5 | 6 | log "github.com/sirupsen/logrus" 7 | ) 8 | 9 | // Reduce computes the reduction of the pair function across the elements of 10 | // the slice. (If the types of the slice and function do not correspond, Reduce 11 | // panics.) For instance, if the slice contains successive integers starting at 12 | // 1 and the function is multiply, the result will be the factorial function. 13 | // If the slice is empty, Reduce returns zero; if it has only one element, it 14 | // returns that element. The return value must be type-asserted by the caller 15 | // back to the element type of the slice. Example: 16 | // func multiply(a, b int) int { return a*b } 17 | // a := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} 18 | // factorial := Reduce(a, multiply, 1).(int) 19 | func Reduce(slice, pairFunction, zero interface{}) interface{} { 20 | in := reflect.ValueOf(slice) 21 | if in.Kind() != reflect.Slice { 22 | log.Fatal("Input is not slice") 23 | 24 | } 25 | n := in.Len() 26 | switch n { 27 | case 0: 28 | return zero 29 | case 1: 30 | return in.Index(0) 31 | } 32 | elemType := in.Type().Elem() 33 | fn := reflect.ValueOf(pairFunction) 34 | if !goodFunc(fn, elemType, elemType, elemType) { 35 | str := elemType.String() 36 | log.Fatal("Function must be of type func(" + str + ", " + str + ") " + str) 37 | } 38 | // Do the first two by hand to prime the pump. 39 | var ins [2]reflect.Value 40 | ins[0] = in.Index(0) 41 | ins[1] = in.Index(1) 42 | out := fn.Call(ins[:])[0] 43 | // Run from index 2 to the end. 44 | for i := 2; i < n; i++ { 45 | ins[0] = out 46 | ins[1] = in.Index(i) 47 | out = fn.Call(ins[:])[0] 48 | } 49 | return out.Interface() 50 | } 51 | -------------------------------------------------------------------------------- /translator/parse_fun.go: -------------------------------------------------------------------------------- 1 | package translator 2 | 3 | import ( 4 | "go/ast" 5 | "go/token" 6 | 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | func getFunParamListRawStr(fset *token.FileSet, ret *ast.FuncDecl) string { 11 | paramsStr := "func (" 12 | for _, v := range ret.Type.Params.List { 13 | exprType := GetExprStr(fset, v.Type) 14 | for j := 0; j < len(v.Names); j++ { 15 | // func mul(a, b int) int 16 | paramsStr = paramsStr + exprType + ", " 17 | } 18 | } 19 | paramsStr = paramsStr[:len(paramsStr)-2] + " )" 20 | return paramsStr 21 | } 22 | 23 | func getFunRetListRawStr(fset *token.FileSet, ret *ast.FuncDecl) string { 24 | retStr := "" 25 | // NOTE (ret1 int,ret2 double ) 26 | if ret.Type.Results == nil { // could be no return value 27 | return retStr 28 | } 29 | for _, v := range ret.Type.Results.List { 30 | // fmt.Println("[getFunRetListRawStr] result v", v) 31 | exprType := GetExprStr(fset, v.Type) 32 | if len(v.Names) == 0 { 33 | retStr = exprType 34 | } else { 35 | for j := 0; j < len(v.Names); j++ { 36 | retStr += exprType 37 | } 38 | } 39 | } 40 | return retStr 41 | } 42 | 43 | func DecorateParamName(name string) string { 44 | return "BETTERGOPARAM" + name 45 | } 46 | 47 | func recordParamType(fset *token.FileSet, ret *ast.FuncType) { 48 | for _, v := range ret.Params.List { 49 | exprType := GetExprStr(fset, v.Type) 50 | for _, name := range v.Names { 51 | nameStr := GetExprStr(fset, name) 52 | log.WithFields(log.Fields{ 53 | "name": nameStr, 54 | "record": DecorateParamName(nameStr), 55 | "value": exprType, 56 | }).Info("Mapping param name and param type") 57 | variableType[DecorateParamName(nameStr)] = exprType 58 | } 59 | } 60 | 61 | } 62 | 63 | func GetFuncType(fset *token.FileSet, ret *ast.FuncDecl) (string, string) { 64 | paramsStr := getFunParamListRawStr(fset, ret) 65 | retStr := getFunRetListRawStr(fset, ret) 66 | recordParamType(fset, ret.Type) 67 | 68 | log.WithFields(log.Fields{ 69 | "name": ret.Name.Name, 70 | "value": paramsStr + retStr, 71 | }).Info("Mapping param name and param type") 72 | variableType[ret.Name.Name] = paramsStr + retStr 73 | return paramsStr + retStr, retStr 74 | // fmt.Println("[FuncDecl] Type.Results", ret.Type.Results.List) 75 | // if ret.Tok == token.DEFINE { Results 76 | // recordDefineVarType(fset, ret) 77 | // } 78 | } 79 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | [English](https://github.com/PioneerIncubator/betterGo/blob/master/README.md) / 中文 2 | 3 | # betterGo 4 | 5 | betterGo实现了我认为Go所缺失的部分 6 | 7 | ## Real Generic 8 | 9 | 为用户提供了可以直接用在代码中的真正的`interface{}`。 10 | 11 | 在部署之前,仅需要使用`translator`生成确定类型的代码,这种方式并不会影响你的代码性能。 12 | 13 | 下面是已经实现的所有泛型函数: 14 | 15 | * `enum.Reduce` 16 | * `enum.Map` 17 | * `enum.Delete`: Delete slice's first element for which fun returns a truthy value. 18 | * `enum.Find`: Returns slice's first element for which fun returns a truthy value. 19 | 20 | ### 实现 21 | 22 | 使用Go AST来分析你使用泛型函数的代码,生成确定类型的函数并替换掉你原先的调用语句 23 | 24 | ### 实际上所做的事 25 | 26 | ![](https://pic1.zhimg.com/50/v2-dd2dc3bc72b058b85774ee804a521165_hd.webp) 27 | 28 | ### [视频介绍](https://www.bilibili.com/video/BV1oT4y1j7L2?from=search&seid=10341891677310558379) 29 | 30 | ### 背景 31 | 32 | 现在的Go语言不支持泛型(像C++中的template、Java中的interface) 33 | 34 | 目前,为实现泛型的需求,在Go语言中往往有如下几种方式[1](#refer-anchor-1): 35 | 36 | > 1. Interface (with method) 37 | > 优点:无需三方库,代码干净而且通用。 38 | > 缺点:需要一些额外的代码量,以及也许没那么夸张的运行时开销。 39 | > 2. Use type assertions 40 | > 优点:无需三方库,代码干净。 41 | > 缺点:需要执行类型断言,接口转换的运行时开销,没有编译时类型检查。 42 | > 3. Reflection 43 | > 优点:干净 44 | > 缺点:相当大的运行时开销,没有编译时类型检查。 45 | > 4. Code generation 46 | > 优点:非常干净的代码(取决工具),编译时类型检查(有些工具甚至允许编写针对通用代码模板的测试),没有运行时开销。 47 | > 缺点:构建需要第三方工具,如果一个模板为不同的目标类型多次实例化,编译后二进制文件较大。 48 | 49 | `betterGo`就是通过`code generation`来实现泛型 50 | 51 | ### 如何使用 52 | 53 | 如果你想使用betterGo来通过自动生成代码的方式实现泛型,可以看下面的例子: 54 | 55 | 在项目中包含了测试用例,例如,需要使用泛型的代码是`test/map/map.go`,如果想用`interface{}` 的函数就是`enum.Map` 这样子用。 56 | 57 | 如果想生成具体类型的函数,就运行这行命令:`go run main.go -w -f test/map/map.go` 58 | 59 | 然后你发现 `test/map/map.go` 改变了,`enum.Map` 变成了: `enum.MapOriginFn(origin, fn)` 60 | 61 | 然后你看项目目录下生成了: `utils/enum/map.go`,就是具体类型的函数 62 | 63 | ### 参与项目 64 | 65 | 如果想和我们一起完成项目的开发,可以直接看代码,找到`AST`相关的包,尝试理解相关函数的作用,很容易就可以理解这个项目以及代码了。 66 | 67 | 如果想从理论出发的话,可以简单看看这本书:https://github.com/chai2010/go-ast-book ,其实他也就是把`AST`包里的代码简单讲讲。 68 | 69 | 想参与具体开发可以参考项目接下来的[TODO List](https://github.com/PioneerIncubator/betterGo/issues/31) 70 | 71 | ### 技术思路 72 | 73 | 1. 导入需要操作的文件/目录 74 | 75 | 2. 通过AST进行语法分析 76 | 77 | AST能分析出每条语句的性质,如: 78 | 79 | - `GenDecl` (一般声明):包括import、常量声明、变量声明、类型声明 80 | - `AssignStmt`(赋值语句):包括赋值语句和短的变量声明(a := 1) 81 | - `FuncDecl`(函数声明) 82 | - `TypeAssertExpr`(类型断言) 83 | - `CallExpr`(函数调用语句) 84 | 85 | 3. 当分析到包含变量的值/类型的语句时(`AssignStmt`、`FuncDecl`)会对变量的值和类型进行记录,并建立二者之间的映射关系,以便于在后续环节中能够通过变量名获取变量的类型 86 | 87 | 4. 当发现函数调用语句(`CallExpr`)时,会检查该函数是否为我们提供的函数,如果是,则通过上一步中记录的参数名对应的类型生成专门处理该类型的一份代码,并存储到指定路径下(如果之前已经生成过相同类型的代码则不重复生成) 88 | 89 | 5. 将原代码中的原来的函数调用语句替换成新的函数调用语句,使其调用上一步中新生成的函数,并更新import的包 90 | 91 | ### Reference 92 | 93 |
[1] Go有什麽泛型的实现方法? - 达的回答 - 知乎 94 | -------------------------------------------------------------------------------- /translator/expr.go: -------------------------------------------------------------------------------- 1 | package translator 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/printer" 8 | "go/token" 9 | 10 | "github.com/PioneerIncubator/betterGo/utils" 11 | log "github.com/sirupsen/logrus" 12 | ) 13 | 14 | func ExtractParamsTypeAndName(fset *token.FileSet, listOfArgs []ast.Expr) (string, []string, []string) { 15 | var paramsType string 16 | var listOfArgVarNames []string 17 | var listOfArgTypes []string 18 | argname := "argname" 19 | argsNum := len(listOfArgs) 20 | for i, arg := range listOfArgs { 21 | argname = utils.IncrementString(argname, "", 1) 22 | switch x := arg.(type) { 23 | case *ast.BasicLit: 24 | argVarName := x.Value 25 | listOfArgVarNames = append(listOfArgVarNames, argVarName) 26 | listOfArgTypes = append(listOfArgTypes, GetBasicLitType(x)) 27 | paramsType = fmt.Sprintf("%s %s %s", paramsType, argname, GetBasicLitType(x)) 28 | case *ast.Ident: 29 | argVarName := x.Name 30 | listOfArgVarNames = append(listOfArgVarNames, argVarName) 31 | listOfArgTypes = append(listOfArgTypes, variableType[argVarName]) 32 | var argVarType string 33 | if paramType, ok := variableType[DecorateParamName(argVarName)]; ok { 34 | argVarType = paramType 35 | } else { 36 | argVarType = variableType[argVarName] 37 | } 38 | log.WithFields(log.Fields{ 39 | "name": argVarName, 40 | "type": argVarType, 41 | }).Info("Find an identifier from function parameters") 42 | paramsType = fmt.Sprintf("%s %s %s", paramsType, argname, argVarType) 43 | case *ast.FuncLit: 44 | argDeclar, retDeclar := "", "" 45 | for _, v := range x.Type.Params.List { 46 | lenNames := len(v.Names) 47 | if argDeclar == "" { 48 | lenNames-- 49 | argDeclar = GetExprStr(fset, v.Type) 50 | } 51 | for i := 0; i < lenNames; i++ { 52 | argDeclar = fmt.Sprintf("%s, %s", argDeclar, GetExprStr(fset, v.Type)) 53 | } 54 | } 55 | for _, v := range x.Type.Results.List { 56 | lenNames := len(v.Names) 57 | if retDeclar == "" { 58 | lenNames-- 59 | retDeclar = GetExprStr(fset, v.Type) 60 | } 61 | for i := 0; i < lenNames; i++ { 62 | retDeclar = fmt.Sprintf("%s,%s", retDeclar, GetExprStr(fset, v.Type)) 63 | } 64 | } 65 | 66 | var lambdaTypeStr string 67 | if len(x.Type.Results.List) == 1 { 68 | lambdaTypeStr = fmt.Sprintf("func(%s) %s", argDeclar, retDeclar) 69 | } else { 70 | lambdaTypeStr = fmt.Sprintf("func(%s)(%s)", argDeclar, retDeclar) 71 | } 72 | paramsType = fmt.Sprintf("%s %s %s", paramsType, argname, lambdaTypeStr) 73 | listOfArgVarNames = append(listOfArgVarNames, "lambda") 74 | listOfArgTypes = append(listOfArgTypes, lambdaTypeStr) 75 | default: 76 | log.Error("Unknown type: ", x) 77 | } 78 | 79 | if i != argsNum-1 { 80 | paramsType += "," 81 | } 82 | } 83 | listOfArgTypes = append(listOfArgTypes, assertType) 84 | return paramsType, listOfArgVarNames, listOfArgTypes 85 | } 86 | 87 | func GetExprStr(fset *token.FileSet, expr interface{}) string { 88 | name := new(bytes.Buffer) 89 | err := printer.Fprint(name, fset, expr) 90 | if err != nil { 91 | log.Fatal(err) 92 | } 93 | return name.String() 94 | } 95 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 2 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 3 | github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= 4 | github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= 9 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 10 | github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= 11 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= 12 | github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= 13 | github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 14 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 15 | github.com/urfave/cli/v2 v2.2.0 h1:JTTnM6wKzdA0Jqodd966MVj4vWbbquZykeX1sKbe2C4= 16 | github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= 17 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 18 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 19 | golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= 20 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 21 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 22 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 23 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 24 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 25 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4= 26 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 28 | golang.org/x/tools v0.0.0-20200108203644-89082a384178 h1:f5gMxb6FbpY48csegk9UPd7IAHVrBD013CU7N4pWzoE= 29 | golang.org/x/tools v0.0.0-20200108203644-89082a384178/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= 30 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 32 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 33 | -------------------------------------------------------------------------------- /translator/var.go: -------------------------------------------------------------------------------- 1 | package translator 2 | 3 | import ( 4 | "go/ast" 5 | "go/token" 6 | 7 | "github.com/PioneerIncubator/betterGo/types" 8 | log "github.com/sirupsen/logrus" 9 | ) 10 | 11 | var variableType = map[string]string{} 12 | var assertPassCnt = 0 13 | var assertType = "" 14 | 15 | func RecordAssertType(input string) { 16 | log.WithField("assertType", input).Info("Recording assert type") 17 | assertType = input 18 | assertPassCnt = 1 19 | // // gen = gen + assertType 20 | // fmt.Println("finally assertType is ", assertType) 21 | } 22 | 23 | func GetAssertType() string { 24 | return assertType 25 | } 26 | 27 | func RecordAssignVarType(fset *token.FileSet, ret *ast.AssignStmt) { 28 | if len(ret.Lhs) == len(ret.Rhs) { 29 | for i, l := range ret.Lhs { 30 | assignVar := reflectType(fset, l) 31 | assignType := reflectType(fset, ret.Rhs[i]) 32 | if assignType == types.CallExprStr { 33 | expr := ret.Rhs[i].(*ast.CallExpr) 34 | if GetExprStr(fset, expr.Fun) == "make" { 35 | switch x := expr.Args[0].(type) { 36 | case *ast.ArrayType: 37 | assignType = reflectType(fset, x.Elt) 38 | assignType = "[]" + assignType 39 | } 40 | } 41 | } 42 | if assignType == types.BasicLitStr { 43 | expr := ret.Rhs[i].(*ast.BasicLit) 44 | assignType = GetBasicLitType(expr) 45 | } 46 | 47 | log.WithFields(log.Fields{ 48 | "name": assignVar, 49 | "value": assignType, 50 | }).Info("Mapping variable name and variable type") 51 | variableType[assignVar] = assignType 52 | } 53 | } 54 | } 55 | 56 | func GetBasicLitType(expr *ast.BasicLit) string { 57 | switch expr.Kind { 58 | case token.INT: 59 | return "int" 60 | case token.FLOAT: 61 | return "float64" 62 | case token.STRING: 63 | return "string" 64 | case token.CHAR: 65 | return "char" 66 | } 67 | return "" 68 | } 69 | 70 | func RecordDeclVarType(fset *token.FileSet, ret *ast.ValueSpec) { 71 | for i, declVar := range ret.Names { 72 | if len(ret.Values) == 0 { 73 | declVarType := reflectType(fset, ret.Type) 74 | log.WithFields(log.Fields{ 75 | "name": declVar, 76 | "value": declVarType, 77 | }).Info("Mapping variable name and variable type") 78 | variableType[declVar.Name] = declVarType 79 | } else { 80 | value := ret.Values[i] 81 | declVarType := reflectType(fset, value) 82 | log.WithFields(log.Fields{ 83 | "name": declVar, 84 | "value": declVarType, 85 | }).Info("Mapping variable name and variable type") 86 | if declVarType == types.BasicLitStr { 87 | declVarType = GetBasicLitType(value.(*ast.BasicLit)) 88 | } 89 | variableType[declVar.Name] = declVarType 90 | } 91 | } 92 | } 93 | 94 | func reflectType(fset *token.FileSet, arg interface{}) string { 95 | s := "" 96 | switch x := arg.(type) { 97 | case *ast.ArrayType: 98 | return "[]" 99 | case *ast.CallExpr: 100 | return types.CallExprStr 101 | case *ast.ParenExpr: 102 | case *ast.FuncLit: 103 | // s = x.Value 104 | case *ast.BasicLit: 105 | s = x.Value 106 | return types.BasicLitStr 107 | case *ast.Ident: 108 | s = x.Name 109 | // return "Ident" 110 | } 111 | return s 112 | // if s != "" { 113 | // fmt.Printf("[reflectType] :\t%s\n", s) 114 | // } 115 | 116 | } 117 | -------------------------------------------------------------------------------- /translator/gen_fun_call.go: -------------------------------------------------------------------------------- 1 | package translator 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | "strings" 8 | ) 9 | 10 | func extractParamsName(listOfArgs []ast.Expr) string { 11 | var paramsName string 12 | for _, arg := range listOfArgs { 13 | switch x := arg.(type) { 14 | case *ast.BasicLit: 15 | paramsName = strings.Title(fmt.Sprintf("%s %s", paramsName, GetBasicLitType(x))) 16 | case *ast.Ident: 17 | paramsName = fmt.Sprintf("%s %s", paramsName, strings.Title(x.Name)) 18 | } 19 | } 20 | return strings.ReplaceAll(paramsName, " ", "") 21 | } 22 | 23 | // func Reduce(argname_1 []int, argname_2 func (int, int, string)int, argname_3 int) int 24 | func genFunctionBody(funName string) string { 25 | var body string 26 | switch funName { 27 | case "Reduce": 28 | body = ` 29 | lenSlice := len(argname_1) 30 | switch lenSlice { 31 | case 0: 32 | return 0 33 | case 1: 34 | return argname_1[1] 35 | } 36 | out := argname_2(argname_3, argname_1[0]) 37 | next := argname_1[1] 38 | for i := 1; i < lenSlice; i++ { 39 | next = argname_1[i] 40 | out = argname_2(out, next) 41 | } 42 | return out 43 | ` 44 | case "Add": 45 | body = ` 46 | return argname_1 + argname_2 47 | ` 48 | case "Map": 49 | body = ` 50 | lenSlice := len(argname_1) 51 | if lenSlice == 0 { 52 | return 53 | } 54 | for i := range argname_1 { 55 | argname_1[i] = argname_2(argname_1[i]) 56 | } 57 | ` 58 | case "Delete": 59 | body = ` 60 | lenSlice := len(argname_1) 61 | if lenSlice == 0 { 62 | return false 63 | } 64 | count := 0 65 | for i := range argname_1 { 66 | if argname_2(argname_1[i]) { 67 | argname_1[count] = argname_1[i] 68 | count++ 69 | } 70 | } 71 | argname_1 = argname_1[:count] 72 | return true 73 | ` 74 | case "Find": 75 | body = ` 76 | lenSlice := len(argname_1) 77 | if lenSlice == 0 { 78 | return nil 79 | } 80 | for i := range argname_1 { 81 | if argname_2(argname_1[i]) { 82 | return argname_1[i] 83 | } 84 | } 85 | return nil 86 | ` 87 | case "Sum": 88 | body = ` 89 | lenSlice := len(argname_1) 90 | if lenSlice == 0 { 91 | return nil 92 | } 93 | sum := argname_1[0] 94 | for i := range argname_1 { 95 | sum += i 96 | } 97 | sum = sum - argname_1[0] 98 | return sum 99 | ` 100 | 101 | case "Max": 102 | body = ` 103 | lenSlice := len(argname_1) 104 | if lenSlice == 0 { 105 | return nil 106 | } 107 | maxValue := argname_1[0] 108 | for i := range argname_1 { 109 | if(i > maxValue) { 110 | maxValue = i 111 | } 112 | } 113 | return maxValue 114 | ` 115 | case "Min": 116 | body = ` 117 | lenSlice := len(argname_1) 118 | if lenSlice == 0 { 119 | return nil 120 | } 121 | minValue := argname_1[0] 122 | for i := range argname_1 { 123 | if(i < minValue) { 124 | minValue = i 125 | } 126 | } 127 | return minValue 128 | ` 129 | } 130 | 131 | return body 132 | } 133 | 134 | func GenEnumFunctionDecl(fset *token.FileSet, funName string, listOfArgs []ast.Expr) (string, string) { 135 | paramsTypeDecl, _, _ := ExtractParamsTypeAndName(fset, listOfArgs) 136 | switch funName { 137 | case "enum.Reduce": 138 | // iterate function args to reveal the type 139 | // Reduce(slice, pairFunction, zero interface{}) interface{} 140 | funName = "Reduce" 141 | case "enum.Add": 142 | funName = "Add" 143 | case "enum.Map": 144 | funName = "Map" 145 | case "enum.Delete": 146 | funName = "Delete" 147 | case "enum.Find": 148 | funName = "Find" 149 | case "enum.Sum": 150 | funName = "Sum" 151 | case "enum.Max": 152 | funName = "Max" 153 | case "enum.Min": 154 | funName = "Min" 155 | } 156 | functionBody := genFunctionBody(funName) 157 | 158 | funName += extractParamsName(listOfArgs) 159 | var funcitonDecl string 160 | if assertPassCnt == 1 { 161 | funcitonDecl = fmt.Sprintf( 162 | `func %s(%s) %s { 163 | %s 164 | }`, 165 | funName, 166 | paramsTypeDecl, 167 | // TODO : Use sth bettor to record .assert 168 | assertType, 169 | functionBody, 170 | ) 171 | } else { 172 | funcitonDecl = fmt.Sprintf( 173 | `func %s(%s) { 174 | %s 175 | }`, 176 | funName, 177 | paramsTypeDecl, 178 | functionBody, 179 | ) 180 | 181 | } 182 | return funName, funcitonDecl 183 | } 184 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | English / [中文](https://github.com/PioneerIncubator/betterGo/blob/master/README_CN.md) 2 | 3 | # betterGo 4 | 5 | betterGo implement parts that I think Golang missed 6 | 7 | ## Real Generic 8 | 9 | Provide the real interface{} to you so that you can use it in your code. 10 | Before deployment, just use translator to generate specific type code, in which way will not affect your performance. 11 | 12 | Here are all generic functions: 13 | * `enum.Reduce` 14 | * `enum.Map` 15 | * `enum.Delete`: Delete slice's first element for which fun returns a truthy value. 16 | * `enum.Find`: Returns slice's first element for which fun returns a truthy value. 17 | 18 | ### Implementation 19 | 20 | Use go ast to analyse your code where using generic functions, generate specific function for your types and replace your original call expressions. 21 | 22 | ### What I actually do 23 | 24 | ![](https://pic1.zhimg.com/50/v2-dd2dc3bc72b058b85774ee804a521165_hd.webp) 25 | 26 | 27 | 28 | I do this shit for you :P 29 | 30 | ### Background 31 | 32 | The current Go doesn't have generics (like template in c++, interface in Java). 33 | 34 | The following ways are often used in Go in order to implement generics: 35 | 36 | 1. Interface (with method) 37 | 38 | Pros: No third-party libraries required, clean and universal code. 39 | 40 | Cons: Requires some additional amount of code, and perhaps a less dramatic runtime overhead. 41 | 42 | 2. Use type assertions 43 | 44 | Pros: No third-party libraries required, clean and universal code. 45 | 46 | Cons: Requires execution of type assertions, runtime overhead for interface conversions, and no compile-time type checking. 47 | 48 | 3. Reflection 49 | 50 | Pros: Clean code. 51 | 52 | Cons: Considerable runtime overhead, and no compile-time type checking. 53 | 54 | 4. Code generation 55 | 56 | Pros: Extremely clean code, compile-time type checking, no runtime overhead. 57 | 58 | Cons: Requires third-party libraries, larger compiled binaries. 59 | 60 | `betterGo` is a generic implementation of `code generation`. 61 | 62 | ### Usage 63 | 64 | If you want to use `betterGo` to implement generics by automatically generating code, have a look at the following example: 65 | 66 | There are test cases in the project, for example, the code that needs to be generic is `test/map/map.go`, if you want to use the `interface{}` function, just `enum.Map`. 67 | 68 | If you want to generate a function of a specific type, run this command: `go run main.go -w -f test/map/map.go` 69 | 70 | Then you'll find that `test/map/map.go` has changed, `enum.Map` has become `enum.MapOriginFn(origin, fn)`. 71 | 72 | After that you can see that the project directory generates: `utils/enum/map.go`, which is a function of a specific type. 73 | 74 | ### Contributing 75 | 76 | If you want to work with us on this project, you can look directly at the code, find the package related to `AST` and try to understand the relevant function so that you can easily understand the project and the code. 77 | 78 | If you want to start with theory, you can find some information on `AST` and study it briefly. 79 | 80 | ### Implementation method 81 | 82 | 1. Load the file/directory to be manipulated 83 | 84 | 2. Syntactic analysis by `AST` 85 | 86 | `AST` can analyze the nature of each statement, such as: 87 | 88 | - `GenDecl` (General Declaration): Includes import, constant declaration, variable declaration, type declaration. 89 | - `AssignStmt`(Assignment Statement): Includes assignment statements and short variable declarations (a := 1). 90 | - `FuncDecl`(Function Declaration): 91 | - `TypeAssertExpr`(Type Assertion Expression) 92 | - `CallExpr`( Call Expression) 93 | 94 | 3. When a statement containing the value/type of a variable (`AssignStmt`, `FuncDecl`) is analyzed, the value and type of the variable are recorded and a mapping between them is established so that the type of the variable can be obtained from the variable name in subsequent sessions. 95 | 96 | 4. When a function call expression (`CallExpr`) is found, it is checked whether the function is provided by us, and if it is, the code that deals specifically with that type is generated from the type corresponding to the argument name recorded in the previous step and stored in the specified path (if code of the same type has already been generated before, it is not repeated). 97 | 98 | 5. Replaces the original function call expression in the original code with a new function call expression that calls the newly generated function from the previous step, and updates the import package. 99 | 100 | -------------------------------------------------------------------------------- /fileoperations/genFunc.go: -------------------------------------------------------------------------------- 1 | package fileoperations 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "go/format" 8 | "io" 9 | "os" 10 | "regexp" 11 | "strings" 12 | 13 | log "github.com/sirupsen/logrus" 14 | ) 15 | 16 | func checkFileExists(filePath string) bool { 17 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 18 | return false 19 | } 20 | return true 21 | } 22 | 23 | func CheckFuncExists(filePath string, listOfArgTypes []string) (bool, string) { 24 | if !checkFileExists(filePath) { 25 | return false, "" 26 | } 27 | 28 | // Generate `target`, which will be used to match the function name in `filepath` 29 | // It will be like `argname_1 int, argname_2 int\)` 30 | var target string 31 | switch length := len(listOfArgTypes); length { 32 | case 0: 33 | log.Fatal("Error:There is no argument in listOfArgTypes") 34 | case 1: 35 | // There is no comma behind %s because there just have only one arg 36 | target = fmt.Sprintf("argname_%d %s", 1, listOfArgTypes[0]) 37 | default: 38 | // There is a comma behind %s because it's not the last arg 39 | target = fmt.Sprintf("argname_%d %s,", 1, listOfArgTypes[0]) 40 | i := 1 41 | for ; i < length-2; i++ { 42 | // There is a comma behind %s because it's not the last arg 43 | target = fmt.Sprintf("%s argname_%d %s,", target, i+1, listOfArgTypes[i]) 44 | } 45 | // There is no comma behind %s because it's the last arg 46 | target = fmt.Sprintf("%s argname_%d %s) %s", target, i+1, listOfArgTypes[i], listOfArgTypes[length-1]) 47 | } 48 | target = regexp.QuoteMeta(target) 49 | 50 | log.WithFields(log.Fields{ 51 | "target": target, 52 | "filePath": filePath, 53 | }).Info("The target function is being checked from the file for existence") 54 | funcExists, funcName := matchFunc(filePath, target) 55 | 56 | return funcExists, funcName 57 | } 58 | 59 | func matchFunc(filePath, origin string) (bool, string) { 60 | f, err := os.OpenFile(filePath, os.O_RDWR, 0666) 61 | if err != nil { 62 | log.Fatal(err) 63 | } 64 | defer func() { 65 | err := f.Close() 66 | if err != nil { 67 | log.Error(err) 68 | } 69 | }() 70 | 71 | reader := bufio.NewReader(f) 72 | for { 73 | line, _, err := reader.ReadLine() 74 | if err != nil { 75 | if err == io.EOF { 76 | return false, "" 77 | } 78 | log.Fatal(err) 79 | } 80 | 81 | if ok, _ := regexp.Match(origin, line); ok { 82 | funcName := getFuncNameFromLine(line) 83 | log.WithFields(log.Fields{ 84 | "prevFuncName": funcName, 85 | }).Warn("Function has been generated before!") 86 | return true, funcName 87 | } 88 | } 89 | } 90 | 91 | func getFuncNameFromLine(line []byte) string { 92 | // line is like "func AddAB( argname_1 int, argname_2 int) int {" 93 | // then this func will match funcName which like "AddAB" in line 94 | expr := "func \\w+\\(" // regular expression 95 | reg, _ := regexp.Compile(expr) 96 | // matchRet is the result of regular expression match, it will like "func AddAB(" 97 | matchRet := string(reg.Find(line)) 98 | // funcName is like "AddAB" 99 | funcName := matchRet[5 : len(matchRet)-1] 100 | return funcName 101 | } 102 | 103 | func ensureDirExists(filePath string) error { 104 | s := strings.Split(filePath, "/") 105 | s = s[0 : len(s)-1] 106 | dirPath := strings.Join(s, "/") 107 | if _, err := os.Stat(dirPath); os.IsNotExist(err) { 108 | err = os.Mkdir(dirPath, 0777) 109 | if err != nil { 110 | return err 111 | } 112 | } 113 | return nil 114 | } 115 | 116 | func ensureFileExists(filePath string) (*os.File, bool, error) { 117 | var f *os.File 118 | var err error 119 | exist := false 120 | if err = ensureDirExists(filePath); err != nil { 121 | log.Error(err) 122 | } 123 | if checkFileExists(filePath) { 124 | exist = true 125 | f, err = os.OpenFile(filePath, os.O_APPEND|os.O_RDWR, 0666) 126 | } else { 127 | f, err = os.Create(filePath) 128 | } 129 | 130 | if err != nil { 131 | log.Fatal(err) 132 | } 133 | 134 | return f, exist, err 135 | } 136 | 137 | func WriteFuncToFile(filePath, packageName string, input []byte) error { 138 | var err error 139 | input, err = format.Source(input) 140 | if err != nil { 141 | log.Error(err) 142 | } 143 | f, exist, err := ensureFileExists(filePath) 144 | defer func() { 145 | err := f.Close() 146 | if err != nil { 147 | log.Error(err) 148 | } 149 | }() 150 | if err != nil { 151 | log.Fatal(err) 152 | } 153 | 154 | writer := bufio.NewWriter(f) 155 | if !exist { 156 | var buffer bytes.Buffer 157 | buffer.Write([]byte(packageName + "\n")) 158 | buffer.Write(input) 159 | input = buffer.Bytes() 160 | } 161 | if _, err = writer.Write(input); err != nil { 162 | return err 163 | } 164 | err = writer.Flush() 165 | if err != nil { 166 | return err 167 | } 168 | return nil 169 | } 170 | -------------------------------------------------------------------------------- /fileoperations/replace.go: -------------------------------------------------------------------------------- 1 | package fileoperations 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "os" 8 | "path/filepath" 9 | "regexp" 10 | "strings" 11 | 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | // Generate function calling statement by funName and arguments 16 | func GenCallExpr(funName, assertType string, listOfArgs []string, isNew bool) string { 17 | callExpr := funName 18 | numOfArgs := len(listOfArgs) 19 | 20 | callExpr += "(" 21 | for index, arg := range listOfArgs { 22 | callExpr += arg 23 | if index != numOfArgs-1 { 24 | callExpr += ", " 25 | } 26 | } 27 | 28 | callExpr += ")" 29 | if !isNew { 30 | if len(assertType) != 0 { 31 | callExpr = fmt.Sprintf("%s.(%s)", callExpr, assertType) 32 | } 33 | callExpr = regexp.QuoteMeta(callExpr) 34 | } 35 | 36 | return callExpr 37 | } 38 | 39 | func ReplaceOriginFuncByFile(file, origin, target string) { 40 | output, needHandle, err := readFile(file, origin, target) 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | if needHandle { 45 | err = writeCallExprToFile(file, output) 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | log.WithFields(log.Fields{ 50 | "originCallExpr": origin, 51 | "targetCallExpr": target, 52 | }).Info("Replace function call expression successfully") 53 | 54 | // replace import statement 55 | dir, _ := os.Getwd() // get current dir, equal to "pwd", like "/Users/.../src/.../test" 56 | gopath := fmt.Sprintf("%s/src/", os.Getenv("GOPATH")) // get env "GOPATH", like "/Users/.../src/" 57 | pkgName := strings.Split(origin, ".")[0] // get package name, like "enum" 58 | pkgName = strings.TrimRight(pkgName, "\\") 59 | oldPath := strings.ReplaceAll(dir, gopath, "") // oldPath == dir - gopath, like ".../test" 60 | oldImport := fmt.Sprintf("\"github.com/PioneerIncubator/betterGo/%s\"", pkgName) 61 | newImport := fmt.Sprintf("%s \"%s/utils/%s\"", pkgName, oldPath, pkgName) 62 | replaceOriginImport(file, oldImport, newImport) 63 | 64 | } else { 65 | log.WithFields(log.Fields{ 66 | "originCallExpr": origin, 67 | }).Error("Replace function call expression failed, the expr to be replaced was not found!") 68 | } 69 | } 70 | 71 | func ReplaceOriginFuncByDir(path, origin, target string) { 72 | files := getFiles(path) 73 | for _, file := range files { 74 | log.WithFields(log.Fields{ 75 | "fileName": file, 76 | }).Info("The function call expression is being replaced") 77 | ReplaceOriginFuncByFile(file, origin, target) 78 | } 79 | } 80 | 81 | func replaceOriginImport(file, origin, target string) { 82 | origin = regexp.QuoteMeta(origin) 83 | output, needHandle, err := readFile(file, origin, target) 84 | if err != nil { 85 | log.Fatal(err) 86 | } 87 | if needHandle { 88 | err = writeCallExprToFile(file, output) 89 | if err != nil { 90 | log.Fatal(err) 91 | } 92 | log.WithFields(log.Fields{ 93 | "originImportStmt": origin, 94 | "targetImportStmt": target, 95 | }).Info("Replace import statement successfully") 96 | } else { 97 | log.WithFields(log.Fields{ 98 | "originImportStmt": origin, 99 | }).Error("Replace import statement failed, the stmt to be replaced was not found!") 100 | } 101 | } 102 | 103 | // Read the file line by line to match origin and replace by target 104 | func readFile(filePath, origin, target string) ([]byte, bool, error) { 105 | f, err := os.OpenFile(filePath, os.O_RDONLY, 0644) 106 | if err != nil { 107 | return nil, false, err 108 | } 109 | defer func() { 110 | err := f.Close() 111 | if err != nil { 112 | log.Error(err) 113 | } 114 | }() 115 | reader := bufio.NewReader(f) 116 | needHandle := false 117 | output := make([]byte, 0) 118 | for { 119 | line, _, err := reader.ReadLine() 120 | if err != nil { 121 | if err == io.EOF { 122 | return output, needHandle, nil 123 | } 124 | return nil, needHandle, err 125 | } 126 | 127 | if ok, _ := regexp.Match(origin, line); ok { 128 | log.WithFields(log.Fields{ 129 | "filePath": filePath, 130 | "statement": origin, 131 | }).Info("The statement was found from the file") 132 | reg := regexp.MustCompile(origin) 133 | newByte := reg.ReplaceAll(line, []byte(target)) 134 | output = append(output, newByte...) 135 | output = append(output, []byte("\n")...) 136 | if !needHandle { 137 | needHandle = true 138 | } 139 | } else { 140 | output = append(output, line...) 141 | output = append(output, []byte("\n")...) 142 | } 143 | } 144 | } 145 | 146 | // Write target function calling statement to the file 147 | func writeCallExprToFile(filePath string, input []byte) error { 148 | f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_TRUNC, 0600) 149 | if err != nil { 150 | log.Fatal(err) 151 | } 152 | defer func() { 153 | err := f.Close() 154 | if err != nil { 155 | log.Error(err) 156 | } 157 | }() 158 | writer := bufio.NewWriter(f) 159 | _, err = writer.Write(input) 160 | if err != nil { 161 | return err 162 | } 163 | err = writer.Flush() 164 | if err != nil { 165 | log.Fatal(err) 166 | } 167 | return nil 168 | } 169 | 170 | func getFiles(path string) []string { 171 | files := make([]string, 0) 172 | err := filepath.Walk(path, func(path string, f os.FileInfo, err error) error { 173 | if f == nil { 174 | return err 175 | } 176 | if f.IsDir() { 177 | return nil 178 | } 179 | files = append(files, path) 180 | return nil 181 | }) 182 | if err != nil { 183 | log.WithFields(log.Fields{ 184 | "err": err, 185 | }).Error("Error getting file from directory") 186 | } 187 | return files 188 | } 189 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/parser" 7 | "go/token" 8 | "os" 9 | "strings" 10 | 11 | "github.com/PioneerIncubator/betterGo/fileoperations" 12 | log "github.com/sirupsen/logrus" 13 | 14 | "github.com/PioneerIncubator/betterGo/translator" 15 | "github.com/urfave/cli/v2" 16 | "golang.org/x/tools/go/ast/astutil" 17 | ) 18 | 19 | func replaceOriginFunc(fset *token.FileSet, ret *ast.CallExpr, callFunExpr, newFunName, filePath string, isDir bool) { 20 | s := strings.Split(callFunExpr, ".") 21 | pkgName := s[0] 22 | newFunName = fmt.Sprintf("%s.%s", pkgName, newFunName) 23 | _, args, _ := translator.ExtractParamsTypeAndName(fset, ret.Args) 24 | 25 | originStr := fileoperations.GenCallExpr(callFunExpr, translator.GetAssertType(), args, false) 26 | targetStr := fileoperations.GenCallExpr(newFunName, translator.GetAssertType(), args, true) 27 | 28 | filePath = fmt.Sprintf("./%s", filePath) 29 | if !isDir { 30 | fileoperations.ReplaceOriginFuncByFile(filePath, originStr, targetStr) 31 | } else { 32 | fileoperations.ReplaceOriginFuncByDir(filePath, originStr, targetStr) 33 | } 34 | } 35 | 36 | func genTargetFuncImplement(fset *token.FileSet, ret *ast.CallExpr, callFunExpr, funDeclStr string) (bool, string) { 37 | s := strings.Split(callFunExpr, ".") 38 | pkgName := s[0] 39 | funName := s[1] 40 | genFilePath := fmt.Sprintf("./utils/%s", pkgName) 41 | genFileName := fmt.Sprintf("%s.go", funName) 42 | genFileName = strings.ToLower(genFileName) 43 | filePath := fmt.Sprintf("%s/%s", genFilePath, genFileName) 44 | 45 | _, _, listOfArgTypes := translator.ExtractParamsTypeAndName(fset, ret.Args) 46 | funcExists, previousFuncName := fileoperations.CheckFuncExists(filePath, listOfArgTypes) 47 | if funcExists { 48 | return true, previousFuncName 49 | } 50 | 51 | buffer := []byte(fmt.Sprintf("\n%s", funDeclStr)) 52 | pkgStatement := fmt.Sprintf("package %s", pkgName) 53 | err := fileoperations.WriteFuncToFile(filePath, pkgStatement, buffer) 54 | if err != nil { 55 | log.Fatal(err) 56 | } 57 | 58 | return false, previousFuncName 59 | } 60 | 61 | // func isFunction() { 62 | 63 | // } 64 | 65 | func loopASTNode(fset *token.FileSet, node *ast.File, filePath string, isDir, rewriteAndGen bool) { 66 | for _, f := range node.Decls { 67 | // fmt.Println("loop node.Decls") 68 | // find a function declaration. 69 | fn, ok := f.(*ast.FuncDecl) 70 | if !ok { 71 | continue 72 | } 73 | astutil.Apply(fn, func(cr *astutil.Cursor) bool { 74 | n := cr.Node() 75 | switch ret := n.(type) { 76 | case *ast.GenDecl: 77 | log.WithFields(log.Fields{ 78 | "astNodeType": ret, 79 | }).Info("Find a general declaration") 80 | case *ast.ValueSpec: 81 | log.WithFields(log.Fields{ 82 | "astNodeType": ret, 83 | }).Info("Find a constant or variable declaration") 84 | translator.RecordDeclVarType(fset, ret) 85 | case *ast.AssignStmt: 86 | log.WithFields(log.Fields{ 87 | "astNodeType": ret, 88 | }).Info("Find an assign statement") 89 | if ret.Tok == token.DEFINE { // a := 12 90 | translator.RecordAssignVarType(fset, ret) 91 | } 92 | case *ast.FuncDecl: 93 | if ret.Name.Name != "main" { 94 | log.WithFields(log.Fields{ 95 | "astNodeType": ret, 96 | "funcName": ret.Name.Name, 97 | }).Info("Find a function declaration") 98 | translator.GetFuncType(fset, ret) 99 | } 100 | case *ast.TypeAssertExpr: 101 | //TODO: expr lik out := enum.Reduce(a, mul, 1).(int) 102 | // Assert is parse before function call 103 | // which means we 'll parse (int) then enum.Reduce 104 | log.WithFields(log.Fields{ 105 | "astNodeType": ret, 106 | }).Info("Find a type assert expression") 107 | assertType := translator.GetExprStr(fset, ret.Type) 108 | translator.RecordAssertType(assertType) 109 | case *ast.CallExpr: 110 | funName := translator.GetExprStr(fset, ret.Fun) 111 | // fmt.Println("[CallExpr] funName", funName) 112 | if strings.Contains(funName, "enum") { 113 | newFunName, funDeclStr := translator.GenEnumFunctionDecl(fset, funName, ret.Args) 114 | log.WithFields(log.Fields{ 115 | "astNodeType": ret, 116 | "newFunName": newFunName, 117 | "newFunDeclStmt": funDeclStr, 118 | }).Info("Find a call expression") 119 | 120 | if rewriteAndGen { 121 | // Generate function to file 122 | funcExists, prevFuncName := genTargetFuncImplement(fset, ret, funName, funDeclStr) 123 | 124 | // Replace origin function call expression 125 | if funcExists { 126 | replaceOriginFunc(fset, ret, funName, prevFuncName, filePath, isDir) 127 | } else { 128 | replaceOriginFunc(fset, ret, funName, newFunName, filePath, isDir) 129 | } 130 | } 131 | } 132 | } 133 | return true 134 | }, nil) 135 | } 136 | } 137 | 138 | func loopASTFile(filePath string, rewriteAndGen bool) { 139 | fset := token.NewFileSet() 140 | node, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments) 141 | if err != nil { 142 | log.WithFields(log.Fields{ 143 | "filePath": filePath, 144 | "err": err, 145 | }).Fatal("Parse file fail") 146 | } 147 | loopASTNode(fset, node, filePath, false, rewriteAndGen) 148 | } 149 | 150 | func loopASTDir(filePath string, rewriteAndGen bool) { 151 | fset := token.NewFileSet() 152 | pkgs, err := parser.ParseDir(fset, filePath, nil, parser.ParseComments) 153 | if err != nil { 154 | log.WithFields(log.Fields{ 155 | "dirPath": filePath, 156 | "err": err, 157 | }).Fatal("Parse dir fail") 158 | } 159 | for _, v := range pkgs { 160 | for _, fileNode := range v.Files { 161 | loopASTNode(fset, fileNode, filePath, true, rewriteAndGen) 162 | } 163 | } 164 | 165 | } 166 | 167 | func main() { 168 | log.SetFormatter(&log.TextFormatter{}) 169 | log.SetOutput(os.Stdout) 170 | log.SetLevel(log.InfoLevel) 171 | 172 | app := &cli.App{ 173 | Flags: []cli.Flag{ 174 | &cli.StringFlag{ 175 | Name: "file", 176 | Aliases: []string{"f"}, 177 | Usage: "Generate and replace the file with Enum files", 178 | }, 179 | &cli.StringFlag{ 180 | Name: "dir", 181 | Aliases: []string{"d"}, 182 | Usage: "Generate and replace the dirctory with Enum files", 183 | }, 184 | &cli.BoolFlag{ 185 | Name: "rewrite&gen", 186 | Aliases: []string{"w"}, 187 | Value: false, 188 | Usage: "Rewrite files and generate files", 189 | }, 190 | }, 191 | Action: func(c *cli.Context) error { 192 | rewriteAndGen := c.Bool("rewrite&gen") 193 | if c.String("file") != "" { 194 | loopASTFile(c.String("file"), rewriteAndGen) 195 | return nil 196 | } 197 | if c.String("dir") != "" { 198 | loopASTDir(c.String("dir"), rewriteAndGen) 199 | return nil 200 | } 201 | 202 | log.Fatal("file or dir flag empty") 203 | return nil 204 | }, 205 | } 206 | 207 | err := app.Run(os.Args) 208 | if err != nil { 209 | log.Fatal(err) 210 | } 211 | } 212 | --------------------------------------------------------------------------------