├── .gitignore ├── Makefile ├── .bumpversion.cfg ├── make ├── go.mod ├── .github └── workflows │ └── ci.yaml ├── README.md ├── examples ├── optionalargument │ └── example_test.go ├── basic │ └── example_test.go ├── hook │ └── example_test.go └── cleanup │ └── example_test.go ├── LICENSE ├── go.sum ├── export_test.go ├── di.go └── di_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | /.cache 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | include /ci/Makefile 2 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.3.0 3 | commit = True 4 | tag = True 5 | -------------------------------------------------------------------------------- /make: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | eval "$(curl -SsLf https://github.com/go-tk/ci/raw/v1/make.sh || echo "exit ${?}")" 4 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-tk/di 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/go-tk/testcase v0.8.0 7 | github.com/stretchr/testify v1.8.0 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | gopkg.in/yaml.v3 v3.0.1 // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | defaults: 3 | run: 4 | shell: bash 5 | on: 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - run: >- 18 | ./make 19 | POST_GEN='git diff --exit-code' 20 | POST_FMT='git diff --exit-code' 21 | POST_FMTMOD='git diff --exit-code' 22 | TEST_FLAGS='-race -coverprofile=coverage.txt' 23 | - uses: codecov/codecov-action@v2 24 | with: 25 | files: coverage.txt 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # di 2 | 3 | [![GoDev](https://pkg.go.dev/badge/golang.org/x/pkgsite.svg)](https://pkg.go.dev/github.com/go-tk/di) 4 | [![Workflow Status](https://github.com/go-tk/di/actions/workflows/ci.yaml/badge.svg?branch=main)](https://github.com/go-tk/di/actions/workflows/ci.yaml?query=branch%3Amain) 5 | [![Coverage Status](https://codecov.io/gh/go-tk/di/branch/main/graph/badge.svg)](https://codecov.io/gh/go-tk/di/branch/main) 6 | 7 | Tiny dependency injection framework 8 | 9 | ## Usage 10 | 11 | Looking into the examples is a easy way to get the idea of what is the functionality this library provides. 12 | 13 | ## Examples 14 | 15 | - [Basic](examples/basic/example_test.go) 16 | - [Optional Argument](examples/optionalargument/example_test.go) 17 | - [Cleanup](examples/cleanup/example_test.go) 18 | - [Hook](examples/hook/example_test.go) 19 | -------------------------------------------------------------------------------- /examples/optionalargument/example_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/go-tk/di" 8 | ) 9 | 10 | func Example() { 11 | var program di.Program 12 | 13 | substractFooWithBar(&program) 14 | provideFoo(&program) 15 | // NOTE: Program will rearrange Functions properly basing on dependency analysis. 16 | 17 | defer program.Clean() 18 | program.MustRun(context.Background()) 19 | // Output: 20 | // foo = 100 21 | // foo - bar = 99 22 | } 23 | 24 | func provideFoo(program *di.Program) { 25 | var foo int 26 | program.MustNewFunction( 27 | di.Result("FOO", &foo), 28 | di.Body(func(context.Context) error { 29 | foo = 100 30 | fmt.Printf("foo = %d\n", foo) 31 | return nil 32 | }), 33 | ) 34 | } 35 | 36 | func substractFooWithBar(program *di.Program) { 37 | var ( 38 | foo int 39 | bar int = 1 40 | ) 41 | program.MustNewFunction( 42 | di.Argument("FOO", &foo), 43 | di.OptionalArgument("BAR", &bar), 44 | di.Body(func(context.Context) error { 45 | fmt.Printf("foo - bar = %d\n", foo-bar) 46 | return nil 47 | }), 48 | ) 49 | } 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Roy O'Young 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/go-tk/testcase v0.7.1 h1:NuAU98179W2KKfawEfAJm7Wp5tSswPr+vwYgR5K1GeY= 5 | github.com/go-tk/testcase v0.7.1/go.mod h1:rDUZ94OdR2u4H2yp59RYxUZpDUrTiuhhFmVUzsayXgo= 6 | github.com/go-tk/testcase v0.8.0 h1:vsytisbYFz/sMwSGx9zzzKmtmdN/8mTeBrVgExfqfE0= 7 | github.com/go-tk/testcase v0.8.0/go.mod h1:hYePj/cathPqoq9Ys68BmDtflQtmtjdFUCRF/vYjffM= 8 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 9 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 10 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 11 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 12 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 13 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 14 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 15 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 16 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 17 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 18 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 19 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 20 | -------------------------------------------------------------------------------- /examples/basic/example_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/go-tk/di" 8 | ) 9 | 10 | func Example() { 11 | var program di.Program 12 | 13 | provideBaz(&program) 14 | provideFoo(&program) 15 | showAll(&program) 16 | provideBar(&program) 17 | // NOTE: Program will rearrange Functions properly basing on dependency analysis. 18 | 19 | defer program.Clean() 20 | program.MustRun(context.Background()) 21 | // Output: 22 | // foo = 100 23 | // bar = 200 24 | // baz = 300 25 | // foo, bar, baz = 100, 200, 300 26 | } 27 | 28 | func provideFoo(program *di.Program) { 29 | var foo int 30 | program.MustNewFunction( 31 | di.Result("FOO", &foo), 32 | di.Body(func(context.Context) error { 33 | foo = 100 34 | fmt.Printf("foo = %d\n", foo) 35 | return nil 36 | }), 37 | ) 38 | } 39 | 40 | func provideBar(program *di.Program) { 41 | var ( 42 | foo int 43 | bar int 44 | ) 45 | program.MustNewFunction( 46 | di.Argument("FOO", &foo), 47 | di.Result("BAR", &bar), 48 | di.Body(func(context.Context) error { 49 | bar = foo * 2 50 | fmt.Printf("bar = %d\n", bar) 51 | return nil 52 | }), 53 | ) 54 | } 55 | 56 | func provideBaz(program *di.Program) { 57 | var ( 58 | foo int 59 | bar int 60 | baz int 61 | ) 62 | program.MustNewFunction( 63 | di.Argument("FOO", &foo), 64 | di.Argument("BAR", &bar), 65 | di.Result("BAZ", &baz), 66 | di.Body(func(context.Context) error { 67 | baz = foo + bar 68 | fmt.Printf("baz = %d\n", baz) 69 | return nil 70 | }), 71 | ) 72 | } 73 | 74 | func showAll(program *di.Program) { 75 | var ( 76 | foo int 77 | bar int 78 | baz int 79 | ) 80 | program.MustNewFunction( 81 | di.Argument("FOO", &foo), 82 | di.Argument("BAR", &bar), 83 | di.Argument("BAZ", &baz), 84 | di.Body(func(context.Context) error { 85 | fmt.Printf("foo, bar, baz = %d, %d, %d\n", foo, bar, baz) 86 | return nil 87 | }), 88 | ) 89 | } 90 | -------------------------------------------------------------------------------- /examples/hook/example_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/go-tk/di" 9 | ) 10 | 11 | func Example() { 12 | var program di.Program 13 | 14 | showUserNameList(&program) 15 | modifyUserNameList(&program) 16 | provideUserNameList(&program) 17 | provideAdditionalUserName(&program) 18 | // NOTE: Program will rearrange Functions properly basing on dependency analysis. 19 | 20 | defer program.Clean() 21 | program.MustRun(context.Background()) 22 | // Output: 23 | // user name list: tom,jeff,spike 24 | } 25 | 26 | func provideUserNameList(program *di.Program) { 27 | var userNameList []string 28 | program.MustNewFunction( 29 | di.Result("USER_NAME_LIST", &userNameList), 30 | di.Body(func(context.Context) error { 31 | userNameList = []string{"tom", "jeff"} 32 | return nil 33 | }), 34 | ) 35 | } 36 | 37 | func provideAdditionalUserName(program *di.Program) { 38 | var additionalUserName string 39 | program.MustNewFunction( 40 | di.Result("ADDITIONAL_USER_NAME", &additionalUserName), 41 | di.Body(func(context.Context) error { 42 | additionalUserName = "spike" 43 | return nil 44 | }), 45 | ) 46 | } 47 | 48 | func showUserNameList(program *di.Program) { 49 | var userNameList []string 50 | program.MustNewFunction( 51 | di.Argument("USER_NAME_LIST", &userNameList), 52 | di.Body(func(context.Context) error { 53 | fmt.Printf("user name list: %v\n", strings.Join(userNameList, ",")) 54 | return nil 55 | }), 56 | ) 57 | } 58 | 59 | func modifyUserNameList(program *di.Program) { 60 | var ( 61 | additionalUserName string 62 | userNameList *[]string 63 | ) 64 | program.MustNewFunction( 65 | di.Argument("ADDITIONAL_USER_NAME", &additionalUserName), 66 | di.Body(func(context.Context) error { return nil }), 67 | di.Hook("USER_NAME_LIST", &userNameList, func(context.Context) error { 68 | *userNameList = append(*userNameList, additionalUserName) 69 | return nil 70 | }), 71 | ) 72 | } 73 | -------------------------------------------------------------------------------- /examples/cleanup/example_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/go-tk/di" 11 | ) 12 | 13 | func Example() { 14 | var program di.Program 15 | 16 | writeTempFile(&program) 17 | provideTempDirName(&program) 18 | provideTempFile(&program) 19 | // NOTE: Program will rearrange Functions properly basing on dependency analysis. 20 | 21 | defer program.Clean() 22 | program.MustRun(context.Background()) 23 | // Output: 24 | // 1. create temp dir 25 | // 2. create and open temp file 26 | // 3. write temp file 27 | // 4. close and delete temp file 28 | // 5. delete temp dir 29 | } 30 | 31 | func provideTempDirName(program *di.Program) { 32 | var tempDirName string 33 | program.MustNewFunction( 34 | di.Result("TEMP_DIR_NAME", &tempDirName), 35 | di.Body(func(context.Context) error { 36 | fmt.Println("1. create temp dir") 37 | var err error 38 | tempDirName, err = ioutil.TempDir("", "") 39 | return err 40 | }), 41 | di.Cleanup(func() { 42 | fmt.Println("5. delete temp dir") 43 | os.Remove(tempDirName) 44 | }), 45 | ) 46 | } 47 | 48 | func provideTempFile(program *di.Program) { 49 | var ( 50 | tempDirName string 51 | tempFile *os.File 52 | ) 53 | program.MustNewFunction( 54 | di.Argument("TEMP_DIR_NAME", &tempDirName), 55 | di.Result("TEMP_FILE", &tempFile), 56 | di.Body(func(context.Context) error { 57 | fmt.Println("2. create and open temp file") 58 | var err error 59 | tempFile, err = os.Create(filepath.Join(tempDirName, "temp")) 60 | return err 61 | }), 62 | di.Cleanup(func() { 63 | fmt.Println("4. close and delete temp file") 64 | tempFile.Close() 65 | os.Remove(tempFile.Name()) 66 | }), 67 | ) 68 | } 69 | 70 | func writeTempFile(program *di.Program) { 71 | var tempFile *os.File 72 | program.MustNewFunction( 73 | di.Argument("TEMP_FILE", &tempFile), 74 | di.Body(func(context.Context) error { 75 | fmt.Println("3. write temp file") 76 | _, err := tempFile.WriteString("hello world") 77 | return err 78 | }), 79 | ) 80 | } 81 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | type Function = function 10 | 11 | func (p *Program) Dump(buffer *bytes.Buffer) { 12 | for i := range p.functions { 13 | function := &p.functions[i] 14 | p.dumpFunction(i, function, buffer) 15 | } 16 | for i := range p.arguments { 17 | argument := &p.arguments[i] 18 | p.dumpArgument(i, argument, buffer) 19 | } 20 | for i := range p.results { 21 | result := &p.results[i] 22 | p.dumpResult(i, result, buffer) 23 | } 24 | for i := range p.hooks { 25 | hook := &p.hooks[i] 26 | p.dumpHook(i, hook, buffer) 27 | } 28 | fmt.Fprintf(buffer, "SortedFunctionIndexes: %v\n", p.sortedFunctionIndexes) 29 | fmt.Fprintf(buffer, "CalledFunctionCount: %v\n", p.calledFunctionCount) 30 | } 31 | 32 | func (p *Program) dumpFunction(i int, function *function, buffer *bytes.Buffer) { 33 | fmt.Fprintf(buffer, "Function[%d]:\n", i) 34 | fmt.Fprintf(buffer, "\tIndex: %v\n", function.Index) 35 | fmt.Fprintf(buffer, "\tName: %v\n", function.Name) 36 | fmt.Fprintf(buffer, "\tArgumentIndexes: %v\n", function.ArgumentIndexes) 37 | fmt.Fprintf(buffer, "\tResultIndexes: %v\n", function.ResultIndexes) 38 | fmt.Fprintf(buffer, "\tHasBody: %v\n", function.Body != nil) 39 | fmt.Fprintf(buffer, "\tHookIndexes: %v\n", function.HookIndexes) 40 | fmt.Fprintf(buffer, "\tHasCleanup: %v\n", function.Cleanup != nil) 41 | } 42 | 43 | func (p *Program) dumpArgument(i int, argument *argument, buffer *bytes.Buffer) { 44 | fmt.Fprintf(buffer, "Argument[%d]:\n", i) 45 | fmt.Fprintf(buffer, "\tFunctionIndex: %v\n", argument.FunctionIndex) 46 | fmt.Fprintf(buffer, "\tValueRef: %v\n", argument.ValueRef) 47 | fmt.Fprintf(buffer, "\tHasValueReceiver: %v\n", argument.ValueReceiver != (reflect.Value{})) 48 | fmt.Fprintf(buffer, "\tIsOptional: %v\n", argument.IsOptional) 49 | fmt.Fprintf(buffer, "\tResultIndex: %v\n", argument.ResultIndex) 50 | fmt.Fprintf(buffer, "\tReceiveValueAddr: %v\n", argument.ReceiveValueAddr) 51 | } 52 | 53 | func (p *Program) dumpResult(i int, result *result, buffer *bytes.Buffer) { 54 | fmt.Fprintf(buffer, "Result[%d]:\n", i) 55 | fmt.Fprintf(buffer, "\tFunctionIndex: %v\n", result.FunctionIndex) 56 | fmt.Fprintf(buffer, "\tValueName: %v\n", result.ValueName) 57 | fmt.Fprintf(buffer, "\tHasValue: %v\n", result.Value != (reflect.Value{})) 58 | fmt.Fprintf(buffer, "\tHookIndexes: %v\n", result.HookIndexes) 59 | } 60 | 61 | func (p *Program) dumpHook(i int, hook *hook, buffer *bytes.Buffer) { 62 | fmt.Fprintf(buffer, "Hook[%d]:\n", i) 63 | fmt.Fprintf(buffer, "\tFunctionIndex: %v\n", hook.FunctionIndex) 64 | fmt.Fprintf(buffer, "\tValueRef: %v\n", hook.ValueRef) 65 | fmt.Fprintf(buffer, "\tHasValueReceiver: %v\n", hook.ValueReceiver != (reflect.Value{})) 66 | fmt.Fprintf(buffer, "\tHasCallback: %v\n", hook.Callback != nil) 67 | fmt.Fprintf(buffer, "\tReceiveValueAddr: %v\n", hook.ReceiveValueAddr) 68 | } 69 | 70 | func (p *Program) DumpAsString() string { 71 | var buffer bytes.Buffer 72 | p.Dump(&buffer) 73 | return buffer.String() 74 | } 75 | -------------------------------------------------------------------------------- /di.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "runtime" 9 | "strings" 10 | ) 11 | 12 | // Program consists of DI Functions which are containers for dependency injection. 13 | type Program struct { 14 | functions []function 15 | arguments []argument 16 | results []result 17 | hooks []hook 18 | sortedFunctionIndexes []int 19 | calledFunctionCount int 20 | } 21 | 22 | type function struct { 23 | Index int 24 | Name string 25 | ArgumentIndexes []int 26 | ResultIndexes []int 27 | Body func(context.Context) error 28 | HookIndexes []int 29 | Cleanup func() 30 | } 31 | 32 | // FunctionBuilder is the type of function that constructs a DI Function. 33 | type FunctionBuilder func(function *function, program *Program) (err error) 34 | 35 | // NewFunction add a DI Function into the Program. 36 | func (p *Program) NewFunction(functionBuilders ...FunctionBuilder) error { 37 | pc, _, _, _ := runtime.Caller(1) 38 | functionName := runtime.FuncForPC(pc).Name() 39 | return p.doNewFunction(functionName, functionBuilders...) 40 | } 41 | 42 | // MustNewFunction likes NewFunction but panics when an error occurs. 43 | func (p *Program) MustNewFunction(functionBuilders ...FunctionBuilder) { 44 | pc, _, _, _ := runtime.Caller(1) 45 | functionName := runtime.FuncForPC(pc).Name() 46 | if err := p.doNewFunction(functionName, functionBuilders...); err != nil { 47 | panic(fmt.Sprintf("new function: %v", err)) 48 | } 49 | } 50 | 51 | func (p *Program) doNewFunction(functionName string, functionBuilders ...FunctionBuilder) (returnedErr error) { 52 | functionIndex := len(p.functions) 53 | p.functions = append(p.functions, function{Index: functionIndex}) 54 | defer func() { 55 | if returnedErr != nil { 56 | p.functions = p.functions[:functionIndex] 57 | } 58 | }() 59 | function := &p.functions[functionIndex] 60 | function.Name = functionName 61 | for _, functionBuilder := range functionBuilders { 62 | if err := functionBuilder(function, p); err != nil { 63 | return err 64 | } 65 | } 66 | if function.Body == nil { 67 | return fmt.Errorf("%w; functionName=%q", ErrBodyRequired, functionName) 68 | } 69 | return nil 70 | } 71 | 72 | // ErrBodyRequired is returned by Program.NewFunction() when no body is specified. 73 | var ErrBodyRequired = errors.New("di: body required") 74 | 75 | type argument struct { 76 | FunctionIndex int 77 | ValueRef string 78 | ValueReceiver reflect.Value 79 | IsOptional bool 80 | ResultIndex int 81 | ReceiveValueAddr bool 82 | } 83 | 84 | // Argument specifies an argument for a DI Function. 85 | func Argument(valueRef string, rawValueReceiverPtr interface{}) FunctionBuilder { 86 | return argument1(valueRef, rawValueReceiverPtr, false) 87 | } 88 | 89 | // OptionalArgument specifies an optional argument for a DI Function. 90 | func OptionalArgument(valueRef string, rawValueReceiverPtr interface{}) FunctionBuilder { 91 | return argument1(valueRef, rawValueReceiverPtr, true) 92 | } 93 | 94 | func argument1(valueRef string, rawValueReceiverPtr interface{}, isOptional bool) FunctionBuilder { 95 | return func(function *function, program *Program) error { 96 | if valueRef == "" { 97 | return fmt.Errorf("%w: empty value ref; functionName=%q", ErrInvalidArgument, function.Name) 98 | } 99 | if rawValueReceiverPtr == nil { 100 | return fmt.Errorf("%w: no value receiver; functionName=%q valueRef=%q", ErrInvalidArgument, function.Name, valueRef) 101 | } 102 | valueReceiverPtr := reflect.ValueOf(rawValueReceiverPtr) 103 | if valueReceiverPtr.Kind() != reflect.Ptr { 104 | return fmt.Errorf("%w: invalid value receiver pointer; valueReceiverPtrType=%q functionName=%q valueRef=%q", 105 | ErrInvalidArgument, valueReceiverPtr.Type(), function.Name, valueRef) 106 | } 107 | if valueReceiverPtr.IsNil() { 108 | return fmt.Errorf("%w: no value receiver; functionName=%q valueRef=%q", ErrInvalidArgument, function.Name, valueRef) 109 | } 110 | argumentIndex := len(program.arguments) 111 | program.arguments = append(program.arguments, argument{}) 112 | argument := &program.arguments[argumentIndex] 113 | argument.FunctionIndex = function.Index 114 | argument.ValueRef = valueRef 115 | argument.ValueReceiver = valueReceiverPtr.Elem() 116 | argument.IsOptional = isOptional 117 | argument.ResultIndex = -1 118 | function.ArgumentIndexes = append(function.ArgumentIndexes, argumentIndex) 119 | return nil 120 | } 121 | } 122 | 123 | // ErrInvalidArgument is returned by Program.NewFunction() when an invalid argument is specified. 124 | var ErrInvalidArgument = errors.New("di: invalid augment") 125 | 126 | type result struct { 127 | FunctionIndex int 128 | ValueName string 129 | Value reflect.Value 130 | HookIndexes []int 131 | } 132 | 133 | // Result specifies a result for a DI Function. 134 | func Result(valueName string, rawValuePtr interface{}) FunctionBuilder { 135 | return func(function *function, program *Program) error { 136 | if valueName == "" { 137 | return fmt.Errorf("%w: empty value name; functionName=%q", ErrInvalidResult, function.Name) 138 | } 139 | if rawValuePtr == nil { 140 | return fmt.Errorf("%w: no value; functionName=%q valueName=%q", ErrInvalidResult, function.Name, valueName) 141 | } 142 | valuePtr := reflect.ValueOf(rawValuePtr) 143 | if valuePtr.Kind() != reflect.Ptr { 144 | return fmt.Errorf("%w: invalid value pointer; valuePtrType=%q functionName=%q valueName=%q", 145 | ErrInvalidResult, valuePtr.Type(), function.Name, valueName) 146 | } 147 | if valuePtr.IsNil() { 148 | return fmt.Errorf("%w: no value; functionName=%q valueName=%q", ErrInvalidResult, function.Name, valueName) 149 | } 150 | resultIndex := len(program.results) 151 | program.results = append(program.results, result{}) 152 | result := &program.results[resultIndex] 153 | result.FunctionIndex = function.Index 154 | result.ValueName = valueName 155 | result.Value = valuePtr.Elem() 156 | function.ResultIndexes = append(function.ResultIndexes, resultIndex) 157 | return nil 158 | } 159 | } 160 | 161 | // ErrInvalidResult is returned by Program.NewFunction() when an invalid result is specified. 162 | var ErrInvalidResult = errors.New("di: invalid result") 163 | 164 | // Body specifies the body for a DI Function. 165 | func Body(body func(context.Context) error) FunctionBuilder { 166 | return func(function *function, program *Program) error { 167 | if body == nil { 168 | return fmt.Errorf("%w; functionName=%q", ErrNilBody, function.Name) 169 | } 170 | function.Body = body 171 | return nil 172 | } 173 | } 174 | 175 | // ErrNilBody is returned by Program.NewFunction() when nil body is specified. 176 | var ErrNilBody = errors.New("di: nil body") 177 | 178 | // Cleanup specifies the cleanup for a DI Function. 179 | func Cleanup(cleanup func()) FunctionBuilder { 180 | return func(function *function, program *Program) error { 181 | if cleanup == nil { 182 | return fmt.Errorf("%w; functionName=%q", ErrNilCleanup, function.Name) 183 | } 184 | function.Cleanup = cleanup 185 | return nil 186 | } 187 | } 188 | 189 | // ErrNilCleanup is returned by Program.NewFunction() when nil cleanup is specified. 190 | var ErrNilCleanup = errors.New("di: nil cleanup") 191 | 192 | type hook struct { 193 | FunctionIndex int 194 | ValueRef string 195 | ValueReceiver reflect.Value 196 | Callback func(context.Context) error 197 | ReceiveValueAddr bool 198 | } 199 | 200 | // Hook specifies a hook for a DI Function. 201 | func Hook(valueRef string, rawValueReceiverPtr interface{}, callback func(context.Context) error) FunctionBuilder { 202 | return func(function *function, program *Program) error { 203 | if valueRef == "" { 204 | return fmt.Errorf("%w: empty value ref; functionName=%q", ErrInvalidHook, function.Name) 205 | } 206 | if rawValueReceiverPtr == nil { 207 | return fmt.Errorf("%w: no value receiver; functionName=%q valueRef=%q", ErrInvalidHook, function.Name, valueRef) 208 | } 209 | valueReceiverPtr := reflect.ValueOf(rawValueReceiverPtr) 210 | if valueReceiverPtr.Kind() != reflect.Ptr { 211 | return fmt.Errorf("%w: invalid value receiver pointer; valueReceiverPtrType=%q functionName=%q valueRef=%q", 212 | ErrInvalidHook, valueReceiverPtr.Type(), function.Name, valueRef) 213 | } 214 | if valueReceiverPtr.IsNil() { 215 | return fmt.Errorf("%w: no value receiver; functionName=%q valueRef=%q", ErrInvalidHook, function.Name, valueRef) 216 | } 217 | if callback == nil { 218 | return fmt.Errorf("%w: nil callback; functionName=%q valueRef=%q", ErrInvalidHook, function.Name, valueRef) 219 | } 220 | hookIndex := len(program.hooks) 221 | program.hooks = append(program.hooks, hook{}) 222 | hook := &program.hooks[hookIndex] 223 | hook.FunctionIndex = function.Index 224 | hook.ValueRef = valueRef 225 | hook.ValueReceiver = valueReceiverPtr.Elem() 226 | hook.Callback = callback 227 | function.HookIndexes = append(function.HookIndexes, hookIndex) 228 | return nil 229 | } 230 | } 231 | 232 | // ErrInvalidHook is returned by Program.NewFunction() when an invalid hook is specified. 233 | var ErrInvalidHook = errors.New("di: invalid hook") 234 | 235 | // Run calls all DI Functions added into the Program, the order in which DI Functions are to be called 236 | // is based on dependency analysis. 237 | func (p *Program) Run(ctx context.Context) error { 238 | if err := p.resolve(); err != nil { 239 | return err 240 | } 241 | if err := p.sortFunctions(); err != nil { 242 | return err 243 | } 244 | return p.callFunctions(ctx) 245 | } 246 | 247 | func (p *Program) resolve() error { 248 | valueName2ResultIndex := make(map[string]int, len(p.results)) 249 | for resultIndex := range p.results { 250 | result := &p.results[resultIndex] 251 | if resultIndex2, ok := valueName2ResultIndex[result.ValueName]; ok { 252 | result2 := &p.results[resultIndex2] 253 | return fmt.Errorf("%w; valueName=%q functionName1=%q functionName2=%q", 254 | ErrDuplicateValueName, result.ValueName, p.functions[result.FunctionIndex].Name, 255 | p.functions[result2.FunctionIndex].Name) 256 | } 257 | valueName2ResultIndex[result.ValueName] = resultIndex 258 | } 259 | for argumentIndex := range p.arguments { 260 | argument := &p.arguments[argumentIndex] 261 | resultIndex, ok := valueName2ResultIndex[argument.ValueRef] 262 | if !ok { 263 | if argument.IsOptional { 264 | continue 265 | } 266 | return fmt.Errorf("%w; valueRef=%q functionName=%q", 267 | ErrValueNotFound, argument.ValueRef, p.functions[argument.FunctionIndex].Name) 268 | } 269 | result := &p.results[resultIndex] 270 | valueType := result.Value.Type() 271 | valueReceiverType := argument.ValueReceiver.Type() 272 | if valueReceiverType == reflect.PtrTo(valueType) { 273 | argument.ReceiveValueAddr = true 274 | } else { 275 | if !valueType.AssignableTo(valueReceiverType) { 276 | return fmt.Errorf("%w; valueReceiverType=%q valueType=%q valueRef=%q functionName=%q", 277 | ErrIncompatibleValueReceiver, valueReceiverType, valueType, argument.ValueRef, 278 | p.functions[argument.FunctionIndex].Name) 279 | } 280 | } 281 | argument.ResultIndex = resultIndex 282 | } 283 | for hookIndex := range p.hooks { 284 | hook := &p.hooks[hookIndex] 285 | resultIndex, ok := valueName2ResultIndex[hook.ValueRef] 286 | if !ok { 287 | return fmt.Errorf("%w; valueRef=%q functionName=%q", 288 | ErrValueNotFound, hook.ValueRef, p.functions[hook.FunctionIndex].Name) 289 | } 290 | result := &p.results[resultIndex] 291 | valueType := result.Value.Type() 292 | valueReceiverType := hook.ValueReceiver.Type() 293 | if valueReceiverType == reflect.PtrTo(valueType) { 294 | hook.ReceiveValueAddr = true 295 | } else { 296 | if !valueType.AssignableTo(valueReceiverType) { 297 | return fmt.Errorf("%w; valueReceiverType=%q valueType=%q valueRef=%q functionName=%q", 298 | ErrIncompatibleValueReceiver, valueReceiverType, valueType, hook.ValueRef, 299 | p.functions[hook.FunctionIndex].Name) 300 | } 301 | } 302 | result.HookIndexes = append(result.HookIndexes, hookIndex) 303 | } 304 | return nil 305 | } 306 | 307 | func (p *Program) sortFunctions() error { 308 | var walk func(*function, interface{}) bool 309 | var path []interface{} 310 | visitedFunctionIndexes := make(map[int]struct{}, len(p.functions)) 311 | walk = func(function *function, from interface{}) bool { 312 | functionIndex := function.Index 313 | if _, ok := visitedFunctionIndexes[functionIndex]; ok { 314 | return true 315 | } 316 | pathLength := len(path) 317 | if from == nil { 318 | path = append(path, function) 319 | } else { 320 | path = append(path, from, function) 321 | } 322 | if functionIndex < 0 { 323 | return false 324 | } 325 | function.Index = -1 326 | for _, argumentIndex := range function.ArgumentIndexes { 327 | argument := &p.arguments[argumentIndex] 328 | if argument.ResultIndex < 0 { 329 | continue 330 | } 331 | result := &p.results[argument.ResultIndex] 332 | function2 := &p.functions[result.FunctionIndex] 333 | if !walk(function2, argument) { 334 | return false 335 | } 336 | } 337 | for _, resultIndex := range function.ResultIndexes { 338 | result := &p.results[resultIndex] 339 | for _, hookIndex := range result.HookIndexes { 340 | hook := &p.hooks[hookIndex] 341 | function2 := &p.functions[hook.FunctionIndex] 342 | if !walk(function2, hook) { 343 | return false 344 | } 345 | } 346 | } 347 | function.Index = functionIndex 348 | path = path[:pathLength] 349 | visitedFunctionIndexes[functionIndex] = struct{}{} 350 | p.sortedFunctionIndexes = append(p.sortedFunctionIndexes, functionIndex) 351 | return true 352 | } 353 | dumpPath := func() string { 354 | var builder strings.Builder 355 | n := len(path) 356 | for i, j := 0, n-1; i < j; i += 2 { 357 | function, from := path[i].(*function), path[i+1] 358 | switch from := from.(type) { 359 | case *argument: 360 | builder.WriteString(fmt.Sprintf("%s@argument:%s => ", function.Name, from.ValueRef)) 361 | case *hook: 362 | builder.WriteString(fmt.Sprintf("%s@hook:%s => ", function.Name, from.ValueRef)) 363 | default: 364 | panic("unreachable code") 365 | } 366 | } 367 | function := path[n-1].(*function) 368 | builder.WriteString(function.Name) 369 | return builder.String() 370 | } 371 | for functionIndex := range p.functions { 372 | function := &p.functions[functionIndex] 373 | if !walk(function, nil) { 374 | return fmt.Errorf("%w; path=%q", ErrCircularDependencies, dumpPath()) 375 | } 376 | } 377 | return nil 378 | } 379 | 380 | func (p *Program) callFunctions(ctx context.Context) error { 381 | for _, functionIndex := range p.sortedFunctionIndexes { 382 | function := &p.functions[functionIndex] 383 | for _, argumentIndex := range function.ArgumentIndexes { 384 | argument := &p.arguments[argumentIndex] 385 | if argument.ResultIndex < 0 { 386 | continue 387 | } 388 | result := &p.results[argument.ResultIndex] 389 | if argument.ReceiveValueAddr { 390 | argument.ValueReceiver.Set(result.Value.Addr()) 391 | } else { 392 | argument.ValueReceiver.Set(result.Value) 393 | } 394 | } 395 | if err := function.Body(ctx); err != nil { 396 | return fmt.Errorf("call function; functionName=%q: %w", function.Name, err) 397 | } 398 | p.calledFunctionCount++ 399 | for _, resultIndex := range function.ResultIndexes { 400 | result := &p.results[resultIndex] 401 | for _, hookIndex := range result.HookIndexes { 402 | hook := &p.hooks[hookIndex] 403 | if hook.ReceiveValueAddr { 404 | hook.ValueReceiver.Set(result.Value.Addr()) 405 | } else { 406 | hook.ValueReceiver.Set(result.Value) 407 | } 408 | if err := hook.Callback(ctx); err != nil { 409 | function2 := &p.functions[hook.FunctionIndex] 410 | return fmt.Errorf("do callback; functionName=%q valueRef=%q: %w", function2.Name, hook.ValueRef, err) 411 | } 412 | } 413 | } 414 | } 415 | return nil 416 | } 417 | 418 | var ( 419 | // ErrDuplicateValueName is returned by Program.Run() when a value name used by Result() is duplicate. 420 | ErrDuplicateValueName = errors.New("di: duplicate value name") 421 | 422 | // ErrValueNotFound is returned by Program.Run() when a value used by Argument()/Hook() does not exist. 423 | ErrValueNotFound = errors.New("di: value not found") 424 | 425 | // ErrIncompatibleValueReceiver is returned by Program.Run() when a value receiver used by Argument()/Hook() is incompatible. 426 | ErrIncompatibleValueReceiver = errors.New("di: incompatible value receiver") 427 | 428 | // ErrCircularDependencies is returned by Program.Run() when circular dependencies are detected. 429 | ErrCircularDependencies = errors.New("di: circular dependencies") 430 | ) 431 | 432 | // MustRun likes Run but panics when an error occurs. 433 | func (p *Program) MustRun(ctx context.Context) { 434 | if err := p.Run(ctx); err != nil { 435 | panic(fmt.Sprintf("run program: %v", err)) 436 | } 437 | } 438 | 439 | // Clean calls cleanups of DI Functions, the order in which cleanups are to be called 440 | // is reversed to the order in which DI Functions are called. 441 | func (p *Program) Clean() { 442 | for i := p.calledFunctionCount - 1; i >= 0; i-- { 443 | functionIndex := p.sortedFunctionIndexes[i] 444 | function := &p.functions[functionIndex] 445 | if cleanup := function.Cleanup; cleanup != nil { 446 | cleanup() 447 | } 448 | } 449 | } 450 | -------------------------------------------------------------------------------- /di_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | . "github.com/go-tk/di" 8 | "github.com/go-tk/testcase" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestArgument(t *testing.T) { 13 | type C struct { 14 | functionBuilder FunctionBuilder 15 | 16 | errStr string 17 | err error 18 | repr string 19 | } 20 | tc := testcase.New(func(t *testing.T, c *C) { 21 | testcase.Callback(t, "0") 22 | 23 | var p Program 24 | err := p.NewFunction(c.functionBuilder, Body(func(context.Context) error { return nil })) 25 | if c.errStr == "" { 26 | assert.NoError(t, err) 27 | } else { 28 | assert.EqualError(t, err, c.errStr) 29 | if c.err != nil { 30 | assert.ErrorIs(t, err, c.err) 31 | } 32 | } 33 | if t.Failed() { 34 | t.FailNow() 35 | } 36 | assert.Equal(t, c.repr, p.DumpAsString()) 37 | }) 38 | 39 | tc.WithCallback("0", func(t *testing.T, c *C) { 40 | c.functionBuilder = Argument("", nil) 41 | c.err = ErrInvalidArgument 42 | c.errStr = c.err.Error() + `: empty value ref; functionName="github.com/go-tk/di_test.TestArgument.func1"` 43 | c.repr = ` 44 | SortedFunctionIndexes: [] 45 | CalledFunctionCount: 0 46 | `[1:] 47 | }).RunParallel(t) 48 | 49 | tc.WithCallback("0", func(t *testing.T, c *C) { 50 | c.functionBuilder = Argument("foo", nil) 51 | c.err = ErrInvalidArgument 52 | c.errStr = c.err.Error() + `: no value receiver; functionName="github.com/go-tk/di_test.TestArgument.func1" valueRef="foo"` 53 | c.repr = ` 54 | SortedFunctionIndexes: [] 55 | CalledFunctionCount: 0 56 | `[1:] 57 | }).RunParallel(t) 58 | 59 | tc.WithCallback("0", func(t *testing.T, c *C) { 60 | c.functionBuilder = Argument("foo", 0) 61 | c.err = ErrInvalidArgument 62 | c.errStr = c.err.Error() + `: invalid value receiver pointer; valueReceiverPtrType="int" functionName="github.com/go-tk/di_test.TestArgument.func1" valueRef="foo"` 63 | c.repr = ` 64 | SortedFunctionIndexes: [] 65 | CalledFunctionCount: 0 66 | `[1:] 67 | }).RunParallel(t) 68 | 69 | tc.WithCallback("0", func(t *testing.T, c *C) { 70 | c.functionBuilder = Argument("foo", (*string)(nil)) 71 | c.err = ErrInvalidArgument 72 | c.errStr = c.err.Error() + `: no value receiver; functionName="github.com/go-tk/di_test.TestArgument.func1" valueRef="foo"` 73 | c.repr = ` 74 | SortedFunctionIndexes: [] 75 | CalledFunctionCount: 0 76 | `[1:] 77 | }).RunParallel(t) 78 | 79 | tc.WithCallback("0", func(t *testing.T, c *C) { 80 | c.functionBuilder = Argument("foo", new(string)) 81 | c.repr = ` 82 | Function[0]: 83 | Index: 0 84 | Name: github.com/go-tk/di_test.TestArgument.func1 85 | ArgumentIndexes: [0] 86 | ResultIndexes: [] 87 | HasBody: true 88 | HookIndexes: [] 89 | HasCleanup: false 90 | Argument[0]: 91 | FunctionIndex: 0 92 | ValueRef: foo 93 | HasValueReceiver: true 94 | IsOptional: false 95 | ResultIndex: -1 96 | ReceiveValueAddr: false 97 | SortedFunctionIndexes: [] 98 | CalledFunctionCount: 0 99 | `[1:] 100 | }).RunParallel(t) 101 | 102 | tc.WithCallback("0", func(t *testing.T, c *C) { 103 | c.functionBuilder = OptionalArgument("foo", new(string)) 104 | c.repr = ` 105 | Function[0]: 106 | Index: 0 107 | Name: github.com/go-tk/di_test.TestArgument.func1 108 | ArgumentIndexes: [0] 109 | ResultIndexes: [] 110 | HasBody: true 111 | HookIndexes: [] 112 | HasCleanup: false 113 | Argument[0]: 114 | FunctionIndex: 0 115 | ValueRef: foo 116 | HasValueReceiver: true 117 | IsOptional: true 118 | ResultIndex: -1 119 | ReceiveValueAddr: false 120 | SortedFunctionIndexes: [] 121 | CalledFunctionCount: 0 122 | `[1:] 123 | }).RunParallel(t) 124 | } 125 | 126 | func TestHook(t *testing.T) { 127 | type C struct { 128 | functionBuilder FunctionBuilder 129 | 130 | errStr string 131 | err error 132 | repr string 133 | } 134 | tc := testcase.New(func(t *testing.T, c *C) { 135 | testcase.Callback(t, "0") 136 | 137 | var p Program 138 | err := p.NewFunction(c.functionBuilder, Body(func(context.Context) error { return nil })) 139 | if c.errStr == "" { 140 | assert.NoError(t, err) 141 | } else { 142 | assert.EqualError(t, err, c.errStr) 143 | if c.err != nil { 144 | assert.ErrorIs(t, err, c.err) 145 | } 146 | } 147 | if t.Failed() { 148 | t.FailNow() 149 | } 150 | assert.Equal(t, c.repr, p.DumpAsString()) 151 | }) 152 | 153 | tc.WithCallback("0", func(t *testing.T, c *C) { 154 | c.functionBuilder = Hook("", nil, nil) 155 | c.err = ErrInvalidHook 156 | c.errStr = c.err.Error() + `: empty value ref; functionName="github.com/go-tk/di_test.TestHook.func1"` 157 | c.repr = ` 158 | SortedFunctionIndexes: [] 159 | CalledFunctionCount: 0 160 | `[1:] 161 | }).RunParallel(t) 162 | 163 | tc.WithCallback("0", func(t *testing.T, c *C) { 164 | c.functionBuilder = Hook("foo", nil, nil) 165 | c.err = ErrInvalidHook 166 | c.errStr = c.err.Error() + `: no value receiver; functionName="github.com/go-tk/di_test.TestHook.func1" valueRef="foo"` 167 | c.repr = ` 168 | SortedFunctionIndexes: [] 169 | CalledFunctionCount: 0 170 | `[1:] 171 | }).RunParallel(t) 172 | 173 | tc.WithCallback("0", func(t *testing.T, c *C) { 174 | c.functionBuilder = Hook("foo", 0, nil) 175 | c.err = ErrInvalidHook 176 | c.errStr = c.err.Error() + `: invalid value receiver pointer; valueReceiverPtrType="int" functionName="github.com/go-tk/di_test.TestHook.func1" valueRef="foo"` 177 | c.repr = ` 178 | SortedFunctionIndexes: [] 179 | CalledFunctionCount: 0 180 | `[1:] 181 | }).RunParallel(t) 182 | 183 | tc.WithCallback("0", func(t *testing.T, c *C) { 184 | c.functionBuilder = Hook("foo", (*string)(nil), nil) 185 | c.err = ErrInvalidHook 186 | c.errStr = c.err.Error() + `: no value receiver; functionName="github.com/go-tk/di_test.TestHook.func1" valueRef="foo"` 187 | c.repr = ` 188 | SortedFunctionIndexes: [] 189 | CalledFunctionCount: 0 190 | `[1:] 191 | }).RunParallel(t) 192 | 193 | tc.WithCallback("0", func(t *testing.T, c *C) { 194 | c.functionBuilder = Hook("foo", new(string), nil) 195 | c.err = ErrInvalidHook 196 | c.errStr = c.err.Error() + `: nil callback; functionName="github.com/go-tk/di_test.TestHook.func1" valueRef="foo"` 197 | c.repr = ` 198 | SortedFunctionIndexes: [] 199 | CalledFunctionCount: 0 200 | `[1:] 201 | }).RunParallel(t) 202 | 203 | tc.WithCallback("0", func(t *testing.T, c *C) { 204 | c.functionBuilder = Hook("foo", new(string), func(context.Context) error { return nil }) 205 | c.repr = ` 206 | Function[0]: 207 | Index: 0 208 | Name: github.com/go-tk/di_test.TestHook.func1 209 | ArgumentIndexes: [] 210 | ResultIndexes: [] 211 | HasBody: true 212 | HookIndexes: [0] 213 | HasCleanup: false 214 | Hook[0]: 215 | FunctionIndex: 0 216 | ValueRef: foo 217 | HasValueReceiver: true 218 | HasCallback: true 219 | ReceiveValueAddr: false 220 | SortedFunctionIndexes: [] 221 | CalledFunctionCount: 0 222 | `[1:] 223 | }).RunParallel(t) 224 | } 225 | 226 | func TestResult(t *testing.T) { 227 | type C struct { 228 | functionBuilder FunctionBuilder 229 | 230 | errStr string 231 | err error 232 | repr string 233 | } 234 | tc := testcase.New(func(t *testing.T, c *C) { 235 | testcase.Callback(t, "0") 236 | 237 | var p Program 238 | err := p.NewFunction(c.functionBuilder, Body(func(context.Context) error { return nil })) 239 | if c.errStr == "" { 240 | assert.NoError(t, err) 241 | } else { 242 | assert.EqualError(t, err, c.errStr) 243 | if c.err != nil { 244 | assert.ErrorIs(t, err, c.err) 245 | } 246 | } 247 | if t.Failed() { 248 | t.FailNow() 249 | } 250 | assert.Equal(t, c.repr, p.DumpAsString()) 251 | }) 252 | 253 | tc.WithCallback("0", func(t *testing.T, c *C) { 254 | c.functionBuilder = Result("", nil) 255 | c.err = ErrInvalidResult 256 | c.errStr = c.err.Error() + `: empty value name; functionName="github.com/go-tk/di_test.TestResult.func1"` 257 | c.repr = ` 258 | SortedFunctionIndexes: [] 259 | CalledFunctionCount: 0 260 | `[1:] 261 | }).RunParallel(t) 262 | 263 | tc.WithCallback("0", func(t *testing.T, c *C) { 264 | c.functionBuilder = Result("foo", nil) 265 | c.err = ErrInvalidResult 266 | c.errStr = c.err.Error() + `: no value; functionName="github.com/go-tk/di_test.TestResult.func1" valueName="foo"` 267 | c.repr = ` 268 | SortedFunctionIndexes: [] 269 | CalledFunctionCount: 0 270 | `[1:] 271 | }).RunParallel(t) 272 | 273 | tc.WithCallback("0", func(t *testing.T, c *C) { 274 | c.functionBuilder = Result("foo", 0) 275 | c.err = ErrInvalidResult 276 | c.errStr = c.err.Error() + `: invalid value pointer; valuePtrType="int" functionName="github.com/go-tk/di_test.TestResult.func1" valueName="foo"` 277 | c.repr = ` 278 | SortedFunctionIndexes: [] 279 | CalledFunctionCount: 0 280 | `[1:] 281 | }).RunParallel(t) 282 | 283 | tc.WithCallback("0", func(t *testing.T, c *C) { 284 | c.functionBuilder = Result("foo", (*string)(nil)) 285 | c.err = ErrInvalidResult 286 | c.errStr = c.err.Error() + `: no value; functionName="github.com/go-tk/di_test.TestResult.func1" valueName="foo"` 287 | c.repr = ` 288 | SortedFunctionIndexes: [] 289 | CalledFunctionCount: 0 290 | `[1:] 291 | }).RunParallel(t) 292 | 293 | tc.WithCallback("0", func(t *testing.T, c *C) { 294 | c.functionBuilder = Result("foo", new(string)) 295 | c.repr = ` 296 | Function[0]: 297 | Index: 0 298 | Name: github.com/go-tk/di_test.TestResult.func1 299 | ArgumentIndexes: [] 300 | ResultIndexes: [0] 301 | HasBody: true 302 | HookIndexes: [] 303 | HasCleanup: false 304 | Result[0]: 305 | FunctionIndex: 0 306 | ValueName: foo 307 | HasValue: true 308 | HookIndexes: [] 309 | SortedFunctionIndexes: [] 310 | CalledFunctionCount: 0 311 | `[1:] 312 | }).RunParallel(t) 313 | } 314 | 315 | func TestBody(t *testing.T) { 316 | type C struct { 317 | functionBuilder FunctionBuilder 318 | 319 | errStr string 320 | err error 321 | repr string 322 | } 323 | tc := testcase.New(func(t *testing.T, c *C) { 324 | testcase.Callback(t, "0") 325 | 326 | var p Program 327 | err := p.NewFunction(c.functionBuilder, Body(func(context.Context) error { return nil })) 328 | if c.errStr == "" { 329 | assert.NoError(t, err) 330 | } else { 331 | assert.EqualError(t, err, c.errStr) 332 | if c.err != nil { 333 | assert.ErrorIs(t, err, c.err) 334 | } 335 | } 336 | if t.Failed() { 337 | t.FailNow() 338 | } 339 | assert.Equal(t, c.repr, p.DumpAsString()) 340 | }) 341 | 342 | tc.WithCallback("0", func(t *testing.T, c *C) { 343 | c.functionBuilder = Body(nil) 344 | c.err = ErrNilBody 345 | c.errStr = c.err.Error() + `; functionName="github.com/go-tk/di_test.TestBody.func1"` 346 | c.repr = ` 347 | SortedFunctionIndexes: [] 348 | CalledFunctionCount: 0 349 | `[1:] 350 | }).RunParallel(t) 351 | 352 | tc.WithCallback("0", func(t *testing.T, c *C) { 353 | c.functionBuilder = Body(func(context.Context) error { return nil }) 354 | c.repr = ` 355 | Function[0]: 356 | Index: 0 357 | Name: github.com/go-tk/di_test.TestBody.func1 358 | ArgumentIndexes: [] 359 | ResultIndexes: [] 360 | HasBody: true 361 | HookIndexes: [] 362 | HasCleanup: false 363 | SortedFunctionIndexes: [] 364 | CalledFunctionCount: 0 365 | `[1:] 366 | }).RunParallel(t) 367 | } 368 | 369 | func TestCleanup(t *testing.T) { 370 | type C struct { 371 | functionBuilder FunctionBuilder 372 | 373 | errStr string 374 | err error 375 | repr string 376 | } 377 | tc := testcase.New(func(t *testing.T, c *C) { 378 | testcase.Callback(t, "0") 379 | 380 | var p Program 381 | err := p.NewFunction(c.functionBuilder, Body(func(context.Context) error { return nil })) 382 | if c.errStr == "" { 383 | assert.NoError(t, err) 384 | } else { 385 | assert.EqualError(t, err, c.errStr) 386 | if c.err != nil { 387 | assert.ErrorIs(t, err, c.err) 388 | } 389 | } 390 | if t.Failed() { 391 | t.FailNow() 392 | } 393 | assert.Equal(t, c.repr, p.DumpAsString()) 394 | }) 395 | 396 | tc.WithCallback("0", func(t *testing.T, c *C) { 397 | c.functionBuilder = Cleanup(nil) 398 | c.err = ErrNilCleanup 399 | c.errStr = c.err.Error() + `; functionName="github.com/go-tk/di_test.TestCleanup.func1"` 400 | c.repr = ` 401 | SortedFunctionIndexes: [] 402 | CalledFunctionCount: 0 403 | `[1:] 404 | }).RunParallel(t) 405 | 406 | tc.WithCallback("0", func(t *testing.T, c *C) { 407 | c.functionBuilder = Cleanup(func() {}) 408 | c.repr = ` 409 | Function[0]: 410 | Index: 0 411 | Name: github.com/go-tk/di_test.TestCleanup.func1 412 | ArgumentIndexes: [] 413 | ResultIndexes: [] 414 | HasBody: true 415 | HookIndexes: [] 416 | HasCleanup: true 417 | SortedFunctionIndexes: [] 418 | CalledFunctionCount: 0 419 | `[1:] 420 | }).RunParallel(t) 421 | } 422 | 423 | func TestNewFunction(t *testing.T) { 424 | type C struct { 425 | functionBuilders []FunctionBuilder 426 | 427 | errStr string 428 | err error 429 | repr string 430 | } 431 | tc := testcase.New(func(t *testing.T, c *C) { 432 | testcase.Callback(t, "0") 433 | 434 | var p Program 435 | err := p.NewFunction(c.functionBuilders...) 436 | if c.errStr == "" { 437 | assert.NoError(t, err) 438 | } else { 439 | assert.EqualError(t, err, c.errStr) 440 | if c.err != nil { 441 | assert.ErrorIs(t, err, c.err) 442 | } 443 | } 444 | if t.Failed() { 445 | t.FailNow() 446 | } 447 | assert.Equal(t, c.repr, p.DumpAsString()) 448 | }) 449 | 450 | tc.WithCallback("0", func(t *testing.T, c *C) { 451 | c.err = ErrBodyRequired 452 | c.errStr = c.err.Error() + `; functionName="github.com/go-tk/di_test.TestNewFunction.func1"` 453 | c.repr = ` 454 | SortedFunctionIndexes: [] 455 | CalledFunctionCount: 0 456 | `[1:] 457 | }).RunParallel(t) 458 | 459 | tc.WithCallback("0", func(t *testing.T, c *C) { 460 | var ( 461 | arg1 int 462 | arg2 string 463 | res1 float64 464 | res2 byte 465 | var1 int32 466 | var2 int64 467 | ) 468 | c.functionBuilders = []FunctionBuilder{ 469 | Argument("arg1", &arg1), 470 | Argument("arg2", &arg2), 471 | Result("res1", &res1), 472 | Result("res2", &res2), 473 | Body(func(context.Context) error { return nil }), 474 | Hook("var1", &var1, func(context.Context) error { return nil }), 475 | Hook("var2", &var2, func(context.Context) error { return nil }), 476 | Cleanup(func() {}), 477 | } 478 | c.repr = ` 479 | Function[0]: 480 | Index: 0 481 | Name: github.com/go-tk/di_test.TestNewFunction.func1 482 | ArgumentIndexes: [0 1] 483 | ResultIndexes: [0 1] 484 | HasBody: true 485 | HookIndexes: [0 1] 486 | HasCleanup: true 487 | Argument[0]: 488 | FunctionIndex: 0 489 | ValueRef: arg1 490 | HasValueReceiver: true 491 | IsOptional: false 492 | ResultIndex: -1 493 | ReceiveValueAddr: false 494 | Argument[1]: 495 | FunctionIndex: 0 496 | ValueRef: arg2 497 | HasValueReceiver: true 498 | IsOptional: false 499 | ResultIndex: -1 500 | ReceiveValueAddr: false 501 | Result[0]: 502 | FunctionIndex: 0 503 | ValueName: res1 504 | HasValue: true 505 | HookIndexes: [] 506 | Result[1]: 507 | FunctionIndex: 0 508 | ValueName: res2 509 | HasValue: true 510 | HookIndexes: [] 511 | Hook[0]: 512 | FunctionIndex: 0 513 | ValueRef: var1 514 | HasValueReceiver: true 515 | HasCallback: true 516 | ReceiveValueAddr: false 517 | Hook[1]: 518 | FunctionIndex: 0 519 | ValueRef: var2 520 | HasValueReceiver: true 521 | HasCallback: true 522 | ReceiveValueAddr: false 523 | SortedFunctionIndexes: [] 524 | CalledFunctionCount: 0 525 | `[1:] 526 | }).RunParallel(t) 527 | } 528 | 529 | func TestProgram_MustNewFunction(t *testing.T) { 530 | { 531 | var p Program 532 | p.MustNewFunction(Body(func(context.Context) error { return nil })) 533 | } 534 | assert.PanicsWithValue(t, `new function: di: body required; functionName="github.com/go-tk/di_test.TestProgram_MustNewFunction.func2"`, func() { 535 | var p Program 536 | p.MustNewFunction() 537 | }) 538 | } 539 | 540 | func TestProgram_Run(t *testing.T) { 541 | type C struct { 542 | p *Program 543 | ctx *context.Context 544 | 545 | errStr string 546 | err error 547 | repr string 548 | } 549 | tc := testcase.New(func(t *testing.T, c *C) { 550 | var p Program 551 | ctx := context.Background() 552 | c.p = &p 553 | c.ctx = &ctx 554 | 555 | testcase.Callback(t, "0") 556 | 557 | err := p.Run(ctx) 558 | if c.errStr == "" { 559 | assert.NoError(t, err) 560 | } else { 561 | assert.EqualError(t, err, c.errStr) 562 | if c.err != nil { 563 | assert.ErrorIs(t, err, c.err) 564 | } 565 | } 566 | if t.Failed() { 567 | t.FailNow() 568 | } 569 | assert.Equal(t, c.repr, p.DumpAsString()) 570 | }) 571 | 572 | tc.WithCallback("0", func(t *testing.T, c *C) { 573 | func() { 574 | var res1 int 575 | err := c.p.NewFunction(Result("res1", &res1), Body(func(context.Context) error { return nil })) 576 | if err != nil { 577 | t.Fatal(err) 578 | } 579 | }() 580 | func() { 581 | var res1 int 582 | err := c.p.NewFunction(Result("res1", &res1), Body(func(context.Context) error { return nil })) 583 | if err != nil { 584 | t.Fatal(err) 585 | } 586 | }() 587 | c.err = ErrDuplicateValueName 588 | c.errStr = c.err.Error() + `; valueName="res1" functionName1="github.com/go-tk/di_test.TestProgram_Run.func2.2" functionName2="github.com/go-tk/di_test.TestProgram_Run.func2.1"` 589 | c.repr = ` 590 | Function[0]: 591 | Index: 0 592 | Name: github.com/go-tk/di_test.TestProgram_Run.func2.1 593 | ArgumentIndexes: [] 594 | ResultIndexes: [0] 595 | HasBody: true 596 | HookIndexes: [] 597 | HasCleanup: false 598 | Function[1]: 599 | Index: 1 600 | Name: github.com/go-tk/di_test.TestProgram_Run.func2.2 601 | ArgumentIndexes: [] 602 | ResultIndexes: [1] 603 | HasBody: true 604 | HookIndexes: [] 605 | HasCleanup: false 606 | Result[0]: 607 | FunctionIndex: 0 608 | ValueName: res1 609 | HasValue: true 610 | HookIndexes: [] 611 | Result[1]: 612 | FunctionIndex: 1 613 | ValueName: res1 614 | HasValue: true 615 | HookIndexes: [] 616 | SortedFunctionIndexes: [] 617 | CalledFunctionCount: 0 618 | `[1:] 619 | }).RunParallel(t) 620 | 621 | tc.WithCallback("0", func(t *testing.T, c *C) { 622 | func() { 623 | var arg int 624 | err := c.p.NewFunction(Argument("arg", &arg), Body(func(context.Context) error { return nil })) 625 | if err != nil { 626 | t.Fatal(err) 627 | } 628 | }() 629 | c.err = ErrValueNotFound 630 | c.errStr = c.err.Error() + `; valueRef="arg" functionName="github.com/go-tk/di_test.TestProgram_Run.func3.1"` 631 | c.repr = ` 632 | Function[0]: 633 | Index: 0 634 | Name: github.com/go-tk/di_test.TestProgram_Run.func3.1 635 | ArgumentIndexes: [0] 636 | ResultIndexes: [] 637 | HasBody: true 638 | HookIndexes: [] 639 | HasCleanup: false 640 | Argument[0]: 641 | FunctionIndex: 0 642 | ValueRef: arg 643 | HasValueReceiver: true 644 | IsOptional: false 645 | ResultIndex: -1 646 | ReceiveValueAddr: false 647 | SortedFunctionIndexes: [] 648 | CalledFunctionCount: 0 649 | `[1:] 650 | }).RunParallel(t) 651 | 652 | tc.WithCallback("0", func(t *testing.T, c *C) { 653 | func() { 654 | var arg int 655 | err := c.p.NewFunction(OptionalArgument("arg", &arg), Body(func(context.Context) error { return nil })) 656 | if err != nil { 657 | t.Fatal(err) 658 | } 659 | }() 660 | c.repr = ` 661 | Function[0]: 662 | Index: 0 663 | Name: github.com/go-tk/di_test.TestProgram_Run.func4.1 664 | ArgumentIndexes: [0] 665 | ResultIndexes: [] 666 | HasBody: true 667 | HookIndexes: [] 668 | HasCleanup: false 669 | Argument[0]: 670 | FunctionIndex: 0 671 | ValueRef: arg 672 | HasValueReceiver: true 673 | IsOptional: true 674 | ResultIndex: -1 675 | ReceiveValueAddr: false 676 | SortedFunctionIndexes: [0] 677 | CalledFunctionCount: 1 678 | `[1:] 679 | }).RunParallel(t) 680 | 681 | tc.WithCallback("0", func(t *testing.T, c *C) { 682 | func() { 683 | var val int 684 | err := c.p.NewFunction(Result("val", &val), Body(func(context.Context) error { return nil })) 685 | if err != nil { 686 | t.Fatal(err) 687 | } 688 | }() 689 | func() { 690 | var val string 691 | err := c.p.NewFunction(Argument("val", &val), Body(func(context.Context) error { return nil })) 692 | if err != nil { 693 | t.Fatal(err) 694 | } 695 | }() 696 | c.err = ErrIncompatibleValueReceiver 697 | c.errStr = c.err.Error() + `; valueReceiverType="string" valueType="int" valueRef="val" functionName="github.com/go-tk/di_test.TestProgram_Run.func5.2"` 698 | c.repr = ` 699 | Function[0]: 700 | Index: 0 701 | Name: github.com/go-tk/di_test.TestProgram_Run.func5.1 702 | ArgumentIndexes: [] 703 | ResultIndexes: [0] 704 | HasBody: true 705 | HookIndexes: [] 706 | HasCleanup: false 707 | Function[1]: 708 | Index: 1 709 | Name: github.com/go-tk/di_test.TestProgram_Run.func5.2 710 | ArgumentIndexes: [0] 711 | ResultIndexes: [] 712 | HasBody: true 713 | HookIndexes: [] 714 | HasCleanup: false 715 | Argument[0]: 716 | FunctionIndex: 1 717 | ValueRef: val 718 | HasValueReceiver: true 719 | IsOptional: false 720 | ResultIndex: -1 721 | ReceiveValueAddr: false 722 | Result[0]: 723 | FunctionIndex: 0 724 | ValueName: val 725 | HasValue: true 726 | HookIndexes: [] 727 | SortedFunctionIndexes: [] 728 | CalledFunctionCount: 0 729 | `[1:] 730 | }).RunParallel(t) 731 | 732 | tc.WithCallback("0", func(t *testing.T, c *C) { 733 | func() { 734 | var hook int 735 | err := c.p.NewFunction(Body(func(context.Context) error { return nil }), 736 | Hook("hook", &hook, func(context.Context) error { return nil })) 737 | if err != nil { 738 | t.Fatal(err) 739 | } 740 | }() 741 | c.err = ErrValueNotFound 742 | c.errStr = c.err.Error() + `; valueRef="hook" functionName="github.com/go-tk/di_test.TestProgram_Run.func6.1"` 743 | c.repr = ` 744 | Function[0]: 745 | Index: 0 746 | Name: github.com/go-tk/di_test.TestProgram_Run.func6.1 747 | ArgumentIndexes: [] 748 | ResultIndexes: [] 749 | HasBody: true 750 | HookIndexes: [0] 751 | HasCleanup: false 752 | Hook[0]: 753 | FunctionIndex: 0 754 | ValueRef: hook 755 | HasValueReceiver: true 756 | HasCallback: true 757 | ReceiveValueAddr: false 758 | SortedFunctionIndexes: [] 759 | CalledFunctionCount: 0 760 | `[1:] 761 | }).RunParallel(t) 762 | 763 | tc.WithCallback("0", func(t *testing.T, c *C) { 764 | func() { 765 | var val int 766 | err := c.p.NewFunction(Result("val", &val), Body(func(context.Context) error { return nil })) 767 | if err != nil { 768 | t.Fatal(err) 769 | } 770 | }() 771 | func() { 772 | var val string 773 | err := c.p.NewFunction(Body(func(context.Context) error { return nil }), 774 | Hook("val", &val, func(context.Context) error { return nil })) 775 | if err != nil { 776 | t.Fatal(err) 777 | } 778 | }() 779 | c.err = ErrIncompatibleValueReceiver 780 | c.errStr = c.err.Error() + `; valueReceiverType="string" valueType="int" valueRef="val" functionName="github.com/go-tk/di_test.TestProgram_Run.func7.2"` 781 | c.repr = ` 782 | Function[0]: 783 | Index: 0 784 | Name: github.com/go-tk/di_test.TestProgram_Run.func7.1 785 | ArgumentIndexes: [] 786 | ResultIndexes: [0] 787 | HasBody: true 788 | HookIndexes: [] 789 | HasCleanup: false 790 | Function[1]: 791 | Index: 1 792 | Name: github.com/go-tk/di_test.TestProgram_Run.func7.2 793 | ArgumentIndexes: [] 794 | ResultIndexes: [] 795 | HasBody: true 796 | HookIndexes: [0] 797 | HasCleanup: false 798 | Result[0]: 799 | FunctionIndex: 0 800 | ValueName: val 801 | HasValue: true 802 | HookIndexes: [] 803 | Hook[0]: 804 | FunctionIndex: 1 805 | ValueRef: val 806 | HasValueReceiver: true 807 | HasCallback: true 808 | ReceiveValueAddr: false 809 | SortedFunctionIndexes: [] 810 | CalledFunctionCount: 0 811 | `[1:] 812 | }).RunParallel(t) 813 | 814 | tc.WithCallback("0", func(t *testing.T, c *C) { 815 | func() { 816 | var val int 817 | var pval *int 818 | err := c.p.NewFunction(Argument("val", &pval), Result("val", &val), Body(func(context.Context) error { return nil })) 819 | if err != nil { 820 | t.Fatal(err) 821 | } 822 | }() 823 | c.err = ErrCircularDependencies 824 | c.errStr = c.err.Error() + `; path="github.com/go-tk/di_test.TestProgram_Run.func8.1@argument:val => github.com/go-tk/di_test.TestProgram_Run.func8.1"` 825 | c.repr = ` 826 | Function[0]: 827 | Index: -1 828 | Name: github.com/go-tk/di_test.TestProgram_Run.func8.1 829 | ArgumentIndexes: [0] 830 | ResultIndexes: [0] 831 | HasBody: true 832 | HookIndexes: [] 833 | HasCleanup: false 834 | Argument[0]: 835 | FunctionIndex: 0 836 | ValueRef: val 837 | HasValueReceiver: true 838 | IsOptional: false 839 | ResultIndex: 0 840 | ReceiveValueAddr: true 841 | Result[0]: 842 | FunctionIndex: 0 843 | ValueName: val 844 | HasValue: true 845 | HookIndexes: [] 846 | SortedFunctionIndexes: [] 847 | CalledFunctionCount: 0 848 | `[1:] 849 | }).RunParallel(t) 850 | 851 | tc.WithCallback("0", func(t *testing.T, c *C) { 852 | func() { 853 | var val int 854 | var pval *int 855 | err := c.p.NewFunction(Result("val", &val), Body(func(context.Context) error { return nil }), 856 | Hook("val", &pval, func(context.Context) error { return nil })) 857 | if err != nil { 858 | t.Fatal(err) 859 | } 860 | }() 861 | c.err = ErrCircularDependencies 862 | c.errStr = c.err.Error() + `; path="github.com/go-tk/di_test.TestProgram_Run.func9.1@hook:val => github.com/go-tk/di_test.TestProgram_Run.func9.1"` 863 | c.repr = ` 864 | Function[0]: 865 | Index: -1 866 | Name: github.com/go-tk/di_test.TestProgram_Run.func9.1 867 | ArgumentIndexes: [] 868 | ResultIndexes: [0] 869 | HasBody: true 870 | HookIndexes: [0] 871 | HasCleanup: false 872 | Result[0]: 873 | FunctionIndex: 0 874 | ValueName: val 875 | HasValue: true 876 | HookIndexes: [0] 877 | Hook[0]: 878 | FunctionIndex: 0 879 | ValueRef: val 880 | HasValueReceiver: true 881 | HasCallback: true 882 | ReceiveValueAddr: true 883 | SortedFunctionIndexes: [] 884 | CalledFunctionCount: 0 885 | `[1:] 886 | }).RunParallel(t) 887 | 888 | tc.WithCallback("0", func(t *testing.T, c *C) { 889 | func() { 890 | var val int 891 | err := c.p.NewFunction(Argument("val", &val), Body(func(context.Context) error { return nil })) 892 | if err != nil { 893 | t.Fatal(err) 894 | } 895 | }() 896 | func() { 897 | var val int 898 | err := c.p.NewFunction(Argument("val", &val), Body(func(context.Context) error { return nil })) 899 | if err != nil { 900 | t.Fatal(err) 901 | } 902 | }() 903 | func() { 904 | var val int 905 | err := c.p.NewFunction(Result("val", &val), Body(func(context.Context) error { return nil })) 906 | if err != nil { 907 | t.Fatal(err) 908 | } 909 | }() 910 | c.repr = ` 911 | Function[0]: 912 | Index: 0 913 | Name: github.com/go-tk/di_test.TestProgram_Run.func10.1 914 | ArgumentIndexes: [0] 915 | ResultIndexes: [] 916 | HasBody: true 917 | HookIndexes: [] 918 | HasCleanup: false 919 | Function[1]: 920 | Index: 1 921 | Name: github.com/go-tk/di_test.TestProgram_Run.func10.2 922 | ArgumentIndexes: [1] 923 | ResultIndexes: [] 924 | HasBody: true 925 | HookIndexes: [] 926 | HasCleanup: false 927 | Function[2]: 928 | Index: 2 929 | Name: github.com/go-tk/di_test.TestProgram_Run.func10.3 930 | ArgumentIndexes: [] 931 | ResultIndexes: [0] 932 | HasBody: true 933 | HookIndexes: [] 934 | HasCleanup: false 935 | Argument[0]: 936 | FunctionIndex: 0 937 | ValueRef: val 938 | HasValueReceiver: true 939 | IsOptional: false 940 | ResultIndex: 0 941 | ReceiveValueAddr: false 942 | Argument[1]: 943 | FunctionIndex: 1 944 | ValueRef: val 945 | HasValueReceiver: true 946 | IsOptional: false 947 | ResultIndex: 0 948 | ReceiveValueAddr: false 949 | Result[0]: 950 | FunctionIndex: 2 951 | ValueName: val 952 | HasValue: true 953 | HookIndexes: [] 954 | SortedFunctionIndexes: [2 0 1] 955 | CalledFunctionCount: 3 956 | `[1:] 957 | }).RunParallel(t) 958 | 959 | tc.WithCallback("0", func(t *testing.T, c *C) { 960 | func() { 961 | var ( 962 | x int 963 | y *int 964 | ) 965 | err := c.p.NewFunction(Argument("x", &x), Argument("y", &y), Body(func(context.Context) error { 966 | x += 1 967 | *y -= 1 968 | return nil 969 | })) 970 | if err != nil { 971 | t.Fatal(err) 972 | } 973 | }() 974 | func() { 975 | var ( 976 | x int 977 | y int 978 | ) 979 | err := c.p.NewFunction(Argument("x", &x), Argument("y", &y), Body(func(context.Context) error { 980 | assert.Equal(t, 100, x) 981 | assert.Equal(t, 98, y) 982 | y -= 1 983 | return nil 984 | })) 985 | if err != nil { 986 | t.Fatal(err) 987 | } 988 | }() 989 | func() { 990 | var ( 991 | x int 992 | y int 993 | ) 994 | err := c.p.NewFunction(Result("x", &x), Result("y", &y), Body(func(context.Context) error { 995 | x = 100 996 | y = 99 997 | return nil 998 | })) 999 | if err != nil { 1000 | t.Fatal(err) 1001 | } 1002 | }() 1003 | c.repr = ` 1004 | Function[0]: 1005 | Index: 0 1006 | Name: github.com/go-tk/di_test.TestProgram_Run.func11.1 1007 | ArgumentIndexes: [0 1] 1008 | ResultIndexes: [] 1009 | HasBody: true 1010 | HookIndexes: [] 1011 | HasCleanup: false 1012 | Function[1]: 1013 | Index: 1 1014 | Name: github.com/go-tk/di_test.TestProgram_Run.func11.2 1015 | ArgumentIndexes: [2 3] 1016 | ResultIndexes: [] 1017 | HasBody: true 1018 | HookIndexes: [] 1019 | HasCleanup: false 1020 | Function[2]: 1021 | Index: 2 1022 | Name: github.com/go-tk/di_test.TestProgram_Run.func11.3 1023 | ArgumentIndexes: [] 1024 | ResultIndexes: [0 1] 1025 | HasBody: true 1026 | HookIndexes: [] 1027 | HasCleanup: false 1028 | Argument[0]: 1029 | FunctionIndex: 0 1030 | ValueRef: x 1031 | HasValueReceiver: true 1032 | IsOptional: false 1033 | ResultIndex: 0 1034 | ReceiveValueAddr: false 1035 | Argument[1]: 1036 | FunctionIndex: 0 1037 | ValueRef: y 1038 | HasValueReceiver: true 1039 | IsOptional: false 1040 | ResultIndex: 1 1041 | ReceiveValueAddr: true 1042 | Argument[2]: 1043 | FunctionIndex: 1 1044 | ValueRef: x 1045 | HasValueReceiver: true 1046 | IsOptional: false 1047 | ResultIndex: 0 1048 | ReceiveValueAddr: false 1049 | Argument[3]: 1050 | FunctionIndex: 1 1051 | ValueRef: y 1052 | HasValueReceiver: true 1053 | IsOptional: false 1054 | ResultIndex: 1 1055 | ReceiveValueAddr: false 1056 | Result[0]: 1057 | FunctionIndex: 2 1058 | ValueName: x 1059 | HasValue: true 1060 | HookIndexes: [] 1061 | Result[1]: 1062 | FunctionIndex: 2 1063 | ValueName: y 1064 | HasValue: true 1065 | HookIndexes: [] 1066 | SortedFunctionIndexes: [2 0 1] 1067 | CalledFunctionCount: 3 1068 | `[1:] 1069 | }).RunParallel(t) 1070 | 1071 | tc.WithCallback("0", func(t *testing.T, c *C) { 1072 | func() { 1073 | var ( 1074 | x int 1075 | y int 1076 | ) 1077 | err := c.p.NewFunction(Argument("x", &x), Argument("y", &y), Body(func(context.Context) error { 1078 | assert.Equal(t, 100, x) 1079 | assert.Equal(t, 98, y) 1080 | y -= 1 1081 | return nil 1082 | })) 1083 | if err != nil { 1084 | t.Fatal(err) 1085 | } 1086 | }() 1087 | func() { 1088 | var ( 1089 | x int 1090 | y int 1091 | ) 1092 | err := c.p.NewFunction(Result("x", &x), Result("y", &y), Body(func(context.Context) error { 1093 | x = 100 1094 | y = 99 1095 | return nil 1096 | })) 1097 | if err != nil { 1098 | t.Fatal(err) 1099 | } 1100 | }() 1101 | func() { 1102 | var ( 1103 | x int 1104 | y *int 1105 | ) 1106 | err := c.p.NewFunction(Body(func(context.Context) error { return nil }), 1107 | Hook("x", &x, func(context.Context) error { x += 1; return nil }), 1108 | Hook("y", &y, func(context.Context) error { *y -= 1; return nil })) 1109 | 1110 | if err != nil { 1111 | t.Fatal(err) 1112 | } 1113 | }() 1114 | c.repr = ` 1115 | Function[0]: 1116 | Index: 0 1117 | Name: github.com/go-tk/di_test.TestProgram_Run.func12.1 1118 | ArgumentIndexes: [0 1] 1119 | ResultIndexes: [] 1120 | HasBody: true 1121 | HookIndexes: [] 1122 | HasCleanup: false 1123 | Function[1]: 1124 | Index: 1 1125 | Name: github.com/go-tk/di_test.TestProgram_Run.func12.2 1126 | ArgumentIndexes: [] 1127 | ResultIndexes: [0 1] 1128 | HasBody: true 1129 | HookIndexes: [] 1130 | HasCleanup: false 1131 | Function[2]: 1132 | Index: 2 1133 | Name: github.com/go-tk/di_test.TestProgram_Run.func12.3 1134 | ArgumentIndexes: [] 1135 | ResultIndexes: [] 1136 | HasBody: true 1137 | HookIndexes: [0 1] 1138 | HasCleanup: false 1139 | Argument[0]: 1140 | FunctionIndex: 0 1141 | ValueRef: x 1142 | HasValueReceiver: true 1143 | IsOptional: false 1144 | ResultIndex: 0 1145 | ReceiveValueAddr: false 1146 | Argument[1]: 1147 | FunctionIndex: 0 1148 | ValueRef: y 1149 | HasValueReceiver: true 1150 | IsOptional: false 1151 | ResultIndex: 1 1152 | ReceiveValueAddr: false 1153 | Result[0]: 1154 | FunctionIndex: 1 1155 | ValueName: x 1156 | HasValue: true 1157 | HookIndexes: [0] 1158 | Result[1]: 1159 | FunctionIndex: 1 1160 | ValueName: y 1161 | HasValue: true 1162 | HookIndexes: [1] 1163 | Hook[0]: 1164 | FunctionIndex: 2 1165 | ValueRef: x 1166 | HasValueReceiver: true 1167 | HasCallback: true 1168 | ReceiveValueAddr: false 1169 | Hook[1]: 1170 | FunctionIndex: 2 1171 | ValueRef: y 1172 | HasValueReceiver: true 1173 | HasCallback: true 1174 | ReceiveValueAddr: true 1175 | SortedFunctionIndexes: [2 1 0] 1176 | CalledFunctionCount: 3 1177 | `[1:] 1178 | }).RunParallel(t) 1179 | 1180 | tc.WithCallback("0", func(t *testing.T, c *C) { 1181 | var cancel context.CancelFunc 1182 | *c.ctx, cancel = context.WithTimeout(context.Background(), 0) 1183 | _ = cancel 1184 | func() { 1185 | err := c.p.NewFunction(Body(func(ctx context.Context) error { 1186 | <-ctx.Done() 1187 | return ctx.Err() 1188 | })) 1189 | if err != nil { 1190 | t.Fatal(err) 1191 | } 1192 | }() 1193 | c.err = context.DeadlineExceeded 1194 | c.errStr = `call function; functionName="github.com/go-tk/di_test.TestProgram_Run.func13.1": ` + c.err.Error() 1195 | c.repr = ` 1196 | Function[0]: 1197 | Index: 0 1198 | Name: github.com/go-tk/di_test.TestProgram_Run.func13.1 1199 | ArgumentIndexes: [] 1200 | ResultIndexes: [] 1201 | HasBody: true 1202 | HookIndexes: [] 1203 | HasCleanup: false 1204 | SortedFunctionIndexes: [0] 1205 | CalledFunctionCount: 0 1206 | `[1:] 1207 | }).RunParallel(t) 1208 | 1209 | tc.WithCallback("0", func(t *testing.T, c *C) { 1210 | var cancel context.CancelFunc 1211 | *c.ctx, cancel = context.WithCancel(context.Background()) 1212 | cancel() 1213 | func() { 1214 | var x int 1215 | err := c.p.NewFunction(Result("x", &x), Body(func(context.Context) error { 1216 | x = 100 1217 | return nil 1218 | })) 1219 | if err != nil { 1220 | t.Fatal(err) 1221 | } 1222 | }() 1223 | func() { 1224 | var x int 1225 | err := c.p.NewFunction(Body(func(context.Context) error { return nil }), 1226 | Hook("x", &x, func(context.Context) error { return context.Canceled })) 1227 | if err != nil { 1228 | t.Fatal(err) 1229 | } 1230 | }() 1231 | c.err = context.Canceled 1232 | c.errStr = `do callback; functionName="github.com/go-tk/di_test.TestProgram_Run.func14.2" valueRef="x": ` + c.err.Error() 1233 | c.repr = ` 1234 | Function[0]: 1235 | Index: 0 1236 | Name: github.com/go-tk/di_test.TestProgram_Run.func14.1 1237 | ArgumentIndexes: [] 1238 | ResultIndexes: [0] 1239 | HasBody: true 1240 | HookIndexes: [] 1241 | HasCleanup: false 1242 | Function[1]: 1243 | Index: 1 1244 | Name: github.com/go-tk/di_test.TestProgram_Run.func14.2 1245 | ArgumentIndexes: [] 1246 | ResultIndexes: [] 1247 | HasBody: true 1248 | HookIndexes: [0] 1249 | HasCleanup: false 1250 | Result[0]: 1251 | FunctionIndex: 0 1252 | ValueName: x 1253 | HasValue: true 1254 | HookIndexes: [0] 1255 | Hook[0]: 1256 | FunctionIndex: 1 1257 | ValueRef: x 1258 | HasValueReceiver: true 1259 | HasCallback: true 1260 | ReceiveValueAddr: false 1261 | SortedFunctionIndexes: [1 0] 1262 | CalledFunctionCount: 2 1263 | `[1:] 1264 | }).RunParallel(t) 1265 | } 1266 | 1267 | func TestProgram_MustRun(t *testing.T) { 1268 | { 1269 | var p Program 1270 | p.MustRun(context.Background()) 1271 | } 1272 | assert.PanicsWithValue(t, `run program: di: value not found; valueRef="x" functionName="github.com/go-tk/di_test.TestProgram_MustRun.func1"`, func() { 1273 | var p Program 1274 | var x int 1275 | p.MustNewFunction(Argument("x", &x), Body(func(context.Context) error { return nil })) 1276 | p.MustRun(context.Background()) 1277 | }) 1278 | } 1279 | 1280 | func TestProgram_Clean(t *testing.T) { 1281 | type C struct { 1282 | p *Program 1283 | ctx *context.Context 1284 | 1285 | errStr string 1286 | err error 1287 | seq string 1288 | } 1289 | tc := testcase.New(func(t *testing.T, c *C) { 1290 | var p Program 1291 | ctx := context.Background() 1292 | c.p = &p 1293 | c.ctx = &ctx 1294 | 1295 | testcase.Callback(t, "0") 1296 | 1297 | err := p.Run(ctx) 1298 | if c.errStr == "" { 1299 | assert.NoError(t, err) 1300 | } else { 1301 | assert.EqualError(t, err, c.errStr) 1302 | if c.err != nil { 1303 | assert.ErrorIs(t, err, c.err) 1304 | } 1305 | } 1306 | if t.Failed() { 1307 | t.FailNow() 1308 | } 1309 | p.Clean() 1310 | 1311 | testcase.Callback(t, "1") 1312 | }) 1313 | 1314 | tc.WithCallback("0", func(t *testing.T, c *C) { 1315 | func() { 1316 | var ( 1317 | x int 1318 | y int 1319 | ) 1320 | err := c.p.NewFunction( 1321 | Argument("y", &y), 1322 | Result("x", &x), 1323 | Body(func(context.Context) error { c.seq += "A"; return nil }), 1324 | Cleanup(func() { c.seq += "B" })) 1325 | if err != nil { 1326 | t.Fatal(err) 1327 | } 1328 | }() 1329 | func() { 1330 | var y int 1331 | err := c.p.NewFunction( 1332 | Result("y", &y), 1333 | Body(func(context.Context) error { c.seq += "C"; return nil }), 1334 | Cleanup(func() { c.seq += "D" })) 1335 | if err != nil { 1336 | t.Fatal(err) 1337 | } 1338 | }() 1339 | func() { 1340 | var y int 1341 | err := c.p.NewFunction( 1342 | Body(func(context.Context) error { c.seq += "E"; return nil }), 1343 | Hook("y", &y, func(context.Context) error { c.seq += "F"; return nil }), 1344 | Cleanup(func() { c.seq += "G" })) 1345 | if err != nil { 1346 | t.Fatal(err) 1347 | } 1348 | }() 1349 | }).WithCallback("1", func(t *testing.T, c *C) { 1350 | assert.Equal(t, c.seq, "ECFABDG") 1351 | }).RunParallel(t) 1352 | 1353 | tc.WithCallback("0", func(t *testing.T, c *C) { 1354 | func() { 1355 | var ( 1356 | x int 1357 | y int 1358 | ) 1359 | err := c.p.NewFunction( 1360 | Argument("y", &y), 1361 | Result("x", &x), 1362 | Body(func(context.Context) error { c.seq += "A"; return nil }), 1363 | Cleanup(func() { c.seq += "B" })) 1364 | if err != nil { 1365 | t.Fatal(err) 1366 | } 1367 | }() 1368 | func() { 1369 | var y int 1370 | err := c.p.NewFunction( 1371 | Result("y", &y), 1372 | Body(func(context.Context) error { c.seq += "C"; return nil }), 1373 | Cleanup(func() { c.seq += "D" })) 1374 | if err != nil { 1375 | t.Fatal(err) 1376 | } 1377 | }() 1378 | func() { 1379 | var y int 1380 | err := c.p.NewFunction( 1381 | Body(func(context.Context) error { c.seq += "E"; return nil }), 1382 | Hook("y", &y, func(context.Context) error { c.seq += "F"; return context.Canceled }), 1383 | Cleanup(func() { c.seq += "G" })) 1384 | if err != nil { 1385 | t.Fatal(err) 1386 | } 1387 | }() 1388 | c.errStr = `do callback; functionName="github.com/go-tk/di_test.TestProgram_Clean.func4.3" valueRef="y": ` + context.Canceled.Error() 1389 | }).WithCallback("1", func(t *testing.T, c *C) { 1390 | assert.Equal(t, c.seq, "ECFDG") 1391 | }).RunParallel(t) 1392 | 1393 | tc.WithCallback("0", func(t *testing.T, c *C) { 1394 | func() { 1395 | var ( 1396 | x int 1397 | y int 1398 | ) 1399 | err := c.p.NewFunction( 1400 | Argument("y", &y), 1401 | Result("x", &x), 1402 | Body(func(context.Context) error { c.seq += "A"; return context.Canceled }), 1403 | Cleanup(func() { c.seq += "B" })) 1404 | if err != nil { 1405 | t.Fatal(err) 1406 | } 1407 | }() 1408 | func() { 1409 | var y int 1410 | err := c.p.NewFunction( 1411 | Result("y", &y), 1412 | Body(func(context.Context) error { c.seq += "C"; return nil }), 1413 | Cleanup(func() { c.seq += "D" })) 1414 | if err != nil { 1415 | t.Fatal(err) 1416 | } 1417 | }() 1418 | func() { 1419 | var y int 1420 | err := c.p.NewFunction( 1421 | Body(func(context.Context) error { c.seq += "E"; return nil }), 1422 | Hook("y", &y, func(context.Context) error { c.seq += "F"; return nil }), 1423 | Cleanup(func() { c.seq += "G" })) 1424 | if err != nil { 1425 | t.Fatal(err) 1426 | } 1427 | }() 1428 | c.errStr = `call function; functionName="github.com/go-tk/di_test.TestProgram_Clean.func6.1": ` + context.Canceled.Error() 1429 | }).WithCallback("1", func(t *testing.T, c *C) { 1430 | assert.Equal(t, c.seq, "ECFADG") 1431 | }).RunParallel(t) 1432 | } 1433 | --------------------------------------------------------------------------------