├── .build └── test.sh ├── .codecov.yml ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── _tutorial └── main.go ├── container.go ├── container_test.go ├── di ├── container.go ├── container_test.go ├── dot.go ├── errors.go ├── internal │ ├── ditest │ │ ├── bar.go │ │ ├── baz.go │ │ ├── foo.go │ │ ├── fooer_group.go │ │ ├── full.go │ │ ├── incorrect.go │ │ ├── interfaces.go │ │ └── qux.go │ ├── graphkv │ │ ├── directed_graph.go │ │ ├── directed_graph_test.go │ │ ├── edge.go │ │ ├── errors.go │ │ ├── graph.go │ │ ├── graph_test.go │ │ ├── graphkv.go │ │ ├── graphkv_test.go │ │ ├── key.go │ │ ├── output.go │ │ ├── output_test.go │ │ ├── sort.go │ │ └── sort_test.go │ └── reflection │ │ ├── func.go │ │ ├── iface.go │ │ └── reflection.go ├── invoker.go ├── key.go ├── options.go ├── panic.go ├── parameter.go ├── parameter_bag.go ├── parameter_bag_test.go ├── parameter_list.go ├── provider.go ├── provider_ctor.go ├── provider_embed.go ├── provider_group.go ├── provider_iface.go ├── provider_stub.go └── singleton.go ├── doc.go ├── go.mod ├── go.sum ├── graph.png ├── logo.png ├── options.go └── options_test.go /.build/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | if [[ -f coverage.txt ]]; then 6 | rm coverage.txt 7 | fi 8 | 9 | for d in $(go list ./... | grep -v ditest); do 10 | go test -coverprofile=profile.out -coverpkg=./... -covermode=atomic "$d" 11 | if [[ -f profile.out ]]; then 12 | cat profile.out >> coverage.txt 13 | rm profile.out 14 | fi 15 | done -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | range: 70..98 3 | round: down 4 | precision: 2 5 | status: 6 | project: 7 | default: 8 | enabled: yes 9 | target: 92 10 | if_not_found: success 11 | if_ci_failed: error 12 | patch: 13 | default: 14 | enabled: yes 15 | target: 70 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.txt 2 | profile.out 3 | vendor 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | go_import_path: github.com/defval/inject/v2 2 | language: go 3 | sudo: false 4 | 5 | matrix: 6 | include: 7 | - go: "1.11.x" 8 | - go: "1.12.x" 9 | - go: "1.13.x" 10 | fast_finish: true 11 | 12 | env: 13 | global: 14 | - GO111MODULE=on 15 | 16 | script: 17 | - make test 18 | 19 | after_success: 20 | - bash <(curl -s https://codecov.io/bash) 21 | 22 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## Unreleased 8 | 9 | ## Fixed 10 | 11 | - Cleanup ordering 12 | - Cleanup with prototypes 13 | - Removed duplicate function `resolveParameterProvider()` 14 | - Refactor graph storage 15 | - Change internal di container interface. 16 | - Code style fixes 17 | - Documentation fixes 18 | 19 | ## v2.2.2 20 | 21 | Internal refactoring 22 | 23 | ### Added 24 | 25 | - Visualize parameter bag 26 | 27 | ### Fixed 28 | 29 | - Visualize type detection 30 | 31 | ## v2.2.1 32 | 33 | ### Fixed 34 | 35 | - Invoke: interface is nil, not error 36 | 37 | ## v2.2.0 38 | 39 | ### Added 40 | 41 | - `container.Invoke()` for invocations 42 | 43 | ## v2.1.1 44 | 45 | ### Fixed 46 | 47 | - Incorrect di.Parameter resolving 48 | 49 | ## v2.1.0 50 | 51 | ### Added 52 | 53 | - Visualization 54 | 55 | ## v2.0.1 56 | 57 | ### Added 58 | 59 | - Helper methods to ParameterBag 60 | 61 | ## v2.0.0 62 | 63 | Massive refactoring and rethinking features. 64 | 65 | ### Changed 66 | 67 | - Graph implementation 68 | - Simplify injection code 69 | - Documentation 70 | 71 | ### Added 72 | 73 | - Prototypes 74 | - Cleanup 75 | - Parameter bag 76 | - Optional parameters 77 | - Low-level container interface 78 | 79 | ### Removed 80 | 81 | - Replacing (investigating) 82 | - Non constructor providers 83 | - Combined providers 84 | 85 | ## v1.5.2 86 | 87 | ### Fixed 88 | 89 | - Checksum problem 90 | 91 | ## v1.5.1 92 | 93 | ### Added 94 | 95 | - Ability to extract `github.com/emicklei/*dot.Graph` from container 96 | 97 | ## v1.5.0 98 | 99 | ### Added 100 | 101 | - Error `inject.ErrTypeNotProvided` 102 | 103 | ## v1.4.4 104 | 105 | ### Changed 106 | 107 | - Internal refactoring of adding nodes 108 | 109 | ## v1.4.3 110 | 111 | ### Fixed 112 | 113 | - Improve test coverage 114 | 115 | ### Changed 116 | 117 | - Internal refactoring of groups creation 118 | 119 | ## v1.4.2 120 | 121 | ### Fixed 122 | 123 | - Replace: check that provider implement interface 124 | 125 | ## v1.4.1 126 | 127 | ### Fixed 128 | 129 | - Lint 130 | 131 | ## v1.4.0 132 | 133 | ### Change 134 | 135 | - `Container.WriteTo()` signature to `io.WriterTo` 136 | 137 | ## v1.3.1 138 | 139 | ### Added 140 | 141 | - Documentation 142 | 143 | ## v1.3.0 144 | 145 | ### Added 146 | 147 | - Graph visualization 148 | 149 | ## v1.2.1 150 | 151 | ### Fixed 152 | 153 | - inject.As() allows provide same interface without name 154 | 155 | ## v1.2.0 156 | 157 | ### Changed 158 | 159 | - Combined provider declaration 160 | - Some refactoring 161 | 162 | ## v1.1.1 163 | 164 | ### Changed 165 | 166 | - Refactor graph storage 167 | 168 | ## v1.1.0 169 | 170 | ### Added 171 | 172 | - Combined provider 173 | 174 | ### Changed 175 | 176 | - Provider refactoring 177 | 178 | ## v1.0.0 179 | 180 | - Initial release -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Maxim Bovtunov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ## Injector makefile 2 | 3 | .PHONY: test ## Run tests 4 | test: 5 | @.build/test.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Dependency%20injection%20container%20for%20Golang&url=https://github.com/defval/inject&hashtags=golang,go,di,dependency-injection) 3 | 4 | [![Documentation](https://img.shields.io/badge/godoc-reference-blue.svg?color=24B898&style=for-the-badge&logo=go&logoColor=ffffff)](https://godoc.org/github.com/defval/inject) 5 | ![Release](https://img.shields.io/github/tag/defval/inject.svg?label=release&color=24B898&logo=github&style=for-the-badge) 6 | [![Build Status](https://img.shields.io/travis/defval/inject.svg?style=for-the-badge&logo=travis)](https://travis-ci.org/defval/inject) 7 | [![Code Coverage](https://img.shields.io/codecov/c/github/defval/inject.svg?style=for-the-badge&logo=codecov)](https://codecov.io/gh/defval/inject) 8 | 9 | ## This repository will be archived 10 | 11 | After using this library in production for a year, I made some 12 | conclusions about library API and really useful features. I don't want 13 | to make breaking changes and have a `v3` version. It's not that popular. 14 | 15 | I put old and new ideas in [goava/di](https://github.com/goava/di). 16 | 17 | ## How will dependency injection help me? 18 | 19 | Dependency injection is one form of the broader technique of inversion 20 | of control. It is used to increase modularity of the program and make it 21 | extensible. 22 | 23 | ## Contents 24 | 25 | - [Installing](#installing) 26 | - [Tutorial](#tutorial) 27 | - [Providing](#providing) 28 | - [Extraction](#extraction) 29 | - [Invocation](#invocation) 30 | - [Lazy-loading](#lazy-loading) 31 | - [Interfaces](#interfaces) 32 | - [Groups](#groups) 33 | - [Advanced features](#advanced-features) 34 | - [Named definitions](#named-definitions) 35 | - [Optional parameters](#optional-parameters) 36 | - [Parameter Bag](#parameter-bag) 37 | - [Prototypes](#prototypes) 38 | - [Cleanup](#cleanup) 39 | - [Visualization](#visualization) 40 | - [Contributing](#contributing) 41 | 42 | ## Installing 43 | 44 | ```shell 45 | go get -u github.com/defval/inject/v2 46 | ``` 47 | 48 | This library follows [SemVer](http://semver.org/) strictly. 49 | 50 | ## Tutorial 51 | 52 | Let's learn to use Inject by example. We will code a simple application 53 | that processes HTTP requests. 54 | 55 | The full tutorial code is available [here](./_tutorial/main.go) 56 | 57 | ### Providing 58 | 59 | To start, we will need to create two fundamental types: `http.Server` 60 | and `http.ServeMux`. Let's create a simple constructors that initialize 61 | it: 62 | 63 | ```go 64 | // NewServer creates a http server with provided mux as handler. 65 | func NewServer(mux *http.ServeMux) *http.Server { 66 | return &http.Server{ 67 | Handler: mux, 68 | } 69 | } 70 | 71 | // NewServeMux creates a new http serve mux. 72 | func NewServeMux() *http.ServeMux { 73 | return &http.ServeMux{} 74 | } 75 | ``` 76 | 77 | > Supported constructor signature: 78 | > 79 | > ```go 80 | > func([dep1, dep2, depN]) (result, [cleanup, error]) 81 | > ``` 82 | 83 | Now let's teach a container to build these types. 84 | 85 | ```go 86 | container := inject.New( 87 | // provide http server 88 | inject.Provide(NewServer), 89 | // provide http serve mux 90 | inject.Provide(NewServeMux) 91 | ) 92 | ``` 93 | 94 | The function `inject.New()` parse our constructors, compile dependency 95 | graph and return `*inject.Container` type for interaction. Container 96 | panics if it could not compile. 97 | 98 | > I think that panic at the initialization of the application and not in 99 | > runtime is usual. 100 | 101 | ### Extraction 102 | 103 | We can extract the built server from the container. For this, define the 104 | variable of extracted type and pass variable pointer to `Extract` 105 | function. 106 | 107 | > If extracted type not found or the process of building instance cause 108 | > error, `Extract` return error. 109 | 110 | If no error occurred, we can use the variable as if we had built it 111 | yourself. 112 | 113 | ```go 114 | // declare type variable 115 | var server *http.Server 116 | // extracting 117 | err := container.Extract(&server) 118 | if err != nil { 119 | // check extraction error 120 | } 121 | 122 | server.ListenAndServe() 123 | ``` 124 | 125 | > Note that by default, the container creates instances as a singleton. 126 | > But you can change this behaviour. See [Prototypes](#prototypes). 127 | 128 | ### Invocation 129 | 130 | As an alternative to extraction we can use `Invoke()` function. It 131 | resolves function dependencies and call the function. Invoke function 132 | may return optional error. 133 | 134 | ```go 135 | // StartServer starts the server. 136 | func StartServer(server *http.Server) error { 137 | return server.ListenAndServe() 138 | } 139 | 140 | container.Invoke(StartServer) 141 | ``` 142 | 143 | ### Lazy-loading 144 | 145 | Result dependencies will be lazy-loaded. If no one requires a type from 146 | the container it will not be constructed. 147 | 148 | ### Interfaces 149 | 150 | Inject make possible to provide implementation as an interface. 151 | 152 | ```go 153 | // NewServer creates a http server with provided mux as handler. 154 | func NewServer(handler http.Handler) *http.Server { 155 | return &http.Server{ 156 | Handler: handler, 157 | } 158 | } 159 | ``` 160 | 161 | For a container to know that as an implementation of `http.Handler` is 162 | necessary to use, we use the option `inject.As()`. The arguments of this 163 | option must be a pointer(s) to an interface like `new(Endpoint)`. 164 | 165 | > This syntax may seem strange, but I have not found a better way to 166 | > specify the interface. 167 | 168 | Updated container initialization code: 169 | 170 | ```go 171 | container := inject.New( 172 | // provide http server 173 | inject.Provide(NewServer), 174 | // provide http serve mux as http.Handler interface 175 | inject.Provide(NewServeMux, inject.As(new(http.Handler))) 176 | ) 177 | ``` 178 | 179 | Now container uses provide `*http.ServeMux` as `http.Handler` in server 180 | constructor. Using interfaces contributes to writing more testable code. 181 | 182 | ### Groups 183 | 184 | Container automatically groups all implementations of interface to 185 | `[]` group. For example, provide with 186 | `inject.As(new(http.Handler)` automatically creates a group 187 | `[]http.Handler`. 188 | 189 | Let's add some http controllers using this feature. Controllers have 190 | typical behavior. It is registering routes. At first, will create an 191 | interface for it. 192 | 193 | ```go 194 | // Controller is an interface that can register its routes. 195 | type Controller interface { 196 | RegisterRoutes(mux *http.ServeMux) 197 | } 198 | ``` 199 | 200 | Now we will write controllers and implement `Controller` interface. 201 | 202 | ##### OrderController 203 | 204 | ```go 205 | // OrderController is a http controller for orders. 206 | type OrderController struct {} 207 | 208 | // NewOrderController creates a auth http controller. 209 | func NewOrderController() *OrderController { 210 | return &OrderController{} 211 | } 212 | 213 | // RegisterRoutes is a Controller interface implementation. 214 | func (a *OrderController) RegisterRoutes(mux *http.ServeMux) { 215 | mux.HandleFunc("/orders", a.RetrieveOrders) 216 | } 217 | 218 | // Retrieve loads orders and writes it to the writer. 219 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, request *http.Request) { 220 | // implementation 221 | } 222 | ``` 223 | 224 | ##### UserController 225 | 226 | ```go 227 | // UserController is a http endpoint for a user. 228 | type UserController struct {} 229 | 230 | // NewUserController creates a user http endpoint. 231 | func NewUserController() *UserController { 232 | return &UserController{} 233 | } 234 | 235 | // RegisterRoutes is a Controller interface implementation. 236 | func (e *UserController) RegisterRoutes(mux *http.ServeMux) { 237 | mux.HandleFunc("/users", e.RetrieveUsers) 238 | } 239 | 240 | // Retrieve loads users and writes it using the writer. 241 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, request *http.Request) { 242 | // implementation 243 | } 244 | ``` 245 | 246 | Just like in the example with interfaces, we will use `inject.As()` 247 | provide option. 248 | 249 | ```go 250 | container := inject.New( 251 | inject.Provide(NewServer), // provide http server 252 | inject.Provide(NewServeMux), // provide http serve mux 253 | // endpoints 254 | inject.Provide(NewOrderController, inject.As(new(Controller))), // provide order controller 255 | inject.Provide(NewUserController, inject.As(new(Controller))), // provide user controller 256 | ) 257 | ``` 258 | 259 | Now, we can use `[]Controller` group in our mux. See updated code: 260 | 261 | ```go 262 | // NewServeMux creates a new http serve mux. 263 | func NewServeMux(controllers []Controller) *http.ServeMux { 264 | mux := &http.ServeMux{} 265 | 266 | for _, controller := range controllers { 267 | controller.RegisterRoutes(mux) 268 | } 269 | 270 | return mux 271 | } 272 | ``` 273 | 274 | ## Advanced features 275 | 276 | ### Named definitions 277 | 278 | In some cases you have more than one instance of one type. For example 279 | two instances of database: master - for writing, slave - for reading. 280 | 281 | First way is a wrapping types: 282 | 283 | ```go 284 | // MasterDatabase provide write database access. 285 | type MasterDatabase struct { 286 | *Database 287 | } 288 | 289 | // SlaveDatabase provide read database access. 290 | type SlaveDatabase struct { 291 | *Database 292 | } 293 | ``` 294 | 295 | Second way is a using named definitions with `inject.WithName()` provide 296 | option: 297 | 298 | ```go 299 | // provide master database 300 | inject.Provide(NewMasterDatabase, inject.WithName("master")) 301 | // provide slave database 302 | inject.Provide(NewSlaveDatabase, inject.WithName("slave")) 303 | ``` 304 | 305 | If you need to extract it from container use `inject.Name()` extract 306 | option. 307 | 308 | ```go 309 | var db *Database 310 | container.Extract(&db, inject.Name("master")) 311 | ``` 312 | 313 | If you need to provide named definition in other constructor use 314 | `di.Parameter` with embedding. 315 | 316 | ```go 317 | // ServiceParameters 318 | type ServiceParameters struct { 319 | di.Parameter 320 | 321 | // use `di` tag for the container to know that field need to be injected. 322 | MasterDatabase *Database `di:"master"` 323 | SlaveDatabase *Database `di:"slave"` 324 | } 325 | 326 | // NewService creates new service with provided parameters. 327 | func NewService(parameters ServiceParameters) *Service { 328 | return &Service{ 329 | MasterDatabase: parameters.MasterDatabase, 330 | SlaveDatabase: parameters.SlaveDatabase, 331 | } 332 | } 333 | ``` 334 | 335 | ### Optional parameters 336 | 337 | Also `di.Parameter` provide ability to skip dependency if it not exists 338 | in container. 339 | 340 | ```go 341 | // ServiceParameter 342 | type ServiceParameter struct { 343 | di.Parameter 344 | 345 | Logger *Logger `di:"optional"` 346 | } 347 | ``` 348 | 349 | > Constructors that declare dependencies as optional must handle the 350 | > case of those dependencies being absent. 351 | 352 | You can use naming and optional together. 353 | 354 | ```go 355 | // ServiceParameter 356 | type ServiceParameter struct { 357 | di.Parameter 358 | 359 | StdOutLogger *Logger `di:"stdout"` 360 | FileLogger *Logger `di:"file,optional"` 361 | } 362 | ``` 363 | 364 | ### Parameter Bag 365 | 366 | If you need to specify some parameters on definition level you can use 367 | `inject.ParameterBag` provide option. This is a `map[string]interface{}` 368 | that transforms to `di.ParameterBag` type. 369 | 370 | ```go 371 | // Provide server with parameter bag 372 | inject.Provide(NewServer, inject.ParameterBag{ 373 | "addr": ":8080", 374 | }) 375 | 376 | // NewServer create a server with provided parameter bag. Note: use di.ParameterBag type. 377 | // Not inject.ParameterBag. 378 | func NewServer(pb di.ParameterBag) *http.Server { 379 | return &http.Server{ 380 | Addr: pb.RequireString("addr"), 381 | } 382 | } 383 | ``` 384 | 385 | ### Prototypes 386 | 387 | If you want to create a new instance on each extraction use 388 | `inject.Prototype()` provide option. 389 | 390 | ```go 391 | inject.Provide(NewRequestContext, inject.Prototype()) 392 | ``` 393 | 394 | > todo: real use case 395 | 396 | ### Cleanup 397 | 398 | If a provider creates a value that needs to be cleaned up, then it can 399 | return a closure to clean up the resource. 400 | 401 | ```go 402 | func NewFile(log Logger, path Path) (*os.File, func(), error) { 403 | f, err := os.Open(string(path)) 404 | if err != nil { 405 | return nil, nil, err 406 | } 407 | cleanup := func() { 408 | if err := f.Close(); err != nil { 409 | log.Log(err) 410 | } 411 | } 412 | return f, cleanup, nil 413 | } 414 | ``` 415 | 416 | After `container.Cleanup()` call, it iterate over instances and call 417 | cleanup function if it exists. 418 | 419 | ```go 420 | container := inject.New( 421 | // ... 422 | inject.Provide(NewFile), 423 | ) 424 | 425 | // do something 426 | container.Cleanup() // file was closed 427 | ``` 428 | 429 | > Cleanup now work incorrectly with prototype providers. 430 | 431 | ## Visualization 432 | 433 | Dependency graph may be presented via 434 | ([Graphviz](https://www.graphviz.org/)). For it, load string 435 | representation: 436 | 437 | ```go 438 | var graph *di.Graph 439 | if err = container.Extract(&graph); err != nil { 440 | // handle err 441 | } 442 | 443 | dotGraph := graph.String() // use string representation 444 | ``` 445 | 446 | And paste it to graphviz online tool: 448 | 449 | 450 | 451 | ## Contributing 452 | 453 | I will be glad if you contribute to this library. I don't know much 454 | English, so contributing to the documentation is very meaningful to me. 455 | 456 | [![](https://sourcerer.io/fame/defval/defval/inject/images/0)](https://sourcerer.io/fame/defval/defval/inject/links/0)[![](https://sourcerer.io/fame/defval/defval/inject/images/1)](https://sourcerer.io/fame/defval/defval/inject/links/1)[![](https://sourcerer.io/fame/defval/defval/inject/images/2)](https://sourcerer.io/fame/defval/defval/inject/links/2)[![](https://sourcerer.io/fame/defval/defval/inject/images/3)](https://sourcerer.io/fame/defval/defval/inject/links/3)[![](https://sourcerer.io/fame/defval/defval/inject/images/4)](https://sourcerer.io/fame/defval/defval/inject/links/4)[![](https://sourcerer.io/fame/defval/defval/inject/images/5)](https://sourcerer.io/fame/defval/defval/inject/links/5)[![](https://sourcerer.io/fame/defval/defval/inject/images/6)](https://sourcerer.io/fame/defval/defval/inject/links/6)[![](https://sourcerer.io/fame/defval/defval/inject/images/7)](https://sourcerer.io/fame/defval/defval/inject/links/7) 457 | 458 | -------------------------------------------------------------------------------- /_tutorial/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/defval/inject/v2" 7 | ) 8 | 9 | func main() { 10 | container := inject.New( 11 | inject.Provide(NewServer), // provide http server 12 | inject.Provide(NewServeMux), // provide http serve mux 13 | // endpoints 14 | inject.Provide(NewOrderController, inject.As(new(Controller))), // provide order controller 15 | inject.Provide(NewUserController, inject.As(new(Controller))), // provide user controller 16 | ) 17 | 18 | var server *http.Server 19 | err := container.Extract(&server) 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | server.ListenAndServe() 25 | } 26 | 27 | // NewServer creates a http server with provided mux as handler. 28 | func NewServer(mux *http.ServeMux) *http.Server { 29 | return &http.Server{ 30 | Handler: mux, 31 | } 32 | } 33 | 34 | // NewServeMux creates a new http serve mux. 35 | func NewServeMux(controllers []Controller) *http.ServeMux { 36 | mux := &http.ServeMux{} 37 | 38 | for _, controller := range controllers { 39 | controller.RegisterRoutes(mux) 40 | } 41 | 42 | return mux 43 | } 44 | 45 | // Controller is an interface that can register its routes. 46 | type Controller interface { 47 | RegisterRoutes(mux *http.ServeMux) 48 | } 49 | 50 | // OrderController is a http controller for orders. 51 | type OrderController struct{} 52 | 53 | // NewOrderController creates a auth http controller. 54 | func NewOrderController() *OrderController { 55 | return &OrderController{} 56 | } 57 | 58 | // RegisterRoutes is a Controller interface implementation. 59 | func (a *OrderController) RegisterRoutes(mux *http.ServeMux) { 60 | mux.HandleFunc("/orders", a.RetrieveOrders) 61 | } 62 | 63 | // Retrieve loads orders and writes it to the writer. 64 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, request *http.Request) { 65 | writer.WriteHeader(http.StatusOK) 66 | _, _ = writer.Write([]byte("Orders")) 67 | } 68 | 69 | // UserController is a http endpoint for a user. 70 | type UserController struct{} 71 | 72 | // NewUserController creates a user http endpoint. 73 | func NewUserController() *UserController { 74 | return &UserController{} 75 | } 76 | 77 | // RegisterRoutes is a Controller interface implementation. 78 | func (e *UserController) RegisterRoutes(mux *http.ServeMux) { 79 | mux.HandleFunc("/users", e.RetrieveUsers) 80 | } 81 | 82 | // Retrieve loads users and writes it using the writer. 83 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, request *http.Request) { 84 | writer.WriteHeader(http.StatusOK) 85 | _, _ = writer.Write([]byte("Users")) 86 | } 87 | -------------------------------------------------------------------------------- /container.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import ( 4 | "github.com/defval/inject/v2/di" 5 | ) 6 | 7 | // New creates a new container with provided options. 8 | func New(options ...Option) *Container { 9 | var c = &Container{ 10 | container: di.New(), 11 | } 12 | // apply options. 13 | for _, opt := range options { 14 | opt.apply(c) 15 | } 16 | c.compile() 17 | return c 18 | } 19 | 20 | // Container is a dependency injection container. 21 | type Container struct { 22 | providers []provide 23 | container *di.Container 24 | } 25 | 26 | // Extract populates given target pointer with type instance provided in the container. 27 | // 28 | // var server *http.Server 29 | // if err = container.Extract(&server); err != nil { 30 | // // extract failed 31 | // } 32 | // 33 | // If the target type does not exist in a container or instance type building failed, Extract() returns an error. 34 | // Use ExtractOption for modifying the behavior of this function. 35 | func (c *Container) Extract(target interface{}, options ...ExtractOption) (err error) { 36 | var params = di.ExtractParams{} 37 | // apply extract options 38 | for _, opt := range options { 39 | opt.apply(¶ms) 40 | } 41 | return c.container.Extract(target, params) 42 | } 43 | 44 | // Invoke invokes custom function. Dependencies of function will be resolved via container. 45 | func (c *Container) Invoke(fn interface{}) error { 46 | return c.container.Invoke(fn) 47 | } 48 | 49 | // Cleanup cleanup container. 50 | func (c *Container) Cleanup() { 51 | c.container.Cleanup() 52 | } 53 | 54 | func (c *Container) compile() { 55 | for _, po := range c.providers { 56 | c.container.Provide(po.provider, po.params) 57 | } 58 | c.container.Compile() 59 | return 60 | } 61 | 62 | type provide struct { 63 | provider interface{} 64 | params di.ProvideParams 65 | } 66 | -------------------------------------------------------------------------------- /container_test.go: -------------------------------------------------------------------------------- 1 | package inject_test 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/defval/inject/v2" 12 | ) 13 | 14 | func TestContainer(t *testing.T) { 15 | var HTTPBundle = inject.Bundle( 16 | inject.Provide(ProvideAddr("0.0.0.0", "8080")), 17 | inject.Provide(NewMux, inject.As(new(http.Handler))), 18 | inject.Provide(NewHTTPServer, inject.Prototype(), inject.WithName("server")), 19 | ) 20 | 21 | c := inject.New(HTTPBundle) 22 | 23 | var server1 *http.Server 24 | err := c.Extract(&server1, inject.Name("server")) 25 | require.NoError(t, err) 26 | 27 | var server2 *http.Server 28 | err = c.Extract(&server2, inject.Name("server")) 29 | require.NoError(t, err) 30 | 31 | err = c.Invoke(PrintAddr) 32 | require.NoError(t, err) 33 | } 34 | 35 | // Addr 36 | type Addr string 37 | 38 | // ProvideAddr 39 | func ProvideAddr(host string, port string) func() Addr { 40 | return func() Addr { 41 | return Addr(net.JoinHostPort(host, port)) 42 | } 43 | } 44 | 45 | // NewHTTPServer 46 | func NewHTTPServer(addr Addr, handler http.Handler) *http.Server { 47 | return &http.Server{ 48 | Addr: string(addr), 49 | Handler: handler, 50 | } 51 | } 52 | 53 | // NewMux 54 | func NewMux() *http.ServeMux { 55 | return &http.ServeMux{} 56 | } 57 | 58 | // PrintAddr 59 | func PrintAddr(addr Addr) { 60 | fmt.Println(addr) 61 | } 62 | -------------------------------------------------------------------------------- /di/container.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "github.com/defval/inject/v2/di/internal/graphkv" 8 | "github.com/defval/inject/v2/di/internal/reflection" 9 | ) 10 | 11 | // Interactor is a helper interface. 12 | type Interactor interface { 13 | Extract(target interface{}, options ...ExtractOption) error 14 | Invoke(fn interface{}, options ...InvokeOption) error 15 | } 16 | 17 | // Builder is helper interface. 18 | type Builder interface { 19 | Provide(provider interface{}, options ...ProvideOption) 20 | } 21 | 22 | // New create new container. 23 | func New() *Container { 24 | return &Container{ 25 | graph: graphkv.New(), 26 | } 27 | } 28 | 29 | // Container is a dependency injection container. 30 | type Container struct { 31 | compiled bool 32 | graph *graphkv.Graph 33 | cleanups []func() 34 | } 35 | 36 | // Provide adds constructor into container with parameters. 37 | func (c *Container) Provide(constructor interface{}, options ...ProvideOption) { 38 | params := ProvideParams{} 39 | for _, opt := range options { 40 | opt.apply(¶ms) 41 | } 42 | provider := internalProvider(newProviderConstructor(params.Name, constructor)) 43 | key := provider.Key() 44 | if c.graph.Exists(key) { 45 | panicf("The `%s` type already exists in container", provider.Key()) 46 | } 47 | if !params.IsPrototype { 48 | provider = asSingleton(provider) 49 | } 50 | // add provider to graph 51 | c.graph.Add(key, provider) 52 | // parse embed parameters 53 | for _, param := range provider.ParameterList() { 54 | if param.embed { 55 | embed := newProviderEmbed(param) 56 | c.graph.Add(embed.Key(), embed) 57 | } 58 | } 59 | // provide parameter bag 60 | if len(params.Parameters) != 0 { 61 | parameterBugProvider := createParameterBugProvider(provider.Key(), params.Parameters) 62 | c.graph.Add(parameterBugProvider.Key(), parameterBugProvider) 63 | } 64 | // process interfaces 65 | for _, iface := range params.Interfaces { 66 | c.processProviderInterface(provider, iface) 67 | } 68 | } 69 | 70 | // Compile compiles the container. It iterates over all nodes 71 | // in graph and register their parameters. 72 | func (c *Container) Compile() { 73 | graphProvider := func() *Graph { return &Graph{graph: c.graph.DOTGraph()} } 74 | interactorProvider := func() Interactor { return c } 75 | c.Provide(graphProvider) 76 | c.Provide(interactorProvider) 77 | for _, node := range c.graph.Nodes() { 78 | c.registerProviderParameters(node.Value.(internalProvider)) 79 | } 80 | if err := c.graph.CheckCycles(); err != nil { 81 | panic(err.Error()) 82 | } 83 | c.compiled = true 84 | } 85 | 86 | // Extract builds instance of target type and fills target pointer. 87 | func (c *Container) Extract(target interface{}, options ...ExtractOption) error { 88 | params := ExtractParams{} 89 | for _, opt := range options { 90 | opt.apply(¶ms) 91 | } 92 | if !c.compiled { 93 | return fmt.Errorf("container not compiled") 94 | } 95 | if target == nil { 96 | return fmt.Errorf("extract target must be a pointer, got `nil`") 97 | } 98 | if !reflection.IsPtr(target) { 99 | return fmt.Errorf("extract target must be a pointer, got `%s`", reflect.TypeOf(target)) 100 | } 101 | typ := reflect.TypeOf(target) 102 | param := parameter{ 103 | name: params.Name, 104 | res: typ.Elem(), 105 | embed: isEmbedParameter(typ), 106 | } 107 | value, err := param.ResolveValue(c) 108 | if err != nil { 109 | return err 110 | } 111 | targetValue := reflect.ValueOf(target).Elem() 112 | targetValue.Set(value) 113 | return nil 114 | } 115 | 116 | // Invoke calls provided function. 117 | func (c *Container) Invoke(fn interface{}, options ...InvokeOption) error { 118 | params := InvokeParams{} 119 | for _, opt := range options { 120 | opt.apply(¶ms) 121 | } 122 | if !c.compiled { 123 | return fmt.Errorf("container not compiled") 124 | } 125 | invoker, err := newInvoker(fn) 126 | if err != nil { 127 | return err 128 | } 129 | return invoker.Invoke(c) 130 | } 131 | 132 | // Cleanup runs destructors in order that was been created. 133 | func (c *Container) Cleanup() { 134 | for _, cleanup := range c.cleanups { 135 | cleanup() 136 | } 137 | } 138 | 139 | // processProviderInterface represents instances as interfaces and groups. 140 | func (c *Container) processProviderInterface(provider internalProvider, as interface{}) { 141 | // create interface from provider 142 | iface := newProviderInterface(provider, as) 143 | key := iface.Key() 144 | if c.graph.Exists(key) { 145 | stub := newProviderStub(key, "have several implementations") 146 | c.graph.Replace(key, stub) 147 | } else { 148 | // add interface node 149 | c.graph.Add(key, iface) 150 | } 151 | // create group 152 | group := newProviderGroup(key) 153 | groupKey := group.Key() 154 | // check exists 155 | if c.graph.Exists(groupKey) { 156 | // if exists use existing group 157 | node := c.graph.Get(groupKey) 158 | group = node.Value.(*providerGroup) 159 | } else { 160 | // else add new group to graph 161 | c.graph.Add(groupKey, group) 162 | } 163 | // add provider reference into group 164 | providerKey := provider.Key() 165 | group.Add(providerKey) 166 | } 167 | 168 | // registerProviderParameters registers provider parameters in a dependency graph. 169 | func (c *Container) registerProviderParameters(p internalProvider) { 170 | for _, param := range p.ParameterList() { 171 | provider, exists := param.ResolveProvider(c) 172 | if exists { 173 | c.graph.Edge(provider.Key(), p.Key()) 174 | continue 175 | } 176 | if !exists && !param.optional { 177 | panicf("%s: dependency %s not exists in container", p.Key(), param) 178 | } 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /di/container_test.go: -------------------------------------------------------------------------------- 1 | package di_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "reflect" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/require" 12 | 13 | "github.com/defval/inject/v2/di" 14 | "github.com/defval/inject/v2/di/internal/ditest" 15 | ) 16 | 17 | func TestContainerCompileErrors(t *testing.T) { 18 | t.Run("dependency cycle cause panic", func(t *testing.T) { 19 | c := NewTestContainer(t) 20 | c.MustProvide(ditest.NewCycleFooBar) 21 | c.MustProvide(ditest.NewBar) 22 | c.MustCompileError("the graph cannot be cyclic") 23 | }) 24 | 25 | t.Run("not existing dependency cause compile error", func(t *testing.T) { 26 | c := NewTestContainer(t) 27 | c.MustProvide(ditest.NewBar) 28 | c.MustCompileError("*ditest.Bar: dependency *ditest.Foo not exists in container") 29 | }) 30 | 31 | t.Run("not existing non pointer dependency cause compile error", func(t *testing.T) { 32 | c := NewTestContainer(t) 33 | type TestStruct struct { 34 | } 35 | 36 | c.MustProvide(func(s TestStruct) bool { 37 | return true 38 | }) 39 | 40 | require.PanicsWithValue(t, "bool: dependency di_test.TestStruct not exists in container", func() { 41 | c.Compile() 42 | }) 43 | }) 44 | } 45 | 46 | func TestContainerProvideErrors(t *testing.T) { 47 | t.Run("provide string cause panic", func(t *testing.T) { 48 | c := NewTestContainer(t) 49 | c.MustProvideError("string", "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `string`") 50 | }) 51 | 52 | t.Run("provide nil cause panic", func(t *testing.T) { 53 | c := NewTestContainer(t) 54 | c.MustProvideError(nil, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `nil`") 55 | }) 56 | 57 | t.Run("provide struct pointer cause panic", func(t *testing.T) { 58 | c := NewTestContainer(t) 59 | c.MustProvideError(&ditest.Foo{}, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `*ditest.Foo`") 60 | }) 61 | 62 | t.Run("provide constructor without result cause panic", func(t *testing.T) { 63 | c := NewTestContainer(t) 64 | c.MustProvideError(ditest.ConstructorWithoutResult, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `github.com/defval/inject/v2/di/internal/ditest.ConstructorWithoutResult`") 65 | }) 66 | 67 | t.Run("provide constructor with many results cause panic", func(t *testing.T) { 68 | c := NewTestContainer(t) 69 | c.MustProvideError(ditest.ConstructorWithManyResults, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `github.com/defval/inject/v2/di/internal/ditest.ConstructorWithManyResults`") 70 | }) 71 | 72 | t.Run("provide constructor with incorrect result error argument", func(t *testing.T) { 73 | c := NewTestContainer(t) 74 | c.MustProvideError(ditest.ConstructorWithIncorrectResultError, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `github.com/defval/inject/v2/di/internal/ditest.ConstructorWithIncorrectResultError`") 75 | }) 76 | 77 | t.Run("provide duplicate", func(t *testing.T) { 78 | c := NewTestContainer(t) 79 | c.MustProvide(ditest.NewFoo) 80 | c.MustProvideError(ditest.NewFoo, "The `*ditest.Foo` type already exists in container") 81 | }) 82 | 83 | t.Run("provide as not implemented interface cause error", func(t *testing.T) { 84 | c := NewTestContainer(t) 85 | c.MustProvide(ditest.NewFoo) 86 | c.MustProvideError(ditest.NewBar, "*ditest.Bar not implement ditest.Barer", new(ditest.Barer)) 87 | }) 88 | 89 | t.Run("provide as not interface cause error", func(t *testing.T) { 90 | c := NewTestContainer(t) 91 | c.MustProvide(ditest.NewFoo) 92 | c.MustProvideError(ditest.NewBar, "*ditest.Foo: not a pointer to interface", new(ditest.Foo)) 93 | }) 94 | } 95 | 96 | func TestContainerExtractErrors(t *testing.T) { 97 | t.Run("container panic on trying extract before compilation", func(t *testing.T) { 98 | c := NewTestContainer(t) 99 | foo := &ditest.Foo{} 100 | c.MustProvide(ditest.CreateFooConstructor(foo)) 101 | var extracted *ditest.Foo 102 | c.MustExtractError(&extracted, "container not compiled") 103 | }) 104 | 105 | t.Run("extract into string cause error", func(t *testing.T) { 106 | c := NewTestContainer(t) 107 | c.MustProvide(ditest.NewFoo) 108 | c.MustCompile() 109 | c.MustExtractError("string", "extract target must be a pointer, got `string`") 110 | }) 111 | 112 | t.Run("extract into struct cause error", func(t *testing.T) { 113 | c := NewTestContainer(t) 114 | c.MustProvide(ditest.NewFoo) 115 | c.MustCompile() 116 | c.MustExtractError(struct{}{}, "extract target must be a pointer, got `struct {}`") 117 | }) 118 | 119 | t.Run("extract into nil cause error", func(t *testing.T) { 120 | c := NewTestContainer(t) 121 | c.MustProvide(ditest.NewFoo) 122 | c.MustCompile() 123 | c.MustExtractError(nil, "extract target must be a pointer, got `nil`") 124 | }) 125 | 126 | t.Run("container does not find type because its named", func(t *testing.T) { 127 | c := NewTestContainer(t) 128 | foo := &ditest.Foo{} 129 | c.MustProvideWithName("foo", ditest.CreateFooConstructor(foo)) 130 | c.MustCompile() 131 | 132 | var extracted *ditest.Foo 133 | c.MustExtractError(&extracted, "*ditest.Foo: not exists in container") 134 | }) 135 | 136 | t.Run("extract returns error because dependency constructing failed", func(t *testing.T) { 137 | c := NewTestContainer(t) 138 | c.MustProvide(ditest.CreateFooConstructorWithError(errors.New("internal error"))) 139 | c.MustProvide(ditest.NewBar) 140 | c.MustCompile() 141 | var bar *ditest.Bar 142 | c.MustExtractError(&bar, "*ditest.Foo: internal error") 143 | }) 144 | 145 | t.Run("extract interface with multiple implementations cause error", func(t *testing.T) { 146 | c := NewTestContainer(t) 147 | c.MustProvide(ditest.NewFoo) 148 | c.MustProvide(ditest.NewBar, new(ditest.Fooer)) 149 | c.MustProvide(ditest.NewBaz, new(ditest.Fooer)) 150 | c.MustCompile() 151 | 152 | var extracted ditest.Fooer 153 | c.MustExtractError(&extracted, "ditest.Fooer: have several implementations") 154 | }) 155 | } 156 | 157 | func TestContainerInvokeErrors(t *testing.T) { 158 | t.Run("invoke function with incorrect signature cause error", func(t *testing.T) { 159 | c := NewTestContainer(t) 160 | c.MustCompile() 161 | c.MustInvokeError(func() *ditest.Foo { 162 | return nil 163 | }, "the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `func() *ditest.Foo`") 164 | }) 165 | 166 | t.Run("invoke function with undefined dependency cause error", func(t *testing.T) { 167 | c := NewTestContainer(t) 168 | c.MustCompile() 169 | c.MustInvokeError(func(foo *ditest.Foo) {}, "could not resolve invoke parameters: *ditest.Foo: not exists in container") 170 | }) 171 | 172 | t.Run("invoke before compile cause error", func(t *testing.T) { 173 | c := NewTestContainer(t) 174 | c.MustInvokeError(func() {}, "container not compiled") 175 | }) 176 | } 177 | 178 | func TestContainerProvide(t *testing.T) { 179 | t.Run("container successfully accept simple constructor", func(t *testing.T) { 180 | c := NewTestContainer(t) 181 | c.MustProvide(ditest.NewFoo) 182 | }) 183 | 184 | t.Run("container successfully accept constructor with error", func(t *testing.T) { 185 | c := NewTestContainer(t) 186 | c.MustProvide(ditest.CreateFooConstructorWithError(nil)) 187 | }) 188 | 189 | t.Run("container successfully accept constructor with cleanup function", func(t *testing.T) { 190 | c := NewTestContainer(t) 191 | 192 | cleanup := func() {} 193 | c.MustProvide(ditest.CreateFooConstructorWithCleanup(cleanup)) 194 | }) 195 | 196 | } 197 | 198 | func TestContainerExtract(t *testing.T) { 199 | t.Run("container extract correct pointer", func(t *testing.T) { 200 | c := NewTestContainer(t) 201 | foo := &ditest.Foo{} 202 | c.MustProvide(ditest.CreateFooConstructor(foo)) 203 | c.MustCompile() 204 | 205 | var extracted *ditest.Foo 206 | c.MustExtractPtr(foo, &extracted) 207 | }) 208 | 209 | t.Run("container extract same pointer on each extraction", func(t *testing.T) { 210 | c := NewTestContainer(t) 211 | foo := &ditest.Foo{} 212 | c.MustProvide(ditest.CreateFooConstructor(foo)) 213 | c.MustCompile() 214 | 215 | var extracted1 *ditest.Foo 216 | c.MustExtractPtr(foo, &extracted1) 217 | 218 | var extracted2 *ditest.Foo 219 | c.MustExtractPtr(foo, &extracted2) 220 | }) 221 | 222 | t.Run("container extract instance if error is nil", func(t *testing.T) { 223 | c := NewTestContainer(t) 224 | c.MustProvide(ditest.CreateFooConstructorWithError(nil)) 225 | c.MustCompile() 226 | 227 | var extracted *ditest.Foo 228 | c.MustExtract(&extracted) 229 | }) 230 | 231 | t.Run("container extract instance if cleanup and error is nil", func(t *testing.T) { 232 | c := NewTestContainer(t) 233 | 234 | c.MustProvide(ditest.CreateFooConstructorWithCleanupAndError(nil, nil)) 235 | c.MustCompile() 236 | 237 | var extracted *ditest.Foo 238 | c.MustExtract(&extracted) 239 | }) 240 | 241 | t.Run("container extract correct named pointer", func(t *testing.T) { 242 | c := NewTestContainer(t) 243 | foo := &ditest.Foo{} 244 | c.MustProvideWithName("foo", ditest.CreateFooConstructor(foo)) 245 | c.MustCompile() 246 | 247 | var extracted *ditest.Foo 248 | c.MustExtractWithName("foo", &extracted) 249 | }) 250 | 251 | t.Run("container extract correct interface implementation", func(t *testing.T) { 252 | c := NewTestContainer(t) 253 | bar := &ditest.Bar{} 254 | c.MustProvide(ditest.NewFoo) 255 | c.MustProvide(ditest.CreateBarConstructor(bar), new(ditest.Fooer)) 256 | c.MustCompile() 257 | 258 | var extracted ditest.Fooer 259 | c.MustExtractPtr(bar, &extracted) 260 | }) 261 | 262 | t.Run("container creates group from interface and extract it", func(t *testing.T) { 263 | c := NewTestContainer(t) 264 | c.MustProvide(ditest.NewFoo) 265 | c.MustProvide(ditest.NewBar, new(ditest.Fooer)) 266 | c.MustProvide(ditest.NewBaz, new(ditest.Fooer)) 267 | c.MustCompile() 268 | 269 | var group []ditest.Fooer 270 | c.MustExtract(&group) 271 | require.Len(t, group, 2) 272 | }) 273 | 274 | t.Run("container extract new instance of prototype by each extraction", func(t *testing.T) { 275 | c := NewTestContainer(t) 276 | c.MustProvide(ditest.NewFoo) 277 | c.MustProvidePrototype(ditest.NewBar) 278 | c.MustCompile() 279 | 280 | var extracted1 *ditest.Bar 281 | c.MustExtract(&extracted1) 282 | var extracted2 *ditest.Bar 283 | c.MustExtract(&extracted2) 284 | 285 | c.MustNotEqualPointer(extracted1, extracted2) 286 | }) 287 | 288 | t.Run("container resolve interactor", func(t *testing.T) { 289 | c := NewTestContainer(t) 290 | foo := ditest.NewFoo() 291 | c.MustProvide(ditest.CreateFooConstructor(foo)) 292 | c.MustCompile() 293 | var interactor di.Interactor 294 | c.MustExtract(&interactor) 295 | var extractedFoo *ditest.Foo 296 | require.NoError(t, interactor.Extract(&extractedFoo)) 297 | c.MustEqualPointer(foo, extractedFoo) 298 | }) 299 | } 300 | 301 | func TestContainerResolve(t *testing.T) { 302 | t.Run("container resolve correct argument", func(t *testing.T) { 303 | c := NewTestContainer(t) 304 | foo := &ditest.Foo{} 305 | c.MustProvide(ditest.CreateFooConstructor(foo)) 306 | c.MustProvide(ditest.NewBar) 307 | c.MustCompile() 308 | 309 | var bar *ditest.Bar 310 | c.MustExtract(&bar) 311 | c.MustEqualPointer(foo, bar.Foo()) 312 | }) 313 | 314 | t.Run("container resolve correct interface implementation", func(t *testing.T) { 315 | c := NewTestContainer(t) 316 | 317 | foo := ditest.NewFoo() 318 | bar := ditest.NewBar(foo) 319 | 320 | c.MustProvide(ditest.CreateFooConstructor(foo)) 321 | c.MustProvide(ditest.CreateBarConstructor(bar), new(ditest.Fooer)) 322 | c.MustProvide(ditest.NewQux) 323 | c.MustCompile() 324 | 325 | var qux *ditest.Qux 326 | c.MustExtract(&qux) 327 | c.MustEqualPointer(bar, qux.Fooer()) 328 | }) 329 | 330 | t.Run("container resolve correct group", func(t *testing.T) { 331 | c := NewTestContainer(t) 332 | 333 | c.MustProvide(ditest.NewFoo) 334 | c.MustProvide(ditest.NewBar, new(ditest.Fooer)) 335 | c.MustProvide(ditest.NewBaz, new(ditest.Fooer)) 336 | c.MustProvide(ditest.NewFooerGroup) 337 | c.MustCompile() 338 | 339 | var bar *ditest.Bar 340 | c.MustExtract(&bar) 341 | 342 | var baz *ditest.Baz 343 | c.MustExtract(&baz) 344 | 345 | var group *ditest.FooerGroup 346 | c.MustExtract(&group) 347 | require.Len(t, group.Fooers(), 2) 348 | c.MustEqualPointer(bar, group.Fooers()[0]) 349 | c.MustEqualPointer(baz, group.Fooers()[1]) 350 | }) 351 | } 352 | 353 | func TestContainerResolveEmbedParameters(t *testing.T) { 354 | t.Run("container resolve embed parameters", func(t *testing.T) { 355 | c := NewTestContainer(t) 356 | foo := ditest.NewFoo() 357 | bar := ditest.NewBar(foo) 358 | c.MustProvide(ditest.CreateFooConstructor(foo)) 359 | c.MustProvide(ditest.CreateBarConstructor(bar)) 360 | c.MustProvide(ditest.NewBazFromParameters) 361 | c.MustCompile() 362 | 363 | var extracted *ditest.Baz 364 | c.MustExtract(&extracted) 365 | c.MustEqualPointer(foo, extracted.Foo()) 366 | c.MustEqualPointer(bar, extracted.Bar()) 367 | }) 368 | 369 | t.Run("container skip optional parameter", func(t *testing.T) { 370 | c := NewTestContainer(t) 371 | foo := ditest.NewFoo() 372 | c.MustProvide(ditest.CreateFooConstructor(foo)) 373 | c.MustProvide(ditest.NewBazFromParameters) 374 | c.MustCompile() 375 | 376 | var extracted *ditest.Baz 377 | c.MustExtract(&extracted) 378 | c.MustEqualPointer(foo, extracted.Foo()) 379 | require.Nil(t, extracted.Bar()) 380 | }) 381 | 382 | t.Run("container resolve optional not existing group as nil", func(t *testing.T) { 383 | c := NewTestContainer(t) 384 | type Params struct { 385 | di.Parameter 386 | Handlers []http.Handler `di:"optional"` 387 | } 388 | c.MustProvide(func(params Params) bool { 389 | return params.Handlers == nil 390 | }) 391 | c.MustCompile() 392 | var extracted bool 393 | c.MustExtract(&extracted) 394 | require.True(t, extracted) 395 | }) 396 | 397 | t.Run("container skip private fields in parameter", func(t *testing.T) { 398 | c := NewTestContainer(t) 399 | type Param struct { 400 | di.Parameter 401 | private []http.Handler `di:"optional"` 402 | Addrs []net.Addr `di:"optional"` 403 | HaveNotTag string 404 | } 405 | c.MustProvide(func(param Param) bool { 406 | return param.Addrs == nil 407 | }) 408 | c.MustCompile() 409 | var extracted bool 410 | c.MustExtract(&extracted) 411 | require.True(t, extracted) 412 | }) 413 | } 414 | 415 | func TestContainerInvoke(t *testing.T) { 416 | t.Run("container call invoke function", func(t *testing.T) { 417 | c := NewTestContainer(t) 418 | c.MustCompile() 419 | var invokeCalled bool 420 | c.MustInvoke(func() { 421 | invokeCalled = true 422 | }) 423 | require.True(t, invokeCalled) 424 | }) 425 | 426 | t.Run("container resolve dependencies in invoke function", func(t *testing.T) { 427 | c := NewTestContainer(t) 428 | foo := ditest.NewFoo() 429 | c.MustProvide(ditest.CreateFooConstructor(foo)) 430 | c.MustCompile() 431 | c.MustInvoke(func(invokeFoo *ditest.Foo) { 432 | c.MustEqualPointer(foo, invokeFoo) 433 | }) 434 | }) 435 | 436 | t.Run("container invoke return correct error", func(t *testing.T) { 437 | c := NewTestContainer(t) 438 | c.MustProvide(ditest.NewFoo) 439 | c.Compile() 440 | c.MustInvokeError(func(foo *ditest.Foo) error { 441 | return errors.New("invoke error") 442 | }, "invoke error") 443 | }) 444 | 445 | t.Run("container invoke with nil error", func(t *testing.T) { 446 | c := NewTestContainer(t) 447 | c.MustProvide(ditest.NewFoo) 448 | c.Compile() 449 | c.MustInvoke(func(foo *ditest.Foo) error { 450 | return nil 451 | }) 452 | }) 453 | } 454 | 455 | func TestContainerResolveParameterBag(t *testing.T) { 456 | t.Run("container extract correct parameter bag for type", func(t *testing.T) { 457 | c := NewTestContainer(t) 458 | 459 | c.Provide(ditest.NewFooWithParameters, di.ProvideParams{ 460 | Parameters: di.ParameterBag{ 461 | "name": "test", 462 | }, 463 | }) 464 | 465 | c.MustCompile() 466 | 467 | var foo *ditest.Foo 468 | err := c.Extract(&foo) 469 | 470 | require.NoError(t, err) 471 | require.Equal(t, "test", foo.Name) 472 | }) 473 | 474 | t.Run("container extract correct parameter bag for named type", func(t *testing.T) { 475 | c := NewTestContainer(t) 476 | 477 | c.Provide(ditest.NewFooWithParameters, di.ProvideParams{ 478 | Name: "named", 479 | Parameters: di.ParameterBag{ 480 | "name": "test", 481 | }, 482 | }) 483 | 484 | c.MustCompile() 485 | 486 | var foo *ditest.Foo 487 | err := c.Extract(&foo, di.ExtractParams{ 488 | Name: "named", 489 | }) 490 | 491 | require.NoError(t, err) 492 | require.Equal(t, "test", foo.Name) 493 | }) 494 | } 495 | 496 | func TestContainerCleanup(t *testing.T) { 497 | t.Run("container run cleanup function after container close", func(t *testing.T) { 498 | c := NewTestContainer(t) 499 | var cleanupCalled bool 500 | c.MustProvide(ditest.CreateFooConstructorWithCleanup(func() { cleanupCalled = true })) 501 | c.MustCompile() 502 | 503 | var extracted *ditest.Foo 504 | c.MustExtract(&extracted) 505 | c.Cleanup() 506 | 507 | require.True(t, cleanupCalled) 508 | }) 509 | 510 | t.Run("cleanup run in correct order", func(t *testing.T) { 511 | c := NewTestContainer(t) 512 | var cleanupCalls []string 513 | c.MustProvide(func(bar *ditest.Bar) (*ditest.Foo, func()) { 514 | return &ditest.Foo{}, func() { cleanupCalls = append(cleanupCalls, "foo") } 515 | }) 516 | c.MustProvide(func() (*ditest.Bar, func()) { 517 | return &ditest.Bar{}, func() { cleanupCalls = append(cleanupCalls, "bar") } 518 | }) 519 | c.MustCompile() 520 | 521 | var foo *ditest.Foo 522 | c.MustExtract(&foo) 523 | c.Cleanup() 524 | require.Equal(t, []string{"bar", "foo"}, cleanupCalls) 525 | }) 526 | 527 | t.Run("cleanup for every prototyped instance", func(t *testing.T) { 528 | c := NewTestContainer(t) 529 | var cleanupCalls []string 530 | c.Provide(func() (*ditest.Foo, func()) { 531 | return &ditest.Foo{}, func() { 532 | cleanupCalls = append(cleanupCalls, fmt.Sprintf("foo_%d", len(cleanupCalls))) 533 | } 534 | }, di.ProvideParams{ 535 | IsPrototype: true, 536 | }) 537 | c.MustCompile() 538 | var foo1, foo2 *ditest.Foo 539 | c.MustExtract(&foo1) 540 | c.MustExtract(&foo2) 541 | c.Cleanup() 542 | require.Equal(t, []string{"foo_0", "foo_1"}, cleanupCalls) 543 | }) 544 | } 545 | 546 | func TestContainer_GraphVisualizing(t *testing.T) { 547 | t.Run("graph", func(t *testing.T) { 548 | c := NewTestContainer(t) 549 | 550 | c.MustProvide(ditest.NewLogger) 551 | c.MustProvide(ditest.NewServer) 552 | c.MustProvide(ditest.NewRouter, new(http.Handler)) 553 | c.MustProvide(ditest.NewAccountController, new(ditest.Controller)) 554 | c.MustProvide(ditest.NewAuthController, new(ditest.Controller)) 555 | c.MustCompile() 556 | 557 | var graph *di.Graph 558 | require.NoError(t, c.Extract(&graph)) 559 | 560 | fmt.Println(graph.String()) 561 | 562 | require.Equal(t, `digraph { 563 | subgraph cluster_s3 { 564 | ID = "cluster_s3"; 565 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; 566 | n9[color="#46494C",fontcolor="white",fontname="COURIER",label="*di.Graph",shape="box",style="filled"]; 567 | n10[color="#46494C",fontcolor="white",fontname="COURIER",label="di.Interactor",shape="box",style="filled"]; 568 | 569 | }subgraph cluster_s2 { 570 | ID = "cluster_s2"; 571 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; 572 | n6[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AccountController",shape="box",style="filled"]; 573 | n8[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AuthController",shape="box",style="filled"]; 574 | n7[color="#E54B4B",fontcolor="white",fontname="COURIER",label="[]ditest.Controller",shape="doubleoctagon",style="filled"]; 575 | n4[color="#E5984B",fontcolor="white",fontname="COURIER",label="ditest.RouterParams",shape="box",style="filled"]; 576 | 577 | }subgraph cluster_s0 { 578 | ID = "cluster_s0"; 579 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; 580 | n1[color="#46494C",fontcolor="white",fontname="COURIER",label="*log.Logger",shape="box",style="filled"]; 581 | 582 | }subgraph cluster_s1 { 583 | ID = "cluster_s1"; 584 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; 585 | n3[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.ServeMux",shape="box",style="filled"]; 586 | n2[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.Server",shape="box",style="filled"]; 587 | n5[color="#2589BD",fontcolor="white",fontname="COURIER",label="http.Handler",style="filled"]; 588 | 589 | }splines="ortho"; 590 | n6->n7[color="#949494"]; 591 | n8->n7[color="#949494"]; 592 | n3->n5[color="#949494"]; 593 | n1->n2[color="#949494"]; 594 | n1->n3[color="#949494"]; 595 | n1->n6[color="#949494"]; 596 | n1->n8[color="#949494"]; 597 | n7->n4[color="#949494"]; 598 | n4->n3[color="#949494"]; 599 | n5->n2[color="#949494"]; 600 | 601 | }`, graph.String()) 602 | }) 603 | } 604 | 605 | // NewTestContainer 606 | func NewTestContainer(t *testing.T) *TestContainer { 607 | return &TestContainer{t, di.New()} 608 | } 609 | 610 | // TestContainer 611 | type TestContainer struct { 612 | t *testing.T 613 | *di.Container 614 | } 615 | 616 | func (c *TestContainer) MustProvide(provider interface{}, as ...interface{}) { 617 | require.NotPanics(c.t, func() { 618 | c.Provide(provider, di.ProvideParams{ 619 | Interfaces: as, 620 | }) 621 | }, "provide should not panic") 622 | } 623 | 624 | func (c *TestContainer) MustProvidePrototype(provider interface{}, as ...interface{}) { 625 | require.NotPanics(c.t, func() { 626 | c.Provide(provider, di.ProvideParams{ 627 | Interfaces: as, 628 | IsPrototype: true, 629 | }) 630 | }) 631 | } 632 | 633 | func (c *TestContainer) MustProvideWithName(name string, provider interface{}, as ...interface{}) { 634 | require.NotPanics(c.t, func() { 635 | c.Provide(provider, di.ProvideParams{ 636 | Name: name, 637 | Interfaces: as, 638 | }) 639 | }) 640 | } 641 | 642 | func (c *TestContainer) MustProvideError(provider interface{}, msg string, as ...interface{}) { 643 | require.PanicsWithValue(c.t, msg, func() { 644 | c.Provide(provider, di.ProvideParams{ 645 | Interfaces: as, 646 | }) 647 | }) 648 | } 649 | 650 | func (c *TestContainer) MustCompile() { 651 | require.NotPanics(c.t, func() { 652 | c.Compile() 653 | }) 654 | } 655 | 656 | func (c *TestContainer) MustCompileError(msg string) { 657 | require.PanicsWithValue(c.t, msg, func() { 658 | c.Compile() 659 | }) 660 | } 661 | 662 | func (c *TestContainer) MustExtract(target interface{}) { 663 | require.NoError(c.t, c.Extract(target)) 664 | } 665 | 666 | func (c *TestContainer) MustExtractWithName(name string, target interface{}) { 667 | require.NoError(c.t, c.Extract(target, di.ExtractParams{ 668 | Name: name, 669 | })) 670 | } 671 | 672 | func (c *TestContainer) MustExtractError(target interface{}, msg string) { 673 | require.EqualError(c.t, c.Extract(target, di.ExtractParams{}), msg) 674 | } 675 | 676 | func (c *TestContainer) MustExtractWithNameError(name string, target interface{}, msg string) { 677 | require.EqualError(c.t, c.Extract(target, di.ExtractParams{ 678 | Name: name, 679 | }), msg) 680 | } 681 | 682 | // MustExtractPtr extract value from container into target and check that target and expected pointers are equal. 683 | func (c *TestContainer) MustExtractPtr(expected, target interface{}) { 684 | c.MustExtract(target) 685 | 686 | // indirect 687 | actual := reflect.ValueOf(target).Elem().Interface() 688 | c.MustEqualPointer(expected, actual) 689 | } 690 | 691 | func (c *TestContainer) MustExtractPtrWithName(expected interface{}, name string, target interface{}) { 692 | c.MustExtractWithName(name, target) 693 | 694 | actual := reflect.ValueOf(target).Elem().Interface() 695 | c.MustEqualPointer(expected, actual) 696 | } 697 | 698 | func (c *TestContainer) MustInvoke(fn interface{}) { 699 | require.NoError(c.t, c.Invoke(fn)) 700 | } 701 | 702 | func (c *TestContainer) MustInvokeError(fn interface{}, msg string) { 703 | require.EqualError(c.t, c.Invoke(fn), msg) 704 | } 705 | 706 | func (c *TestContainer) MustEqualPointer(expected interface{}, actual interface{}) { 707 | require.Equal(c.t, 708 | fmt.Sprintf("%p", actual), 709 | fmt.Sprintf("%p", expected), 710 | "actual and expected pointers should be equal", 711 | ) 712 | } 713 | 714 | func (c *TestContainer) MustNotEqualPointer(expected interface{}, actual interface{}) { 715 | require.NotEqual(c.t, 716 | fmt.Sprintf("%p", actual), 717 | fmt.Sprintf("%p", expected), 718 | "actual and expected pointers should not be equal", 719 | ) 720 | } 721 | -------------------------------------------------------------------------------- /di/dot.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/emicklei/dot" 7 | ) 8 | 9 | // Graph 10 | type Graph struct { 11 | graph *dot.Graph 12 | } 13 | 14 | func (g *Graph) WriteTo(writer io.Writer) { 15 | g.graph.Write(writer) 16 | } 17 | 18 | func (g *Graph) String() string { 19 | return g.graph.String() 20 | } 21 | -------------------------------------------------------------------------------- /di/errors.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "fmt" 4 | 5 | // ErrParameterProvideFailed 6 | type ErrParameterProvideFailed struct { 7 | k key 8 | err error 9 | } 10 | 11 | func (e ErrParameterProvideFailed) Error() string { 12 | return fmt.Sprintf("%s: %s", e.k, e.err) 13 | } 14 | 15 | // ErrParameterProviderNotFound 16 | type ErrParameterProviderNotFound struct { 17 | param parameter 18 | } 19 | 20 | func (e ErrParameterProviderNotFound) Error() string { 21 | return fmt.Sprintf("%s: not exists in container", e.param) 22 | } 23 | -------------------------------------------------------------------------------- /di/internal/ditest/bar.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | // Bar 4 | type Bar struct { 5 | foo *Foo 6 | } 7 | 8 | // NewBar 9 | func NewBar(foo *Foo) *Bar { 10 | return &Bar{ 11 | foo: foo, 12 | } 13 | } 14 | 15 | // CreateBarConstructor 16 | func CreateBarConstructor(bar *Bar) func(foo *Foo) *Bar { 17 | return func(foo *Foo) *Bar { 18 | bar.foo = foo 19 | return bar 20 | } 21 | } 22 | 23 | func (b *Bar) Foo() *Foo { return b.foo } 24 | -------------------------------------------------------------------------------- /di/internal/ditest/baz.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | import "github.com/defval/inject/v2/di" 4 | 5 | // Baz 6 | type Baz struct { 7 | foo *Foo 8 | bar *Bar 9 | } 10 | 11 | // NewBaz 12 | func NewBaz(foo *Foo, bar *Bar) *Baz { 13 | return &Baz{ 14 | foo: foo, 15 | bar: bar, 16 | } 17 | } 18 | 19 | // BazParameters 20 | type BazParameters struct { 21 | di.Parameter 22 | 23 | Foo *Foo `di:""` 24 | Bar *Bar `di:"optional"` 25 | } 26 | 27 | // NewBazFromParameters 28 | func NewBazFromParameters(params BazParameters) *Baz { 29 | return &Baz{ 30 | foo: params.Foo, 31 | bar: params.Bar, 32 | } 33 | } 34 | 35 | func (b *Baz) Foo() *Foo { return b.foo } 36 | func (b *Baz) Bar() *Bar { return b.bar } 37 | -------------------------------------------------------------------------------- /di/internal/ditest/foo.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | import "github.com/defval/inject/v2/di" 4 | 5 | // Foo test struct 6 | type Foo struct { 7 | Name string 8 | } 9 | 10 | // NewFoo create new foo 11 | func NewFoo() *Foo { 12 | return &Foo{} 13 | } 14 | 15 | // NewFooWithParameters 16 | func NewFooWithParameters(parameters di.ParameterBag) *Foo { 17 | return &Foo{Name: parameters.RequireString("name")} 18 | } 19 | 20 | // NewCycleFooBar 21 | func NewCycleFooBar(bar *Bar) *Foo { 22 | return &Foo{} 23 | } 24 | 25 | // CreateFooConstructor 26 | func CreateFooConstructor(foo *Foo) func() *Foo { 27 | return func() *Foo { 28 | return foo 29 | } 30 | } 31 | 32 | // CreateFooConstructorWithError 33 | func CreateFooConstructorWithError(err error) func() (*Foo, error) { 34 | return func() (foo *Foo, e error) { 35 | return &Foo{}, err 36 | } 37 | } 38 | 39 | // CreateFooConstructorWithCleanup 40 | func CreateFooConstructorWithCleanup(cleanup func()) func() (*Foo, func()) { 41 | return func() (foo *Foo, i func()) { 42 | return &Foo{}, cleanup 43 | } 44 | } 45 | 46 | // CreateFooConstructorWithCleanupAndError 47 | func CreateFooConstructorWithCleanupAndError(cleanup func(), err error) func() (*Foo, func(), error) { 48 | return func() (foo *Foo, i func(), e error) { 49 | return &Foo{}, cleanup, err 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /di/internal/ditest/fooer_group.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | // FooerGroup 4 | type FooerGroup struct { 5 | fooers []Fooer 6 | } 7 | 8 | // NewFooerGroup 9 | func NewFooerGroup(fooers []Fooer) *FooerGroup { 10 | return &FooerGroup{fooers: fooers} 11 | } 12 | 13 | func (g *FooerGroup) Fooers() []Fooer { 14 | return g.fooers 15 | } 16 | -------------------------------------------------------------------------------- /di/internal/ditest/full.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/defval/inject/v2/di" 9 | ) 10 | 11 | // NewLogger 12 | func NewLogger() *log.Logger { 13 | logger := log.New(os.Stdout, "", 0) 14 | defer logger.Println("Logger loaded!") 15 | 16 | return logger 17 | } 18 | 19 | // NewServer 20 | func NewServer(logger *log.Logger, handler http.Handler) *http.Server { 21 | defer logger.Println("Server created!") 22 | return &http.Server{ 23 | Handler: handler, 24 | } 25 | } 26 | 27 | // RouterParams 28 | type RouterParams struct { 29 | di.Parameter 30 | Controllers []Controller `di:"optional"` 31 | } 32 | 33 | // NewRouter 34 | func NewRouter(logger *log.Logger, params RouterParams) *http.ServeMux { 35 | logger.Println("Create router!") 36 | defer logger.Println("Router created!") 37 | 38 | mux := &http.ServeMux{} 39 | 40 | for _, ctrl := range params.Controllers { 41 | ctrl.RegisterRoutes(mux) 42 | } 43 | 44 | return mux 45 | } 46 | 47 | // Controller 48 | type Controller interface { 49 | RegisterRoutes(mux *http.ServeMux) 50 | } 51 | 52 | // AccountController 53 | type AccountController struct { 54 | Logger *log.Logger 55 | } 56 | 57 | // NewAccountController 58 | func NewAccountController(logger *log.Logger) *AccountController { 59 | return &AccountController{Logger: logger} 60 | } 61 | 62 | // RegisterRoutes 63 | func (c *AccountController) RegisterRoutes(mux *http.ServeMux) { 64 | c.Logger.Println("AccountController registered!") 65 | 66 | // register your routes 67 | } 68 | 69 | // AuthController 70 | type AuthController struct { 71 | Logger *log.Logger 72 | } 73 | 74 | // NewAuthController 75 | func NewAuthController(logger *log.Logger) *AuthController { 76 | return &AuthController{Logger: logger} 77 | } 78 | 79 | // RegisterRoutes 80 | func (c *AuthController) RegisterRoutes(mux *http.ServeMux) { 81 | c.Logger.Println("AuthController registered!") 82 | 83 | // register your routes 84 | } 85 | -------------------------------------------------------------------------------- /di/internal/ditest/incorrect.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | // ConstructorWithoutResult 4 | func ConstructorWithoutResult() { 5 | 6 | } 7 | 8 | // ConstructorWithManyResults 9 | func ConstructorWithManyResults() (*Foo, *Bar, error) { 10 | return &Foo{}, &Bar{}, nil 11 | } 12 | 13 | // ConstructorWithIncorrectResultError 14 | func ConstructorWithIncorrectResultError() (*Foo, *Bar) { 15 | return &Foo{}, &Bar{} 16 | } 17 | -------------------------------------------------------------------------------- /di/internal/ditest/interfaces.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | // Fooer 4 | type Fooer interface { 5 | Foo() *Foo 6 | } 7 | 8 | // Barer 9 | type Barer interface { 10 | Bar() *Bar 11 | } 12 | 13 | // Bazer 14 | type Bazer interface { 15 | Baz() *Baz 16 | } 17 | -------------------------------------------------------------------------------- /di/internal/ditest/qux.go: -------------------------------------------------------------------------------- 1 | package ditest 2 | 3 | // Qux 4 | type Qux struct { 5 | fooer Fooer 6 | } 7 | 8 | // NewQux 9 | func NewQux(foo Fooer) *Qux { 10 | return &Qux{ 11 | fooer: foo, 12 | } 13 | } 14 | 15 | func (q *Qux) Fooer() Fooer { return q.fooer } 16 | -------------------------------------------------------------------------------- /di/internal/graphkv/directed_graph.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | // directedGraph is a graph supporting directed edges between nodes. 4 | type directedGraph struct { 5 | *graph 6 | edges *directedEdgeList 7 | } 8 | 9 | // newDirectedGraph creates a graph of nodes with directed edges. 10 | func newDirectedGraph() *directedGraph { 11 | return &directedGraph{ 12 | graph: newGraph(), 13 | edges: newDirectedEdgeList(), 14 | } 15 | } 16 | 17 | // Copy returns a clone of the directed graph. 18 | func (g *directedGraph) Copy() *directedGraph { 19 | return &directedGraph{ 20 | graph: g.graph.Copy(), 21 | edges: g.edges.Copy(), 22 | } 23 | } 24 | 25 | // EdgeCount returns the number of direced edges between nodes. 26 | func (g *directedGraph) EdgeCount() int { 27 | return g.edges.Count() 28 | } 29 | 30 | // AddEdge adds the edge to the graph. 31 | func (g *directedGraph) AddEdge(from Key, to Key) { 32 | // prevent adding an edge referring to missing nodes 33 | if !g.NodeExists(from) { 34 | g.AddNode(from) 35 | } 36 | if !g.NodeExists(to) { 37 | g.AddNode(to) 38 | } 39 | 40 | g.edges.Add(from, to) 41 | } 42 | 43 | // RemoveEdge removes the edge from the graph. 44 | func (g *directedGraph) RemoveEdge(from Key, to Key) { 45 | g.edges.Remove(from, to) 46 | } 47 | 48 | // HasEdges determines whether the graph contains any edges to or from the node. 49 | func (g *directedGraph) HasEdges(node Key) bool { 50 | if g.HasIncomingEdges(node) { 51 | return true 52 | } 53 | return g.HasOutgoingEdges(node) 54 | } 55 | 56 | // EdgeExists checks whether the edge exists within the graph. 57 | func (g *directedGraph) EdgeExists(from Key, to Key) bool { 58 | return g.edges.Exists(from, to) 59 | } 60 | 61 | // HasIncomingEdges checks whether the graph contains any directed 62 | // edges pointing to the node. 63 | func (g *directedGraph) HasIncomingEdges(node Key) bool { 64 | return g.edges.HasIncomingEdges(node) 65 | } 66 | 67 | // IncomingEdges returns the nodes belonging to directed edges pointing 68 | // towards the specified node. 69 | func (g *directedGraph) IncomingEdges(node Key) []Key { 70 | return g.edges.IncomingEdges(node) 71 | } 72 | 73 | // IncomingEdgeCount returns the number of edges pointing from the specified 74 | // node (indegree). 75 | func (g *directedGraph) IncomingEdgeCount(node Key) int { 76 | return g.edges.IncomingEdgeCount(node) 77 | } 78 | 79 | // HasOutgoingEdges checks whether the graph contains any directed 80 | // edges pointing from the node. 81 | func (g *directedGraph) HasOutgoingEdges(node Key) bool { 82 | return g.edges.HasOutgoingEdges(node) 83 | } 84 | 85 | // OutgoingEdges returns the nodes belonging to directed edges pointing 86 | // from the specified node. 87 | func (g *directedGraph) OutgoingEdges(node Key) []Key { 88 | return g.edges.OutgoingEdges(node) 89 | } 90 | 91 | // OutgoingEdgeCount returns the number of edges pointing from the specified 92 | // node (outdegree). 93 | func (g *directedGraph) OutgoingEdgeCount(node Key) int { 94 | return g.edges.OutgoingEdgeCount(node) 95 | } 96 | 97 | // RootNodes finds the entry-point nodes to the graph, i.e. those without 98 | // incoming edges. 99 | func (g *directedGraph) RootNodes() []Key { 100 | results := make([]Key, 0) 101 | for _, node := range g.Nodes() { 102 | if !g.HasIncomingEdges(node) { 103 | results = append(results, node) 104 | } 105 | } 106 | return results 107 | } 108 | 109 | // IsolatedNodes finds independent nodes in the graph, i.e. those without edges. 110 | func (g *directedGraph) IsolatedNodes() []Key { 111 | results := make([]Key, 0) 112 | for _, node := range g.Nodes() { 113 | if !g.HasEdges(node) { 114 | results = append(results, node) 115 | } 116 | } 117 | return results 118 | } 119 | 120 | // AdjacencyMatrix returns a matrix indicating whether pairs of nodes are 121 | // adjacent or not within the graph. 122 | func (g *directedGraph) AdjacencyMatrix() map[Key]map[Key]bool { 123 | matrix := make(map[Key]map[Key]bool, g.NodeCount()) 124 | for _, a := range g.Nodes() { 125 | matrix[a] = make(map[Key]bool, g.NodeCount()) 126 | 127 | for _, b := range g.Nodes() { 128 | matrix[a][b] = g.EdgeExists(a, b) 129 | } 130 | } 131 | return matrix 132 | } 133 | 134 | // RemoveTransitives removes any transitive edges so that as fewest possible 135 | // edges exist while matching the reachability of the original graph. 136 | func (g *directedGraph) RemoveTransitives() { 137 | for _, a := range g.Nodes() { 138 | for _, b := range g.Nodes() { 139 | if !g.EdgeExists(a, b) { 140 | continue 141 | } 142 | for _, c := range g.Nodes() { 143 | if g.EdgeExists(b, c) { 144 | g.RemoveEdge(a, c) 145 | } 146 | } 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /di/internal/graphkv/directed_graph_test.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func newTestDirectedGraph() *directedGraph { 10 | graph := newDirectedGraph() 11 | graph.AddNodes("A", "B", "C", "D") 12 | return graph 13 | } 14 | 15 | func TestNewDirectedGraph(t *testing.T) { 16 | graph := newDirectedGraph() 17 | assert.NotNil(t, graph, "graph should not be nil") 18 | assert.Zero(t, graph.NodeCount(), "graph.NodeCount() should equal zero") 19 | assert.Empty(t, graph.Nodes(), "graph.Nodes() should equal empty") 20 | assert.Zero(t, graph.EdgeCount(), "graph.EdgeCount() should equal zero") 21 | } 22 | 23 | func TestDirectedGraphAddEdge(t *testing.T) { 24 | graph := newTestDirectedGraph() 25 | graph.AddEdge("A", "B") 26 | graph.AddEdge("B", "D") 27 | graph.AddEdge("C", "B") 28 | 29 | assert.Equal(t, 3, graph.EdgeCount(), "graph.EdgeCount() should equal 3") 30 | assert.True(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal true") 31 | assert.True(t, graph.EdgeExists("B", "D"), "graph.EdgeExists(B, D) should equal true") 32 | assert.True(t, graph.EdgeExists("C", "B"), "graph.EdgeExists(C, B) should equal true") 33 | } 34 | 35 | func TestDirectedGraphAddEdgeDuplicate(t *testing.T) { 36 | graph := newTestDirectedGraph() 37 | graph.AddEdge("A", "B") 38 | graph.AddEdge("B", "C") 39 | graph.AddEdge("B", "C") 40 | 41 | assert.Equal(t, 2, graph.EdgeCount(), "graph.EdgeCount() should equal 2") 42 | assert.True(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal true") 43 | assert.True(t, graph.EdgeExists("B", "C"), "graph.EdgeExists(B, C) should equal true") 44 | } 45 | 46 | func TestDirectedGraphAddEdgeMissingNodes(t *testing.T) { 47 | graph := newDirectedGraph() 48 | graph.AddEdge("A", "B") 49 | graph.AddEdge("B", "C") 50 | 51 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 2") 52 | assert.Equal(t, 2, graph.EdgeCount(), "graph.EdgeCount() should equal 2") 53 | assert.True(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal true") 54 | assert.True(t, graph.EdgeExists("B", "C"), "graph.EdgeExists(B, C) should equal true") 55 | } 56 | 57 | func TestDirectedGraphRemoveEdge(t *testing.T) { 58 | graph := newTestDirectedGraph() 59 | graph.AddEdge("A", "B") 60 | graph.AddEdge("B", "D") 61 | graph.AddEdge("C", "B") 62 | graph.RemoveEdge("A", "B") 63 | graph.RemoveEdge("C", "B") 64 | 65 | assert.Equal(t, 1, graph.EdgeCount(), "graph.EdgeCount() should equal 1") 66 | assert.False(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal false") 67 | assert.False(t, graph.EdgeExists("C", "B"), "graph.EdgeExists(C, B) should equal false") 68 | assert.True(t, graph.EdgeExists("B", "D"), "graph.EdgeExists(B, D) should equal true") 69 | } 70 | 71 | func TestDirectedGraphRemoveEdgeMissing(t *testing.T) { 72 | graph := newTestDirectedGraph() 73 | graph.AddEdge("C", "B") 74 | graph.RemoveEdge("D", "A") 75 | graph.RemoveEdge("C", "B") 76 | 77 | assert.Zero(t, graph.EdgeCount(), "graph.EdgeCount() should equal zero") 78 | } 79 | 80 | func TestDirectedGraphHasEdges(t *testing.T) { 81 | graph := newTestDirectedGraph() 82 | graph.AddEdge("A", "C") 83 | 84 | assert.True(t, graph.HasEdges("A"), "graph.HasEdges(A) should equal true") 85 | assert.False(t, graph.HasEdges("B"), "graph.HasEdges(B) should equal false") 86 | assert.True(t, graph.HasEdges("C"), "graph.HasEdges(C) should equal true") 87 | assert.False(t, graph.HasEdges("D"), "graph.HasEdges(D) should equal false") 88 | } 89 | 90 | func TestDirectedGraphIncomingEdgeCount(t *testing.T) { 91 | graph := newTestDirectedGraph() 92 | graph.AddEdge("A", "C") 93 | graph.AddEdge("B", "C") 94 | 95 | assert.Zero(t, graph.IncomingEdgeCount("A"), "graph.IncomingEdgeCount(A) should equal 0") 96 | assert.Zero(t, graph.IncomingEdgeCount("B"), "graph.IncomingEdgeCount(B) should equal 0") 97 | assert.Equal(t, 2, graph.IncomingEdgeCount("C"), "graph.IncomingEdgeCount(C) should equal 1") 98 | assert.Zero(t, graph.IncomingEdgeCount("D"), "graph.IncomingEdgeCount(D) should equal 0") 99 | } 100 | 101 | func TestDirectedGraphOutgoingEdgeCount(t *testing.T) { 102 | graph := newTestDirectedGraph() 103 | graph.AddEdge("A", "B") 104 | graph.AddEdge("A", "C") 105 | 106 | assert.Equal(t, 2, graph.OutgoingEdgeCount("A"), "graph.OutgoingEdgeCount(A) should equal 2") 107 | assert.Zero(t, graph.OutgoingEdgeCount("B"), "graph.OutgoingEdgeCount(B) should equal 0") 108 | assert.Zero(t, graph.OutgoingEdgeCount("C"), "graph.OutgoingEdgeCount(C) should equal 0") 109 | assert.Zero(t, graph.OutgoingEdgeCount("D"), "graph.OutgoingEdgeCount(D) should equal 0") 110 | } 111 | 112 | func TestDirectedGraphRootNodes(t *testing.T) { 113 | graph := newTestDirectedGraph() 114 | graph.AddEdge("A", "B") 115 | graph.AddEdge("B", "C") 116 | graph.AddEdge("D", "C") 117 | graph.AddEdge("E", "C") 118 | graph.AddEdge("F", "E") 119 | 120 | assert.Equal(t, []Key{"A", "D", "F"}, graph.RootNodes(), "graph.RootNodes() should equal [A, D, F]") 121 | } 122 | 123 | func TestDirectedGraphIsolatedNodes(t *testing.T) { 124 | graph := newTestDirectedGraph() 125 | graph.AddEdge("A", "C") 126 | 127 | assert.Equal(t, []Key{"B", "D"}, graph.IsolatedNodes(), "graph.IsolatedNodes() should equal [B, D]") 128 | } 129 | 130 | func TestDirectedGraphAdjacencyMatrix(t *testing.T) { 131 | graph := newTestDirectedGraph() 132 | graph.AddEdge("A", "C") 133 | graph.AddEdge("A", "B") 134 | graph.AddEdge("B", "D") 135 | graph.AddEdge("C", "A") 136 | graph.AddEdge("D", "D") 137 | 138 | expected := map[interface{}]map[interface{}]bool{ 139 | "A": map[interface{}]bool{"A": false, "B": true, "C": true, "D": false}, 140 | "B": map[interface{}]bool{"A": false, "B": false, "C": false, "D": true}, 141 | "C": map[interface{}]bool{"A": true, "B": false, "C": false, "D": false}, 142 | "D": map[interface{}]bool{"D": true, "A": false, "B": false, "C": false}, 143 | } 144 | 145 | assert.Equal(t, expected, graph.AdjacencyMatrix(), "graph.AdjacencyMatrix() should equal [B, D]") 146 | } 147 | -------------------------------------------------------------------------------- /di/internal/graphkv/edge.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | type directedEdgeList struct { 4 | outgoingEdges map[Key]*nodeList 5 | incomingEdges map[Key]*nodeList 6 | } 7 | 8 | func newDirectedEdgeList() *directedEdgeList { 9 | return &directedEdgeList{ 10 | outgoingEdges: make(map[Key]*nodeList), 11 | incomingEdges: make(map[Key]*nodeList), 12 | } 13 | } 14 | 15 | func (l *directedEdgeList) Copy() *directedEdgeList { 16 | outgoingEdges := make(map[Key]*nodeList, len(l.outgoingEdges)) 17 | for node, edges := range l.outgoingEdges { 18 | outgoingEdges[node] = edges.Copy() 19 | } 20 | 21 | incomingEdges := make(map[Key]*nodeList, len(l.incomingEdges)) 22 | for node, edges := range l.incomingEdges { 23 | incomingEdges[node] = edges.Copy() 24 | } 25 | 26 | return &directedEdgeList{ 27 | outgoingEdges: outgoingEdges, 28 | incomingEdges: incomingEdges, 29 | } 30 | } 31 | 32 | func (l *directedEdgeList) Count() int { 33 | return len(l.outgoingEdges) 34 | } 35 | 36 | func (l *directedEdgeList) HasOutgoingEdges(node Key) bool { 37 | _, ok := l.outgoingEdges[node] 38 | return ok 39 | } 40 | 41 | func (l *directedEdgeList) OutgoingEdgeCount(node Key) int { 42 | if list := l.outgoingNodeList(node, false); list != nil { 43 | return list.Count() 44 | } 45 | return 0 46 | } 47 | 48 | func (l *directedEdgeList) outgoingNodeList(node Key, create bool) *nodeList { 49 | if list, ok := l.outgoingEdges[node]; ok { 50 | return list 51 | } 52 | if create { 53 | list := newNodeList() 54 | l.outgoingEdges[node] = list 55 | return list 56 | } 57 | return nil 58 | } 59 | 60 | func (l *directedEdgeList) OutgoingEdges(node Key) []Key { 61 | if list := l.outgoingNodeList(node, false); list != nil { 62 | return list.Nodes() 63 | } 64 | return nil 65 | } 66 | 67 | func (l *directedEdgeList) HasIncomingEdges(node Key) bool { 68 | _, ok := l.incomingEdges[node] 69 | return ok 70 | } 71 | 72 | func (l *directedEdgeList) IncomingEdgeCount(node Key) int { 73 | if list := l.incomingNodeList(node, false); list != nil { 74 | return list.Count() 75 | } 76 | return 0 77 | } 78 | 79 | func (l *directedEdgeList) incomingNodeList(node Key, create bool) *nodeList { 80 | if list, ok := l.incomingEdges[node]; ok { 81 | return list 82 | } 83 | if create { 84 | list := newNodeList() 85 | l.incomingEdges[node] = list 86 | return list 87 | } 88 | return nil 89 | } 90 | 91 | func (l *directedEdgeList) IncomingEdges(node Key) []Key { 92 | if list := l.incomingNodeList(node, false); list != nil { 93 | return list.Nodes() 94 | } 95 | return nil 96 | } 97 | 98 | func (l *directedEdgeList) Add(from Key, to Key) { 99 | l.outgoingNodeList(from, true).Add(to) 100 | l.incomingNodeList(to, true).Add(from) 101 | } 102 | 103 | func (l *directedEdgeList) Remove(from Key, to Key) { 104 | if list := l.outgoingNodeList(from, false); list != nil { 105 | list.Remove(to) 106 | 107 | if list.Count() == 0 { 108 | delete(l.outgoingEdges, from) 109 | } 110 | } 111 | if list := l.incomingNodeList(to, false); list != nil { 112 | list.Remove(from) 113 | 114 | if list.Count() == 0 { 115 | delete(l.incomingEdges, to) 116 | } 117 | } 118 | } 119 | 120 | func (l *directedEdgeList) Exists(from Key, to Key) bool { 121 | if list := l.outgoingNodeList(from, false); list != nil { 122 | return list.Exists(to) 123 | } 124 | return false 125 | } 126 | -------------------------------------------------------------------------------- /di/internal/graphkv/errors.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import "fmt" 4 | 5 | // ErrKeyAlreadyExists 6 | type ErrKeyAlreadyExists struct { 7 | Key Key 8 | } 9 | 10 | func (e ErrKeyAlreadyExists) Error() string { 11 | return fmt.Sprintf("%s already exists", e.Key) 12 | } 13 | 14 | // ErrNodeNotExists 15 | type ErrNodeNotExists struct { 16 | Key Key 17 | } 18 | 19 | // ErrNodeNotExists 20 | func (e ErrNodeNotExists) Error() string { 21 | return fmt.Sprintf("%s not exists", e.Key) 22 | } 23 | -------------------------------------------------------------------------------- /di/internal/graphkv/graph.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | type graph struct { 4 | nodes *nodeList 5 | } 6 | 7 | func newGraph() *graph { 8 | return &graph{ 9 | nodes: newNodeList(), 10 | } 11 | } 12 | 13 | // Copy returns a clone of the graph. 14 | func (g *graph) Copy() *graph { 15 | return &graph{ 16 | nodes: g.nodes.Copy(), 17 | } 18 | } 19 | 20 | // Nodes returns the graph's nodes. 21 | // The slice is mutable for performance reasons but should not be mutated. 22 | func (g *graph) Nodes() []Key { 23 | return g.nodes.Nodes() 24 | } 25 | 26 | // NodeCount returns the number of nodes. 27 | func (g *graph) NodeCount() int { 28 | return g.nodes.Count() 29 | } 30 | 31 | // AddNode inserts the specified node into the graph. 32 | // A node can be any value, e.g. int, string, pointer to a struct, map etc. 33 | // Duplicate nodes are ignored. 34 | func (g *graph) AddNode(node Key) { 35 | g.AddNodes(node) 36 | } 37 | 38 | // AddNodes inserts the specified nodes into the graph. 39 | // A node can be any value, e.g. int, string, pointer to a struct, map etc. 40 | // Duplicate nodes are ignored. 41 | func (g *graph) AddNodes(nodes ...Key) { 42 | g.nodes.Add(nodes...) 43 | } 44 | 45 | // RemoveNode removes the specified nodes from the graph. 46 | // If the node does not exist within the graph the call will fail silently. 47 | func (g *graph) RemoveNode(node Key) { 48 | g.RemoveNodes(node) 49 | } 50 | 51 | // RemoveNodes removes the specified nodes from the graph. 52 | // If a node does not exist within the graph the call will fail silently. 53 | func (g *graph) RemoveNodes(nodes ...Key) { 54 | g.nodes.Remove(nodes...) 55 | } 56 | 57 | // NodeExists determines whether the specified node exists within the graph. 58 | func (g *graph) NodeExists(node Key) bool { 59 | return g.nodes.Exists(node) 60 | } 61 | -------------------------------------------------------------------------------- /di/internal/graphkv/graph_test.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewGraph(t *testing.T) { 12 | graph := newGraph() 13 | assert.NotNil(t, graph, "graph should not be nil") 14 | assert.Zero(t, graph.NodeCount(), "graph.NodeCount() should equal zero") 15 | assert.Empty(t, graph.Nodes(), "graph.Nodes() should equal empty") 16 | } 17 | 18 | func TestGraphAddNode(t *testing.T) { 19 | graph := newGraph() 20 | graph.AddNode("A") 21 | graph.AddNode("B") 22 | graph.AddNode("C") 23 | 24 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 3") 25 | assert.Equal(t, []interface{}{"A", "B", "C"}, graph.Nodes(), "graph.Nodes() should equal [A, B, C]") 26 | } 27 | 28 | func TestGraphAddNodeDuplicate(t *testing.T) { 29 | graph := newGraph() 30 | graph.AddNode("A") 31 | graph.AddNode("B") 32 | graph.AddNode("C") 33 | graph.AddNode("A") 34 | 35 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 3") 36 | assert.Equal(t, []interface{}{"A", "B", "C"}, graph.Nodes(), "graph.Nodes() should equal [A, B, C]") 37 | } 38 | 39 | func BenchmarkGraphAddNodes(b *testing.B) { 40 | for i := 12.0; i <= 20; i++ { 41 | count := int(math.Pow(2, i)) 42 | 43 | b.Run(fmt.Sprintf("%d", count), func(b *testing.B) { 44 | graph := newGraph() 45 | for i := 0; i < count; i++ { 46 | graph.AddNode(i) 47 | } 48 | }) 49 | } 50 | } 51 | 52 | func TestGraphAddNodes(t *testing.T) { 53 | graph := newGraph() 54 | graph.AddNodes("A", "B", "C") 55 | 56 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 3") 57 | assert.Equal(t, []interface{}{"A", "B", "C"}, graph.Nodes(), "graph.Nodes() should equal [A, B, C]") 58 | } 59 | 60 | func TestGraphRemoveNode(t *testing.T) { 61 | graph := newGraph() 62 | graph.AddNode("A") 63 | graph.AddNode("B") 64 | graph.AddNode("C") 65 | graph.AddNode("D") 66 | graph.RemoveNode("A") 67 | graph.RemoveNode("C") 68 | 69 | assert.Equal(t, 2, graph.NodeCount(), "graph.NodeCount() should equal 2") 70 | assert.Equal(t, []interface{}{"B", "D"}, graph.Nodes(), "graph.Nodes() should equal [B, D]") 71 | } 72 | 73 | func TestGraphRemoveNodeMissing(t *testing.T) { 74 | graph := newGraph() 75 | graph.AddNode("A") 76 | graph.AddNode("B") 77 | graph.AddNode("C") 78 | graph.AddNode("D") 79 | graph.RemoveNode("A") 80 | graph.RemoveNode("A") 81 | graph.RemoveNode("E") 82 | 83 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 2") 84 | assert.Equal(t, []interface{}{"B", "C", "D"}, graph.Nodes(), "graph.Nodes() should equal [B, C, D]") 85 | } 86 | 87 | func TestGraphRemoveNodes(t *testing.T) { 88 | graph := newGraph() 89 | graph.AddNode("A") 90 | graph.AddNode("B") 91 | graph.AddNode("C") 92 | graph.AddNode("D") 93 | graph.RemoveNodes("A", "C") 94 | 95 | assert.Equal(t, 2, graph.NodeCount(), "graph.NodeCount() should equal 2") 96 | assert.Equal(t, []interface{}{"B", "D"}, graph.Nodes(), "graph.Nodes() should equal [B, D]") 97 | } 98 | 99 | func TestGraphNodeExists(t *testing.T) { 100 | graph := newGraph() 101 | assert.False(t, graph.NodeExists("A"), "graph.NodeExists(\"A\") should equal false") 102 | assert.False(t, graph.NodeExists("B"), "graph.NodeExists(\"B\") should equal false") 103 | 104 | graph.AddNode("A") 105 | assert.True(t, graph.NodeExists("A"), "graph.NodeExists(\"A\") should equal true") 106 | assert.False(t, graph.NodeExists("B"), "graph.NodeExists(\"B\") should equal false") 107 | 108 | graph.RemoveNode("A") 109 | assert.False(t, graph.NodeExists("A"), "graph.NodeExists(\"A\") should equal false") 110 | assert.False(t, graph.NodeExists("B"), "graph.NodeExists(\"B\") should equal false") 111 | } 112 | -------------------------------------------------------------------------------- /di/internal/graphkv/graphkv.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import ( 4 | "github.com/emicklei/dot" 5 | ) 6 | 7 | // Node 8 | type Node struct { 9 | Key Key 10 | Value interface{} 11 | } 12 | 13 | // Graph 14 | type Graph struct { 15 | dag *directedGraph 16 | values map[Key]interface{} 17 | } 18 | 19 | // New 20 | // AddNode 21 | // NodeExists 22 | // Sort 23 | // AddEdge 24 | func New() *Graph { 25 | return &Graph{ 26 | dag: newDirectedGraph(), 27 | values: map[Key]interface{}{}, 28 | } 29 | } 30 | 31 | // Get 32 | func (g *Graph) Get(key Key) Node { 33 | return Node{Key: key, Value: g.values[key]} 34 | } 35 | 36 | // Replace 37 | func (g *Graph) Replace(key Key, value interface{}) { 38 | g.values[key] = value 39 | } 40 | 41 | // Add 42 | func (g *Graph) Add(key Key, value interface{}) { 43 | g.dag.AddNode(key) 44 | g.values[key] = value 45 | } 46 | 47 | // Edge 48 | func (g *Graph) Edge(from Key, to Key) { 49 | g.dag.AddEdge(from, to) 50 | } 51 | 52 | // Exists 53 | func (g *Graph) Exists(key Key) bool { 54 | return g.dag.NodeExists(key) 55 | } 56 | 57 | // Nodes 58 | func (g *Graph) Nodes() []Node { 59 | var nodes []Node 60 | for _, key := range g.dag.Nodes() { 61 | nodes = append(nodes, Node{key, g.values[key]}) 62 | } 63 | return nodes 64 | } 65 | 66 | // CheckCycles 67 | func (g *Graph) CheckCycles() error { 68 | _, err := g.dag.DFSSort() 69 | return err // todo: errors 70 | } 71 | 72 | // DOTGraph 73 | func (g *Graph) DOTGraph() *dot.Graph { 74 | return g.dag.DOTGraph() 75 | } 76 | -------------------------------------------------------------------------------- /di/internal/graphkv/graphkv_test.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | // todo: tests 4 | -------------------------------------------------------------------------------- /di/internal/graphkv/key.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | // Key represents a graph node. 4 | type Key = interface{} 5 | 6 | type nodeList struct { 7 | nodes []Key 8 | set map[Key]bool 9 | } 10 | 11 | func newNodeList() *nodeList { 12 | return &nodeList{ 13 | nodes: make([]Key, 0), 14 | set: make(map[Key]bool), 15 | } 16 | } 17 | 18 | func (l *nodeList) Copy() *nodeList { 19 | nodes := make([]Key, len(l.nodes)) 20 | copy(nodes, l.nodes) 21 | 22 | set := make(map[Key]bool, len(nodes)) 23 | for _, node := range nodes { 24 | set[node] = true 25 | } 26 | 27 | return &nodeList{ 28 | nodes: nodes, 29 | set: set, 30 | } 31 | } 32 | 33 | func (l *nodeList) Nodes() []Key { 34 | return l.nodes 35 | } 36 | 37 | func (l *nodeList) Count() int { 38 | return len(l.nodes) 39 | } 40 | 41 | func (l *nodeList) Exists(node Key) bool { 42 | _, ok := l.set[node] 43 | return ok 44 | } 45 | 46 | func (l *nodeList) Add(nodes ...Key) { 47 | for _, node := range nodes { 48 | if l.Exists(node) { 49 | continue 50 | } 51 | 52 | l.nodes = append(l.nodes, node) 53 | l.set[node] = true 54 | } 55 | } 56 | 57 | func (l *nodeList) Remove(nodes ...Key) { 58 | for i := len(l.nodes) - 1; i >= 0; i-- { 59 | for j, node := range nodes { 60 | if l.nodes[i] == node { 61 | copy(l.nodes[i:], l.nodes[i+1:]) 62 | l.nodes[len(l.nodes)-1] = nil 63 | l.nodes = l.nodes[:len(l.nodes)-1] 64 | 65 | delete(l.set, node) 66 | 67 | copy(nodes[j:], nodes[j+1:]) 68 | nodes[len(nodes)-1] = nil 69 | nodes = nodes[:len(nodes)-1] 70 | 71 | break 72 | } 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /di/internal/graphkv/output.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/emicklei/dot" 7 | ) 8 | 9 | // NodeVisualizer 10 | type NodeVisualizer interface { 11 | Visualize(node *dot.Node) 12 | SubGraph() string 13 | IsAlwaysVisible() bool 14 | } 15 | 16 | // DOTGraph returns a textual representation of the graph in the DOT graph 17 | // description language. 18 | func (g *directedGraph) DOTGraph() *dot.Graph { 19 | root := dot.NewGraph(dot.Directed) 20 | root.Attr("splines", "ortho") 21 | 22 | subgraphs := make(map[string]*dot.Graph) 23 | itemsByNode := make(map[Key]dot.Node) 24 | for _, node := range g.Nodes() { 25 | nv := node.(NodeVisualizer) 26 | 27 | if !g.HasOutgoingEdges(node) && !nv.IsAlwaysVisible() { 28 | continue 29 | } 30 | 31 | name := fmt.Sprintf("%s", node) 32 | subgraph, ok := subgraphs[nv.SubGraph()] 33 | if !ok { 34 | subgraph = root.Subgraph(nv.SubGraph(), dot.ClusterOption{}) 35 | subgraphs[nv.SubGraph()] = subgraph 36 | applySubGraphStyle(subgraph) 37 | } 38 | item := subgraph.Node(name) 39 | nv.Visualize(&item) 40 | itemsByNode[node] = item 41 | 42 | } 43 | 44 | for fromNode, fromItem := range itemsByNode { 45 | for _, toNode := range g.OutgoingEdges(fromNode) { 46 | if toItem, ok := itemsByNode[toNode]; ok { 47 | root.Edge(fromItem, toItem).Attr("color", "#949494") 48 | } 49 | } 50 | } 51 | 52 | return root 53 | } 54 | 55 | func applySubGraphStyle(graph *dot.Graph) { 56 | graph.Attr("label", "") 57 | graph.Attr("style", "rounded") 58 | graph.Attr("bgcolor", "#E8E8E8") 59 | graph.Attr("color", "lightgrey") 60 | graph.Attr("fontname", "COURIER") 61 | graph.Attr("fontcolor", "#46494C") 62 | } 63 | -------------------------------------------------------------------------------- /di/internal/graphkv/output_test.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | -------------------------------------------------------------------------------- /di/internal/graphkv/sort.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | // Errors relating to the DFSSorter. 8 | var ( 9 | ErrCyclicGraph = errors.New("the graph cannot be cyclic") 10 | ) 11 | 12 | // DFSSorter topologically sorts a directed graph's nodes based on the 13 | // directed edges between them using the Depth-first search algorithm. 14 | type DFSSorter struct { 15 | graph *directedGraph 16 | sorted []Key 17 | visiting map[Key]bool 18 | discovered map[Key]bool 19 | } 20 | 21 | // NewDFSSorter returns a new DFS sorter. 22 | func NewDFSSorter(graph *directedGraph) *DFSSorter { 23 | return &DFSSorter{ 24 | graph: graph, 25 | } 26 | } 27 | 28 | func (s *DFSSorter) init() { 29 | s.sorted = make([]Key, 0, s.graph.NodeCount()) 30 | s.visiting = make(map[Key]bool) 31 | s.discovered = make(map[Key]bool, s.graph.NodeCount()) 32 | } 33 | 34 | // Sort returns the sorted nodes. 35 | func (s *DFSSorter) Sort() ([]Key, error) { 36 | s.init() 37 | 38 | // > while there are unmarked nodes do 39 | for _, node := range s.graph.Nodes() { 40 | if err := s.visit(node); err != nil { 41 | return nil, err 42 | } 43 | } 44 | 45 | // as the nodes were appended to the slice for performance reasons, 46 | // rather than prepended as correctly stated by the algorithm, 47 | // we need to reverse the sorted slice 48 | for i, j := 0, len(s.sorted)-1; i < j; i, j = i+1, j-1 { 49 | s.sorted[i], s.sorted[j] = s.sorted[j], s.sorted[i] 50 | } 51 | 52 | return s.sorted, nil 53 | } 54 | 55 | // See https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search 56 | func (s *DFSSorter) visit(node Key) error { 57 | // > if n has a permanent mark then return 58 | if discovered, ok := s.discovered[node]; ok && discovered { 59 | return nil 60 | } 61 | // > if n has a temporary mark then stop (not a DAG) 62 | if visiting, ok := s.visiting[node]; ok && visiting { 63 | return ErrCyclicGraph 64 | } 65 | 66 | // > mark n temporarily 67 | s.visiting[node] = true 68 | 69 | // > for each node m with an edge from n to m do 70 | for _, outgoing := range s.graph.OutgoingEdges(node) { 71 | if err := s.visit(outgoing); err != nil { 72 | return err 73 | } 74 | } 75 | 76 | s.discovered[node] = true 77 | delete(s.visiting, node) 78 | 79 | s.sorted = append(s.sorted, node) 80 | return nil 81 | } 82 | 83 | // DFSSort returns the graph's nodes in topological order based on the 84 | // directed edges between them using the Depth-first search algorithm. 85 | func (g *directedGraph) DFSSort() ([]Key, error) { 86 | sorter := NewDFSSorter(g) 87 | return sorter.Sort() 88 | } 89 | 90 | // Errors relating to the CoffmanGrahamSorter. 91 | var ( 92 | ErrDependencyOrder = errors.New("the topological dependency order is incorrect") 93 | ) 94 | 95 | // CoffmanGrahamSorter sorts a graph's nodes into a sequence of levels, 96 | // arranging so that a node which comes after another in the order is 97 | // assigned to a lower level, and that a level never exceeds the width. 98 | // See https://en.wikipedia.org/wiki/Coffman–Graham_algorithm 99 | type CoffmanGrahamSorter struct { 100 | graph *directedGraph 101 | width int 102 | } 103 | 104 | // NewCoffmanGrahamSorter returns a new Coffman-Graham sorter. 105 | func NewCoffmanGrahamSorter(graph *directedGraph, width int) *CoffmanGrahamSorter { 106 | return &CoffmanGrahamSorter{ 107 | graph: graph, 108 | width: width, 109 | } 110 | } 111 | 112 | // Sort returns the sorted nodes. 113 | func (s *CoffmanGrahamSorter) Sort() ([][]Key, error) { 114 | // create a copy of the graph and remove transitive edges 115 | reduced := s.graph.Copy() 116 | reduced.RemoveTransitives() 117 | 118 | // topologically sort the graph nodes 119 | nodes, err := reduced.DFSSort() 120 | if err != nil { 121 | return nil, err 122 | } 123 | 124 | layers := make([][]Key, 0) 125 | levels := make(map[Key]int, len(nodes)) 126 | 127 | for _, node := range nodes { 128 | dependantLevel := -1 129 | for _, dependant := range reduced.IncomingEdges(node) { 130 | level, ok := levels[dependant] 131 | if !ok { 132 | return nil, ErrDependencyOrder 133 | } 134 | if level > dependantLevel { 135 | dependantLevel = level 136 | } 137 | } 138 | 139 | level := -1 140 | // find the first unfilled layer outgoing the dependent layer 141 | // skip this if the dependent layer is the last 142 | if dependantLevel < len(layers)-1 { 143 | for i := dependantLevel + 1; i < len(layers); i++ { 144 | // ensure the layer doesn't exceed the desired width 145 | if len(layers[i]) < s.width { 146 | level = i 147 | break 148 | } 149 | } 150 | } 151 | // create a new layer new none was found 152 | if level == -1 { 153 | layers = append(layers, make([]Key, 0, 1)) 154 | level = len(layers) - 1 155 | } 156 | 157 | layers[level] = append(layers[level], node) 158 | levels[node] = level 159 | } 160 | 161 | return layers, nil 162 | } 163 | 164 | // CoffmanGrahamSort sorts the graph's nodes into a sequence of levels, 165 | // arranging so that a node which comes after another in the order is 166 | // assigned to a lower level, and that a level never exceeds the specified width. 167 | func (g *directedGraph) CoffmanGrahamSort(width int) ([][]Key, error) { 168 | sorter := NewCoffmanGrahamSorter(g, width) 169 | return sorter.Sort() 170 | } 171 | -------------------------------------------------------------------------------- /di/internal/graphkv/sort_test.go: -------------------------------------------------------------------------------- 1 | package graphkv 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDFSSorter(t *testing.T) { 10 | graph := newDirectedGraph() 11 | graph.AddNodes(0, 1, 2, 3, 4, 5, 6, 7) 12 | graph.AddEdge(0, 2) 13 | graph.AddEdge(1, 2) 14 | graph.AddEdge(1, 5) 15 | graph.AddEdge(1, 6) 16 | graph.AddEdge(2, 5) 17 | graph.AddEdge(3, 5) 18 | graph.AddEdge(5, 6) 19 | graph.AddEdge(5, 7) 20 | 21 | sorted, err := graph.DFSSort() 22 | 23 | assert.NoError(t, err, "graph.DFSSort() error should be nil") 24 | assert.Equal(t, []Key{4, 3, 1, 0, 2, 5, 7, 6}, sorted, "graph.DFSSort() nodes should equal [4, 3, 1, 0, 2, 5, 7, 6]") 25 | } 26 | 27 | func TestDFSSorterCyclic(t *testing.T) { 28 | graph := newDirectedGraph() 29 | graph.AddNodes(0, 1) 30 | graph.AddEdge(0, 1) 31 | graph.AddEdge(1, 0) 32 | 33 | sorted, err := graph.DFSSort() 34 | 35 | assert.EqualError(t, err, ErrCyclicGraph.Error(), "graph.DFSSort() error should be ErrCyclicGraph") 36 | assert.Nil(t, sorted, "graph.DFSSort() nodes should be nil") 37 | } 38 | 39 | func TestCoffmanGrahamSorter(t *testing.T) { 40 | graph := newDirectedGraph() 41 | 42 | graph.AddNodes(0, 1, 2, 3, 4, 5, 6, 7, 8) 43 | graph.AddEdge(0, 2) 44 | graph.AddEdge(0, 5) 45 | graph.AddEdge(1, 2) 46 | graph.AddEdge(2, 3) 47 | graph.AddEdge(2, 4) 48 | graph.AddEdge(3, 6) 49 | graph.AddEdge(4, 6) 50 | graph.AddEdge(5, 7) 51 | graph.AddEdge(6, 7) 52 | graph.AddEdge(6, 8) 53 | 54 | sorted, err := graph.CoffmanGrahamSort(2) 55 | 56 | assert.NoError(t, err, "graph.CoffmanGrahamSort(2)0 error should be nil") 57 | assert.Equal(t, [][]Key{ 58 | []Key{1, 0}, 59 | []Key{5, 2}, 60 | []Key{4, 3}, 61 | []Key{6}, 62 | []Key{8, 7}, 63 | }, sorted, "graph.CoffmanGrahamSort(2) nodes should equal [[1, 0], [5, 2], [4, 3], [6], [8, 7]]") 64 | } 65 | 66 | func TestCoffmanGrahamSorterCyclic(t *testing.T) { 67 | graph := newDirectedGraph() 68 | 69 | graph.AddNodes(0, 1, 2, 3, 4, 5, 6, 7, 8) 70 | graph.AddEdge(0, 2) 71 | graph.AddEdge(0, 5) 72 | graph.AddEdge(1, 2) 73 | graph.AddEdge(2, 0) // cyclic edge 74 | graph.AddEdge(2, 3) 75 | graph.AddEdge(2, 4) 76 | graph.AddEdge(3, 6) 77 | graph.AddEdge(4, 6) 78 | graph.AddEdge(5, 7) 79 | graph.AddEdge(6, 7) 80 | graph.AddEdge(6, 8) 81 | 82 | sorted, err := graph.CoffmanGrahamSort(2) 83 | 84 | assert.EqualError(t, err, ErrCyclicGraph.Error(), "graph.CoffmanGrahamSort(2) error should be ErrCyclicGraph") 85 | assert.Nil(t, sorted, "graph.CoffmanGrahamSort(2) nodes should be nil") 86 | } 87 | -------------------------------------------------------------------------------- /di/internal/reflection/func.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "runtime" 7 | ) 8 | 9 | // IsFunc 10 | func IsFunc(value interface{}) bool { 11 | return reflect.ValueOf(value).Kind() == reflect.Func 12 | } 13 | 14 | // Func 15 | type Func struct { 16 | Name string 17 | reflect.Type 18 | reflect.Value 19 | } 20 | 21 | // InspectFunction 22 | func InspectFunction(fn interface{}) *Func { 23 | if !IsFunc(fn) { 24 | panic(fmt.Sprintf("%s: not a function", reflect.TypeOf(fn).Kind())) // todo: improve message 25 | } 26 | 27 | val := reflect.ValueOf(fn) 28 | fnpc := runtime.FuncForPC(val.Pointer()) 29 | 30 | return &Func{ 31 | Name: fnpc.Name(), 32 | Type: val.Type(), 33 | Value: val, 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /di/internal/reflection/iface.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // InspectInterfacePtr 9 | func InspectInterfacePtr(iface interface{}) *Interface { 10 | typ := reflect.TypeOf(iface) 11 | if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Interface { 12 | panic(fmt.Sprintf("%s: not a pointer to interface", typ)) // todo: improve message 13 | } 14 | 15 | return &Interface{ 16 | Name: typ.Elem().Name(), 17 | Type: typ.Elem(), 18 | } 19 | } 20 | 21 | // Interface 22 | type Interface struct { 23 | Name string 24 | Type reflect.Type 25 | } 26 | -------------------------------------------------------------------------------- /di/internal/reflection/reflection.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import "reflect" 4 | 5 | var errorInterface = reflect.TypeOf(new(error)).Elem() 6 | 7 | // IsError 8 | func IsError(typ reflect.Type) bool { 9 | return typ.Implements(errorInterface) 10 | } 11 | 12 | // IsCleanup 13 | func IsCleanup(typ reflect.Type) bool { 14 | return typ.Kind() == reflect.Func && typ.NumIn() == 0 && typ.NumOut() == 0 15 | } 16 | 17 | // IsPtr 18 | func IsPtr(value interface{}) bool { 19 | return reflect.ValueOf(value).Kind() == reflect.Ptr 20 | } 21 | -------------------------------------------------------------------------------- /di/invoker.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "github.com/defval/inject/v2/di/internal/reflection" 8 | ) 9 | 10 | type invokerType int 11 | 12 | const ( 13 | invokerUnknown invokerType = iota 14 | invokerStd // func (deps) {} 15 | invokerError // func (deps) error {} 16 | ) 17 | 18 | func determineInvokerType(fn *reflection.Func) (invokerType, error) { 19 | if fn.NumOut() == 0 { 20 | return invokerStd, nil 21 | } 22 | if fn.NumOut() == 1 && reflection.IsError(fn.Out(0)) { 23 | return invokerError, nil 24 | } 25 | return invokerUnknown, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", fn.Type) 26 | } 27 | 28 | type invoker struct { 29 | typ invokerType 30 | fn *reflection.Func 31 | } 32 | 33 | func newInvoker(fn interface{}) (*invoker, error) { 34 | if fn == nil { 35 | return nil, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", "nil") 36 | } 37 | if !reflection.IsFunc(fn) { 38 | return nil, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", reflect.ValueOf(fn).Type()) 39 | } 40 | ifn := reflection.InspectFunction(fn) 41 | typ, err := determineInvokerType(ifn) 42 | if err != nil { 43 | return nil, err 44 | } 45 | return &invoker{ 46 | typ: typ, 47 | fn: reflection.InspectFunction(fn), 48 | }, nil 49 | } 50 | 51 | func (i *invoker) Invoke(c *Container) error { 52 | plist := i.parameters() 53 | values, err := plist.Resolve(c) 54 | if err != nil { 55 | return fmt.Errorf("could not resolve invoke parameters: %s", err) 56 | } 57 | results := i.fn.Call(values) 58 | if len(results) == 0 { 59 | return nil 60 | } 61 | if results[0].Interface() == nil { 62 | return nil 63 | } 64 | return results[0].Interface().(error) 65 | } 66 | 67 | func (i *invoker) parameters() parameterList { 68 | var plist parameterList 69 | for j := 0; j < i.fn.NumIn(); j++ { 70 | ptype := i.fn.In(j) 71 | p := parameter{ 72 | res: ptype, 73 | optional: false, 74 | embed: isEmbedParameter(ptype), 75 | } 76 | plist = append(plist, p) 77 | } 78 | return plist 79 | } 80 | -------------------------------------------------------------------------------- /di/key.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "github.com/emicklei/dot" 8 | ) 9 | 10 | // key is a id of provider in container 11 | type key struct { 12 | name string 13 | res reflect.Type 14 | typ providerType 15 | } 16 | 17 | // String represent resultKey as string. 18 | func (k key) String() string { 19 | if k.name == "" { 20 | return fmt.Sprintf("%s", k.res) 21 | } 22 | return fmt.Sprintf("%s[%s]", k.res, k.name) 23 | } 24 | 25 | // IsAlwaysVisible 26 | func (k key) IsAlwaysVisible() bool { 27 | return k.typ == ptConstructor 28 | } 29 | 30 | // Package 31 | func (k key) SubGraph() string { 32 | var pkg string 33 | switch k.res.Kind() { 34 | case reflect.Slice, reflect.Ptr: 35 | pkg = k.res.Elem().PkgPath() 36 | default: 37 | pkg = k.res.PkgPath() 38 | } 39 | 40 | return pkg 41 | } 42 | 43 | // Visualize 44 | func (k key) Visualize(node *dot.Node) { 45 | node.Label(k.String()) 46 | node.Attr("fontname", "COURIER") 47 | node.Attr("style", "filled") 48 | node.Attr("fontcolor", "white") 49 | switch k.typ { 50 | case ptConstructor: 51 | node.Attr("shape", "box") 52 | node.Attr("color", "#46494C") 53 | case ptGroup: 54 | node.Attr("shape", "doubleoctagon") 55 | node.Attr("color", "#E54B4B") 56 | case ptInterface: 57 | node.Attr("color", "#2589BD") 58 | case ptEmbedParameter: 59 | node.Attr("shape", "box") 60 | node.Attr("color", "#E5984B") 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /di/options.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | // ExtractOption 4 | type ProvideOption interface { 5 | apply(params *ProvideParams) 6 | } 7 | 8 | type provideOption func(params *ProvideParams) 9 | 10 | func (o provideOption) apply(params *ProvideParams) { 11 | o(params) 12 | } 13 | 14 | // ProvideParams is a `Provide()` method options. Name is a unique identifier of type instance. Provider is a constructor 15 | // function. Interfaces is a interface that implements a provider result type. 16 | type ProvideParams struct { 17 | Name string 18 | Interfaces []interface{} 19 | Parameters ParameterBag 20 | IsPrototype bool 21 | } 22 | 23 | func (p ProvideParams) apply(params *ProvideParams) { 24 | *params = p 25 | } 26 | 27 | // As 28 | func As(interfaces ...interface{}) ProvideOption { 29 | return provideOption(func(params *ProvideParams) { 30 | params.Interfaces = append(params.Interfaces, interfaces...) 31 | }) 32 | } 33 | 34 | // InvokeParams is a invoke parameters. 35 | type InvokeParams struct{} 36 | 37 | func (p InvokeParams) apply(params *InvokeParams) { 38 | *params = p 39 | } 40 | 41 | // InvokeOption 42 | type InvokeOption interface { 43 | apply(params *InvokeParams) 44 | } 45 | 46 | // ExtractParams 47 | type ExtractParams struct { 48 | Name string 49 | } 50 | 51 | func (p ExtractParams) apply(params *ExtractParams) { 52 | *params = p 53 | } 54 | 55 | // ExtractOption 56 | type ExtractOption interface { 57 | apply(params *ExtractParams) 58 | } 59 | 60 | type extractOption func(params *ExtractParams) 61 | 62 | func (o extractOption) apply(params *ExtractParams) { 63 | o(params) 64 | } 65 | -------------------------------------------------------------------------------- /di/panic.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "fmt" 4 | 5 | func panicf(format string, a ...interface{}) { 6 | panic(fmt.Sprintf(format, a...)) 7 | } 8 | -------------------------------------------------------------------------------- /di/parameter.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // Parameter 8 | type Parameter struct { 9 | internalParameter 10 | } 11 | 12 | // parameterRequired 13 | type parameter struct { 14 | name string 15 | res reflect.Type 16 | optional bool 17 | embed bool 18 | } 19 | 20 | func (p parameter) String() string { 21 | return key{name: p.name, res: p.res}.String() 22 | } 23 | 24 | // ResolveProvider resolves parameter provider 25 | func (p parameter) ResolveProvider(c *Container) (internalProvider, bool) { 26 | for _, pt := range providerLookupSequence { 27 | k := key{ 28 | name: p.name, 29 | res: p.res, 30 | typ: pt, 31 | } 32 | if !c.graph.Exists(k) { 33 | continue 34 | } 35 | node := c.graph.Get(k) 36 | return node.Value.(internalProvider), true 37 | } 38 | return nil, false 39 | } 40 | 41 | func (p parameter) ResolveValue(c *Container) (reflect.Value, error) { 42 | provider, exists := p.ResolveProvider(c) 43 | if !exists && p.optional { 44 | return reflect.New(p.res).Elem(), nil 45 | } 46 | if !exists { 47 | return reflect.Value{}, ErrParameterProviderNotFound{param: p} 48 | } 49 | pl := provider.ParameterList() 50 | values, err := pl.Resolve(c) 51 | if err != nil { 52 | return reflect.Value{}, err 53 | } 54 | value, cleanup, err := provider.Provide(values...) 55 | if err != nil { 56 | return value, ErrParameterProvideFailed{k: provider.Key(), err: err} 57 | } 58 | if cleanup != nil { 59 | c.cleanups = append(c.cleanups, cleanup) 60 | } 61 | return value, nil 62 | } 63 | 64 | // isEmbedParameter 65 | func isEmbedParameter(typ reflect.Type) bool { 66 | return typ.Kind() == reflect.Struct && typ.Implements(parameterInterface) 67 | } 68 | 69 | // internalParameter 70 | type internalParameter interface { 71 | isDependencyInjectionParameter() 72 | } 73 | 74 | // parameterInterface 75 | var parameterInterface = reflect.TypeOf(new(internalParameter)).Elem() 76 | -------------------------------------------------------------------------------- /di/parameter_bag.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // createParameterBugProvider 9 | func createParameterBugProvider(key key, parameters ParameterBag) internalProvider { 10 | return newProviderConstructor(key.String(), func() ParameterBag { return parameters }) 11 | } 12 | 13 | // parameterBagType 14 | var parameterBagType = reflect.TypeOf(ParameterBag{}) 15 | 16 | // ParameterBag 17 | type ParameterBag map[string]interface{} 18 | 19 | // Exists 20 | func (b ParameterBag) Exists(key string) bool { 21 | _, ok := b[key] 22 | return ok 23 | } 24 | 25 | // Get 26 | func (b ParameterBag) Get(key string) (interface{}, bool) { 27 | value, ok := b[key] 28 | return value, ok 29 | } 30 | 31 | // String 32 | func (b ParameterBag) String(key string) (string, bool) { 33 | value, ok := b[key].(string) 34 | return value, ok 35 | } 36 | 37 | // Int64 38 | func (b ParameterBag) Int64(key string) (int64, bool) { 39 | value, ok := b[key].(int64) 40 | return value, ok 41 | } 42 | 43 | // Int 44 | func (b ParameterBag) Int(key string) (int, bool) { 45 | value, ok := b[key].(int) 46 | return value, ok 47 | } 48 | 49 | // Float64 50 | func (b ParameterBag) Float64(key string) (float64, bool) { 51 | value, ok := b[key].(float64) 52 | return value, ok 53 | } 54 | 55 | // Require 56 | func (b ParameterBag) Require(key string) interface{} { 57 | value, ok := b[key] 58 | if !ok { 59 | panic(fmt.Sprintf("value for string key `%s` not found", key)) 60 | } 61 | return value 62 | } 63 | 64 | // RequireString 65 | func (b ParameterBag) RequireString(key string) string { 66 | value, ok := b[key].(string) 67 | if !ok { 68 | panic(fmt.Sprintf("value for string key `%s` not found", key)) 69 | } 70 | return value 71 | } 72 | 73 | // RequireInt64 74 | func (b ParameterBag) RequireInt64(key string) int64 { 75 | value, ok := b[key].(int64) 76 | if !ok { 77 | panic(fmt.Sprintf("value for string key `%s` not found", key)) 78 | } 79 | return value 80 | } 81 | 82 | // RequireInt 83 | func (b ParameterBag) RequireInt(key string) int { 84 | value, ok := b[key].(int) 85 | if !ok { 86 | panic(fmt.Sprintf("value for string key `%s` not found", key)) 87 | } 88 | return value 89 | } 90 | 91 | // RequireFloat64 92 | func (b ParameterBag) RequireFloat64(key string) float64 { 93 | value, ok := b[key].(float64) 94 | if !ok { 95 | panic(fmt.Sprintf("value for string key `%s` not found", key)) 96 | } 97 | return value 98 | } 99 | -------------------------------------------------------------------------------- /di/parameter_bag_test.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestParameterBag_Get(t *testing.T) { 10 | t.Run("key exists", func(t *testing.T) { 11 | pb := ParameterBag{ 12 | "get": "get", 13 | } 14 | 15 | v, ok := pb.Get("get") 16 | require.Equal(t, "get", v) 17 | require.True(t, ok) 18 | }) 19 | 20 | t.Run("key not exists", func(t *testing.T) { 21 | pb := ParameterBag{} 22 | 23 | v, ok := pb.Get("get") 24 | require.Equal(t, nil, v) 25 | require.False(t, ok) 26 | }) 27 | } 28 | 29 | func TestParameterBag_GetType(t *testing.T) { 30 | t.Run("key exists", func(t *testing.T) { 31 | pb := ParameterBag{ 32 | "string": "string", 33 | "int64": int64(64), 34 | "int": int(64), 35 | "float64": float64(64), 36 | } 37 | 38 | s, ok := pb.String("string") 39 | require.Equal(t, "string", s) 40 | require.True(t, ok) 41 | 42 | i64, ok := pb.Int64("int64") 43 | require.Equal(t, int64(64), i64) 44 | require.True(t, ok) 45 | 46 | i, ok := pb.Int("int") 47 | require.Equal(t, int(64), i) 48 | require.True(t, ok) 49 | 50 | f64, ok := pb.Float64("float64") 51 | require.Equal(t, float64(64), f64) 52 | require.True(t, ok) 53 | }) 54 | 55 | t.Run("key not exists", func(t *testing.T) { 56 | pb := ParameterBag{} 57 | 58 | s, ok := pb.String("string") 59 | require.Equal(t, "", s) 60 | require.False(t, ok) 61 | 62 | i64, ok := pb.Int64("int64") 63 | require.Equal(t, int64(0), i64) 64 | require.False(t, ok) 65 | 66 | i, ok := pb.Int("int") 67 | require.Equal(t, int(0), i) 68 | require.False(t, ok) 69 | 70 | f64, ok := pb.Float64("float64") 71 | require.Equal(t, float64(0), f64) 72 | require.False(t, ok) 73 | }) 74 | } 75 | 76 | func TestParameterBag_Exists(t *testing.T) { 77 | pb := ParameterBag{} 78 | 79 | require.False(t, pb.Exists("not existing key")) 80 | } 81 | 82 | func TestParameterBag_Require(t *testing.T) { 83 | t.Run("key exists", func(t *testing.T) { 84 | pb := ParameterBag{ 85 | "require": "require", 86 | } 87 | 88 | value := pb.Require("require") 89 | require.Equal(t, "require", value) 90 | }) 91 | 92 | t.Run("key not exists", func(t *testing.T) { 93 | pb := ParameterBag{} 94 | 95 | require.PanicsWithValue(t, "value for string key `not existing key` not found", func() { 96 | pb.Require("not existing key") 97 | }) 98 | }) 99 | } 100 | 101 | func TestParameterBag_RequireTypes(t *testing.T) { 102 | t.Run("key exists", func(t *testing.T) { 103 | pb := ParameterBag{ 104 | "string": "string", 105 | "int64": int64(64), 106 | "int": int(64), 107 | "float64": float64(64), 108 | } 109 | 110 | s := pb.RequireString("string") 111 | require.Equal(t, "string", s) 112 | 113 | i64 := pb.RequireInt64("int64") 114 | require.Equal(t, int64(64), i64) 115 | 116 | i := pb.RequireInt("int") 117 | require.Equal(t, int(64), i) 118 | 119 | f64 := pb.RequireFloat64("float64") 120 | require.Equal(t, float64(64), f64) 121 | }) 122 | 123 | t.Run("key not exists", func(t *testing.T) { 124 | pb := ParameterBag{} 125 | 126 | require.PanicsWithValue(t, "value for string key `string` not found", func() { 127 | pb.RequireString("string") 128 | }) 129 | 130 | require.PanicsWithValue(t, "value for string key `int64` not found", func() { 131 | pb.RequireInt64("int64") 132 | }) 133 | 134 | require.PanicsWithValue(t, "value for string key `int` not found", func() { 135 | pb.RequireInt("int") 136 | }) 137 | 138 | require.PanicsWithValue(t, "value for string key `float64` not found", func() { 139 | pb.RequireFloat64("float64") 140 | }) 141 | }) 142 | } 143 | -------------------------------------------------------------------------------- /di/parameter_list.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "reflect" 4 | 5 | // parameterList 6 | type parameterList []parameter 7 | 8 | // ResolveValues loads all parameters presented in parameter list. 9 | func (pl parameterList) Resolve(c *Container) ([]reflect.Value, error) { 10 | var values []reflect.Value 11 | for _, p := range pl { 12 | value, err := p.ResolveValue(c) 13 | if err != nil { 14 | return nil, err 15 | } 16 | values = append(values, value) 17 | } 18 | 19 | return values, nil 20 | } 21 | -------------------------------------------------------------------------------- /di/provider.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "reflect" 4 | 5 | // provider lookup sequence 6 | var providerLookupSequence = []providerType{ptConstructor, ptInterface, ptGroup, ptEmbedParameter} 7 | 8 | // providerType 9 | type providerType int 10 | 11 | const ( 12 | ptUnknown providerType = iota 13 | ptConstructor 14 | ptInterface 15 | ptGroup 16 | ptEmbedParameter 17 | ) 18 | 19 | // provider 20 | type internalProvider interface { 21 | // The identity of result type. 22 | Key() key 23 | // ParameterList returns array of dependencies. 24 | ParameterList() parameterList 25 | // Provide provides value from provided parameters. 26 | Provide(values ...reflect.Value) (reflect.Value, func(), error) 27 | } 28 | -------------------------------------------------------------------------------- /di/provider_ctor.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/defval/inject/v2/di/internal/reflection" 9 | ) 10 | 11 | type ctorType int 12 | 13 | const ( 14 | ctorUnknown ctorType = iota // unknown ctor signature 15 | ctorStd // (deps) (result) 16 | ctorError // (deps) (result, error) 17 | ctorCleanup // (deps) (result, cleanup) 18 | ctorCleanupError // (deps) (result, cleanup, error) 19 | ) 20 | 21 | // newProviderConstructor 22 | func newProviderConstructor(name string, ctor interface{}) *providerConstructor { 23 | if ctor == nil { 24 | panicf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", "nil") 25 | } 26 | if !reflection.IsFunc(ctor) { 27 | panicf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", reflect.ValueOf(ctor).Type()) 28 | } 29 | fn := reflection.InspectFunction(ctor) 30 | ctorType := determineCtorType(fn) 31 | return &providerConstructor{ 32 | name: name, 33 | ctor: fn, 34 | ctorType: ctorType, 35 | } 36 | } 37 | 38 | // providerConstructor 39 | type providerConstructor struct { 40 | name string 41 | ctor *reflection.Func 42 | ctorType ctorType 43 | clean *reflection.Func 44 | } 45 | 46 | func (c providerConstructor) Key() key { 47 | return key{ 48 | name: c.name, 49 | res: c.ctor.Out(0), 50 | typ: ptConstructor, 51 | } 52 | } 53 | 54 | func (c providerConstructor) ParameterList() parameterList { 55 | var plist parameterList 56 | for i := 0; i < c.ctor.NumIn(); i++ { 57 | ptype := c.ctor.In(i) 58 | var name string 59 | if ptype == parameterBagType { 60 | name = c.Key().String() 61 | } 62 | p := parameter{ 63 | name: name, 64 | res: ptype, 65 | optional: false, 66 | embed: isEmbedParameter(ptype), 67 | } 68 | plist = append(plist, p) 69 | } 70 | return plist 71 | } 72 | 73 | // Provide 74 | func (c *providerConstructor) Provide(values ...reflect.Value) (reflect.Value, func(), error) { 75 | out := callResult(c.ctor.Call(values)) 76 | switch c.ctorType { 77 | case ctorStd: 78 | return out.instance(), nil, nil 79 | case ctorError: 80 | return out.instance(), nil, out.error(1) 81 | case ctorCleanup: 82 | return out.instance(), out.cleanup(), nil 83 | case ctorCleanupError: 84 | return out.instance(), out.cleanup(), out.error(2) 85 | } 86 | return reflect.Value{}, nil, errors.New("you found a bug, please create new issue for " + 87 | "this: https://github.com/defval/inject/issues/new") 88 | } 89 | 90 | // determineCtorType 91 | func determineCtorType(fn *reflection.Func) ctorType { 92 | if fn.NumOut() == 1 { 93 | return ctorStd 94 | } 95 | if fn.NumOut() == 2 { 96 | if reflection.IsError(fn.Out(1)) { 97 | return ctorError 98 | } 99 | if reflection.IsCleanup(fn.Out(1)) { 100 | return ctorCleanup 101 | } 102 | } 103 | if fn.NumOut() == 3 && reflection.IsCleanup(fn.Out(1)) && reflection.IsError(fn.Out(2)) { 104 | return ctorCleanupError 105 | } 106 | panic(fmt.Sprintf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", fn.Name)) 107 | } 108 | 109 | // callResult 110 | type callResult []reflect.Value 111 | 112 | func (r callResult) instance() reflect.Value { 113 | return r[0] 114 | } 115 | 116 | func (r callResult) cleanup() func() { 117 | if r[1].IsNil() { 118 | return nil 119 | } 120 | return r[1].Interface().(func()) 121 | } 122 | 123 | func (r callResult) error(position int) error { 124 | if r[position].IsNil() { 125 | return nil 126 | } 127 | return r[position].Interface().(error) 128 | } 129 | -------------------------------------------------------------------------------- /di/provider_embed.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | // createStructProvider 9 | func newProviderEmbed(p parameter) *providerEmbed { 10 | var embedType reflect.Type 11 | if p.res.Kind() == reflect.Ptr { 12 | embedType = p.res.Elem() 13 | } else { 14 | embedType = p.res 15 | } 16 | 17 | return &providerEmbed{ 18 | key: key{ 19 | name: p.name, 20 | res: p.res, 21 | typ: ptEmbedParameter, 22 | }, 23 | embedType: embedType, 24 | embedValue: reflect.New(embedType).Elem(), 25 | } 26 | } 27 | 28 | type providerEmbed struct { 29 | key key 30 | embedType reflect.Type 31 | embedValue reflect.Value 32 | } 33 | 34 | func (p *providerEmbed) Key() key { 35 | return p.key 36 | } 37 | 38 | func (p *providerEmbed) ParameterList() parameterList { 39 | var plist parameterList 40 | for i := 0; i < p.embedType.NumField(); i++ { 41 | name, optional, isDependency := p.inspectFieldTag(i) 42 | if !isDependency { 43 | continue 44 | } 45 | field := p.embedType.Field(i) 46 | plist = append(plist, parameter{ 47 | name: name, 48 | res: field.Type, 49 | optional: optional, 50 | embed: isEmbedParameter(field.Type), 51 | }) 52 | } 53 | return plist 54 | } 55 | 56 | func (p *providerEmbed) Provide(values ...reflect.Value) (reflect.Value, func(), error) { 57 | for i, offset := 0, 0; i < p.embedType.NumField(); i++ { 58 | _, _, isDependency := p.inspectFieldTag(i) 59 | if !isDependency { 60 | offset++ 61 | continue 62 | } 63 | 64 | p.embedValue.Field(i).Set(values[i-offset]) 65 | } 66 | 67 | return p.embedValue, nil, nil 68 | } 69 | 70 | func (p *providerEmbed) inspectFieldTag(num int) (name string, optional bool, isDependency bool) { 71 | fieldType := p.embedType.Field(num) 72 | fieldValue := p.embedValue.Field(num) 73 | tag, tagExists := fieldType.Tag.Lookup("di") 74 | if !tagExists || !fieldValue.CanSet() { 75 | return "", false, false 76 | } 77 | name, optional = p.parseTag(tag) 78 | return name, optional, true 79 | } 80 | 81 | func (p *providerEmbed) parseTag(tag string) (name string, optional bool) { 82 | options := strings.Split(tag, ",") 83 | if len(options) == 0 { 84 | return "", false 85 | } 86 | if len(options) == 1 && options[0] == "optional" { 87 | return "", true 88 | } 89 | if len(options) == 1 { 90 | return options[0], false 91 | } 92 | if len(options) == 2 && options[1] == "optional" { 93 | return options[0], true 94 | } 95 | panic("incorrect di tag") 96 | } 97 | -------------------------------------------------------------------------------- /di/provider_group.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // newProviderGroup creates new group from provided resultKey. 8 | func newProviderGroup(k key) *providerGroup { 9 | ifaceKey := key{ 10 | res: reflect.SliceOf(k.res), 11 | typ: ptGroup, 12 | } 13 | 14 | return &providerGroup{ 15 | result: ifaceKey, 16 | pl: parameterList{}, 17 | } 18 | } 19 | 20 | // providerGroup 21 | type providerGroup struct { 22 | result key 23 | pl parameterList 24 | } 25 | 26 | // Add 27 | func (i *providerGroup) Add(k key) { 28 | i.pl = append(i.pl, parameter{ 29 | name: k.name, 30 | res: k.res, 31 | optional: false, 32 | embed: false, 33 | }) 34 | } 35 | 36 | // resultKey 37 | func (i providerGroup) Key() key { 38 | return i.result 39 | } 40 | 41 | // parameters 42 | func (i providerGroup) ParameterList() parameterList { 43 | return i.pl 44 | } 45 | 46 | // Provide 47 | func (i providerGroup) Provide(values ...reflect.Value) (reflect.Value, func(), error) { 48 | group := reflect.New(i.result.res).Elem() 49 | return reflect.Append(group, values...), nil, nil 50 | } 51 | -------------------------------------------------------------------------------- /di/provider_iface.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/defval/inject/v2/di/internal/reflection" 7 | ) 8 | 9 | // newProviderInterface 10 | func newProviderInterface(provider internalProvider, as interface{}) *providerInterface { 11 | iface := reflection.InspectInterfacePtr(as) 12 | if !provider.Key().res.Implements(iface.Type) { 13 | panicf("%s not implement %s", provider.Key(), iface.Type) 14 | } 15 | return &providerInterface{ 16 | res: key{ 17 | name: provider.Key().name, 18 | res: iface.Type, 19 | typ: ptInterface, 20 | }, 21 | provider: provider, 22 | } 23 | } 24 | 25 | // providerInterface 26 | type providerInterface struct { 27 | res key 28 | provider internalProvider 29 | } 30 | 31 | func (i *providerInterface) Key() key { 32 | return i.res 33 | } 34 | 35 | func (i *providerInterface) ParameterList() parameterList { 36 | var plist parameterList 37 | plist = append(plist, parameter{ 38 | name: i.provider.Key().name, 39 | res: i.provider.Key().res, 40 | optional: false, 41 | embed: false, 42 | }) 43 | return plist 44 | } 45 | 46 | func (i *providerInterface) Provide(values ...reflect.Value) (reflect.Value, func(), error) { 47 | return values[0], nil, nil 48 | } 49 | -------------------------------------------------------------------------------- /di/provider_stub.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | // providerStub 9 | type providerStub struct { 10 | msg string 11 | res key 12 | } 13 | 14 | // newProviderStub 15 | func newProviderStub(k key, msg string) *providerStub { 16 | return &providerStub{res: k, msg: msg} 17 | } 18 | 19 | func (m *providerStub) Key() key { 20 | return m.res 21 | } 22 | 23 | func (m *providerStub) ParameterList() parameterList { 24 | return parameterList{} 25 | } 26 | 27 | func (m *providerStub) Provide(values ...reflect.Value) (reflect.Value, func(), error) { 28 | return reflect.Value{}, nil, fmt.Errorf(m.msg) 29 | } 30 | -------------------------------------------------------------------------------- /di/singleton.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | // asSingleton creates a singleton wrapper. 8 | func asSingleton(provider internalProvider) *singletonWrapper { 9 | return &singletonWrapper{internalProvider: provider} 10 | } 11 | 12 | // singletonWrapper is a embedParamProvider wrapper. Stores provided value for prevent reinitialization. 13 | type singletonWrapper struct { 14 | internalProvider // source provider 15 | value reflect.Value // value cache 16 | } 17 | 18 | // Provide 19 | func (s *singletonWrapper) Provide(values ...reflect.Value) (reflect.Value, func(), error) { 20 | if s.value.IsValid() { 21 | return s.value, nil, nil 22 | } 23 | value, cleanup, err := s.internalProvider.Provide(values...) 24 | s.value = value 25 | 26 | return value, cleanup, err 27 | } 28 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // The MIT License (MIT) 2 | // 3 | // Copyright (c) 2019 Maxim Bovtunov 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | /* 24 | Package inject make your dependency injection easy. Container allows you to inject dependencies 25 | into constructors or structures without the need to have specified 26 | each argument manually. 27 | 28 | Provide 29 | 30 | First of all, when creating a new container, you need to describe 31 | how to create each instance of a dependency. To do this, use the container 32 | option inject.Provide(). 33 | 34 | container := New( 35 | Provide(NewDependency), 36 | Provide(NewAnotherDependency) 37 | ) 38 | 39 | func NewDependency(dependency *pkg.AnotherDependency) *pkg.Dependency { 40 | return &pkg.Dependency{ 41 | dependency: dependency, 42 | } 43 | } 44 | 45 | func NewAnotherDependency() (*pkg.AnotherDependency, error) { 46 | if dependency, err = initAnotherDependency(); err != nil { 47 | return nil, err 48 | } 49 | 50 | return dependency, nil 51 | } 52 | 53 | Now, container knows how to create *pkg.Dependency and *pkg.AnotherDependency. 54 | For advanced providing see inject.Provide() and inject.ProvideOption documentation. 55 | 56 | Extract 57 | 58 | After building a container, it is easy to get any previously provided type. 59 | To do this, use the container's Extract() method. 60 | 61 | var anotherDependency *pkg.AnotherDependency 62 | if err = container.Extract(&anotherDependency); err != nil { 63 | // handle error 64 | } 65 | 66 | The container collects a dependencies of *pkg.AnotherDependency, creates its instance and 67 | places it in a target pointer. 68 | For advanced extraction see Extract() and inject.ExtractOption documentation. 69 | */ 70 | package inject 71 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/defval/inject/v2 2 | 3 | require ( 4 | github.com/davecgh/go-spew v1.1.1 // indirect 5 | github.com/emicklei/dot v0.10.1 6 | github.com/kr/pretty v0.1.0 // indirect 7 | github.com/stretchr/testify v1.4.0 8 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect 9 | gopkg.in/yaml.v2 v2.2.8 // indirect 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/emicklei/dot v0.10.1 h1:bkzvwgIhhw/cuxxnJy5/5+ZL3GnhFxFfv0eolHtWE2w= 6 | github.com/emicklei/dot v0.10.1/go.mod h1:kZg82Ikwc4pqb31Ct2yb0B7RUqxh3JESIXw2uWSv/xY= 7 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 8 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 9 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 10 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 11 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 12 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 13 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 14 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 15 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= 20 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 21 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 22 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 23 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 24 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 25 | -------------------------------------------------------------------------------- /graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d3fvxl/inject/f8416f73ad0c0ce9487ef3ceae5035a86c465ee7/graph.png -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d3fvxl/inject/f8416f73ad0c0ce9487ef3ceae5035a86c465ee7/logo.png -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import "github.com/defval/inject/v2/di" 4 | 5 | // OPTIONS 6 | 7 | // Option configures container. See inject.Provide(), inject.Bundle(), inject.Replace(). 8 | type Option interface { 9 | apply(*Container) 10 | } 11 | 12 | // Provide returns container option that explains how to create an instance of a type inside a container. 13 | // 14 | // The first argument is the constructor function. A constructor is a function that creates an instance of the required 15 | // type. It can take an unlimited number of arguments needed to create an instance - the first returned value. 16 | // 17 | // func NewServer(mux *http.ServeMux) *http.Server { 18 | // return &http.Server{ 19 | // Handle: mux, 20 | // } 21 | // } 22 | // 23 | // Optionally, you can return a cleanup function and initializing error. 24 | // 25 | // func NewServer(mux *http.ServeMux) (*http.Server, cleanup func(), err error) { 26 | // if time.Now().Day = 1 { 27 | // return nil, nil, errors.New("the server is down on the first day of a month") 28 | // } 29 | // 30 | // server := &http.Server{ 31 | // Handler: mux, 32 | // } 33 | // 34 | // cleanup := func() { 35 | // _ = server.Close() 36 | // } 37 | // 38 | // return server, cleanup, nil 39 | // } 40 | // 41 | // Other function signatures will cause error. 42 | func Provide(provider interface{}, options ...ProvideOption) Option { 43 | return option(func(container *Container) { 44 | // todo: add provider 45 | var params = di.ProvideParams{ 46 | Parameters: map[string]interface{}{}, 47 | } 48 | 49 | for _, opt := range options { 50 | opt.apply(¶ms) 51 | } 52 | container.providers = append(container.providers, provide{ 53 | provider: provider, 54 | params: params, 55 | }) 56 | }) 57 | } 58 | 59 | // Bundle group together container options. 60 | // 61 | // accountBundle := inject.Bundle( 62 | // inject.Provide(NewAccountController), 63 | // inject.Provide(NewAccountRepository), 64 | // ) 65 | // 66 | // authBundle := inject.Bundle( 67 | // inject.Provide(NewAuthController), 68 | // inject.Provide(NewAuthRepository), 69 | // ) 70 | // 71 | // container, _ := New( 72 | // accountBundle, 73 | // authBundle, 74 | // ) 75 | func Bundle(options ...Option) Option { 76 | return option(func(container *Container) { 77 | for _, opt := range options { 78 | opt.apply(container) 79 | } 80 | }) 81 | } 82 | 83 | // ProvideOption modifies default provide behavior. See inject.WithName(), inject.As(), inject.Prototype(). 84 | type ProvideOption interface { 85 | apply(params *di.ProvideParams) 86 | } 87 | 88 | // WithName sets string identifier for provided value. 89 | // 90 | // inject.Provide(&http.Server{}, inject.WithName("first")) 91 | // inject.Provide(&http.Server{}, inject.WithName("second")) 92 | // 93 | // container.Extract(&server, inject.Name("second")) 94 | func WithName(name string) ProvideOption { 95 | return provideOption(func(provider *di.ProvideParams) { 96 | provider.Name = name 97 | }) 98 | } 99 | 100 | // As specifies interfaces that implement provider instance. Provide with As() automatically checks that constructor 101 | // result implements interface and creates slice group with it. 102 | // 103 | // Provide(&http.ServerMux{}, inject.As(new(http.Handler))) 104 | // 105 | // var handler http.Handler 106 | // container.Extract(&handler) // extract as interface 107 | // 108 | // var handlers []http.Handler 109 | // container.Extract(&handlers) // extract group 110 | func As(ifaces ...interface{}) ProvideOption { 111 | return provideOption(func(provider *di.ProvideParams) { 112 | provider.Interfaces = append(provider.Interfaces, ifaces...) 113 | 114 | }) 115 | } 116 | 117 | // Prototype modifies Provide() behavior. By default, each type resolves as a singleton. This option sets that 118 | // each type resolving creates a new instance of the type. 119 | // 120 | // Provide(&http.Server{], inject.Prototype()) 121 | // 122 | // var server1 *http.Server 123 | // var server2 *http.Server 124 | // container.Extract(&server1, &server2) 125 | func Prototype() ProvideOption { 126 | return provideOption(func(provider *di.ProvideParams) { 127 | provider.IsPrototype = true 128 | }) 129 | } 130 | 131 | // ParameterBag is a provider parameter bag. It stores a construction parameters. It is a alternative way to 132 | // configure type. 133 | // 134 | // inject.Provide(NewServer, inject.ParameterBag{ 135 | // "addr": ":8080", 136 | // }) 137 | // 138 | // NewServer(pb inject.ParameterBag) *http.Server { 139 | // return &http.Server{ 140 | // Addr: pb.RequireString("addr"), 141 | // } 142 | // } 143 | type ParameterBag map[string]interface{} 144 | 145 | func (p ParameterBag) apply(provider *di.ProvideParams) { 146 | for k, v := range p { 147 | provider.Parameters[k] = v 148 | } 149 | } 150 | 151 | // ExtractOption modifies default extract behavior. See inject.Name(). 152 | type ExtractOption interface { 153 | apply(params *di.ExtractParams) 154 | } 155 | 156 | // EXTRACT OPTIONS. 157 | 158 | // Name specify definition name. 159 | func Name(name string) ExtractOption { 160 | return extractOption(func(eo *di.ExtractParams) { 161 | eo.Name = name 162 | }) 163 | } 164 | 165 | type option func(container *Container) 166 | 167 | func (o option) apply(container *Container) { o(container) } 168 | 169 | type provideOption func(provider *di.ProvideParams) 170 | 171 | func (o provideOption) apply(provider *di.ProvideParams) { o(provider) } 172 | 173 | type extractOption func(eo *di.ExtractParams) 174 | 175 | func (o extractOption) apply(eo *di.ExtractParams) { o(eo) } 176 | 177 | type extractOptions struct { 178 | name string 179 | target interface{} 180 | } 181 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/defval/inject/v2/di" 10 | ) 11 | 12 | // TestProvideOptions 13 | func TestProvideOptions(t *testing.T) { 14 | opts := &di.ProvideParams{ 15 | Parameters: map[string]interface{}{}, 16 | } 17 | 18 | for _, opt := range []ProvideOption{ 19 | WithName("test"), 20 | As(new(http.Handler)), 21 | Prototype(), 22 | ParameterBag{ 23 | "test": "test", 24 | }, 25 | } { 26 | opt.apply(opts) 27 | } 28 | 29 | require.Equal(t, &di.ProvideParams{ 30 | Name: "test", 31 | Interfaces: []interface{}{new(http.Handler)}, 32 | IsPrototype: true, 33 | Parameters: map[string]interface{}{ 34 | "test": "test", 35 | }, 36 | }, opts) 37 | } 38 | 39 | func TestExtractOptions(t *testing.T) { 40 | opts := &di.ExtractParams{} 41 | 42 | for _, opt := range []ExtractOption{ 43 | Name("test"), 44 | } { 45 | opt.apply(opts) 46 | } 47 | 48 | require.Equal(t, &di.ExtractParams{ 49 | Name: "test", 50 | }, opts) 51 | } 52 | --------------------------------------------------------------------------------