├── .codecov.yml ├── .github └── workflows │ ├── codesee-arch-diagram.yml │ └── go.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── _examples ├── condition_provide │ └── main.go ├── goway │ └── main.go └── tutorial │ └── main.go ├── cmp.go ├── cmp_ctor.go ├── cmp_group.go ├── cmp_type.go ├── cmp_value.go ├── container.go ├── container_test.go ├── cycle.go ├── doc.go ├── docs ├── advanced.md └── tutorial.md ├── errors.go ├── go.mod ├── go.sum ├── inject.go ├── inspect.go ├── invocation.go ├── node.go ├── options.go ├── options_test.go ├── schema.go ├── stacktrace.go ├── stacktrace_test.go ├── tags.go ├── tracer.go └── tracer_test.go /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | range: 70..98 3 | round: down 4 | precision: 2 5 | status: 6 | project: 7 | default: 8 | enabled: yes 9 | target: 89 10 | if_not_found: success 11 | if_ci_failed: error 12 | patch: 13 | default: 14 | enabled: yes 15 | target: 70 -------------------------------------------------------------------------------- /.github/workflows/codesee-arch-diagram.yml: -------------------------------------------------------------------------------- 1 | # This workflow was added by CodeSee. Learn more at https://codesee.io/ 2 | # This is v2.0 of this workflow file 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request_target: 8 | types: [opened, synchronize, reopened] 9 | 10 | name: CodeSee 11 | 12 | permissions: read-all 13 | 14 | jobs: 15 | codesee: 16 | runs-on: ubuntu-latest 17 | continue-on-error: true 18 | name: Analyze the repo with CodeSee 19 | steps: 20 | - uses: Codesee-io/codesee-action@v2 21 | with: 22 | codesee-token: ${{ secrets.CODESEE_ARCH_DIAG_API_TOKEN }} 23 | codesee-url: https://app.codesee.io 24 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | tags: [ 'v*' ] 7 | pull_request: 8 | branches: [ '*' ] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | go: [ "1.19.x", "1.20.x" ] 19 | include: 20 | - go: 1.20.x 21 | latest: true 22 | 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v3 26 | 27 | - name: Setup Go 28 | uses: actions/setup-go@v3 29 | with: 30 | go-version: ${{ matrix.go }} 31 | cache: true 32 | cache-dependency-path: '**/go.sum' 33 | 34 | - name: Download Dependencies 35 | run: | 36 | go mod download 37 | 38 | - name: Test 39 | run: make cover 40 | 41 | - name: Upload coverage to codecov.io 42 | uses: codecov/codecov-action@v3 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.txt 2 | profile.out 3 | vendor 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on 6 | [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this 7 | project adheres to 8 | [Semantic Versioning](https://semver.org/spec/v2.0.0.html): TBD, use 9 | modules or another vendor system. 10 | 11 | ## v1.12.0 12 | 13 | ### Changed 14 | 15 | - Update package name 16 | 17 | ## v1.11.1 18 | 19 | ### Fixed 20 | 21 | - [@chirino](https://github.com/chirino): fix: you could not use a 22 | struct to for both di and json marshalling ([#41](https://github.com/defval/di/pull/41)). 23 | 24 | ## v1.11.0 25 | 26 | ### Added 27 | 28 | - `di.Decorate` function that applies function after type resolve. 29 | 30 | ### Changed 31 | 32 | - [@chirino](https://github.com/chirino): Prefer using "di:" field tags 33 | to control injection options to avoid conflicting with tags used by 34 | other libraries ([#38](https://github.com/defval/di/pull/38)). 35 | 36 | ## v1.10.0 37 | 38 | ### Added 39 | 40 | - [@chirino](https://github.com/chirino): Container nesting. See 41 | `AddParent()` function ([#35](https://github.com/defval/di/pull/35)). 42 | - An experimental feature: Instance decoration with `di.Decorate()`. 43 | 44 | ### Fixed 45 | 46 | - [@chirino](https://github.com/chirino): Calling `Resolve()` on a 47 | `di.Injectable` would overwrite the skip fields 48 | ([#34](https://github.com/defval/di/pull/34)). 49 | 50 | ## v1.9.0 51 | 52 | ### Added 53 | 54 | - `container.ProvideValue()` function. 55 | 56 | ## v1.8.0 57 | 58 | ### Added 59 | 60 | - `container.Apply()` function. 61 | 62 | ## v1.7.1 63 | 64 | ### Fixed 65 | 66 | - Style and coverage fixes. 67 | 68 | ## v1.7.0 69 | 70 | ### Added 71 | 72 | - Added embed fields support. 73 | 74 | ### Fixed 75 | 76 | - `di.Inject` now works with structs and pointers correctly. 77 | 78 | ## v1.6.3 79 | 80 | ### Fixed 81 | 82 | - Fix `optional` fields resolving. 83 | 84 | ## v1.6.2 85 | 86 | ### Fixed 87 | 88 | - Fix `di.As()` with several interfaces. 89 | 90 | ## v1.6.1 91 | 92 | ### Fixed 93 | 94 | - Removed debug print. 95 | - Documentation fixes. 96 | 97 | ## v1.6.0 98 | 99 | ### Changed 100 | 101 | - Changed logging interface. See `di.SetTracer()`. 102 | 103 | ### Fixed 104 | 105 | - Some documentation and test updates. 106 | 107 | ## v1.5.0 108 | 109 | ### Added 110 | 111 | - Add error to `Has()`. 112 | 113 | ### Fixed 114 | 115 | - `Has()` returns false if container could not build instance. 116 | 117 | ### Changed 118 | 119 | - The supported version of go >1.13. 120 | 121 | ## v1.4.1 122 | 123 | ### Fixed 124 | 125 | - Fix field injection into interface implementations. 126 | 127 | ## v1.4.0 128 | 129 | ### Added 130 | 131 | - `Iterate` method for lazy loaded iteration by all instances. 132 | 133 | ## v1.3.1 134 | 135 | ### Fixed 136 | 137 | - Bug: Resolve type as interface causes type reinitialization. 138 | 139 | ## v1.3.0: A release that doesn't deserve to be called `v2` 140 | 141 | ### BREAKING CHANGES 142 | 143 | - Provide duplications allowed. 144 | - Removed tag `di`. Now all public fields in injectable type will be 145 | injected. 146 | - Resolving node without tags, now returns all nodes of this type. 147 | - Now, `di:"type_name"` is a `name:"type_name"`. 148 | - Removed `di.Prototype()`: bad practice. 149 | 150 | ### Added 151 | 152 | - Tagging that allows specifying key value identity for types. 153 | - `skip:"true"` field tag option, that skips field providing. 154 | 155 | ### Fixed 156 | 157 | - A bit of bad code 158 | 159 | ## v1.2.1 160 | 161 | ### Fixed 162 | 163 | - [Using `di.WithName()` breaks when having one entry without a `di.Name()`](https://github.com/defval/di/issues/16): 164 | 165 | ## v1.2.0 166 | 167 | ### Added 168 | 169 | - Any type can be automatically resolved as a group. 170 | - The container exposes itself by default. 171 | - The only named type in the group will be resolved without a name. 172 | - Dependency graph can be edited in the runtime (but you need to be 173 | careful with this). 174 | 175 | ## v1.1.0 176 | 177 | ### BREAKING CHANGES 178 | 179 | - Changed `di.Parameter` to `di.Inject`. 180 | - Remove `optional` support from `di` tag. 181 | - Add `optional` tag. See 182 | [this](https://github.com/defval/di#optional-parameters). 183 | 184 | ### Added 185 | 186 | - Support injection into constructor result struct via `di.Inject`. 187 | 188 | ## v1.0.2 189 | 190 | ### Added 191 | 192 | - Location of `di.Provide()`, `di.Invoke()`, `di.Resolve()` in error. 193 | 194 | ### Fixed 195 | 196 | - Fix: `di.As()` with nil causes panic. 197 | 198 | ## v1.0.1 199 | 200 | ### Fixed 201 | 202 | - `container.Provide` could not be called after container compilation 203 | now. 204 | - Improve error messages. 205 | 206 | 207 | ## v1.0.0 208 | 209 | Initial release. 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020-2023 defval 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: cover 2 | cover: 3 | go test -race -coverprofile=cover.out -coverpkg=./... ./... 4 | go tool cover -html=cover.out -o cover.html -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DI 2 | 3 | [![Documentation](https://img.shields.io/badge/godoc-reference-blue.svg?style=for-the-badge&logo=go)](https://pkg.go.dev/github.com/defval/di) 4 | [![GitHub release (latest by date)](https://img.shields.io/github/v/release/defval/di?logo=semver&style=for-the-badge)](https://github.com/defval/di/releases/latest) 5 | [![GitHub Workflow Status (with branch)](https://img.shields.io/github/actions/workflow/status/defval/di/go.yml?branch=master&logo=github-actions&style=for-the-badge)](https://github.com/defval/di/actions/workflows/go.yml) 6 | [![Go Report Card](https://img.shields.io/badge/go%20report-A%2B-green?style=for-the-badge)](https://goreportcard.com/report/github.com/defval/di) 7 | [![Codecov](https://img.shields.io/codecov/c/github/defval/di?logo=codecov&style=for-the-badge)](https://codecov.io/gh/defval/di) 8 | 9 | **DI** is a dependency injection library for the Go programming language. 10 | 11 | Dependency injection is a form of inversion of control that increases modularity and extensibility in your programs. 12 | This library helps you organize responsibilities in your codebase and makes it easy to combine low-level implementations 13 | into high-level behavior without boilerplate. 14 | 15 | ## Features 16 | 17 | - Intuitive auto wiring 18 | - Interface implementations 19 | - Constructor injection 20 | - Optional injection 21 | - Field injection 22 | - Lazy-loading 23 | - Tagging 24 | - Grouping 25 | - Iteration 26 | - Decoration 27 | - Cleanup 28 | - Container Chaining / Scopes 29 | 30 | ## Installation 31 | 32 | ```shell 33 | go get github.com/defval/di 34 | ``` 35 | 36 | ## Documentation 37 | 38 | You can use the standard [pkg.go.dev](https://pkg.go.dev/github.com/defval/di) and inline code comments. If you are new 39 | to auto-wiring libraries such as [google/wire](https://github.com/google/wire) 40 | or [uber-go/dig](https://github.com/uber-go/dig), start with the [tutorial](./docs/tutorial.md). 41 | 42 | ### Essential Reading 43 | 44 | - [Tutorial](./docs/tutorial.md) 45 | - [Examples](./_examples) 46 | - [Advanced Features](./docs/advanced.md) 47 | 48 | ## Example Usage 49 | 50 | ```go 51 | package main 52 | 53 | import ( 54 | "context" 55 | "fmt" 56 | "log" 57 | "net/http" 58 | "os" 59 | "os/signal" 60 | "syscall" 61 | 62 | "github.com/defval/di" 63 | ) 64 | 65 | func main() { 66 | di.SetTracer(&di.StdTracer{}) 67 | // create container 68 | c, err := di.New( 69 | di.Provide(NewContext), // provide application context 70 | di.Provide(NewServer), // provide http server 71 | di.Provide(NewServeMux), // provide http serve mux 72 | // controllers as []Controller group 73 | di.Provide(NewOrderController, di.As(new(Controller))), 74 | di.Provide(NewUserController, di.As(new(Controller))), 75 | ) 76 | // handle container errors 77 | if err != nil { 78 | log.Fatal(err) 79 | } 80 | // invoke function 81 | if err := c.Invoke(StartServer); err != nil { 82 | log.Fatal(err) 83 | } 84 | } 85 | ``` 86 | 87 | Full code available [here](./_examples/tutorial/main.go). 88 | 89 | ## Questions 90 | 91 | If you have any questions, feel free to create an issue. 92 | -------------------------------------------------------------------------------- /_examples/condition_provide/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "net" 8 | "os" 9 | 10 | "github.com/defval/di" 11 | ) 12 | 13 | // Configuration 14 | type Configuration struct { 15 | ConnectionType string 16 | } 17 | 18 | // NewConfiguration creates new configuration. 19 | func NewConfiguration() *Configuration { 20 | c := &Configuration{ 21 | ConnectionType: "tcp", 22 | } 23 | if typ, ok := os.LookupEnv("CONNECTION_TYPE"); ok { 24 | c.ConnectionType = typ 25 | } 26 | return c 27 | } 28 | 29 | // NewTCPConnection creates tcp connection 30 | func NewTCPConn() *net.TCPConn { 31 | return &net.TCPConn{} 32 | } 33 | 34 | // NewUDPConn creates udp connection 35 | func NewUDPConn() *net.UDPConn { 36 | return &net.UDPConn{} 37 | } 38 | 39 | // ProvideConfiguredConnection 40 | func ProvideConfiguredConnection(conf *Configuration, container *di.Container) error { 41 | switch conf.ConnectionType { 42 | case "tcp": 43 | return container.Provide(NewTCPConn, di.As(new(net.Conn))) 44 | case "udp": 45 | return container.Provide(NewUDPConn, di.As(new(net.Conn))) 46 | } 47 | return errors.New("unknown connection type") 48 | } 49 | 50 | func main() { 51 | c, err := di.New( 52 | di.Provide(NewConfiguration), 53 | di.Invoke(ProvideConfiguredConnection), 54 | ) 55 | if err != nil { 56 | log.Fatalln(err) 57 | } 58 | var conn net.Conn 59 | if err := c.Resolve(&conn); err != nil { 60 | log.Fatalln(err) 61 | } 62 | switch conn.(type) { 63 | case *net.TCPConn: 64 | fmt.Println("Provided connection: TCP") 65 | case *net.UDPConn: 66 | fmt.Println("Provided connection: UDP") 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /_examples/goway/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | ) 11 | 12 | func main() { 13 | orders := NewOrderController() 14 | users := NewUserController() 15 | mux := NewServeMux() 16 | mux.HandleFunc("/orders", orders.RetrieveOrders) 17 | mux.HandleFunc("/users", users.RetrieveUsers) 18 | server := NewServer(mux) 19 | log.Println("start server") 20 | errChan := make(chan error) 21 | go func() { 22 | if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { 23 | errChan <- err 24 | } 25 | }() 26 | ctx, cancel := context.WithCancel(context.Background()) 27 | go func() { 28 | stop := make(chan os.Signal) 29 | signal.Notify(stop, syscall.SIGTERM, syscall.SIGINT) 30 | <-stop 31 | cancel() 32 | }() 33 | select { 34 | case <-ctx.Done(): 35 | log.Println("stop server") 36 | if err := server.Close(); err != nil { 37 | log.Fatal(err) 38 | } 39 | case err := <-errChan: 40 | log.Fatal(err) 41 | } 42 | } 43 | 44 | // NewServer creates a http server with provided mux as handler. 45 | func NewServer(mux *http.ServeMux) *http.Server { 46 | server := &http.Server{ 47 | Addr: ":8080", 48 | Handler: mux, 49 | } 50 | return server 51 | } 52 | 53 | // NewServeMux creates a new http serve mux. 54 | func NewServeMux() *http.ServeMux { 55 | return &http.ServeMux{} 56 | } 57 | 58 | // OrderController is a http controller for orders. 59 | type OrderController struct{} 60 | 61 | // NewOrderController creates a auth http controller. 62 | func NewOrderController() *OrderController { 63 | return &OrderController{} 64 | } 65 | 66 | // Retrieve loads orders and writes it to the writer. 67 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, _ *http.Request) { 68 | writer.WriteHeader(http.StatusOK) 69 | _, _ = writer.Write([]byte("Orders")) 70 | } 71 | 72 | // UserController is a http endpoint for a user. 73 | type UserController struct{} 74 | 75 | // NewUserController creates a user http endpoint. 76 | func NewUserController() *UserController { 77 | return &UserController{} 78 | } 79 | 80 | // Retrieve loads users and writes it using the writer. 81 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, _ *http.Request) { 82 | writer.WriteHeader(http.StatusOK) 83 | _, _ = writer.Write([]byte("Users")) 84 | } 85 | -------------------------------------------------------------------------------- /_examples/tutorial/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | 12 | "github.com/defval/di" 13 | ) 14 | 15 | func main() { 16 | di.SetTracer(&di.StdTracer{}) 17 | // create container 18 | c, err := di.New( 19 | di.Provide(NewContext), // provide application context 20 | di.Provide(NewServer), // provide http server 21 | di.Provide(NewServeMux), // provide http serve mux 22 | // controllers as []Controller group 23 | di.Provide(NewOrderController, di.As(new(Controller))), 24 | di.Provide(NewUserController, di.As(new(Controller))), 25 | ) 26 | // handle container errors 27 | if err != nil { 28 | log.Fatal(err) 29 | } 30 | // invoke function 31 | if err := c.Invoke(StartServer); err != nil { 32 | log.Fatal(err) 33 | } 34 | } 35 | 36 | // StartServer starts http server. 37 | func StartServer(ctx context.Context, server *http.Server) error { 38 | log.Println("start server") 39 | errChan := make(chan error) 40 | go func() { 41 | if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { 42 | errChan <- err 43 | } 44 | }() 45 | select { 46 | case <-ctx.Done(): 47 | log.Println("stop server") 48 | return server.Close() 49 | case err := <-errChan: 50 | return fmt.Errorf("server error: %s", err) 51 | } 52 | } 53 | 54 | // NewContext creates new application context. 55 | func NewContext() context.Context { 56 | ctx, cancel := context.WithCancel(context.Background()) 57 | go func() { 58 | stop := make(chan os.Signal) 59 | signal.Notify(stop, syscall.SIGTERM, syscall.SIGINT) 60 | <-stop 61 | cancel() 62 | }() 63 | return ctx 64 | } 65 | 66 | // NewServer creates a http server with provided mux as handler. 67 | func NewServer(mux *http.ServeMux) *http.Server { 68 | server := &http.Server{ 69 | Addr: ":8080", 70 | Handler: mux, 71 | } 72 | return server 73 | } 74 | 75 | // NewServeMux creates a new http serve mux. 76 | func NewServeMux(controllers []Controller) *http.ServeMux { 77 | mux := &http.ServeMux{} 78 | for _, controller := range controllers { 79 | controller.RegisterRoutes(mux) 80 | } 81 | return mux 82 | } 83 | 84 | // Controller is an interface that can register its routes. 85 | type Controller interface { 86 | RegisterRoutes(mux *http.ServeMux) 87 | } 88 | 89 | // OrderController is a http controller for orders. 90 | type OrderController struct{} 91 | 92 | // NewOrderController creates a auth http controller. 93 | func NewOrderController() *OrderController { 94 | return &OrderController{} 95 | } 96 | 97 | // RegisterRoutes is a Controller interface implementation. 98 | func (a *OrderController) RegisterRoutes(mux *http.ServeMux) { 99 | mux.HandleFunc("/orders", a.RetrieveOrders) 100 | } 101 | 102 | // Retrieve loads orders and writes it to the writer. 103 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, _ *http.Request) { 104 | writer.WriteHeader(http.StatusOK) 105 | _, _ = writer.Write([]byte("Orders")) 106 | } 107 | 108 | // UserController is a http endpoint for a user. 109 | type UserController struct{} 110 | 111 | // NewUserController creates a user http endpoint. 112 | func NewUserController() *UserController { 113 | return &UserController{} 114 | } 115 | 116 | // RegisterRoutes is a Controller interface implementation. 117 | func (e *UserController) RegisterRoutes(mux *http.ServeMux) { 118 | mux.HandleFunc("/users", e.RetrieveUsers) 119 | } 120 | 121 | // Retrieve loads users and writes it using the writer. 122 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, _ *http.Request) { 123 | writer.WriteHeader(http.StatusOK) 124 | _, _ = writer.Write([]byte("Users")) 125 | } 126 | -------------------------------------------------------------------------------- /cmp.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // compiler compiles dependency node. 8 | type compiler interface { 9 | // deps return array of nodes that will be used for node compilation. 10 | deps(s schema) ([]*node, error) 11 | // compile compiles node. The dependencies are already compiled dependencies of this type. 12 | compile(dependencies []reflect.Value, s schema) (reflect.Value, error) 13 | } 14 | -------------------------------------------------------------------------------- /cmp_ctor.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // ctorType describes types of constructor provider. 8 | type ctorType int 9 | 10 | const ( 11 | ctorUnknown ctorType = iota 12 | ctorValue // (deps) (result) 13 | ctorValueError // (deps) (result, error) 14 | ctorValueCleanup // (deps) (result, cleanup) 15 | ctorValueCleanupError // (deps) (result, cleanup, error) 16 | ) 17 | 18 | // constructorCompiler compiles constructor functions. 19 | type constructorCompiler struct { 20 | typ ctorType 21 | fn function 22 | } 23 | 24 | // newConstructorCompiler creates new function compiler from function. 25 | func newConstructorCompiler(fn function) (*constructorCompiler, bool) { 26 | ctorType := determineCtorType(fn) 27 | if ctorType == ctorUnknown { 28 | return nil, false 29 | } 30 | return &constructorCompiler{ 31 | typ: ctorType, 32 | fn: fn, 33 | }, true 34 | } 35 | 36 | func (c constructorCompiler) deps(s schema) (deps []*node, err error) { 37 | for i := 0; i < c.fn.NumIn(); i++ { 38 | in := c.fn.Type.In(i) 39 | node, err := s.find(in, Tags{}) 40 | if err != nil { 41 | return nil, err 42 | } 43 | deps = append(deps, node) 44 | } 45 | return deps, nil 46 | } 47 | 48 | func (c constructorCompiler) compile(dependencies []reflect.Value, s schema) (reflect.Value, error) { 49 | // call constructor function 50 | out := funcResult(c.fn.Call(dependencies)) 51 | rv := out.value() 52 | switch c.typ { 53 | case ctorValue: 54 | return rv, nil 55 | case ctorValueError: 56 | return rv, out.error(1) 57 | case ctorValueCleanup: 58 | s.cleanup(out.cleanup()) 59 | return rv, nil 60 | case ctorValueCleanupError: 61 | s.cleanup(out.cleanup()) 62 | return rv, out.error(2) 63 | } 64 | bug() 65 | return reflect.Value{}, nil 66 | } 67 | 68 | // determineCtorType 69 | func determineCtorType(fn function) ctorType { 70 | switch true { 71 | case fn.NumOut() == 1: 72 | return ctorValue 73 | case fn.NumOut() == 2: 74 | if isError(fn.Out(1)) { 75 | return ctorValueError 76 | } 77 | if isCleanup(fn.Out(1)) { 78 | return ctorValueCleanup 79 | } 80 | case fn.NumOut() == 3 && isCleanup(fn.Out(1)) && isError(fn.Out(2)): 81 | return ctorValueCleanupError 82 | } 83 | return ctorUnknown 84 | } 85 | 86 | // funcResult is a helper struct for reflect.Call. 87 | type funcResult []reflect.Value 88 | 89 | // value returns first result type. 90 | func (r funcResult) value() reflect.Value { 91 | return r[0] 92 | } 93 | 94 | // cleanup returns cleanup function. 95 | func (r funcResult) cleanup() func() { 96 | if r[1].IsNil() { 97 | return nil 98 | } 99 | return r[1].Interface().(func()) 100 | } 101 | 102 | // error returns error if it exists. 103 | func (r funcResult) error(position int) error { 104 | if r[position].IsNil() { 105 | return nil 106 | } 107 | return r[position].Interface().(error) 108 | } 109 | -------------------------------------------------------------------------------- /cmp_group.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | type groupCompiler struct { 8 | rt reflect.Type 9 | matched []*node 10 | } 11 | 12 | // newGroupCompiler creates group compiler of rt and with matched nodes. 13 | func newGroupCompiler(rt reflect.Type, matched []*node) *groupCompiler { 14 | return &groupCompiler{ 15 | rt: rt, 16 | matched: matched, 17 | } 18 | } 19 | 20 | func (c *groupCompiler) deps(s schema) (deps []*node, err error) { 21 | return c.matched, nil 22 | } 23 | 24 | func (c *groupCompiler) compile(dependencies []reflect.Value, s schema) (reflect.Value, error) { 25 | return reflect.Append(reflect.New(c.rt).Elem(), dependencies...), nil 26 | } 27 | -------------------------------------------------------------------------------- /cmp_type.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | type typeCompiler struct { 8 | rt reflect.Type 9 | } 10 | 11 | // newTypeCompiler creates compiler that creates new instance of rt. 12 | func newTypeCompiler(rt reflect.Type) *typeCompiler { 13 | return &typeCompiler{ 14 | rt: rt, 15 | } 16 | } 17 | 18 | func (c typeCompiler) deps(s schema) (deps []*node, err error) { 19 | return nil, nil 20 | } 21 | 22 | func (c typeCompiler) compile(dependencies []reflect.Value, s schema) (reflect.Value, error) { 23 | if c.rt.Kind() == reflect.Ptr { 24 | rt := c.rt.Elem() 25 | zero := reflect.Zero(rt) 26 | addr := reflect.New(rt) 27 | addr.Elem().Set(zero) 28 | return addr, nil 29 | 30 | } 31 | return reflect.New(c.rt).Elem(), nil 32 | } 33 | -------------------------------------------------------------------------------- /cmp_value.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | type valueCompiler struct { 8 | rv reflect.Value 9 | } 10 | 11 | func (v valueCompiler) deps(s schema) ([]*node, error) { 12 | return nil, nil 13 | } 14 | 15 | func (v valueCompiler) compile(dependencies []reflect.Value, s schema) (reflect.Value, error) { 16 | return v.rv, nil 17 | } 18 | -------------------------------------------------------------------------------- /container.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | // Container is a dependency injection container. 10 | type Container struct { 11 | // Dependency injection schema. 12 | schema *defaultSchema 13 | // Array of provider cleanups. 14 | cleanups []func() 15 | } 16 | 17 | // New constructs container with provided options. Example usage (simplified): 18 | // 19 | // Define constructors and invocations: 20 | // 21 | // func NewHTTPServer(mux *http.ServeMux) *http.Server { 22 | // return &http.Server{ 23 | // Handler: mux, 24 | // } 25 | // } 26 | // 27 | // func NewHTTPServeMux() *http.ServeMux { 28 | // return http.ServeMux{} 29 | // } 30 | // 31 | // func StartServer(server *http.Server) error { 32 | // return server.ListenAndServe() 33 | // } 34 | // 35 | // Use it with container: 36 | // 37 | // container, err := di.New( 38 | // di.Provide(NewHTTPServer), 39 | // di.Provide(NewHTTPServeMux), 40 | // di.Invoke(StartServer), 41 | // ) 42 | // if err != nil { 43 | // // handle error 44 | // } 45 | func New(options ...Option) (_ *Container, err error) { 46 | c := &Container{ 47 | schema: newDefaultSchema(), 48 | cleanups: []func(){}, 49 | } 50 | var di diopts 51 | // apply container diopts 52 | for _, opt := range options { 53 | opt.apply(&di) 54 | } 55 | // provide container to advanced usage e.g. condition providing 56 | _ = c.provide(func() *Container { return c }) 57 | if err := c.apply(di); err != nil { 58 | return nil, err 59 | } 60 | return c, nil 61 | } 62 | 63 | // Apply applies options to container. 64 | // 65 | // err := container.Apply( 66 | // di.Provide(NewHTTPServer), 67 | // ) 68 | // if err != nil { 69 | // // handle error 70 | // } 71 | func (c *Container) Apply(options ...Option) error { 72 | var di diopts 73 | for _, opt := range options { 74 | opt.apply(&di) 75 | } 76 | return c.apply(di) 77 | } 78 | 79 | // Provide provides to container reliable way to build type. The constructor will be invoked lazily on-demand. 80 | // For more information about constructors see Constructor interface. ProvideOption can add additional behavior to 81 | // the process of type resolving. 82 | func (c *Container) Provide(constructor Constructor, options ...ProvideOption) error { 83 | if err := c.provide(constructor, options...); err != nil { 84 | return errWithStack(err) 85 | } 86 | return nil 87 | } 88 | 89 | // ProvideValue provides value as is. 90 | func (c *Container) ProvideValue(value Value, options ...ProvideOption) error { 91 | if err := c.provideValue(value, options...); err != nil { 92 | return errWithStack(err) 93 | } 94 | return nil 95 | } 96 | 97 | // Invocation is a function whose signature looks like: 98 | // 99 | // func StartServer(server *http.Server) error { 100 | // return server.ListenAndServe() 101 | // } 102 | // 103 | // Like a constructor invocation may have unlimited count of arguments and 104 | // they will be resolved automatically. The invocation can return an optional error. 105 | // Error will be returned as is. 106 | type Invocation interface{} 107 | 108 | // Invoke calls the function fn. It parses function parameters. Looks for it in a container. 109 | // And invokes function with them. See Invocation for details. 110 | func (c *Container) Invoke(invocation Invocation, options ...InvokeOption) error { 111 | err := c.invoke(invocation, options...) 112 | if err != nil && knownError(err) { 113 | return errWithStack(err) 114 | } 115 | if err != nil { 116 | return err 117 | } 118 | return nil 119 | } 120 | 121 | type Pointer interface{} 122 | 123 | // Has checks that type exists in container, if not it return false. 124 | // 125 | // var server *http.Server 126 | // if container.Has(&server) { 127 | // // handle server existence 128 | // } 129 | // 130 | // It like Resolve() but doesn't instantiate a type. 131 | func (c *Container) Has(target Pointer, options ...ResolveOption) (bool, error) { 132 | if _, err := c.find(target, options...); errors.Is(err, ErrTypeNotExists) { 133 | return false, nil 134 | } else if err != nil { 135 | return false, err 136 | } 137 | return true, nil 138 | } 139 | 140 | // Resolve resolves type and fills target pointer. 141 | // 142 | // var server *http.Server 143 | // if err := container.Resolve(&server); err != nil { 144 | // // handle error 145 | // } 146 | func (c *Container) Resolve(ptr Pointer, options ...ResolveOption) error { 147 | if err := c.resolve(ptr, options...); err != nil { 148 | return errWithStack(err) 149 | } 150 | return nil 151 | } 152 | 153 | // ValueFunc is a lazy-loading wrapper for iteration. 154 | type ValueFunc func() (interface{}, error) 155 | 156 | // IterateFunc function that will be called on each instance in iterate selection. 157 | type IterateFunc func(tags Tags, value ValueFunc) error 158 | 159 | // Iterate iterates over group of Pointer type with IterateFunc. 160 | // 161 | // var servers []*http.Server 162 | // iterFn := func(tags di.Tags, loader ValueFunc) error { 163 | // i, err := loader() 164 | // if err != nil { 165 | // return err 166 | // } 167 | // // do stuff with result: i.(*http.Server) 168 | // return nil 169 | // } 170 | // container.Iterate(&servers, iterFn) 171 | func (c *Container) Iterate(target Pointer, fn IterateFunc, options ...ResolveOption) error { 172 | node, err := c.find(target, options...) 173 | if err != nil { 174 | return err 175 | } 176 | group, ok := node.compiler.(*groupCompiler) 177 | if ok { 178 | for i, n := range group.matched { 179 | err = fn(n.tags, func() (interface{}, error) { 180 | v, err := n.Value(c.schema) 181 | if err != nil { 182 | return nil, err 183 | } 184 | return v.Interface(), nil 185 | }) 186 | if err != nil { 187 | return fmt.Errorf("%s with index %d failed: %s", node, i, err) 188 | } 189 | } 190 | return nil 191 | } 192 | return fmt.Errorf("iteration can be used with groups only") 193 | } 194 | 195 | // Cleanup runs destructors in reverse order that was been created. 196 | func (c *Container) Cleanup() { 197 | for i := len(c.schema.cleanups) - 1; i >= 0; i-- { 198 | c.schema.cleanups[i]() 199 | } 200 | } 201 | 202 | // AddParent adds a parent container. Types are resolved from the container, 203 | // it's parents, and ancestors. An error is a cycle is detected in ancestry tree. 204 | func (c *Container) AddParent(parent *Container) error { 205 | return c.schema.addParent(parent.schema) 206 | } 207 | 208 | func (c *Container) apply(di diopts) error { 209 | for _, provide := range di.values { 210 | if err := c.provideValue(provide.value, provide.options...); err != nil { 211 | return fmt.Errorf("%s: %w", provide.frame, err) 212 | } 213 | } 214 | // process di.Resolve() diopts 215 | for _, provide := range di.provides { 216 | if err := c.provide(provide.constructor, provide.options...); err != nil { 217 | return fmt.Errorf("%s: %w", provide.frame, err) 218 | } 219 | } 220 | // error omitted because if logger could not be resolved it will be default 221 | // process di.Invoke() diopts 222 | for _, invoke := range di.invokes { 223 | err := c.invoke(invoke.fn, invoke.options...) 224 | if err != nil && knownError(err) { 225 | return fmt.Errorf("%s: %w", invoke.frame, err) 226 | } 227 | if err != nil { 228 | return err 229 | } 230 | } 231 | // process di.Resolve() diopts 232 | for _, resolve := range di.resolves { 233 | if err := c.resolve(resolve.target, resolve.options...); err != nil { 234 | return fmt.Errorf("%s: %w", resolve.frame, err) 235 | } 236 | } 237 | return nil 238 | } 239 | 240 | func (c *Container) provide(constructor Constructor, options ...ProvideOption) error { 241 | if constructor == nil { 242 | return fmt.Errorf("invalid constructor signature, got nil") 243 | } 244 | params := ProvideParams{} 245 | // apply provide options 246 | for _, opt := range options { 247 | opt.applyProvide(¶ms) 248 | } 249 | n, err := newConstructorNode(constructor) 250 | if err != nil { 251 | return err 252 | } 253 | n.decorators = params.Decorators 254 | for k, v := range params.Tags { 255 | n.tags[k] = v 256 | } 257 | return c.provideNode(n, params) 258 | } 259 | 260 | func (c *Container) provideValue(value Value, options ...ProvideOption) error { 261 | if value == nil { 262 | return fmt.Errorf("invalid value, got nil") 263 | } 264 | params := ProvideParams{} 265 | // apply provide diopts 266 | for _, opt := range options { 267 | opt.applyProvide(¶ms) 268 | } 269 | v := reflect.ValueOf(value) 270 | n := &node{ 271 | compiler: valueCompiler{ 272 | rv: v, 273 | }, 274 | rv: new(reflect.Value), 275 | rt: v.Type(), 276 | tags: params.Tags, 277 | decorators: params.Decorators, 278 | } 279 | return c.provideNode(n, params) 280 | } 281 | 282 | func (c *Container) provideNode(n *node, params ProvideParams) error { 283 | c.schema.register(n) 284 | // register interfaces 285 | for _, cur := range params.Interfaces { 286 | i, err := inspectInterfacePointer(cur) 287 | if err != nil { 288 | return err 289 | } 290 | if !n.rt.Implements(i.Type) { 291 | return fmt.Errorf("%s not implement %s", n, i.Type) 292 | } 293 | c.schema.register(&node{ 294 | rv: n.rv, 295 | rt: i.Type, 296 | tags: n.tags, 297 | compiler: n.compiler, 298 | decorators: n.decorators, 299 | }) 300 | } 301 | return nil 302 | } 303 | 304 | func (c *Container) resolve(ptr Pointer, options ...ResolveOption) error { 305 | node, err := c.find(ptr, options...) 306 | if err != nil { 307 | return err 308 | } 309 | value, err := node.Value(c.schema) 310 | if err != nil { 311 | return fmt.Errorf("%s: %w", node, err) 312 | } 313 | rv := reflect.ValueOf(ptr) 314 | target := rv.Elem() 315 | if canInject(rv.Type()) { 316 | for index := range parsePopulateFields(target.Type()) { 317 | target.Field(index).Set(value.Field(index)) 318 | } 319 | } else { 320 | target.Set(value) 321 | } 322 | return nil 323 | } 324 | 325 | func (c *Container) invoke(invocation Invocation, _ ...InvokeOption) error { 326 | // params := InvokeParams{} 327 | // for _, opt := range diopts { 328 | // opt.apply(¶ms) 329 | // } 330 | if invocation == nil { 331 | return fmt.Errorf("%w, got %s", errInvalidInvocationSignature, "nil") 332 | } 333 | fn, valid := inspectFunction(invocation) 334 | if !valid { 335 | return fmt.Errorf("%w, got %s", errInvalidInvocationSignature, reflect.TypeOf(invocation)) 336 | } 337 | if !validateInvocation(fn) { 338 | return fmt.Errorf("%w, got %s", errInvalidInvocationSignature, reflect.TypeOf(invocation)) 339 | } 340 | nodes, err := parseInvocationParameters(fn, c.schema) 341 | if err != nil { 342 | return err 343 | } 344 | var args []reflect.Value 345 | for _, node := range nodes { 346 | if err := c.schema.prepare(node); err != nil { 347 | return err 348 | } 349 | v, err := node.Value(c.schema) 350 | if err != nil { 351 | return fmt.Errorf("%s: %s", node, err) 352 | } 353 | args = append(args, v) 354 | } 355 | res := funcResult(fn.Call(args)) 356 | if len(res) == 0 { 357 | return nil 358 | } 359 | return res.error(0) 360 | } 361 | 362 | func (c *Container) find(ptr Pointer, options ...ResolveOption) (*node, error) { 363 | if ptr == nil { 364 | return nil, fmt.Errorf("target must be a pointer, got nil") 365 | } 366 | if reflect.ValueOf(ptr).Kind() != reflect.Ptr { 367 | return nil, fmt.Errorf("target must be a pointer, got %s", reflect.TypeOf(ptr)) 368 | } 369 | params := ResolveParams{} 370 | // apply extract diopts 371 | for _, opt := range options { 372 | opt.applyResolve(¶ms) 373 | } 374 | node, err := c.schema.find(reflect.TypeOf(ptr).Elem(), params.Tags) 375 | if err != nil { 376 | return nil, err 377 | } 378 | if err := c.schema.prepare(node); err != nil { 379 | return nil, err 380 | } 381 | return node, nil 382 | } 383 | 384 | type diopts struct { 385 | // Array of di.Provide() options. 386 | provides []provideOptions 387 | // Array of di.ProvideValue() options. 388 | values []provideValueOptions 389 | // Array of di.Invoke() options. 390 | invokes []invokeOptions 391 | // Array of di.Resolve() options. 392 | resolves []resolveOptions 393 | } 394 | -------------------------------------------------------------------------------- /container_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net" 8 | "net/http" 9 | "os" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/defval/di" 15 | ) 16 | 17 | func init() { 18 | di.SetTracer(di.StdTracer{}) 19 | } 20 | 21 | func TestContainer_Provide(t *testing.T) { 22 | t.Run("simple constructor", func(t *testing.T) { 23 | c, err := di.New() 24 | require.NoError(t, err) 25 | require.NoError(t, c.Provide(func() *http.Server { return &http.Server{} })) 26 | }) 27 | 28 | t.Run("constructor with cleanup function", func(t *testing.T) { 29 | c, err := di.New() 30 | require.NoError(t, err) 31 | require.NoError(t, c.Provide(func() (*http.Server, func()) { 32 | return &http.Server{}, func() {} 33 | })) 34 | }) 35 | 36 | t.Run("constructor with cleanup and error", func(t *testing.T) { 37 | c, err := di.New() 38 | require.NoError(t, err) 39 | require.NoError(t, c.Provide(func() (*http.Server, func(), error) { 40 | return &http.Server{}, func() {}, nil 41 | })) 42 | }) 43 | 44 | t.Run("provide string cause error", func(t *testing.T) { 45 | c, err := di.New() 46 | require.NoError(t, err) 47 | err = c.Provide("string") 48 | require.Error(t, err) 49 | require.Contains(t, err.Error(), "container_test.go:") 50 | require.Contains(t, err.Error(), ": invalid constructor signature, got string") 51 | }) 52 | 53 | t.Run("provide nil cause error", func(t *testing.T) { 54 | c, err := di.New() 55 | require.NoError(t, err) 56 | err = c.Provide(nil) 57 | require.Error(t, err) 58 | require.Contains(t, err.Error(), "container_test.go:") 59 | require.Contains(t, err.Error(), ": invalid constructor signature, got nil") 60 | }) 61 | 62 | t.Run("provide struct pointer cause error", func(t *testing.T) { 63 | c, err := di.New() 64 | require.NoError(t, err) 65 | err = c.Provide(&http.Server{}) 66 | require.Error(t, err) 67 | require.Contains(t, err.Error(), "container_test.go:") 68 | require.Contains(t, err.Error(), ": invalid constructor signature, got *http.Server") 69 | }) 70 | 71 | t.Run("provide constructor without result cause error", func(t *testing.T) { 72 | c, err := di.New() 73 | require.NoError(t, err) 74 | err = c.Provide(func() {}) 75 | require.Error(t, err) 76 | require.Contains(t, err.Error(), "container_test.go:") 77 | require.Contains(t, err.Error(), ": invalid constructor signature, got func()") 78 | }) 79 | 80 | t.Run("provide constructor with many resultant types cause error", func(t *testing.T) { 81 | c, err := di.New() 82 | require.NoError(t, err) 83 | ctor := func() (*http.Server, *http.ServeMux, error) { 84 | return nil, nil, nil 85 | } 86 | err = c.Provide(ctor) 87 | require.Error(t, err) 88 | require.Contains(t, err.Error(), "container_test.go:") 89 | require.Contains(t, err.Error(), ": invalid constructor signature, got func() (*http.Server, *http.ServeMux, error)") 90 | }) 91 | 92 | t.Run("provide constructor with incorrect result error", func(t *testing.T) { 93 | c, err := di.New() 94 | require.NoError(t, err) 95 | ctor := func() (*http.Server, *http.ServeMux) { 96 | return nil, nil 97 | } 98 | err = c.Provide(ctor) 99 | require.Error(t, err) 100 | require.Contains(t, err.Error(), "container_test.go:") 101 | require.Contains(t, err.Error(), "invalid constructor signature, got func() (*http.Server, *http.ServeMux)") 102 | }) 103 | 104 | t.Run("provide duplicate not cause error", func(t *testing.T) { 105 | c, err := di.New() 106 | require.NoError(t, err) 107 | ctor := func() *http.Server { return nil } 108 | require.NoError(t, c.Provide(ctor)) 109 | require.NoError(t, c.Provide(ctor)) 110 | }) 111 | 112 | t.Run("provide as not implemented interface cause error", func(t *testing.T) { 113 | c, err := di.New() 114 | require.NoError(t, err) 115 | // http server not implement io.Reader interface 116 | err = c.Provide(func() *http.Server { return nil }, di.As(new(io.Reader))) 117 | require.Error(t, err) 118 | require.Contains(t, err.Error(), "container_test.go:") 119 | require.Contains(t, err.Error(), ": *http.Server not implement io.Reader") 120 | }) 121 | 122 | t.Run("provide type as several interfaces", func(t *testing.T) { 123 | c, err := di.New() 124 | require.NoError(t, err) 125 | require.NotNil(t, c) 126 | file := &os.File{} 127 | require.NoError(t, c.Provide(func() *os.File { return file }, di.As(new(io.Closer), new(io.ReadCloser)))) 128 | var closer io.Closer 129 | var readCloser io.ReadCloser 130 | require.NoError(t, c.Resolve(&closer)) 131 | require.NoError(t, c.Resolve(&readCloser)) 132 | require.Equal(t, fmt.Sprintf("%p", closer), fmt.Sprintf("%p", file)) 133 | require.Equal(t, fmt.Sprintf("%p", readCloser), fmt.Sprintf("%p", file)) 134 | }) 135 | 136 | t.Run("using not interface type in di.As() cause error", func(t *testing.T) { 137 | c, err := di.New() 138 | require.NoError(t, err) 139 | err = c.Provide(func() *http.Server { return nil }, di.As(&http.Server{})) 140 | require.Error(t, err) 141 | require.Contains(t, err.Error(), "container_test.go:") 142 | require.Contains(t, err.Error(), ": *http.Server: not a pointer to interface") 143 | }) 144 | 145 | t.Run("using nil type in di.As() cause error", func(t *testing.T) { 146 | c, err := di.New() 147 | require.NoError(t, err) 148 | err = c.Provide(func() *http.Server { return &http.Server{} }, di.As(nil)) 149 | require.Error(t, err) 150 | require.Contains(t, err.Error(), "container_test.go:") 151 | require.Contains(t, err.Error(), ": nil: not a pointer to interface") 152 | }) 153 | } 154 | 155 | func TestContainer_ProvideValue(t *testing.T) { 156 | t.Run("provide nil value cause error", func(t *testing.T) { 157 | c, err := di.New() 158 | require.NoError(t, err) 159 | require.NotNil(t, c) 160 | err = c.ProvideValue(nil) 161 | require.Error(t, err) 162 | require.Contains(t, err.Error(), "container_test.go:") 163 | require.Contains(t, err.Error(), "invalid value, got nil") 164 | }) 165 | 166 | t.Run("provide and resolve value", func(t *testing.T) { 167 | c, err := di.New() 168 | require.NoError(t, err) 169 | require.NotNil(t, c) 170 | mux := &http.ServeMux{} 171 | err = c.ProvideValue(mux, di.As(new(http.Handler))) 172 | require.NoError(t, err) 173 | err = c.Provide(func(handler http.Handler) *http.Server { 174 | return &http.Server{ 175 | Handler: handler, 176 | } 177 | }) 178 | require.NoError(t, err) 179 | var server *http.Server 180 | err = c.Resolve(&server) 181 | require.NoError(t, err) 182 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", server.Handler)) 183 | }) 184 | 185 | t.Run("provide values by option", func(t *testing.T) { 186 | mux := &http.ServeMux{} 187 | c, err := di.New( 188 | di.ProvideValue(mux), 189 | ) 190 | require.NoError(t, err) 191 | require.NotNil(t, c) 192 | var result *http.ServeMux 193 | err = c.Resolve(&result) 194 | require.NoError(t, err) 195 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", result)) 196 | }) 197 | 198 | t.Run("provide nil value by option", func(t *testing.T) { 199 | c, err := di.New( 200 | di.ProvideValue(nil), 201 | ) 202 | require.Error(t, err) 203 | require.Nil(t, c) 204 | require.Contains(t, err.Error(), "container_test.go:") 205 | require.Contains(t, err.Error(), "invalid value, got nil") 206 | }) 207 | } 208 | 209 | func TestContainer_Resolve(t *testing.T) { 210 | t.Run("resolve into nil cause error", func(t *testing.T) { 211 | c, err := di.New() 212 | require.NoError(t, err) 213 | require.NotNil(t, c) 214 | err = c.Resolve(nil) 215 | require.Error(t, err) 216 | require.Contains(t, err.Error(), "container_test.go:") 217 | require.Contains(t, err.Error(), ": target must be a pointer, got nil") 218 | }) 219 | 220 | t.Run("resolve into struct{} cause error", func(t *testing.T) { 221 | c, err := di.New() 222 | require.NoError(t, err) 223 | require.NotNil(t, c) 224 | err = c.Resolve(struct{}{}) 225 | require.Error(t, err) 226 | require.Contains(t, err.Error(), "container_test.go:") 227 | require.Contains(t, err.Error(), ": target must be a pointer, got struct {}") 228 | }) 229 | 230 | t.Run("resolve into string cause error", func(t *testing.T) { 231 | c, err := di.New() 232 | require.NoError(t, err) 233 | require.NotNil(t, c) 234 | err = c.Resolve("string") 235 | require.Error(t, err) 236 | require.Contains(t, err.Error(), "container_test.go:") 237 | require.Contains(t, err.Error(), ": target must be a pointer, got string") 238 | }) 239 | 240 | t.Run("resolve with failed build", func(t *testing.T) { 241 | c, err := di.New() 242 | require.NoError(t, err) 243 | err = c.Provide(func() (*http.Server, error) { 244 | return &http.Server{}, fmt.Errorf("server build failed") 245 | }) 246 | require.NoError(t, err) 247 | var server *http.Server 248 | err = c.Resolve(&server) 249 | require.Error(t, err) 250 | require.Contains(t, err.Error(), "container_test.go:") 251 | require.Contains(t, err.Error(), ": *http.Server: server build failed") 252 | }) 253 | 254 | t.Run("resolve with failed dependency build", func(t *testing.T) { 255 | c, err := di.New() 256 | require.NoError(t, err) 257 | err = c.Provide(func() (*http.Server, error) { 258 | return &http.Server{}, fmt.Errorf("server build failed") 259 | }) 260 | require.NoError(t, err) 261 | err = c.Provide(func(server *http.Server) string { 262 | return "string" 263 | }) 264 | require.NoError(t, err) 265 | var s string 266 | err = c.Resolve(&s) 267 | require.Error(t, err) 268 | require.Contains(t, err.Error(), "container_test.go:") 269 | require.Contains(t, err.Error(), ": *http.Server: server build failed") 270 | }) 271 | 272 | t.Run("resolve cleanup error", func(t *testing.T) { 273 | c, err := di.New() 274 | require.NoError(t, err) 275 | require.NotNil(t, c) 276 | called := false 277 | cleanup := func() { 278 | called = true 279 | } 280 | require.NoError(t, c.Provide(func() (*http.Server, func(), error) { 281 | return &http.Server{}, cleanup, nil 282 | })) 283 | var server *http.Server 284 | require.NoError(t, c.Resolve(&server)) 285 | c.Cleanup() 286 | require.True(t, called) 287 | }) 288 | 289 | t.Run("resolve returns type that was created in constructor", func(t *testing.T) { 290 | c, err := di.New() 291 | require.NoError(t, err) 292 | require.NotNil(t, c) 293 | server := &http.Server{} 294 | require.NoError(t, c.Provide(func() *http.Server { return server })) 295 | var extracted *http.Server 296 | require.NoError(t, c.Resolve(&extracted)) 297 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", extracted)) 298 | }) 299 | 300 | t.Run("resolve same pointer on each resolve", func(t *testing.T) { 301 | c, err := di.New() 302 | require.NoError(t, err) 303 | require.NotNil(t, c) 304 | require.NoError(t, c.Provide(func() *http.Server { return &http.Server{} })) 305 | var server1 *http.Server 306 | require.NoError(t, c.Resolve(&server1)) 307 | var server2 *http.Server 308 | require.NoError(t, c.Resolve(&server2)) 309 | require.Equal(t, fmt.Sprintf("%p", server1), fmt.Sprintf("%p", server2)) 310 | }) 311 | 312 | t.Run("resolve not existing dependency cause error", func(t *testing.T) { 313 | c, err := di.New() 314 | require.NoError(t, err) 315 | err = c.Provide(func(handler http.Handler) *http.Server { return &http.Server{} }) 316 | require.NoError(t, err) 317 | var server *http.Server 318 | err = c.Resolve(&server) 319 | require.Error(t, err) 320 | require.Contains(t, err.Error(), "container_test.go:") 321 | require.Contains(t, err.Error(), "*http.Server: type http.Handler not exists in the container") 322 | }) 323 | 324 | t.Run("resolve not existing type cause error", func(t *testing.T) { 325 | c, err := di.New() 326 | require.NotNil(t, c) 327 | require.NoError(t, err) 328 | err = c.Resolve(&http.Server{}) 329 | require.Error(t, err) 330 | require.True(t, errors.Is(err, di.ErrTypeNotExists)) 331 | require.Contains(t, err.Error(), "container_test.go:") 332 | require.Contains(t, err.Error(), ": type http.Server not exists in the container") 333 | }) 334 | 335 | t.Run("resolve functions", func(t *testing.T) { 336 | var result []string 337 | fn1 := func() { result = append(result, "fn1") } 338 | fn2 := func() { result = append(result, "fn2") } 339 | fn3 := func() { result = append(result, "fn3") } 340 | c, err := di.New() 341 | require.NoError(t, err) 342 | require.NotNil(t, c) 343 | type MyFunc func() 344 | require.NoError(t, c.Provide(func() MyFunc { return fn1 })) 345 | require.NoError(t, c.Provide(func() MyFunc { return fn2 })) 346 | require.NoError(t, c.Provide(func() MyFunc { return fn3 })) 347 | var funcs []MyFunc 348 | require.NoError(t, c.Resolve(&funcs)) 349 | require.Len(t, funcs, 3) 350 | for _, fn := range funcs { 351 | fn() 352 | } 353 | require.Equal(t, []string{"fn1", "fn2", "fn3"}, result) 354 | }) 355 | 356 | t.Run("container provided by default", func(t *testing.T) { 357 | var container *di.Container 358 | c, err := di.New() 359 | require.NoError(t, err) 360 | require.NotNil(t, c) 361 | require.NoError(t, c.Resolve(&container)) 362 | require.Equal(t, fmt.Sprintf("%p", c), fmt.Sprintf("%p", container)) 363 | }) 364 | 365 | t.Run("cycle cause error", func(t *testing.T) { 366 | c, err := di.New() 367 | require.NoError(t, err) 368 | // bool -> int32 -> int64 -> bool 369 | err = c.Provide(func(int32) bool { return true }) 370 | require.NoError(t, err) 371 | err = c.Provide(func(int64) int32 { return 0 }) 372 | require.NoError(t, err) 373 | err = c.Provide(func(bool) int64 { return 0 }) 374 | require.NoError(t, err) 375 | var b bool 376 | err = c.Resolve(&b) 377 | require.Error(t, err) 378 | require.Contains(t, err.Error(), "container_test.go:") 379 | require.Contains(t, err.Error(), ": cycle detected") // todo: improve message 380 | }) 381 | 382 | //t.Run("first resolve checks graph correctness", func(t *testing.T) { 383 | // c, err := di.New() 384 | // require.NoError(t, err) 385 | // err = c.Provide(func(handler http.Handler) *http.Server { return &http.Server{Handler: handler} }) 386 | // require.NoError(t, err) 387 | // err = c.Provide(func() string { return "" }) 388 | // require.NoError(t, err) 389 | // var s string 390 | // err = c.Resolve(&s) 391 | // require.EqualError(t, err, "") 392 | //}) 393 | 394 | t.Run("resolve not existing dependency type cause error", func(t *testing.T) { 395 | c, err := di.New() 396 | require.NoError(t, err) 397 | require.NotNil(t, c) 398 | require.NoError(t, c.Provide(func(int) int32 { return 0 })) 399 | var i int32 400 | err = c.Resolve(&i) 401 | require.Error(t, err) 402 | require.Contains(t, err.Error(), "container_test.go:") 403 | require.Contains(t, err.Error(), ": int32: type int not exists in the container") 404 | }) 405 | 406 | t.Run("resolve correct argument", func(t *testing.T) { 407 | c, err := di.New() 408 | require.NoError(t, err) 409 | mux := &http.ServeMux{} 410 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 411 | require.NoError(t, c.Provide(func(mux *http.ServeMux) *http.Server { 412 | return &http.Server{Handler: mux} 413 | })) 414 | var server *http.Server 415 | require.NoError(t, c.Resolve(&server)) 416 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", server.Handler)) 417 | }) 418 | } 419 | 420 | func TestContainer_Decorate(t *testing.T) { 421 | t.Run("decorate provide", func(t *testing.T) { 422 | c, err := di.New() 423 | require.NoError(t, err) 424 | executed := false 425 | server := &http.Server{} 426 | decorator := func(value di.Value) error { 427 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", value)) 428 | executed = true 429 | return nil 430 | } 431 | err = c.Provide(func() *http.Server { return server }, di.Decorate(decorator)) 432 | require.NoError(t, err) 433 | require.False(t, executed) 434 | var result *http.Server 435 | err = c.Resolve(&result) 436 | require.NoError(t, err) 437 | require.True(t, executed) 438 | }) 439 | 440 | t.Run("decorate provide value", func(t *testing.T) { 441 | c, err := di.New() 442 | require.NoError(t, err) 443 | executed := false 444 | server := &http.Server{} 445 | err = c.ProvideValue(server, di.Decorate(func(value di.Value) error { 446 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", value)) 447 | executed = true 448 | return nil 449 | })) 450 | require.NoError(t, err) 451 | require.False(t, executed) 452 | var result *http.Server 453 | err = c.Resolve(&result) 454 | require.NoError(t, err) 455 | require.True(t, executed) 456 | }) 457 | 458 | t.Run("decorate error", func(t *testing.T) { 459 | c, err := di.New() 460 | require.NoError(t, err) 461 | executed := false 462 | server := &http.Server{} 463 | err = c.ProvideValue(server, di.Decorate(func(value di.Value) error { 464 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", value)) 465 | executed = true 466 | return errors.New("decorate error") 467 | })) 468 | require.NoError(t, err) 469 | require.False(t, executed) 470 | var result *http.Server 471 | err = c.Resolve(&result) 472 | require.Error(t, err) 473 | require.Contains(t, err.Error(), "container_test.go") 474 | require.Contains(t, err.Error(), ": *http.Server: decorate error") 475 | require.True(t, executed) 476 | }) 477 | 478 | t.Run("decorate group", func(t *testing.T) { 479 | c, err := di.New() 480 | require.NoError(t, err) 481 | executed1 := 0 482 | executed2 := 0 483 | server1 := &http.Server{} 484 | server2 := &http.Server{} 485 | err = c.ProvideValue(server1, di.Decorate(func(value di.Value) error { 486 | require.Equal(t, fmt.Sprintf("%p", server1), fmt.Sprintf("%p", value)) 487 | executed1++ 488 | return nil 489 | })) 490 | require.NoError(t, err) 491 | err = c.ProvideValue(server2, di.Decorate(func(value di.Value) error { 492 | require.Equal(t, fmt.Sprintf("%p", server2), fmt.Sprintf("%p", value)) 493 | executed2++ 494 | return nil 495 | })) 496 | require.NoError(t, err) 497 | require.Zero(t, executed1) 498 | require.Zero(t, executed2) 499 | var result []*http.Server 500 | err = c.Resolve(&result) 501 | require.NoError(t, err) 502 | // check that decorate called only once 503 | err = c.Resolve(&result) 504 | require.NoError(t, err) 505 | require.Equal(t, 1, executed1) 506 | require.Equal(t, 1, executed2) 507 | }) 508 | 509 | t.Run("decorate interfaces", func(t *testing.T) { 510 | c, err := di.New() 511 | require.NoError(t, err) 512 | executed1 := 0 513 | server1 := &http.Server{} 514 | err = c.ProvideValue(server1, di.Decorate(func(value di.Value) error { 515 | require.Equal(t, fmt.Sprintf("%p", server1), fmt.Sprintf("%p", value)) 516 | executed1++ 517 | return nil 518 | }), di.As(new(io.Closer))) 519 | require.NoError(t, err) 520 | require.Zero(t, executed1) 521 | var result io.Closer 522 | err = c.Resolve(&result) 523 | require.NoError(t, err) 524 | err = c.Resolve(&result) 525 | require.NoError(t, err) 526 | require.Equal(t, 1, executed1) 527 | }) 528 | } 529 | func TestContainer_Apply(t *testing.T) { 530 | t.Run("apply applies container options with error", func(t *testing.T) { 531 | c, err := di.New() 532 | require.NoError(t, err) 533 | require.NotNil(t, c) 534 | err = c.Apply( 535 | di.Provide(func() *http.Server { return &http.Server{} }, di.As(new(io.Closer))), 536 | di.Provide(func() *os.File { return &os.File{} }, di.As(new(io.Closer))), 537 | ) 538 | require.NoError(t, err) 539 | var closer io.Closer 540 | err = c.Resolve(&closer) 541 | require.Error(t, err) 542 | require.Contains(t, err.Error(), "container_test.go:") 543 | require.Contains(t, err.Error(), ": multiple definitions of io.Closer, maybe you need to use group type: []io.Closer") 544 | }) 545 | } 546 | 547 | func TestContainer_Interfaces(t *testing.T) { 548 | t.Run("resolve interface with several implementations cause error", func(t *testing.T) { 549 | c, err := di.New( 550 | di.Provide(func() *http.Server { return &http.Server{} }, di.As(new(io.Closer))), 551 | di.Provide(func() *os.File { return &os.File{} }, di.As(new(io.Closer))), 552 | ) 553 | require.NoError(t, err) 554 | var closer io.Closer 555 | err = c.Resolve(&closer) 556 | require.Error(t, err) 557 | require.Contains(t, err.Error(), "container_test.go:") 558 | require.Contains(t, err.Error(), ": multiple definitions of io.Closer, maybe you need to use group type: []io.Closer") 559 | }) 560 | 561 | t.Run("resolve constructor interface argument", func(t *testing.T) { 562 | mux := &http.ServeMux{} 563 | c, err := di.New( 564 | di.Provide(func() *http.ServeMux { return mux }, di.As(new(http.Handler))), 565 | di.Provide(func(handler http.Handler) *http.Server { return &http.Server{Handler: handler} }), 566 | ) 567 | require.NoError(t, err) 568 | var handler http.Handler 569 | err = c.Resolve(&handler) 570 | require.NoError(t, err) 571 | var server *http.Server 572 | err = c.Resolve(&server) 573 | require.NoError(t, err) 574 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", server.Handler)) 575 | }) 576 | 577 | t.Run("resolve not existing unnamed definition with named", func(t *testing.T) { 578 | c, err := di.New() 579 | require.NoError(t, err) 580 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("two"))) 581 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("three"))) 582 | var mux *http.ServeMux 583 | err = c.Resolve(&mux) 584 | require.Error(t, err) 585 | require.Contains(t, err.Error(), "container_test.go:") 586 | require.Contains(t, err.Error(), ": multiple definitions of *http.ServeMux, maybe you need to use group type: []*http.ServeMux") 587 | }) 588 | 589 | t.Run("resolve same pointer on resolve", func(t *testing.T) { 590 | c, err := di.New() 591 | require.NoError(t, err) 592 | require.NotNil(t, c) 593 | require.NoError(t, c.Provide(func() *http.ServeMux { return &http.ServeMux{} }, di.As(new(http.Handler)))) 594 | var server *http.ServeMux 595 | require.NoError(t, c.Resolve(&server)) 596 | var handler http.Handler 597 | require.NoError(t, c.Resolve(&handler)) 598 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", handler)) 599 | }) 600 | } 601 | 602 | func TestContainer_Groups(t *testing.T) { 603 | t.Run("resolve multiple type instances as slice of type", func(t *testing.T) { 604 | c, err := di.New() 605 | require.NoError(t, err) 606 | require.NotNil(t, c) 607 | conn1 := &net.TCPConn{} 608 | conn2 := &net.TCPConn{} 609 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn1 })) 610 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn2 })) 611 | var conns []*net.TCPConn 612 | require.NoError(t, c.Resolve(&conns)) 613 | require.Len(t, conns, 2) 614 | require.Equal(t, fmt.Sprintf("%p", conn1), fmt.Sprintf("%p", conns[0])) 615 | require.Equal(t, fmt.Sprintf("%p", conn2), fmt.Sprintf("%p", conns[1])) 616 | }) 617 | 618 | t.Run("resolve not specific type of group cause error", func(t *testing.T) { 619 | c, err := di.New() 620 | require.NoError(t, err) 621 | require.NotNil(t, c) 622 | conn1 := &net.TCPConn{} 623 | conn2 := &net.TCPConn{} 624 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn1 })) 625 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn2 })) 626 | var conn *net.TCPConn 627 | err = c.Resolve(&conn) 628 | require.Error(t, err) 629 | require.Contains(t, err.Error(), "container_test.go:") 630 | require.Contains(t, err.Error(), ": multiple definitions of *net.TCPConn, maybe you need to use group type: []*net.TCPConn") 631 | }) 632 | 633 | t.Run("resolve group of interface", func(t *testing.T) { 634 | c, err := di.New() 635 | require.NoError(t, err) 636 | require.NotNil(t, c) 637 | server := &http.Server{} 638 | file := &os.File{} 639 | require.NoError(t, c.Provide(func() *http.Server { return server }, di.As(new(io.Closer)))) 640 | require.NoError(t, c.Provide(func() *os.File { return file }, di.As(new(io.Closer)))) 641 | var closers []io.Closer 642 | require.NoError(t, c.Resolve(&closers)) 643 | require.Len(t, closers, 2) 644 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", closers[0])) 645 | require.Equal(t, fmt.Sprintf("%p", file), fmt.Sprintf("%p", closers[1])) 646 | }) 647 | 648 | t.Run("group updates on provide", func(t *testing.T) { 649 | var result []string 650 | fn1 := func() { result = append(result, "fn1") } 651 | fn2 := func() { result = append(result, "fn2") } 652 | fn3 := func() { result = append(result, "fn3") } 653 | c, err := di.New() 654 | require.NoError(t, err) 655 | require.NotNil(t, c) 656 | type MyFunc func() 657 | var funcs []MyFunc 658 | require.NoError(t, c.Provide(func() MyFunc { return fn1 })) 659 | require.NoError(t, c.Resolve(&funcs)) 660 | require.Len(t, funcs, 1) 661 | require.NoError(t, c.Provide(func() MyFunc { return fn2 })) 662 | require.NoError(t, c.Resolve(&funcs)) 663 | require.Len(t, funcs, 2) 664 | require.NoError(t, c.Provide(func() MyFunc { return fn3 })) 665 | require.NoError(t, c.Resolve(&funcs)) 666 | require.Len(t, funcs, 3) 667 | }) 668 | 669 | t.Run("resolve one interface from group of type", func(t *testing.T) { 670 | c, err := di.New() 671 | require.NoError(t, err) 672 | require.NotNil(t, c) 673 | conn1 := &net.TCPConn{} 674 | conn2 := &net.TCPConn{} 675 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn1 })) 676 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn2 }, di.As(new(net.Conn)))) 677 | var conn net.Conn 678 | require.NoError(t, c.Resolve(&conn)) 679 | require.Equal(t, fmt.Sprintf("%p", conn), fmt.Sprintf("%p", conn)) 680 | }) 681 | } 682 | 683 | func TestContainer_Iterate(t *testing.T) { 684 | t.Run("iterate over nil causes error", func(t *testing.T) { 685 | c, err := di.New() 686 | require.NoError(t, err) 687 | require.NotNil(t, c) 688 | err = c.Iterate(nil, func(tags di.Tags, loader di.ValueFunc) error { 689 | return nil 690 | }) 691 | require.EqualError(t, err, "target must be a pointer, got nil") 692 | }) 693 | t.Run("iterate over struct causes error", func(t *testing.T) { 694 | c, err := di.New() 695 | require.NoError(t, err) 696 | require.NotNil(t, c) 697 | err = c.Iterate(http.ServeMux{}, func(tags di.Tags, loader di.ValueFunc) error { 698 | return nil 699 | }) 700 | require.EqualError(t, err, "target must be a pointer, got http.ServeMux") 701 | }) 702 | t.Run("iterate over struct causes error", func(t *testing.T) { 703 | c, err := di.New() 704 | require.NoError(t, err) 705 | require.NotNil(t, c) 706 | require.NoError(t, c.Provide(func() http.ServeMux { return http.ServeMux{} })) 707 | err = c.Iterate(&http.ServeMux{}, func(tags di.Tags, loader di.ValueFunc) error { 708 | return nil 709 | }) 710 | require.EqualError(t, err, "iteration can be used with groups only") 711 | }) 712 | t.Run("iterates over instances", func(t *testing.T) { 713 | c, err := di.New() 714 | require.NoError(t, err) 715 | require.NotNil(t, c) 716 | conn1 := &net.TCPConn{} 717 | conn2 := &net.TCPConn{} 718 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn1 })) 719 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn2 })) 720 | var iterates []*net.TCPConn 721 | var conn []*net.TCPConn 722 | iterFn := func(tags di.Tags, loader di.ValueFunc) error { 723 | i, err := loader() 724 | if err != nil { 725 | return err 726 | } 727 | iterates = append(iterates, i.(*net.TCPConn)) 728 | return nil 729 | } 730 | err = c.Iterate(&conn, iterFn) 731 | require.NoError(t, err) 732 | require.Len(t, iterates, 2) 733 | require.Equal(t, iterates[0], conn1) 734 | require.Equal(t, iterates[1], conn2) 735 | }) 736 | 737 | t.Run("iterates over tagged instances", func(t *testing.T) { 738 | c, err := di.New() 739 | require.NoError(t, err) 740 | require.NotNil(t, c) 741 | conn1 := &net.TCPConn{} 742 | conn2 := &net.TCPConn{} 743 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn1 }, di.Tags{"conn": "tcp1"})) 744 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn2 }, di.Tags{"conn": "tcp2"})) 745 | require.NoError(t, c.Provide(func() *net.TCPConn { return &net.TCPConn{} })) 746 | var iterates []*net.TCPConn 747 | var all []di.Tags 748 | var conn []*net.TCPConn 749 | iterFn := func(tags di.Tags, loader di.ValueFunc) error { 750 | all = append(all, tags) 751 | i, err := loader() 752 | if err != nil { 753 | return err 754 | } 755 | iterates = append(iterates, i.(*net.TCPConn)) 756 | return nil 757 | } 758 | err = c.Iterate(&conn, iterFn, di.Tags{"conn": "*"}) 759 | require.NoError(t, err) 760 | require.Len(t, iterates, 2) 761 | require.Equal(t, conn1, iterates[0]) 762 | require.Equal(t, conn2, iterates[1]) 763 | require.Equal(t, []di.Tags{ 764 | {"conn": "tcp1"}, 765 | {"conn": "tcp2"}, 766 | }, all) 767 | }) 768 | 769 | t.Run("iterates over instances with errors", func(t *testing.T) { 770 | c, err := di.New() 771 | require.NoError(t, err) 772 | require.NotNil(t, c) 773 | conn1 := &net.TCPConn{} 774 | conn2 := &net.TCPConn{} 775 | require.NoError(t, c.Provide(func() *net.TCPConn { return conn1 })) 776 | require.NoError(t, c.Provide(func() (*net.TCPConn, error) { return conn2, fmt.Errorf("tcp conn 2 error") })) 777 | var iterates []*net.TCPConn 778 | var conn []*net.TCPConn 779 | iterFn := func(tags di.Tags, loader di.ValueFunc) error { 780 | i, err := loader() 781 | if err != nil { 782 | return err 783 | } 784 | iterates = append(iterates, i.(*net.TCPConn)) 785 | return nil 786 | } 787 | err = c.Iterate(&conn, iterFn) 788 | require.EqualError(t, err, "[]*net.TCPConn with index 1 failed: tcp conn 2 error") 789 | }) 790 | } 791 | 792 | func TestContainer_Tags(t *testing.T) { 793 | t.Run("resolve named definition", func(t *testing.T) { 794 | c, err := di.New() 795 | require.NoError(t, err) 796 | first := &http.Server{} 797 | second := &http.Server{} 798 | err = c.Provide(func() *http.Server { return first }, di.WithName("first")) 799 | require.NoError(t, err) 800 | err = c.Provide(func() *http.Server { return second }, di.WithName("second")) 801 | require.NoError(t, err) 802 | var extracted *http.Server 803 | err = c.Resolve(&extracted) 804 | require.Error(t, err) 805 | require.Contains(t, err.Error(), "container_test.go:") 806 | require.Contains(t, err.Error(), ": multiple definitions of *http.Server, maybe you need to use group type: []*http.Server") 807 | err = c.Resolve(&extracted, di.Name("first")) 808 | require.NoError(t, err) 809 | require.Equal(t, fmt.Sprintf("%p", first), fmt.Sprintf("%p", extracted)) 810 | err = c.Resolve(&extracted, di.Name("second")) 811 | require.NoError(t, err) 812 | require.Equal(t, fmt.Sprintf("%p", second), fmt.Sprintf("%p", extracted)) 813 | }) 814 | 815 | t.Run("resolve single instance of group without specifying tags cause error", func(t *testing.T) { 816 | c, err := di.New() 817 | require.NoError(t, err) 818 | require.NotNil(t, c) 819 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("first"))) 820 | var mux *http.ServeMux 821 | require.NoError(t, c.Resolve(&mux)) 822 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("second"))) 823 | err = c.Resolve(&mux) 824 | require.Error(t, err) 825 | require.Contains(t, err.Error(), "container_test.go:") 826 | require.Contains(t, err.Error(), ": multiple definitions of *http.ServeMux, maybe you need to use group type: []*http.ServeMux") 827 | }) 828 | 829 | t.Run("resolve not found by tags instance cause error", func(t *testing.T) { 830 | c, err := di.New() 831 | require.NoError(t, err) 832 | require.NotNil(t, c) 833 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("first"))) 834 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("second"))) 835 | var mux *http.ServeMux 836 | err = c.Resolve(&mux, di.Name("unknown")) 837 | require.Error(t, err) 838 | require.Contains(t, err.Error(), "container_test.go:") 839 | require.Contains(t, err.Error(), ": type *http.ServeMux[name:unknown] not exists") 840 | }) 841 | 842 | t.Run("provide duplication of named definition", func(t *testing.T) { 843 | c, err := di.New() 844 | require.NoError(t, err) 845 | require.NotNil(t, c) 846 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("first"))) 847 | err = c.Provide(http.NewServeMux, di.WithName("first")) 848 | require.NoError(t, err) 849 | }) 850 | 851 | t.Run("resolve existing unnamed definition with named", func(t *testing.T) { 852 | c, err := di.New() 853 | require.NoError(t, err) 854 | require.NoError(t, c.Provide(http.NewServeMux)) 855 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("two"))) 856 | require.NoError(t, c.Provide(http.NewServeMux, di.WithName("three"))) 857 | var mux *http.ServeMux 858 | err = c.Resolve(&mux) 859 | require.Error(t, err) 860 | require.Contains(t, err.Error(), "container_test.go:") 861 | require.Contains(t, err.Error(), "multiple definitions of *http.ServeMux, maybe you need to use group type: []*http.ServeMux") 862 | require.NoError(t, c.Resolve(&mux, di.Name("two"))) 863 | require.NoError(t, c.Resolve(&mux, di.Name("three"))) 864 | }) 865 | 866 | t.Run("resolve instances with same tag", func(t *testing.T) { 867 | c, err := di.New() 868 | require.NoError(t, err) 869 | require.NoError(t, c.Provide(http.NewServeMux)) 870 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"tag": "the_same"})) 871 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"tag": "the_same"})) 872 | var muxs []*http.ServeMux 873 | err = c.Resolve(&muxs, di.Tags{"tag": "the_same"}) 874 | require.NoError(t, err) 875 | require.Len(t, muxs, 2) 876 | }) 877 | 878 | t.Run("resolve all instances with tag", func(t *testing.T) { 879 | c, err := di.New() 880 | require.NoError(t, err) 881 | require.NoError(t, c.Provide(http.NewServeMux)) 882 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"server": "one"})) 883 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"server": "two"})) 884 | var muxs []*http.ServeMux 885 | err = c.Resolve(&muxs, di.Tags{"server": "*"}) 886 | require.NoError(t, err) 887 | require.Len(t, muxs, 2) 888 | }) 889 | 890 | t.Run("resolve all instances with several tags", func(t *testing.T) { 891 | c, err := di.New() 892 | require.NoError(t, err) 893 | require.NoError(t, c.Provide(http.NewServeMux)) 894 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"server": "one"})) 895 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"server": "one", "http": "one"})) 896 | require.NoError(t, c.Provide(http.NewServeMux, di.Tags{"server": "two", "http": "two"})) 897 | var muxs []*http.ServeMux 898 | err = c.Resolve(&muxs, di.Tags{"server": "*", "http": "*"}) 899 | require.NoError(t, err) 900 | require.Len(t, muxs, 2) 901 | }) 902 | 903 | t.Run("provide type with tags (deprecated)", func(t *testing.T) { 904 | type Server struct { 905 | di.Tags `http:"true" server:"true"` 906 | } 907 | var s *Server 908 | _, err := di.New( 909 | di.Provide(func() *Server { return &Server{} }), 910 | di.Resolve(&s, di.Tags{"http": "true", "server": "true"}), 911 | ) 912 | require.NoError(t, err) 913 | }) 914 | 915 | t.Run("provide type with tags", func(t *testing.T) { 916 | type Server struct { 917 | di.Tags `di:"http=true,server=true"` 918 | } 919 | var s *Server 920 | _, err := di.New( 921 | di.Provide(func() *Server { return &Server{} }), 922 | di.Resolve(&s, di.Tags{"http": "true", "server": "true"}), 923 | ) 924 | require.NoError(t, err) 925 | }) 926 | 927 | t.Run("provide type with non di tags", func(t *testing.T) { 928 | type Dep struct { 929 | } 930 | type Server struct { 931 | di.Inject 932 | Dep *Dep `json:"dep" di:""` 933 | } 934 | var s Server 935 | _, err := di.New( 936 | di.Provide(func() *Dep { return &Dep{} }), 937 | di.Resolve(&s), 938 | ) 939 | require.NoError(t, err) 940 | }) 941 | 942 | } 943 | 944 | func TestContainer_Group(t *testing.T) { 945 | t.Run("resolve group argument", func(t *testing.T) { 946 | c, err := di.New() 947 | require.NoError(t, err) 948 | server := &http.Server{} 949 | file := &os.File{} 950 | require.NoError(t, c.Provide(func() *http.Server { return server }, di.As(new(io.Closer)))) 951 | require.NoError(t, c.Provide(func() *os.File { return file }, di.As(new(io.Closer)))) 952 | type Closers []io.Closer 953 | require.NoError(t, c.Provide(func(closers []io.Closer) Closers { return closers })) 954 | var closers Closers 955 | require.NoError(t, c.Resolve(&closers)) 956 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", closers[0])) 957 | require.Equal(t, fmt.Sprintf("%p", file), fmt.Sprintf("%p", closers[1])) 958 | }) 959 | 960 | t.Run("incorrect signature", func(t *testing.T) { 961 | c, err := di.New() 962 | require.NoError(t, err) 963 | err = c.Invoke(func() *http.Server { return &http.Server{} }) 964 | require.Error(t, err) 965 | require.Contains(t, err.Error(), "container_test.go:") 966 | require.Contains(t, err.Error(), ": invalid invocation signature, got func() *http.Server") 967 | }) 968 | } 969 | 970 | func TestContainer_Invoke(t *testing.T) { 971 | t.Run("invoke nil", func(t *testing.T) { 972 | c, err := di.New() 973 | require.NoError(t, err) 974 | err = c.Invoke(nil) 975 | require.Error(t, err) 976 | require.Contains(t, err.Error(), "container_test.go:") 977 | require.Contains(t, err.Error(), ": invalid invocation signature, got nil") 978 | }) 979 | 980 | t.Run("invoke non function type", func(t *testing.T) { 981 | c, err := di.New() 982 | require.NoError(t, err) 983 | err = c.Invoke(1) 984 | require.Error(t, err) 985 | require.Contains(t, err.Error(), "container_test.go:") 986 | require.Contains(t, err.Error(), ": invalid invocation signature, got int") 987 | }) 988 | 989 | t.Run("invoke invalid function", func(t *testing.T) { 990 | c, err := di.New() 991 | require.NoError(t, err) 992 | err = c.Invoke(func() *http.Server { return &http.Server{} }) 993 | require.Error(t, err) 994 | require.Contains(t, err.Error(), "container_test.go:") 995 | require.Contains(t, err.Error(), ": invalid invocation signature, got func() *http.Server") 996 | }) 997 | 998 | t.Run("invocation function with not provided dependency cause error", func(t *testing.T) { 999 | c, err := di.New() 1000 | require.NoError(t, err) 1001 | err = c.Invoke(func(server *http.Server) {}) 1002 | require.Error(t, err) 1003 | require.Contains(t, err.Error(), "container_test.go:") 1004 | require.Contains(t, err.Error(), ": type *http.Server not exists in the container") 1005 | }) 1006 | 1007 | t.Run("invocation function with dependency that can't be constructed", func(t *testing.T) { 1008 | c, err := di.New() 1009 | require.NoError(t, err) 1010 | err = c.Provide(func() (*http.Server, error) { return nil, fmt.Errorf("server error") }) 1011 | require.NoError(t, err) 1012 | err = c.Invoke(func(server *http.Server) {}) 1013 | require.EqualError(t, err, "*http.Server: server error") 1014 | }) 1015 | 1016 | t.Run("invoke with nil error must be called", func(t *testing.T) { 1017 | c, err := di.New() 1018 | require.NoError(t, err) 1019 | var invokeCalled bool 1020 | err = c.Invoke(func() error { 1021 | invokeCalled = true 1022 | return nil 1023 | }) 1024 | require.NoError(t, err) 1025 | require.True(t, invokeCalled) 1026 | }) 1027 | 1028 | t.Run("resolve dependencies in invoke", func(t *testing.T) { 1029 | c, err := di.New() 1030 | require.NoError(t, err) 1031 | server := &http.Server{} 1032 | called := false 1033 | require.NoError(t, c.Provide(func() *http.Server { return server })) 1034 | err = c.Invoke(func(in *http.Server) { 1035 | called = true 1036 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", in)) 1037 | }) 1038 | require.NoError(t, err) 1039 | require.True(t, called) 1040 | }) 1041 | 1042 | t.Run("invoke return error as is", func(t *testing.T) { 1043 | c, err := di.New() 1044 | require.NoError(t, err) 1045 | err = c.Invoke(func() error { return fmt.Errorf("invoke error") }) 1046 | require.EqualError(t, err, "invoke error") 1047 | }) 1048 | 1049 | t.Run("cycle cause error", func(t *testing.T) { 1050 | c, err := di.New() 1051 | require.NoError(t, err) 1052 | // bool -> int32 -> int64 -> bool 1053 | err = c.Provide(func(int32) bool { return true }) 1054 | require.NoError(t, err) 1055 | err = c.Provide(func(int64) int32 { return 0 }) 1056 | require.NoError(t, err) 1057 | err = c.Provide(func(bool) int64 { return 0 }) 1058 | require.NoError(t, err) 1059 | err = c.Invoke(func(bool) {}) 1060 | require.Error(t, err) 1061 | require.Contains(t, err.Error(), "container_test.go:") 1062 | require.Contains(t, err.Error(), ": cycle detected") // todo: improve message 1063 | }) 1064 | } 1065 | 1066 | func TestContainer_Has(t *testing.T) { 1067 | t.Run("exists nil returns false", func(t *testing.T) { 1068 | c, err := di.New() 1069 | require.NoError(t, err) 1070 | has, err := c.Has(nil) 1071 | require.EqualError(t, err, "target must be a pointer, got nil") 1072 | require.False(t, has) 1073 | }) 1074 | 1075 | t.Run("exists return true if type exists", func(t *testing.T) { 1076 | c, err := di.New() 1077 | require.NoError(t, err) 1078 | require.NoError(t, c.Provide(func() *http.Server { return &http.Server{} })) 1079 | var server *http.Server 1080 | has, err := c.Has(&server) 1081 | require.NoError(t, err) 1082 | require.True(t, has) 1083 | }) 1084 | 1085 | t.Run("exists return false if type not exists", func(t *testing.T) { 1086 | c, err := di.New() 1087 | require.NoError(t, err) 1088 | var server *http.Server 1089 | has, err := c.Has(&server) 1090 | require.NoError(t, err) 1091 | require.False(t, has) 1092 | }) 1093 | 1094 | t.Run("exists interface", func(t *testing.T) { 1095 | c, err := di.New() 1096 | require.NoError(t, err) 1097 | require.NoError(t, c.Provide(func() *http.Server { return &http.Server{} }, di.As(new(io.Closer)))) 1098 | var server io.Closer 1099 | has, err := c.Has(&server) 1100 | require.NoError(t, err) 1101 | require.True(t, has) 1102 | }) 1103 | 1104 | t.Run("exists named provider", func(t *testing.T) { 1105 | c, err := di.New() 1106 | require.NoError(t, err) 1107 | err = c.Provide(func() *http.Server { return &http.Server{} }, di.Tags{"name": "server"}) 1108 | require.NoError(t, err) 1109 | var server *http.Server 1110 | has, err := c.Has(&server, di.Tags{"name": "server"}) 1111 | require.NoError(t, err) 1112 | require.True(t, has) 1113 | }) 1114 | 1115 | t.Run("type exists but no possible to build returns true", func(t *testing.T) { 1116 | c, err := di.New() 1117 | require.NoError(t, err) 1118 | require.NoError(t, c.Provide(func(b bool) *http.Server { return &http.Server{} })) 1119 | var server *http.Server 1120 | has, err := c.Has(&server) 1121 | require.EqualError(t, err, "*http.Server: type bool not exists in the container") 1122 | require.False(t, has) 1123 | }) 1124 | } 1125 | 1126 | func TestContainer_Inject(t *testing.T) { 1127 | t.Run("inject into provided struct pointer with di.Inject", func(t *testing.T) { 1128 | c, err := di.New() 1129 | require.NoError(t, err) 1130 | type InjectableType struct { 1131 | di.Inject 1132 | Mux *http.ServeMux 1133 | } 1134 | mux := &http.ServeMux{} 1135 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 1136 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1137 | var result *InjectableType 1138 | require.NoError(t, c.Resolve(&result)) 1139 | require.NotNil(t, result.Mux) 1140 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", result.Mux)) 1141 | }) 1142 | 1143 | t.Run("inject into provided struct value with di.Inject", func(t *testing.T) { 1144 | c, err := di.New() 1145 | require.NoError(t, err) 1146 | type InjectableType struct { 1147 | di.Inject 1148 | Mux *http.ServeMux 1149 | } 1150 | mux := &http.ServeMux{} 1151 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 1152 | err = c.Provide(func() InjectableType { return InjectableType{} }) 1153 | require.NoError(t, err) 1154 | var it InjectableType 1155 | err = c.Resolve(&it) 1156 | require.NoError(t, err) 1157 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", it.Mux)) 1158 | }) 1159 | 1160 | t.Run("constructor with injectable embed pointer", func(t *testing.T) { 1161 | c, err := di.New() 1162 | require.NoError(t, err) 1163 | type InjectableType struct { 1164 | di.Inject 1165 | *http.ServeMux 1166 | } 1167 | mux := &http.ServeMux{} 1168 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 1169 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1170 | var result *InjectableType 1171 | require.NoError(t, c.Resolve(&result)) 1172 | require.NotNil(t, result.ServeMux) 1173 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", result.ServeMux)) 1174 | }) 1175 | 1176 | t.Run("container resolve injectable parameter", func(t *testing.T) { 1177 | c, err := di.New() 1178 | require.NoError(t, err) 1179 | type Parameters struct { 1180 | di.Inject 1181 | Server *http.Server 1182 | File *os.File 1183 | } 1184 | server := &http.Server{} 1185 | file := &os.File{} 1186 | require.NoError(t, c.Provide(func() *http.Server { return server })) 1187 | require.NoError(t, c.Provide(func() *os.File { return file })) 1188 | type Result struct { 1189 | server *http.Server 1190 | file *os.File 1191 | } 1192 | require.NoError(t, c.Provide(func(params Parameters) *Result { 1193 | return &Result{params.Server, params.File} 1194 | })) 1195 | var extracted *Result 1196 | require.NoError(t, c.Resolve(&extracted)) 1197 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", extracted.server)) 1198 | require.Equal(t, fmt.Sprintf("%p", file), fmt.Sprintf("%p", extracted.file)) 1199 | }) 1200 | 1201 | t.Run("not existing injectable field cause error", func(t *testing.T) { 1202 | c, err := di.New() 1203 | require.NoError(t, err) 1204 | type InjectableType struct { 1205 | di.Inject 1206 | Mux *http.ServeMux 1207 | } 1208 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1209 | var result *InjectableType 1210 | err = c.Resolve(&result) 1211 | require.Error(t, err) 1212 | require.Contains(t, err.Error(), "container_test.go:") 1213 | require.Contains(t, err.Error(), ": *di_test.InjectableType: type *http.ServeMux not exists in the container") 1214 | }) 1215 | 1216 | t.Run("not existing and optional field set to nil (deprecated)", func(t *testing.T) { 1217 | c, err := di.New() 1218 | require.NoError(t, err) 1219 | type InjectableType struct { 1220 | di.Inject 1221 | Mux *http.ServeMux `optional:"true"` 1222 | } 1223 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1224 | var result *InjectableType 1225 | require.NoError(t, c.Resolve(&result)) 1226 | require.Nil(t, result.Mux) 1227 | }) 1228 | 1229 | t.Run("not existing and optional field set to nil", func(t *testing.T) { 1230 | c, err := di.New() 1231 | require.NoError(t, err) 1232 | type InjectableType struct { 1233 | di.Inject 1234 | Mux *http.ServeMux `di:"optional"` 1235 | } 1236 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1237 | var result *InjectableType 1238 | require.NoError(t, c.Resolve(&result)) 1239 | require.Nil(t, result.Mux) 1240 | }) 1241 | 1242 | t.Run("nested injectable field resolved correctly", func(t *testing.T) { 1243 | c, err := di.New() 1244 | require.NoError(t, err) 1245 | type NestedInjectableType struct { 1246 | di.Inject 1247 | Mux *http.ServeMux 1248 | } 1249 | type InjectableType struct { 1250 | di.Inject 1251 | Nested NestedInjectableType 1252 | } 1253 | mux := &http.ServeMux{} 1254 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1255 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 1256 | var result *InjectableType 1257 | require.NoError(t, c.Resolve(&result)) 1258 | require.NotNil(t, result.Nested.Mux) 1259 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", result.Nested.Mux)) 1260 | var nit NestedInjectableType 1261 | require.NoError(t, c.Resolve(&nit)) 1262 | require.NotNil(t, nit) 1263 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", nit.Mux)) 1264 | }) 1265 | 1266 | t.Run("cycle in injectable fields cause error", func(t *testing.T) { 1267 | c, err := di.New() 1268 | require.NoError(t, err) 1269 | type InjectableType struct { 1270 | di.Inject 1271 | String string 1272 | } 1273 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1274 | require.NoError(t, c.Provide(func(t *InjectableType) string { return "" })) 1275 | var result *InjectableType 1276 | err = c.Resolve(&result) 1277 | require.Error(t, err) 1278 | require.Contains(t, err.Error(), "container_test.go:") 1279 | require.Contains(t, err.Error(), ": cycle detected") 1280 | }) 1281 | 1282 | t.Run("optional parameter may be nil (deprecated)", func(t *testing.T) { 1283 | c, err := di.New() 1284 | require.NoError(t, err) 1285 | type Parameter struct { 1286 | di.Inject 1287 | Server *http.Server `optional:"true"` 1288 | } 1289 | type Result struct { 1290 | server *http.Server 1291 | } 1292 | require.NoError(t, c.Provide(func(params Parameter) *Result { return &Result{server: params.Server} })) 1293 | var extracted *Result 1294 | require.NoError(t, c.Resolve(&extracted)) 1295 | require.Nil(t, extracted.server) 1296 | }) 1297 | 1298 | t.Run("optional parameter may be nil", func(t *testing.T) { 1299 | c, err := di.New() 1300 | require.NoError(t, err) 1301 | type Parameter struct { 1302 | di.Inject 1303 | Server *http.Server `di:"optional"` 1304 | } 1305 | type Result struct { 1306 | server *http.Server 1307 | } 1308 | require.NoError(t, c.Provide(func(params Parameter) *Result { return &Result{server: params.Server} })) 1309 | var extracted *Result 1310 | require.NoError(t, c.Resolve(&extracted)) 1311 | require.Nil(t, extracted.server) 1312 | }) 1313 | 1314 | t.Run("resolve group in params (deprecated)", func(t *testing.T) { 1315 | c, err := di.New() 1316 | require.NoError(t, err) 1317 | 1318 | type Fn func() 1319 | type Params struct { 1320 | di.Inject 1321 | Handlers []Fn `optional:"true"` 1322 | } 1323 | require.NoError(t, c.Provide(func() Fn { return func() {} })) 1324 | require.NoError(t, c.Provide(func() Fn { return func() {} })) 1325 | require.NoError(t, c.Provide(func(params Params) bool { 1326 | return len(params.Handlers) == 2 1327 | })) 1328 | var extracted bool 1329 | require.NoError(t, c.Resolve(&extracted)) 1330 | require.True(t, extracted) 1331 | }) 1332 | 1333 | t.Run("resolve group in params", func(t *testing.T) { 1334 | c, err := di.New() 1335 | require.NoError(t, err) 1336 | 1337 | type Fn func() 1338 | type Params struct { 1339 | di.Inject 1340 | Handlers []Fn `di:"optional"` 1341 | } 1342 | require.NoError(t, c.Provide(func() Fn { return func() {} })) 1343 | require.NoError(t, c.Provide(func() Fn { return func() {} })) 1344 | require.NoError(t, c.Provide(func(params Params) bool { 1345 | return len(params.Handlers) == 2 1346 | })) 1347 | var extracted bool 1348 | require.NoError(t, c.Resolve(&extracted)) 1349 | require.True(t, extracted) 1350 | }) 1351 | 1352 | t.Run("optional group may be nil (deprecated)", func(t *testing.T) { 1353 | c, err := di.New() 1354 | require.NoError(t, err) 1355 | type Params struct { 1356 | di.Inject 1357 | Handlers []http.Handler `optional:"true"` 1358 | } 1359 | require.NoError(t, c.Provide(func(params Params) bool { 1360 | return params.Handlers == nil 1361 | })) 1362 | var extracted bool 1363 | require.NoError(t, c.Resolve(&extracted)) 1364 | require.True(t, extracted) 1365 | }) 1366 | 1367 | t.Run("optional group may be nil", func(t *testing.T) { 1368 | c, err := di.New() 1369 | require.NoError(t, err) 1370 | type Params struct { 1371 | di.Inject 1372 | Handlers []http.Handler `di:"optional"` 1373 | } 1374 | require.NoError(t, c.Provide(func(params Params) bool { 1375 | return params.Handlers == nil 1376 | })) 1377 | var extracted bool 1378 | require.NoError(t, c.Resolve(&extracted)) 1379 | require.True(t, extracted) 1380 | }) 1381 | 1382 | t.Run("skip private and skip tagged fields (deprecated)", func(t *testing.T) { 1383 | c, err := di.New() 1384 | require.NoError(t, err) 1385 | type InjectableParameter struct { 1386 | di.Inject 1387 | private []http.Handler 1388 | Addrs []net.Addr `optional:"true"` 1389 | Skipped *http.ServeMux `skip:"true"` 1390 | } 1391 | type InjectableType struct { 1392 | di.Inject 1393 | private []http.Handler 1394 | Addrs []net.Addr `optional:"true"` 1395 | } 1396 | require.NoError(t, c.Provide(func(param InjectableParameter) bool { 1397 | return param.Addrs == nil 1398 | })) 1399 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1400 | var extracted bool 1401 | require.NoError(t, c.Resolve(&extracted)) 1402 | require.True(t, extracted) 1403 | var result *InjectableType 1404 | require.NoError(t, c.Resolve(&result)) 1405 | 1406 | mux := http.NewServeMux() 1407 | p := InjectableParameter{Skipped: mux} 1408 | require.NoError(t, c.Resolve(&p)) 1409 | require.Equal(t, InjectableParameter{Skipped: mux}, p) 1410 | }) 1411 | 1412 | t.Run("skip private and skip tagged fields", func(t *testing.T) { 1413 | c, err := di.New() 1414 | require.NoError(t, err) 1415 | type InjectableParameter struct { 1416 | di.Inject 1417 | private []http.Handler 1418 | Addrs []net.Addr `di:"optional"` 1419 | Skipped *http.ServeMux `di:"skip"` 1420 | } 1421 | type InjectableType struct { 1422 | di.Inject 1423 | private []http.Handler 1424 | Addrs []net.Addr `optional:"true"` 1425 | } 1426 | require.NoError(t, c.Provide(func(param InjectableParameter) bool { 1427 | return param.Addrs == nil 1428 | })) 1429 | require.NoError(t, c.Provide(func() *InjectableType { return &InjectableType{} })) 1430 | var extracted bool 1431 | require.NoError(t, c.Resolve(&extracted)) 1432 | require.True(t, extracted) 1433 | var result *InjectableType 1434 | require.NoError(t, c.Resolve(&result)) 1435 | 1436 | mux := http.NewServeMux() 1437 | p := InjectableParameter{Skipped: mux} 1438 | require.NoError(t, c.Resolve(&p)) 1439 | require.Equal(t, InjectableParameter{Skipped: mux}, p) 1440 | }) 1441 | 1442 | t.Run("resolving not provided injectable cause error", func(t *testing.T) { 1443 | c, err := di.New() 1444 | require.NoError(t, err) 1445 | type Parameter struct { 1446 | di.Inject 1447 | Server *http.Server 1448 | } 1449 | var p Parameter 1450 | err = c.Resolve(&p) 1451 | require.Error(t, err) 1452 | require.Contains(t, err.Error(), "container_test.go:") 1453 | require.Contains(t, err.Error(), ": di_test.Parameter: type *http.Server not exists in the container") 1454 | }) 1455 | 1456 | t.Run("resolving provided injectable as interface with dependency", func(t *testing.T) { 1457 | type InjectableType struct { 1458 | di.Inject 1459 | Server *http.Server 1460 | } 1461 | ctor := func() *InjectableType { 1462 | return &InjectableType{} 1463 | } 1464 | server := &http.Server{} 1465 | c, err := di.New( 1466 | di.Provide(func() *http.Server { return server }), 1467 | di.Provide(ctor, di.As(new(di.Interface))), 1468 | ) 1469 | require.NoError(t, err) 1470 | var b di.Interface 1471 | err = c.Resolve(&b) 1472 | require.NoError(t, err) 1473 | require.Equal(t, fmt.Sprintf("%p", server), fmt.Sprintf("%p", b.(*InjectableType).Server)) 1474 | }) 1475 | 1476 | t.Run("resolving provided injectable as interface without dependency cause error", func(t *testing.T) { 1477 | type InjectableType struct { 1478 | di.Inject 1479 | Server *http.Server 1480 | } 1481 | ctor := func() *InjectableType { 1482 | return &InjectableType{} 1483 | } 1484 | c, err := di.New( 1485 | di.Provide(ctor, di.As(new(di.Interface))), 1486 | ) 1487 | require.NoError(t, err) 1488 | var b di.Interface 1489 | err = c.Resolve(&b) 1490 | require.Error(t, err) 1491 | require.Contains(t, err.Error(), "container_test.go:") 1492 | require.Contains(t, err.Error(), ": di.Interface: type *http.Server not exists in the container") 1493 | }) 1494 | 1495 | t.Run("invoke with inject dependency struct", func(t *testing.T) { 1496 | type InjectableParam struct { 1497 | di.Inject 1498 | Mux *http.ServeMux 1499 | } 1500 | c, err := di.New() 1501 | require.NoError(t, err) 1502 | mux := http.NewServeMux() 1503 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 1504 | err = c.Invoke(func(params InjectableParam) { 1505 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", params.Mux)) 1506 | }) 1507 | require.NoError(t, err) 1508 | }) 1509 | 1510 | t.Run("invoke with inject dependency pointer", func(t *testing.T) { 1511 | type InjectableParam struct { 1512 | di.Inject 1513 | Mux *http.ServeMux 1514 | } 1515 | c, err := di.New() 1516 | require.NoError(t, err) 1517 | mux := http.NewServeMux() 1518 | require.NoError(t, c.Provide(func() *http.ServeMux { return mux })) 1519 | var ip *InjectableParam 1520 | err = c.Invoke(func(params *InjectableParam) { 1521 | ip = params 1522 | }) 1523 | require.NoError(t, err) 1524 | require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", ip.Mux)) 1525 | }) 1526 | } 1527 | 1528 | func TestContainer_Cleanup(t *testing.T) { 1529 | t.Run("called", func(t *testing.T) { 1530 | c, err := di.New() 1531 | require.NoError(t, err) 1532 | var cleanupCalled bool 1533 | require.NoError(t, c.Provide(func() (*http.Server, func()) { 1534 | return &http.Server{}, func() { cleanupCalled = true } 1535 | })) 1536 | var extracted *http.Server 1537 | require.NoError(t, c.Resolve(&extracted)) 1538 | c.Cleanup() 1539 | require.True(t, cleanupCalled) 1540 | }) 1541 | 1542 | t.Run("correct order", func(t *testing.T) { 1543 | c, err := di.New() 1544 | require.NoError(t, err) 1545 | var cleanupCalls []string 1546 | require.NoError(t, c.Provide(func(handler http.Handler) (*http.Server, func()) { 1547 | return &http.Server{Handler: handler}, func() { cleanupCalls = append(cleanupCalls, "server") } 1548 | })) 1549 | require.NoError(t, c.Provide(func() (*http.ServeMux, func()) { 1550 | return &http.ServeMux{}, func() { cleanupCalls = append(cleanupCalls, "mux") } 1551 | }, di.As(new(http.Handler)))) 1552 | var server *http.Server 1553 | require.NoError(t, c.Resolve(&server)) 1554 | c.Cleanup() 1555 | require.Equal(t, []string{"server", "mux"}, cleanupCalls) 1556 | }) 1557 | } 1558 | 1559 | func TestContainer_AddParent(t *testing.T) { 1560 | t.Run("provide ancestor and resolve in child", func(t *testing.T) { 1561 | papaw, err := di.New() 1562 | require.NoError(t, err) 1563 | require.NotNil(t, papaw) 1564 | parent, err := di.New() 1565 | require.NoError(t, err) 1566 | require.NotNil(t, parent) 1567 | child, err := di.New() 1568 | require.NoError(t, err) 1569 | require.NotNil(t, child) 1570 | 1571 | require.NoError(t, parent.AddParent(papaw)) 1572 | require.NoError(t, child.AddParent(parent)) 1573 | 1574 | conn1 := &net.TCPConn{} 1575 | require.NoError(t, papaw.Provide(func() *net.TCPConn { return conn1 })) 1576 | 1577 | var conn *net.TCPConn 1578 | require.NoError(t, child.Resolve(&conn)) 1579 | require.Equal(t, fmt.Sprintf("%p", conn1), fmt.Sprintf("%p", conn)) 1580 | }) 1581 | 1582 | t.Run("resolve multiple type instances across ancestors", func(t *testing.T) { 1583 | papaw, err := di.New() 1584 | require.NoError(t, err) 1585 | require.NotNil(t, papaw) 1586 | parent, err := di.New() 1587 | require.NoError(t, err) 1588 | require.NotNil(t, parent) 1589 | child, err := di.New() 1590 | require.NoError(t, err) 1591 | require.NotNil(t, child) 1592 | 1593 | require.NoError(t, parent.AddParent(papaw)) 1594 | require.NoError(t, child.AddParent(parent)) 1595 | 1596 | conn1 := &net.TCPConn{} 1597 | conn2 := &net.TCPConn{} 1598 | conn3 := &net.TCPConn{} 1599 | require.NoError(t, papaw.Provide(func() *net.TCPConn { return conn1 })) 1600 | require.NoError(t, parent.Provide(func() *net.TCPConn { return conn2 })) 1601 | require.NoError(t, child.Provide(func() *net.TCPConn { return conn3 })) 1602 | 1603 | var conns []*net.TCPConn 1604 | require.NoError(t, child.Resolve(&conns)) 1605 | require.Len(t, conns, 3) 1606 | require.Equal(t, fmt.Sprintf("%p", conn1), fmt.Sprintf("%p", conns[0])) 1607 | require.Equal(t, fmt.Sprintf("%p", conn2), fmt.Sprintf("%p", conns[1])) 1608 | require.Equal(t, fmt.Sprintf("%p", conn3), fmt.Sprintf("%p", conns[2])) 1609 | }) 1610 | 1611 | t.Run("add parent errors", func(t *testing.T) { 1612 | parent, err := di.New() 1613 | require.NoError(t, err) 1614 | require.NotNil(t, parent) 1615 | err = parent.AddParent(parent) 1616 | require.Contains(t, err.Error(), "self cycle detected") 1617 | child, err := di.New() 1618 | require.NoError(t, err) 1619 | require.NotNil(t, child) 1620 | err = child.AddParent(parent) 1621 | require.NoError(t, err) 1622 | err = parent.AddParent(child) 1623 | require.Contains(t, err.Error(), "cycle detected") 1624 | err = child.AddParent(parent) 1625 | require.Contains(t, err.Error(), "parent already chained") 1626 | papaw, err := di.New() 1627 | require.NoError(t, err) 1628 | require.NotNil(t, papaw) 1629 | err = parent.AddParent(papaw) 1630 | err = papaw.AddParent(child) 1631 | require.Contains(t, err.Error(), "cycle detected") 1632 | }) 1633 | 1634 | } 1635 | -------------------------------------------------------------------------------- /cycle.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | const ( 8 | temporary = 1 9 | permanent = 2 10 | ) 11 | 12 | func visit(s schema, node *node, marks map[*node]int) error { 13 | if marks[node] == permanent { 14 | return nil 15 | } 16 | if marks[node] == temporary { 17 | return errCycleDetected // todo: improve message 18 | } 19 | marks[node] = temporary 20 | params, err := node.deps(s) 21 | if err != nil { 22 | return fmt.Errorf("%s: %s", node, err) 23 | } 24 | for _, param := range params { 25 | if err := visit(s, param, marks); err != nil { 26 | return err 27 | } 28 | } 29 | for _, field := range node.fields() { 30 | n, err := s.find(field.rt, field.tags) 31 | if err != nil && field.optional { 32 | continue 33 | } 34 | if err != nil { 35 | return fmt.Errorf("%s: %s", node, err) 36 | } 37 | if err := visit(s, n, marks); err != nil { 38 | return err 39 | } 40 | } 41 | marks[node] = permanent 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // The MIT License (MIT) 2 | // 3 | // Copyright (c) 2020 defval 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | /* 24 | Package di provides opinionated way to connect your application components. Container 25 | allows you to inject dependencies into constructors or structures without the need to 26 | have specified each argument manually. 27 | */ 28 | package di 29 | -------------------------------------------------------------------------------- /docs/advanced.md: -------------------------------------------------------------------------------- 1 | # Advanced Features 2 | 3 | - [Modules](#modules) 4 | - [Tags](#tags) 5 | - [ProvideValue](#providevalue) 6 | - [Optional Parameters](#optional-parameters) 7 | - [Struct Field Injection](#struct-field-injection) 8 | - [Iteration](#iteration) 9 | - [Decoration](#decoration) 10 | - [Cleanup](#cleanup) 11 | - [Container Chaining / Scopes](#container-chaining--scopes) 12 | 13 | ### Modules 14 | 15 | You can group previous options into a single variable using `di.Options()`: 16 | 17 | ```go 18 | // account module 19 | account := di.Options( 20 | di.Provide(NewAccountController), 21 | di.Provide(NewAccountRepository), 22 | ) 23 | // auth module 24 | auth := di.Options( 25 | di.Provide(NewAuthController), 26 | di.Provide(NewAuthRepository), 27 | ) 28 | // build container 29 | container, err := di.New( 30 | account, 31 | auth, 32 | ) 33 | if err != nil { 34 | // handle error 35 | } 36 | ``` 37 | 38 | ### Tags 39 | 40 | If you have more than one instance of the same type, you can specify an alias. 41 | For example, two instances of a database: leader - for writing, follower - 42 | for reading. 43 | 44 | #### Wrap type into another unique type 45 | 46 | ```go 47 | // Leader provides write database access. 48 | type Leader struct { 49 | *Database 50 | } 51 | 52 | // Follower provides read database access. 53 | type Follower struct { 54 | *Database 55 | } 56 | ``` 57 | 58 | #### Specify tags with `di.Tags` *provide option*: 59 | 60 | ```go 61 | // provide leader database 62 | di.Provide(NewLeader, di.Tags{"type":"leader"}) 63 | // provide follower database 64 | di.Provide(NewFollower, di.Tags{"type", "follower"})) 65 | ``` 66 | 67 | If you need to resolve it from the container, use `di.Tags` *resolve 68 | option*. 69 | 70 | ```go 71 | var db *Database 72 | container.Resolve(&db, di.Tags{"type": "leader"})) 73 | ``` 74 | 75 | If you need to provide a named definition in another constructor, embed 76 | `di.Inject`. 77 | 78 | ```go 79 | // Parameters 80 | type Parameters struct { 81 | di.Inject 82 | 83 | // use tag for the container to know that field need to be injected. 84 | Leader *Database `di:"type=leader"` 85 | Follower *Database `di:"type=follower"` 86 | } 87 | 88 | // NewService creates a new service with provided parameters. 89 | func NewService(parameters Parameters) *Service { 90 | return &Service{ 91 | Leader: parameters.Leader, 92 | Follower: parameters.Leader, 93 | } 94 | } 95 | ``` 96 | 97 | If you need to resolve all types with the same tag key, use `*` as the tag 98 | value: 99 | 100 | ```go 101 | var db []*Database 102 | di.Resolve(&db, di.Tags{"type": "*"}) 103 | ``` 104 | 105 | ### ProvideValue 106 | 107 | Instead of using `di.Provide` to provide a constructor, you can use `di.ProvideValue` and provide values directly. 108 | This is useful to provide primitive values or values that are easily constructed. You can combine this with the use 109 | of `di.Tags` if you have multiple values of the same type and want to identify each one. 110 | 111 | ```go 112 | di.New( 113 | di.ProvideValue(time.Duration(10*time.Second), di.Tags{"name": "http-timeout"}), 114 | ) 115 | 116 | var timeout time.Duration 117 | c.Resolve(&timeout, di.Tags{"name": "http-timeout"}) 118 | ``` 119 | 120 | To differentiate between multiple values of the same type, you can also use golang type aliases instead of using 121 | `di.Tags`. 122 | 123 | ```go 124 | type ProjectName string 125 | type ProjectVersion string 126 | 127 | c, err := di.New( 128 | di.ProvideValue(ProjectName("my-project")), 129 | di.ProvideValue(ProjectVersion("1.0.0")), 130 | ) 131 | 132 | var pn ProjectName 133 | c.Resolve(&pn) 134 | var pv ProjectVersion 135 | c.Resolve(&pv) 136 | ``` 137 | 138 | ### Optional Parameters 139 | 140 | Also, `di.Inject` with tag `di:"optional"` provides the ability to skip a dependency 141 | if it does not exist in the container. 142 | 143 | ```go 144 | // ServiceParameter 145 | type ServiceParameter struct { 146 | di.Inject 147 | 148 | Logger *Logger `di:"optional"` 149 | } 150 | ``` 151 | 152 | > Constructors that declare dependencies as optional must handle the 153 | > case of those dependencies being absent. 154 | 155 | You can use tagged and optional together. 156 | 157 | ```go 158 | // ServiceParameter 159 | type ServiceParameter struct { 160 | di.Inject 161 | 162 | StdOutLogger *Logger `di:"type=stdout"` 163 | FileLogger *Logger `di:"type=file,optional"` 164 | } 165 | ``` 166 | 167 | If you need to skip field injection, use `di:"skip"` tags for this: 168 | 169 | ```go 170 | // ServiceParameter 171 | type ServiceParameter struct { 172 | di.Inject 173 | 174 | StdOutLogger *Logger `di:"type=stdout"` 175 | FileLogger *Logger `di:"type=file,optional"` 176 | SkipField *SomeType `di:"skip"` // injection skipped 177 | } 178 | ``` 179 | 180 | ### Struct Field Injection 181 | 182 | To avoid constant constructor changes, you can use `di.Inject`. Only 183 | struct pointers are supported as constructing results. And only 184 | `di`-tagged fields will be injected. Such a constructor will work with 185 | using `di` only. 186 | 187 | ```go 188 | // Controller has some endpoints. 189 | type Controller struct { 190 | di.Inject // enables struct field injection 191 | 192 | // fields must be public 193 | // tag lets to specify fields need to be injected 194 | Users UserService 195 | Friends FriendsService `di:"type=cached"` 196 | } 197 | 198 | // NewController creates a controller. 199 | func NewController() *Controller { 200 | return &Controller{} 201 | } 202 | ``` 203 | ### Iteration 204 | 205 | The `di` package provides iteration capabilities, allowing you to iterate over a group of a specific Pointer type with the `IterateFunc`. This can be useful when working with multiple instances of a type or when you need to perform actions on each instance. 206 | 207 | ```go 208 | // ValueFunc is a lazy-loading wrapper for iteration. 209 | type ValueFunc func() (interface{}, error) 210 | 211 | // IterateFunc is a function that will be called on each instance in the iterate selection. 212 | type IterateFunc func(tags Tags, value ValueFunc) error 213 | ``` 214 | 215 | To use iteration with the container, follow the example below: 216 | 217 | ```go 218 | var servers []*http.Server 219 | iterFn := func(tags di.Tags, loader ValueFunc) error { 220 | i, err := loader() 221 | if err != nil { 222 | return err 223 | } 224 | // do stuff with result: i.(*http.Server) 225 | return nil 226 | } 227 | 228 | container.Iterate(&servers, iterFn) 229 | ``` 230 | 231 | In this example, the `Iterate` method is called on the container, passing a slice of pointers to the desired type (in this case, `*http.Server`) and the iterate function, which will be executed on each instance. 232 | 233 | ### Decoration 234 | 235 | The `di` package supports decoration, allowing you to modify container instances through the use of decorators. This can be helpful when you need to make additional modifications to instances after they have been constructed. 236 | 237 | ```go 238 | // Decorator can modify container instance. 239 | type Decorator func(value Value) error 240 | 241 | // Decorate will be called after type construction. You can modify your pointer types. 242 | func Decorate(decorators ...Decorator) ProvideOption { 243 | return provideOption(func(params *ProvideParams) { 244 | params.Decorators = append(params.Decorators, decorators...) 245 | }) 246 | } 247 | ``` 248 | 249 | To use decorators, you can add them to the `Provide` method using the `Decorate` function. Here's an example of a decorator that logs the creation of each instance: 250 | 251 | ```go 252 | // Logger is a simple logger interface for demonstration purposes 253 | type Logger interface { 254 | Log(message string) 255 | } 256 | 257 | // logInstanceCreation is a decorator that logs the creation of instances 258 | func logInstanceCreation(logger Logger) Decorator { 259 | return func(value Value) error { 260 | logger.Log(fmt.Sprintf("Instance of type logger created")) 261 | return nil 262 | } 263 | } 264 | 265 | // Usage example 266 | container, err := di.New( 267 | di.Provide(NewMyType, di.Decorate(logInstanceCreation(myLogger))), 268 | ) 269 | ``` 270 | 271 | In this example, the `logInstanceCreation` decorator logs a message every time a new instance is created. The decorator is added to the `Provide` method using the `Decorate` function, and it is executed after the type construction. 272 | 273 | ### Cleanup 274 | 275 | If the constructor creates a value that needs to be cleaned up, then it 276 | can return a closure to clean up the resource. 277 | 278 | ```go 279 | func NewFile(log Logger, path Path) (*os.File, func(), error) { 280 | f, err := os.Open(string(path)) 281 | if err != nil { 282 | return nil, nil, err 283 | } 284 | cleanup := func() { 285 | if err := f.Close(); err != nil { 286 | log.Log(err) 287 | } 288 | } 289 | return f, cleanup, nil 290 | } 291 | ``` 292 | 293 | After `container.Cleanup()` call, it iterates over instances and calls 294 | cleanup function if it exists. 295 | 296 | ```go 297 | container, err := di.New( 298 | // ... 299 | di.Provide(NewFile), 300 | ) 301 | if err != nil { 302 | // handle error 303 | } 304 | // do something 305 | container.Cleanup() // file was closed 306 | ``` 307 | 308 | ### Container Chaining / Scopes 309 | 310 | You can chain containers together so that values can be resolved from a 311 | parent container. This lets you do things like have a configuration 312 | scope container and an application scoped container. By keeping 313 | configuration values in a different container, you can re-create 314 | the application scoped container when you make configuration changes 315 | since each container has an independent lifecycle. 316 | 317 | **Note:** You should cleanup each container manually. 318 | 319 | ```go 320 | configContainer, err := container.New( 321 | di.Provide(NewServerConfig), 322 | ) 323 | 324 | appContainer, err := container.New(di.Provide(config *SeverConfig) *http.Server { 325 | sever := ... 326 | return server 327 | }) 328 | 329 | if err := appContainer.AddParent(configContainer); err != nil { 330 | // handle error 331 | } 332 | 333 | var server *http.Server 334 | err := appContainer.Resolve(&server) 335 | ``` -------------------------------------------------------------------------------- /docs/tutorial.md: -------------------------------------------------------------------------------- 1 | ## Tutorial 2 | 3 | Learn how to use the `di` package by building a simple application that processes HTTP requests. 4 | 5 | The full tutorial code is available 6 | [here](./../_examples/tutorial/main.go). 7 | 8 | - [Tracing](#tracing) 9 | - [Provide](#provide) 10 | - [Resolve](#resolve) 11 | - [Invoke](#invoke) 12 | - [Lazy-loading](#lazy-loading) 13 | - [Interfaces](#interfaces) 14 | - [Groups](#groups) 15 | 16 | ### Tracing 17 | 18 | Before starting, you can enable tracing to get more information about 19 | the library lifecycle. The `di` package includes the default tracer that 20 | prints output using the standard `log` package: 21 | 22 | ```go 23 | func main() { 24 | di.SetTracer(&di.StdTracer{}) 25 | //... 26 | } 27 | ``` 28 | 29 | ### Provide 30 | 31 | First, we need to provide ways to build two fundamental 32 | types: `http.Server` and `http.ServeMux`. Let's create simple 33 | functional constructors that build them: 34 | 35 | ```go 36 | // NewServer builds an HTTP server with the provided mux as handler. 37 | func NewServer(mux *http.ServeMux) *http.Server { 38 | return &http.Server{ 39 | Handler: mux, 40 | } 41 | } 42 | 43 | // NewServeMux creates a new HTTP serve mux. 44 | func NewServeMux() *http.ServeMux { 45 | return &http.ServeMux{} 46 | } 47 | ``` 48 | 49 | > Supported constructor signature: 50 | > 51 | > ```go 52 | > // cleanup and error are optional 53 | > func([dep1, dep2, depN]) (result, [cleanup, error]) 54 | > ``` 55 | 56 | Now we can teach the container to build these types in three ways: 57 | 58 | Using the preferred functional option style: 59 | 60 | ```go 61 | // create container 62 | container, err := di.New( 63 | di.Provide(NewServer), 64 | di.Provide(NewServeMux), 65 | ) 66 | if err != nil { 67 | // handle error 68 | } 69 | ``` 70 | 71 | ### Resolve 72 | 73 | Next, we can resolve the built server from the container. To do this, define 74 | the variable of the resolved type and pass the variable pointer to the `Resolve` 75 | function. 76 | 77 | If no error occurs, we can use the variable. 78 | 79 | ```go 80 | // declare type variable 81 | var server *http.Server 82 | // resolving 83 | err := container.Resolve(&server) 84 | if err != nil { 85 | // handle error 86 | } 87 | 88 | server.ListenAndServe() 89 | ``` 90 | 91 | > The container creates singletons for combinations of the same type and 92 | > tags. 93 | 94 | ### Invoke 95 | 96 | As an alternative to resolve, we can use the `Invoke()` function of 97 | the `Container`. It builds dependencies and calls the provided function. The Invoke 98 | function can return an optional error. 99 | 100 | ```go 101 | // StartServer starts the server. 102 | func StartServer(server *http.Server) error { 103 | return server.ListenAndServe() 104 | } 105 | 106 | if err := container.Invoke(StartServer); err != nil { 107 | // handle error 108 | } 109 | ``` 110 | 111 | Also, you can use the `di.Invoke()` container option to call some 112 | initialization code. 113 | 114 | ```go 115 | container, err := di.New( 116 | di.Provide(NewServer), 117 | di.Invoke(StartServer), 118 | ) 119 | if err != nil { 120 | // handle error 121 | } 122 | ``` 123 | 124 | The container runs all `invoke functions` in the order they were 125 | declared. If one of them fails, the compilation fails. 126 | 127 | ### Lazy-loading 128 | 129 | Resulting dependencies will be lazy-loaded. If no one requests a type from 130 | the container, it won't be constructed. 131 | 132 | ### Interfaces 133 | 134 | You can provide an implementation as an interface. Use `di.As()` for this. 135 | The arguments of this option must be a pointer(s) to an interface like 136 | `new(http.Handler)`. 137 | 138 | ```go 139 | di.Provide(NewServeMux, di.As(new(http.Handler))) 140 | ``` 141 | 142 | > This syntax with `new` can look strange, but I haven't found a better 143 | > way to specify the interface. 144 | > 145 | > Create an issue if you know a better way ;) 146 | 147 | Updated server constructor: 148 | 149 | ```go 150 | // NewServer creates an HTTP server with the provided mux as handler. 151 | func NewServer(handler http.Handler) *http.Server { 152 | return &http.Server{ 153 | Handler: handler, 154 | } 155 | } 156 | ``` 157 | 158 | Final code: 159 | 160 | ```go 161 | container, err := di.New( 162 | // provide HTTP server 163 | di.Provide(NewServer), 164 | // provide HTTP serve mux as http.Handler interface 165 | di.Provide(NewServeMux, di.As(new(http.Handler))) 166 | ) 167 | if err != nil { 168 | // handle error 169 | } 170 | ``` 171 | 172 | Now the container uses `*http.ServeMux` as the implementation of `http.Handler`. 173 | Interface usage contributes to writing more testable code. 174 | 175 | ### Groups 176 | 177 | ##### Grouping 178 | 179 | The container automatically groups the same types into a `[]` slice. It 180 | works with `di.As()` too. For example, `di.As(new(http.Handler)` 181 | automatically creates a group `[]http.Handler`. 182 | 183 | Let's add some HTTP controllers using this feature. The main function of 184 | controllers is registering routes. First, create an interface 185 | for it. 186 | 187 | ```go 188 | // Controller is an interface that can register its routes. 189 | type Controller interface { 190 | RegisterRoutes(mux *http.ServeMux) 191 | } 192 | ``` 193 | 194 | Next, create implementations for this interface. 195 | 196 | ##### Order implementation 197 | 198 | ```go 199 | // OrderController is an HTTP controller for orders. 200 | type OrderController struct {} 201 | 202 | // NewOrderController creates an auth HTTP controller. 203 | func NewOrderController() *OrderController { 204 | return &OrderController{} 205 | } 206 | 207 | // RegisterRoutes is a Controller interface implementation. 208 | func (a *OrderController) RegisterRoutes(mux *http.ServeMux) { 209 | mux.HandleFunc("/orders", a.RetrieveOrders) 210 | } 211 | 212 | // RetrieveOrders loads orders and writes them to the writer. 213 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, request *http.Request) { 214 | // implementation 215 | } 216 | ``` 217 | 218 | ##### User implementation 219 | 220 | ```go 221 | // UserController is an HTTP endpoint for users. 222 | type UserController struct {} 223 | 224 | // NewUserController creates a user HTTP endpoint. 225 | func NewUserController() *UserController { 226 | return &UserController{} 227 | } 228 | 229 | // RegisterRoutes is a Controller interface implementation. 230 | func (e *UserController) RegisterRoutes(mux *http.ServeMux) { 231 | mux.HandleFunc("/users", e.RetrieveUsers) 232 | } 233 | 234 | // RetrieveUsers loads users and writes them using the writer. 235 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, request *http.Request) { 236 | // implementation 237 | } 238 | ``` 239 | 240 | ##### Container initialization code 241 | 242 | Just like in the example with interfaces, we will use the `di.As()` provide 243 | option. 244 | 245 | ```go 246 | container, err := di.New( 247 | di.Provide(NewServer), // provide HTTP server 248 | di.Provide(NewServeMux), // provide HTTP serve mux 249 | // endpoints 250 | di.Provide(NewOrderController, di.As(new(Controller))), // provide order controller 251 | di.Provide(NewUserController, di.As(new(Controller))), // provide user controller 252 | ) 253 | if err != nil { 254 | // handle error 255 | } 256 | ``` 257 | 258 | Now we can use the `[]Controller` group in our mux. Updated code: 259 | 260 | ```go 261 | // NewServeMux creates a new HTTP serve mux. 262 | func NewServeMux(controllers []Controller) *http.ServeMux { 263 | mux := &http.ServeMux{} 264 | 265 | for _, controller := range controllers { 266 | controller.RegisterRoutes(mux) 267 | } 268 | 269 | return mux 270 | } 271 | ``` 272 | 273 | The full tutorial code is available 274 | [here](./../_examples/tutorial/main.go) -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | // ErrTypeNotExists causes when type not found in container. 10 | ErrTypeNotExists = errors.New("not exists in the container") 11 | ) 12 | 13 | var ( 14 | errInvalidInvocationSignature = errors.New("invalid invocation signature") 15 | errCycleDetected = errors.New("cycle detected") 16 | errFieldsNotSupported = errors.New("fields not supported") 17 | ) 18 | 19 | // knownError return true if err is library known error. 20 | func knownError(err error) bool { 21 | if errors.Is(err, ErrTypeNotExists) || 22 | errors.Is(err, errInvalidInvocationSignature) || 23 | errors.Is(err, errCycleDetected) || 24 | errors.Is(err, errFieldsNotSupported) { 25 | return true 26 | } 27 | return false 28 | } 29 | 30 | func errWithStack(err error) error { 31 | return fmt.Errorf("%s: %w", stacktrace(1), err) 32 | } 33 | 34 | func bug() { 35 | panic("you found a bug, please create new issue for this: https://github.com/defval/di/issues/new") 36 | } 37 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/defval/di 2 | 3 | go 1.20 4 | 5 | require github.com/stretchr/testify v1.8.2 6 | 7 | require ( 8 | github.com/davecgh/go-spew v1.1.1 // indirect 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | gopkg.in/yaml.v3 v3.0.1 // indirect 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 5 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 6 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 7 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 8 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 9 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 10 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 11 | github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= 12 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 13 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 14 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 15 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 16 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 17 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 18 | -------------------------------------------------------------------------------- /inject.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // Inject indicates that struct public fields will be injected automatically. 11 | // 12 | // type Application struct { 13 | // di.Inject 14 | // 15 | // Server *http.Server // will be injected 16 | // } 17 | // 18 | // You can specify tags for injected types: 19 | // 20 | // type Application struct { 21 | // di.Inject 22 | // 23 | // Public *http.Server `type:"public"` // *http.Server with type:public tag combination will be injected 24 | // Private *http.Server `type:"private"` // *http.Server with type:private tag combination will be injected 25 | // } 26 | type Inject struct { 27 | injectable 28 | } 29 | 30 | // injectable interface needs to struct fields injection functional. 31 | type injectable interface { 32 | isInjectable() 33 | } 34 | 35 | type field struct { 36 | rt reflect.Type 37 | tags Tags 38 | optional bool 39 | } 40 | 41 | // canInject checks that type t contain di.Inject and supports injecting. 42 | func canInject(t reflect.Type) bool { 43 | if !t.Implements(injectableInterface) { 44 | return false 45 | } 46 | if t.Kind() == reflect.Ptr { 47 | t = t.Elem() 48 | } 49 | if t.Kind() != reflect.Struct { 50 | return false 51 | } 52 | return true 53 | } 54 | 55 | // parsePopulateFields parses fields of struct that can be populated. 56 | func parsePopulateFields(rt reflect.Type) map[int]field { 57 | if !canInject(rt) { 58 | return nil 59 | } 60 | var rv reflect.Value 61 | if !rv.IsValid() { 62 | switch rt.Kind() { 63 | case reflect.Ptr: 64 | rv = reflect.New(rt.Elem()) 65 | default: 66 | rv = reflect.New(rt).Elem() 67 | } 68 | } 69 | if rt.Kind() == reflect.Ptr { 70 | rt = rt.Elem() 71 | rv = rv.Elem() 72 | } 73 | fields := make(map[int]field, rt.NumField()) 74 | // fi - field index 75 | for fi := 0; fi < rt.NumField(); fi++ { 76 | fv := rv.Field(fi) 77 | // check that field can be set 78 | if !fv.CanSet() { 79 | continue 80 | } 81 | // cur - current field 82 | cur := rt.Field(fi) 83 | f, valid := inspectStructField(rt, cur) 84 | if !valid { 85 | continue 86 | } 87 | fields[fi] = field{ 88 | rt: cur.Type, 89 | tags: f.tags, 90 | optional: f.optional, 91 | } 92 | } 93 | return fields 94 | } 95 | 96 | // inspectStructField parses struct field 97 | func inspectStructField(rt reflect.Type, f reflect.StructField) (field, bool) { 98 | 99 | result := field{ 100 | rt: f.Type, 101 | tags: Tags{}, 102 | optional: false, 103 | } 104 | if f.Tag == "" { 105 | return result, true 106 | } 107 | 108 | diTag, found := f.Tag.Lookup("di") 109 | if found { 110 | if diTag == "" { 111 | return result, true 112 | } 113 | for _, v := range strings.Split(diTag, ",") { 114 | v = strings.TrimSpace(v) 115 | switch v { 116 | case "skip": 117 | return field{}, false 118 | case "optional": 119 | result.optional = true 120 | default: 121 | kv := strings.SplitN(v, "=", 2) 122 | if len(kv) == 2 { 123 | result.tags[kv[0]] = kv[1] 124 | } else { 125 | panic(fmt.Sprintf("invalid di tag: key=value got: %s", v)) 126 | } 127 | } 128 | } 129 | return result, true 130 | } else { 131 | // handle the old deprecated struct tagging style. 132 | result, noSkip := inspectStructFieldDeprecated(f) 133 | tracer.Trace("Deprecation warning: please replace the field tags on '%s.%s' with: %v", rt.Name(), f.Name, newTagStyleText(result.tags, result.optional, !noSkip)) 134 | return result, noSkip 135 | } 136 | } 137 | 138 | func newTagStyleText(tags map[string]string, optional bool, skip bool) string { 139 | parts := []string{} 140 | if skip { 141 | parts = append(parts, "skip") 142 | } else { 143 | 144 | if optional { 145 | parts = append(parts, "optional") 146 | } 147 | for k, v := range tags { 148 | parts = append(parts, k+"="+v) 149 | } 150 | } 151 | return `di:"` + strings.Join(parts, ",") + `"` 152 | } 153 | 154 | func inspectStructFieldDeprecated(f reflect.StructField) (field, bool) { 155 | tags := Tags{} 156 | t := string(f.Tag) 157 | optional := false 158 | 159 | // this code copied from reflect.StructField.Lookup() method. 160 | for t != "" { 161 | // Skip leading space. 162 | i := 0 163 | for i < len(t) && t[i] == ' ' { 164 | i++ 165 | } 166 | t = t[i:] 167 | if t == "" { 168 | break 169 | } 170 | 171 | // Scan to colon. A space, a quote or a control character is a syntax error. 172 | // Strictly speaking, control chars include the range [0x7f, 0x9f], not just 173 | // [0x00, 0x1f], but in practice, we ignore the multi-byte control characters 174 | // as it is simpler to inspect the tag's bytes than the tag's runes. 175 | i = 0 176 | for i < len(t) && t[i] > ' ' && t[i] != ':' && t[i] != '"' && t[i] != 0x7f { 177 | i++ 178 | } 179 | if i == 0 || i+1 >= len(t) || t[i] != ':' || t[i+1] != '"' { 180 | break 181 | } 182 | name := string(t[:i]) 183 | t = t[i+1:] 184 | 185 | // Scan quoted string to find value. 186 | i = 1 187 | for i < len(t) && t[i] != '"' { 188 | if t[i] == '\\' { 189 | i++ 190 | } 191 | i++ 192 | } 193 | if i >= len(t) { 194 | break 195 | } 196 | qvalue := string(t[:i+1]) 197 | t = t[i+1:] 198 | value, err := strconv.Unquote(qvalue) 199 | if err != nil { 200 | break 201 | } 202 | if name == "skip" && value == "true" { 203 | return field{ 204 | rt: f.Type, 205 | tags: tags, 206 | optional: optional, 207 | }, false 208 | } 209 | if name == "optional" { 210 | if value == "true" { 211 | optional = true 212 | } 213 | continue 214 | } 215 | tags[name] = value 216 | } 217 | return field{ 218 | rt: f.Type, 219 | tags: tags, 220 | optional: optional, 221 | }, true 222 | } 223 | 224 | var injectableInterface = reflect.TypeOf(new(injectable)).Elem() 225 | -------------------------------------------------------------------------------- /inspect.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "runtime" 7 | ) 8 | 9 | // Func is a function description. 10 | type function struct { 11 | Name string 12 | reflect.Type 13 | reflect.Value 14 | } 15 | 16 | var errorInterface = reflect.TypeOf(new(error)).Elem() 17 | 18 | // isError checks that typ have error signature. 19 | func isError(typ reflect.Type) bool { 20 | return typ.Implements(errorInterface) 21 | } 22 | 23 | // isCleanup checks that typ have cleanup signature. 24 | func isCleanup(typ reflect.Type) bool { 25 | return typ.Kind() == reflect.Func && typ.NumIn() == 0 && typ.NumOut() == 0 26 | } 27 | 28 | // InspectFunc inspects function. 29 | func inspectFunction(fn interface{}) (function, bool) { 30 | if reflect.ValueOf(fn).Kind() != reflect.Func { 31 | return function{}, false 32 | } 33 | val := reflect.ValueOf(fn) 34 | typ := val.Type() 35 | funcForPC := runtime.FuncForPC(val.Pointer()) 36 | return function{ 37 | Name: funcForPC.Name(), 38 | Type: typ, 39 | Value: val, 40 | }, true 41 | } 42 | 43 | // Interface is a interface description. 44 | type link struct { 45 | Name string 46 | Type reflect.Type 47 | } 48 | 49 | // inspectInterfacePointer inspects interface pointer. 50 | func inspectInterfacePointer(i interface{}) (*link, error) { 51 | if i == nil { 52 | return nil, fmt.Errorf("nil: not a pointer to interface") 53 | } 54 | typ := reflect.TypeOf(i) 55 | if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Interface { 56 | return nil, fmt.Errorf("%s: not a pointer to interface", typ) 57 | } 58 | 59 | return &link{ 60 | Name: typ.Elem().Name(), 61 | Type: typ.Elem(), 62 | }, nil 63 | } 64 | -------------------------------------------------------------------------------- /invocation.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | // validateInvocation validates function. 4 | func validateInvocation(fn function) bool { 5 | if fn.NumOut() == 0 { 6 | return true 7 | } 8 | if fn.NumOut() == 1 && isError(fn.Out(0)) { 9 | return true 10 | } 11 | return false 12 | } 13 | 14 | // parseInvocationParameters parses invocation and returns slice of nodes. 15 | func parseInvocationParameters(fn function, s schema) (params []*node, err error) { 16 | for i := 0; i < fn.NumIn(); i++ { 17 | in := fn.Type.In(i) 18 | node, err := s.find(in, Tags{}) 19 | if err != nil { 20 | return nil, err 21 | } 22 | params = append(params, node) 23 | } 24 | return params, nil 25 | } 26 | -------------------------------------------------------------------------------- /node.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // newConstructorNode 9 | func newConstructorNode(ctor interface{}) (*node, error) { 10 | f, valid := inspectFunction(ctor) 11 | if !valid { 12 | return nil, fmt.Errorf("invalid constructor signature, got %s", reflect.TypeOf(ctor)) 13 | } 14 | cmp, ok := newConstructorCompiler(f) 15 | if !ok { 16 | return nil, fmt.Errorf("invalid constructor signature, got %s", f.Type) 17 | } 18 | // result type 19 | rt := f.Out(0) 20 | tags := map[string]string{} 21 | if haveTags(rt) { 22 | tmp := rt 23 | if tmp.Kind() == reflect.Ptr { 24 | tmp = tmp.Elem() 25 | } 26 | f, ok := tmp.FieldByName("Tags") 27 | if !ok { 28 | return nil, fmt.Errorf("tags usage error: need to embed di.Tags without field name") 29 | } 30 | field, ok := inspectStructField(tmp, f) 31 | if ok { 32 | tags = field.tags 33 | } 34 | } 35 | return &node{ 36 | rv: new(reflect.Value), 37 | rt: rt, 38 | tags: tags, 39 | compiler: cmp, 40 | }, nil 41 | } 42 | 43 | // node is a dependency injection node. 44 | type node struct { 45 | compiler 46 | rt reflect.Type 47 | tags Tags 48 | // rv value can be shared between nodes 49 | // initializing node always need to allocate memory for rv 50 | rv *reflect.Value 51 | // decorators 52 | decorators []Decorator 53 | } 54 | 55 | // String is a string representation of node. 56 | func (n *node) String() string { 57 | return fmt.Sprintf("%s%s", n.rt, n.tags) 58 | } 59 | 60 | // Value returns value of node. 61 | func (n *node) Value(s schema) (reflect.Value, error) { 62 | if n.rv.IsValid() { 63 | return *n.rv, nil 64 | } 65 | nodes, _ := n.deps(s) // todo: error skipped, prepare already check dependency graph 66 | var dependencies []reflect.Value 67 | for _, node := range nodes { 68 | v, err := node.Value(s) 69 | if err != nil { 70 | return reflect.Value{}, fmt.Errorf("%s: %w", node, err) 71 | } 72 | dependencies = append(dependencies, v) 73 | } 74 | rv, err := n.compile(dependencies, s) 75 | if err != nil { 76 | tracer.Trace("%s: %s", n.String(), err) 77 | return reflect.Value{}, err 78 | } 79 | // if result value not addr, create pointer for it 80 | if !rv.CanAddr() { 81 | addr := reflect.New(rv.Type()) 82 | addr.Elem().Set(rv) 83 | rv = addr.Elem() 84 | } 85 | if err := populate(s, rv); err != nil { 86 | tracer.Trace("%s: %s", n.String(), err) 87 | return reflect.Value{}, err 88 | } 89 | for _, decorator := range n.decorators { 90 | tracer.Trace("Run resolve decorator for %s", n.String()) 91 | if err := decorator(rv.Interface()); err != nil { 92 | tracer.Trace("Decorator error %s", err) 93 | return reflect.Value{}, err 94 | } 95 | } 96 | *n.rv = rv 97 | tracer.Trace("Resolved %s", n.String()) 98 | return *n.rv, nil 99 | } 100 | 101 | func (n *node) fields() map[int]field { 102 | return parsePopulateFields(n.rt) 103 | } 104 | 105 | // populate populates node fields. 106 | func populate(s schema, rv reflect.Value) error { 107 | if !canInject(rv.Type()) { 108 | return nil 109 | } 110 | // indirect pointer 111 | if rv.Kind() == reflect.Ptr { 112 | rv = reflect.Indirect(rv) 113 | } 114 | for index, field := range parsePopulateFields(rv.Type()) { 115 | node, err := s.find(field.rt, field.tags) 116 | if err != nil && field.optional { 117 | tracer.Trace("-- Skip optional field: %s", field) 118 | continue 119 | } 120 | if err != nil { 121 | return err 122 | } 123 | v, err := node.Value(s) 124 | if err != nil { 125 | return err 126 | } 127 | f := rv.Field(index) 128 | if !f.CanSet() { 129 | panic(fmt.Sprintf("can not set field %s(%d) of %s (addr: %t)", f.Type(), f.Pointer(), rv.Type(), rv.CanAddr())) 130 | } 131 | f.Set(v) 132 | } 133 | return nil 134 | } 135 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | // Option is a functional option that configures container. If you don't know about functional 4 | // options, see https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis. 5 | // Below presented all possible options with their description: 6 | // 7 | // - di.Provide - provide constructors 8 | // - di.ProvideValue - provide value 9 | // - di.Invoke - add invocations 10 | // - di.Resolve - resolves type 11 | type Option interface { 12 | apply(c *diopts) 13 | } 14 | 15 | // Provide returns container option that provides to container reliable way to build type. The constructor will 16 | // be invoked lazily on-demand. For more information about constructors see Constructor interface. ProvideOption can 17 | // add additional behavior to the process of type resolving. 18 | func Provide(constructor Constructor, options ...ProvideOption) Option { 19 | frame := stacktrace(0) 20 | return option(func(c *diopts) { 21 | c.provides = append(c.provides, provideOptions{ 22 | frame, 23 | constructor, 24 | options, 25 | }) 26 | }) 27 | } 28 | 29 | // ProvideValue provides value as is. 30 | func ProvideValue(value Value, options ...ProvideOption) Option { 31 | frame := stacktrace(0) 32 | return option(func(c *diopts) { 33 | c.values = append(c.values, provideValueOptions{ 34 | frame, 35 | value, 36 | options, 37 | }) 38 | }) 39 | } 40 | 41 | // Constructor is a function with follow signature: 42 | // 43 | // func NewHTTPServer(addr string, handler http.Handler) (server *http.Server, cleanup func(), err error) { 44 | // server := &http.Server{ 45 | // Addr: addr, 46 | // } 47 | // cleanup = func() { 48 | // server.Close() 49 | // } 50 | // return server, cleanup, nil 51 | // } 52 | // 53 | // This constructor function teaches container how to build server. Arguments (addr and handler) in this function 54 | // is a dependencies. They will be resolved automatically when someone needs a server. Constructor may have unlimited 55 | // count of dependencies, but note that container should know how build each of them. 56 | // Second result of this function is a optional cleanup callback. It describes that container will do on shutdown. 57 | // Third result is a optional error. Sometimes our types cannot be constructed. 58 | type Constructor interface{} 59 | 60 | // Value is a variable of provided or resolved type. 61 | type Value interface{} 62 | 63 | // ProvideOption is a functional option interface that modify provide behaviour. See di.As(), di.WithName(). 64 | type ProvideOption interface { 65 | applyProvide(params *ProvideParams) 66 | } 67 | 68 | // As returns provide option that specifies interfaces for constructor resultant type. 69 | // 70 | // INTERFACE USAGE: 71 | // 72 | // You can provide type as interface and resolve it later without using of direct implementation. 73 | // This creates less cohesion of code and promotes be more testable. 74 | // 75 | // Create type constructors: 76 | // 77 | // func NewServeMux() *http.ServeMux { 78 | // return &http.ServeMux{} 79 | // } 80 | // 81 | // func NewServer(handler *http.Handler) *http.Server { 82 | // return &http.Server{ 83 | // Handler: handler, 84 | // } 85 | // } 86 | // 87 | // Build container with di.As provide option: 88 | // 89 | // container, err := di.New( 90 | // di.Provide(NewServer), 91 | // di.Provide(NewServeMux, di.As(new(http.Handler)), 92 | // ) 93 | // if err != nil { 94 | // // handle error 95 | // } 96 | // var server *http.Server 97 | // if err := container.Resolve(&http.Server); err != nil { 98 | // // handle error 99 | // } 100 | // 101 | // In this example you can see how container inject type *http.ServeMux as http.Handler 102 | // interface into the server constructor. 103 | // 104 | // GROUP USAGE: 105 | // 106 | // Container automatically creates group for interfaces. For example, you can use type []http.Handler in 107 | // previous example. 108 | // 109 | // var handlers []http.Handler 110 | // if err := container.Resolve(&handlers); err != nil { 111 | // // handle error 112 | // } 113 | // 114 | // Container checks that provided type implements interface if not cause compile error. 115 | func As(interfaces ...Interface) ProvideOption { 116 | return provideOption(func(params *ProvideParams) { 117 | params.Interfaces = append(params.Interfaces, interfaces...) 118 | }) 119 | } 120 | 121 | // Interface is a pointer to interface, like new(http.Handler). Tell container that provided 122 | // type may be used as interface. 123 | type Interface interface{} 124 | 125 | // WithName modifies Provide() behavior. It adds name identity for provided type. 126 | // Deprecated: use di.Tags. 127 | func WithName(name string) ProvideOption { 128 | return provideOption(func(params *ProvideParams) { 129 | if params.Tags == nil { 130 | params.Tags = Tags{} 131 | } 132 | params.Tags["name"] = name 133 | }) 134 | } 135 | 136 | // Decorator can modify container instance. 137 | type Decorator func(value Value) error 138 | 139 | // Decorate will be called after type construction. You can modify your pointer types. 140 | func Decorate(decorators ...Decorator) ProvideOption { 141 | return provideOption(func(params *ProvideParams) { 142 | params.Decorators = append(params.Decorators, decorators...) 143 | }) 144 | } 145 | 146 | // Resolve returns container options that resolves type into target. All resolves will be done on compile stage 147 | // after call invokes. 148 | func Resolve(target Pointer, options ...ResolveOption) Option { 149 | frame := stacktrace(0) 150 | return option(func(c *diopts) { 151 | c.resolves = append(c.resolves, resolveOptions{ 152 | frame, 153 | target, 154 | options, 155 | }) 156 | }) 157 | } 158 | 159 | // Invoke returns container option that registers container invocation. All invocations 160 | // will be called on di.New() after processing di.Provide() options. 161 | // See Container.Invoke() for details. 162 | func Invoke(fn Invocation, options ...InvokeOption) Option { 163 | frame := stacktrace(0) 164 | return option(func(c *diopts) { 165 | c.invokes = append(c.invokes, invokeOptions{ 166 | frame, 167 | fn, 168 | options, 169 | }) 170 | }) 171 | } 172 | 173 | // Options group together container options. 174 | // 175 | // account := di.Options( 176 | // di.Provide(NewAccountController), 177 | // di.Provide(NewAccountRepository), 178 | // ) 179 | // auth := di.Options( 180 | // di.Provide(NewAuthController), 181 | // di.Provide(NewAuthRepository), 182 | // ) 183 | // container, err := di.New( 184 | // account, 185 | // auth, 186 | // ) 187 | // if err != nil { 188 | // // handle error 189 | // } 190 | func Options(options ...Option) Option { 191 | return option(func(container *diopts) { 192 | for _, opt := range options { 193 | opt.apply(container) 194 | } 195 | }) 196 | } 197 | 198 | // ProvideParams is a Provide() method options. Name is a unique identifier of type instance. Provider is a constructor 199 | // function. Interfaces is a interface that implements a provider result type. 200 | type ProvideParams struct { 201 | Tags Tags 202 | Interfaces []Interface 203 | Decorators []Decorator 204 | } 205 | 206 | func (p ProvideParams) applyProvide(params *ProvideParams) { 207 | *params = p 208 | } 209 | 210 | // InvokeOption is a functional option interface that modify invoke behaviour. 211 | type InvokeOption interface { 212 | apply(params *InvokeParams) 213 | } 214 | 215 | // InvokeParams is a invoke parameters. 216 | type InvokeParams struct { 217 | // The function 218 | Fn interface{} 219 | } 220 | 221 | func (p InvokeParams) apply(params *InvokeParams) { 222 | *params = p 223 | } 224 | 225 | // ResolveOption is a functional option interface that modify resolve behaviour. 226 | type ResolveOption interface { 227 | applyResolve(params *ResolveParams) 228 | } 229 | 230 | // Name specifies provider string identity. It needed when you have more than one 231 | // definition of same type. You can identity type by name. 232 | // Deprecated: use di.Tags 233 | func Name(name string) ResolveOption { 234 | return resolveOption(func(params *ResolveParams) { 235 | if params.Tags == nil { 236 | params.Tags = Tags{} 237 | } 238 | params.Tags["name"] = name 239 | }) 240 | } 241 | 242 | // ResolveParams is a resolve parameters. 243 | type ResolveParams struct { 244 | Tags Tags 245 | } 246 | 247 | func (p ResolveParams) applyResolve(params *ResolveParams) { 248 | *params = p 249 | } 250 | 251 | type option func(c *diopts) 252 | 253 | func (o option) apply(c *diopts) { o(c) } 254 | 255 | type provideOption func(params *ProvideParams) 256 | 257 | func (o provideOption) applyProvide(params *ProvideParams) { 258 | o(params) 259 | } 260 | 261 | type resolveOption func(params *ResolveParams) 262 | 263 | func (o resolveOption) applyResolve(params *ResolveParams) { 264 | o(params) 265 | } 266 | 267 | // struct that contains constructor with options. 268 | type provideOptions struct { 269 | frame callerFrame 270 | constructor Constructor 271 | options []ProvideOption 272 | } 273 | 274 | // struct that contains value with options. 275 | type provideValueOptions struct { 276 | frame callerFrame 277 | value Value 278 | options []ProvideOption 279 | } 280 | 281 | // struct that contains invoke function with options. 282 | type invokeOptions struct { 283 | frame callerFrame 284 | fn Invocation 285 | options []InvokeOption 286 | } 287 | 288 | // struct that container resolve target with options. 289 | type resolveOptions struct { 290 | frame callerFrame 291 | target Pointer 292 | options []ResolveOption 293 | } 294 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/defval/di" 11 | ) 12 | 13 | func TestOptions(t *testing.T) { 14 | t.Run("simple", func(t *testing.T) { 15 | var loadedServer *http.Server 16 | var resolvedServer *http.Server 17 | server := &http.Server{} 18 | mux := &http.ServeMux{} 19 | c, err := di.New( 20 | di.Options( 21 | di.Provide(func(handler http.Handler) *http.Server { 22 | server.Handler = handler 23 | return server 24 | }), 25 | di.Provide(func() *http.ServeMux { 26 | return mux 27 | }, di.As(new(http.Handler))), 28 | di.Invoke(func(server *http.Server) { 29 | loadedServer = server 30 | }), 31 | di.Resolve(&resolvedServer), 32 | ), 33 | ) 34 | require.NoError(t, err) 35 | require.NotNil(t, c) 36 | require.Equal(t, loadedServer, server) 37 | require.Equal(t, loadedServer.Handler, mux) 38 | require.Equal(t, resolvedServer, server) 39 | }) 40 | 41 | t.Run("provide failed", func(t *testing.T) { 42 | c, err := di.New( 43 | di.Provide(func() {}), 44 | ) 45 | require.Nil(t, c) 46 | require.NotNil(t, err) 47 | require.Contains(t, err.Error(), "options_test.go:") 48 | require.Contains(t, err.Error(), ": invalid constructor signature, got func()") 49 | }) 50 | 51 | t.Run("invoke failed", func(t *testing.T) { 52 | c, err := di.New( 53 | di.Invoke(func(string2 string) {}), 54 | ) 55 | require.Nil(t, c) 56 | require.Error(t, err) 57 | require.Contains(t, err.Error(), "options_test.go:") 58 | require.Contains(t, err.Error(), ": type string not exists in the container") 59 | }) 60 | 61 | t.Run("invoke error return as is if not internal error", func(t *testing.T) { 62 | var myError = errors.New("my error") 63 | _, err := di.New( 64 | di.Invoke(func() error { 65 | return myError 66 | }), 67 | ) 68 | require.True(t, err == myError) 69 | }) 70 | 71 | t.Run("resolve failed", func(t *testing.T) { 72 | _, err := di.New( 73 | di.Resolve(func() {}), 74 | ) 75 | require.Error(t, err) 76 | require.Contains(t, err.Error(), "options_test.go:") 77 | require.Contains(t, err.Error(), ": target must be a pointer, got func()") 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /schema.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // schema is a dependency injection schema. 9 | type schema interface { 10 | // find finds reflect.Type with matching Tags. 11 | find(t reflect.Type, tags Tags) (*node, error) 12 | // register cleanup 13 | cleanup(cleanup func()) 14 | } 15 | 16 | // schema is a dependency injection schema. 17 | type defaultSchema struct { 18 | parents []*defaultSchema 19 | nodes map[reflect.Type][]*node 20 | cleanups []func() 21 | } 22 | 23 | func (s *defaultSchema) cleanup(cleanup func()) { 24 | s.cleanups = append(s.cleanups, cleanup) 25 | } 26 | 27 | // newDefaultSchema creates new dependency injection schema. 28 | func newDefaultSchema() *defaultSchema { 29 | return &defaultSchema{ 30 | nodes: map[reflect.Type][]*node{}, 31 | } 32 | } 33 | 34 | // register registers reflect.Type provide function with optional Tags. Also, its registers 35 | // type [] for group. 36 | func (s *defaultSchema) register(n *node) { 37 | defer tracer.Trace("Register %s", n) 38 | if _, ok := s.nodes[n.rt]; !ok { 39 | s.nodes[n.rt] = []*node{n} 40 | return 41 | } 42 | s.nodes[n.rt] = append(s.nodes[n.rt], n) 43 | } 44 | 45 | // used depth-first topological sort algorithm 46 | func (s *defaultSchema) prepare(n *node) error { 47 | var marks = map[*node]int{} 48 | if err := visit(s, n, marks); err != nil { 49 | return err 50 | } 51 | return nil 52 | } 53 | 54 | // find finds provideFunc by its reflect.Type and Tags. 55 | func (s *defaultSchema) find(t reflect.Type, tags Tags) (*node, error) { 56 | nodes, ok := s.list(t) 57 | // type found 58 | if ok { 59 | matched := matchTags(nodes, tags) 60 | if len(matched) == 0 { 61 | return nil, fmt.Errorf("type %s%s %w", t, tags, ErrTypeNotExists) 62 | } 63 | if len(matched) > 1 { 64 | return nil, fmt.Errorf("multiple definitions of %s%s, maybe you need to use group type: []%s%s", t, tags, t, tags) 65 | } 66 | return matched[0], nil 67 | } 68 | // if not a group and not have di.Inject 69 | if t.Kind() != reflect.Slice && !canInject(t) { 70 | return nil, fmt.Errorf("type %s%s %w", t, tags, ErrTypeNotExists) 71 | } 72 | if canInject(t) { 73 | node := &node{ 74 | compiler: newTypeCompiler(t), 75 | rt: t, 76 | rv: new(reflect.Value), 77 | } 78 | // save node for future use 79 | s.nodes[t] = append(s.nodes[t], node) 80 | return node, nil 81 | } 82 | return s.group(t, tags) 83 | } 84 | 85 | func (s *defaultSchema) group(t reflect.Type, tags Tags) (*node, error) { 86 | group, ok := s.list(t.Elem()) 87 | if !ok { 88 | return nil, fmt.Errorf("type %s%s %w", t, tags, ErrTypeNotExists) 89 | } 90 | matched := matchTags(group, tags) 91 | if len(matched) == 0 { 92 | return nil, fmt.Errorf("type %s%s %w", t, tags, ErrTypeNotExists) 93 | } 94 | node := &node{ 95 | compiler: newGroupCompiler(t, matched), 96 | rt: t, 97 | tags: tags, 98 | rv: new(reflect.Value), 99 | } 100 | return node, nil 101 | } 102 | 103 | // list lists all the nodes of its reflect.Type 104 | func (s *defaultSchema) list(t reflect.Type) (nodes []*node, ok bool) { 105 | for _, parent := range s.parents { 106 | if n, o := parent.list(t); o { 107 | nodes = append(nodes, n...) 108 | ok = true 109 | } 110 | } 111 | if n, o := s.nodes[t]; o { 112 | nodes = append(nodes, n...) 113 | ok = true 114 | } 115 | return nodes, ok 116 | } 117 | 118 | // isAncestor returns true if a 119 | func (s *defaultSchema) isAncestor(a *defaultSchema) bool { 120 | for _, parent := range s.parents { 121 | if parent == a { 122 | return true 123 | } 124 | if parent.isAncestor(a) { 125 | return true 126 | } 127 | } 128 | return false 129 | } 130 | 131 | func (s *defaultSchema) addParent(parent *defaultSchema) error { 132 | if parent == s { 133 | return fmt.Errorf("self cycle detected") 134 | } 135 | if parent.isAncestor(s) { 136 | return fmt.Errorf("cycle detected") 137 | } 138 | if s.isAncestor(parent) { 139 | return fmt.Errorf("parent already chained") 140 | } 141 | s.parents = append(s.parents, parent) 142 | return nil 143 | } 144 | -------------------------------------------------------------------------------- /stacktrace.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "strings" 7 | ) 8 | 9 | // stacktrace returns stacktrace call frame with skip. 10 | func stacktrace(skip int) (frame callerFrame) { 11 | pc, file, line, ok := runtime.Caller(skip + 2) 12 | if !ok { 13 | return callerFrame{} 14 | } 15 | f := runtime.FuncForPC(pc) 16 | return callerFrame{ 17 | function: shortFuncName(f), 18 | file: file, 19 | line: line, 20 | } 21 | } 22 | 23 | // callerFrame represents stacktrace frame. 24 | type callerFrame struct { 25 | function string 26 | file string 27 | line int 28 | } 29 | 30 | // Format formats stacktrace frame. 31 | func (f callerFrame) Format(s fmt.State, c rune) { 32 | _, _ = fmt.Fprintf(s, "%s:%d", f.file, f.line) 33 | } 34 | 35 | func shortFuncName(f *runtime.Func) string { 36 | longName := f.Name() 37 | 38 | withoutPath := longName[strings.LastIndex(longName, "/")+1:] 39 | withoutPackage := withoutPath[strings.Index(withoutPath, ".")+1:] 40 | 41 | shortName := withoutPackage 42 | shortName = strings.Replace(shortName, "(", "", 1) 43 | shortName = strings.Replace(shortName, "*", "", 1) 44 | shortName = strings.Replace(shortName, ")", "", 1) 45 | 46 | return shortName 47 | } 48 | -------------------------------------------------------------------------------- /stacktrace_test.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func Test_stacktrace(t *testing.T) { 9 | type args struct { 10 | skip int 11 | } 12 | tests := []struct { 13 | name string 14 | args args 15 | wantFrame callerFrame 16 | }{ 17 | { 18 | name: "incorrect skip", 19 | args: args{skip: 10}, 20 | wantFrame: callerFrame{ 21 | function: "", 22 | file: "", 23 | line: 0, 24 | }, 25 | }, 26 | { 27 | name: "incorrect skip", 28 | args: args{skip: 10}, 29 | wantFrame: callerFrame{ 30 | function: "", 31 | file: "", 32 | line: 0, 33 | }, 34 | }, 35 | } 36 | for _, tt := range tests { 37 | t.Run(tt.name, func(t *testing.T) { 38 | if gotFrame := stacktrace(tt.args.skip); !reflect.DeepEqual(gotFrame, tt.wantFrame) { 39 | t.Errorf("stacktrace() = %v, want %v", gotFrame, tt.wantFrame) 40 | } 41 | }) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /tags.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | "strings" 7 | ) 8 | 9 | var iTaggable = reflect.TypeOf(new(taggable)).Elem() 10 | 11 | // Tags is a string representation of key value pairs. 12 | // 13 | // type Server struct { 14 | // di.Tags `http:"true" server:"true"` 15 | // } 16 | // _, err := di.New( 17 | // di.Provide(func() *Server { return &Server{} }), 18 | // ) 19 | // var s *Server 20 | // c.Resolve(&s, di.Tags{"http": "true", "server": "true"}) 21 | type Tags map[string]string 22 | 23 | // injectable interface needs to struct fields injection functional. 24 | type taggable interface { 25 | isTaggable() 26 | } 27 | 28 | func (t Tags) isTaggable() { 29 | bug() 30 | } 31 | 32 | // haveTags checks that typ is taggable 33 | func haveTags(typ reflect.Type) bool { 34 | return typ.Implements(iTaggable) 35 | } 36 | 37 | func (t Tags) applyProvide(params *ProvideParams) { 38 | if params.Tags == nil { 39 | params.Tags = map[string]string{} 40 | } 41 | 42 | for k, v := range t { 43 | params.Tags[k] = v 44 | } 45 | } 46 | 47 | func (t Tags) applyResolve(params *ResolveParams) { 48 | if params.Tags == nil { 49 | params.Tags = map[string]string{} 50 | } 51 | for k, v := range t { 52 | params.Tags[k] = v 53 | } 54 | } 55 | 56 | // String is a tags string representation. 57 | func (t Tags) String() string { 58 | var keys []string 59 | for k := range t { 60 | keys = append(keys, k) 61 | } 62 | sort.Strings(keys) 63 | for i := 0; i < len(keys); i++ { 64 | keys[i] = keys[i] + ":" + t[keys[i]] 65 | } 66 | if len(keys) == 0 { 67 | return "" 68 | } 69 | return "[" + strings.Join(keys, ";") + "]" 70 | } 71 | 72 | // match checks that all of key value pairs exists in t. Not equal. 73 | func (t Tags) match(tags Tags) bool { 74 | for k, v := range tags { 75 | tv, ok := t[k] 76 | if !ok { 77 | return false 78 | } 79 | if v == "*" { 80 | continue 81 | } 82 | if tv != v { 83 | return false 84 | } 85 | } 86 | return true 87 | } 88 | 89 | func matchTags(nodes []*node, tags Tags) []*node { 90 | matched := make([]*node, 0, 1) 91 | for i := 0; i < len(nodes); i++ { 92 | if nodes[i].tags.match(tags) { 93 | matched = append(matched, nodes[i]) 94 | } 95 | } 96 | return matched 97 | } 98 | -------------------------------------------------------------------------------- /tracer.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "log" 4 | 5 | var tracer Tracer = &nopTracer{} 6 | 7 | // SetTracer sets global tracer. 8 | func SetTracer(t Tracer) { 9 | tracer = t 10 | } 11 | 12 | // Tracer traces dependency injection cycle. 13 | type Tracer interface { 14 | // Trace prints library logs. 15 | Trace(format string, args ...interface{}) 16 | } 17 | 18 | // StdTracer traces dependency injection cycle to stdout. 19 | type StdTracer struct { 20 | } 21 | 22 | // Trace traces debug information with default logger. 23 | func (s StdTracer) Trace(format string, args ...interface{}) { 24 | log.Printf(format, args...) 25 | } 26 | 27 | // default nop tracer 28 | type nopTracer struct { 29 | } 30 | 31 | func (n nopTracer) Trace(format string, args ...interface{}) { 32 | } 33 | -------------------------------------------------------------------------------- /tracer_test.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "testing" 4 | 5 | func TestNopTracer_Trace(t *testing.T) { 6 | tracer := nopTracer{} 7 | tracer.Trace("test") 8 | } 9 | --------------------------------------------------------------------------------