├── .build
└── test.sh
├── .codecov.yml
├── .gitignore
├── .travis.yml
├── CHANGELOG.md
├── LICENSE
├── Makefile
├── README.md
├── _tutorial
└── main.go
├── container.go
├── container_test.go
├── di
├── container.go
├── container_test.go
├── dot.go
├── errors.go
├── internal
│ ├── ditest
│ │ ├── bar.go
│ │ ├── baz.go
│ │ ├── foo.go
│ │ ├── fooer_group.go
│ │ ├── full.go
│ │ ├── incorrect.go
│ │ ├── interfaces.go
│ │ └── qux.go
│ ├── graphkv
│ │ ├── directed_graph.go
│ │ ├── directed_graph_test.go
│ │ ├── edge.go
│ │ ├── errors.go
│ │ ├── graph.go
│ │ ├── graph_test.go
│ │ ├── graphkv.go
│ │ ├── graphkv_test.go
│ │ ├── key.go
│ │ ├── output.go
│ │ ├── output_test.go
│ │ ├── sort.go
│ │ └── sort_test.go
│ └── reflection
│ │ ├── func.go
│ │ ├── iface.go
│ │ └── reflection.go
├── invoker.go
├── key.go
├── options.go
├── panic.go
├── parameter.go
├── parameter_bag.go
├── parameter_bag_test.go
├── parameter_list.go
├── provider.go
├── provider_ctor.go
├── provider_embed.go
├── provider_group.go
├── provider_iface.go
├── provider_stub.go
└── singleton.go
├── doc.go
├── go.mod
├── go.sum
├── graph.png
├── logo.png
├── options.go
└── options_test.go
/.build/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 |
5 | if [[ -f coverage.txt ]]; then
6 | rm coverage.txt
7 | fi
8 |
9 | for d in $(go list ./... | grep -v ditest); do
10 | go test -coverprofile=profile.out -coverpkg=./... -covermode=atomic "$d"
11 | if [[ -f profile.out ]]; then
12 | cat profile.out >> coverage.txt
13 | rm profile.out
14 | fi
15 | done
--------------------------------------------------------------------------------
/.codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | range: 70..98
3 | round: down
4 | precision: 2
5 | status:
6 | project:
7 | default:
8 | enabled: yes
9 | target: 92
10 | if_not_found: success
11 | if_ci_failed: error
12 | patch:
13 | default:
14 | enabled: yes
15 | target: 70
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | coverage.txt
2 | profile.out
3 | vendor
4 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | go_import_path: github.com/defval/inject/v2
2 | language: go
3 | sudo: false
4 |
5 | matrix:
6 | include:
7 | - go: "1.11.x"
8 | - go: "1.12.x"
9 | - go: "1.13.x"
10 | fast_finish: true
11 |
12 | env:
13 | global:
14 | - GO111MODULE=on
15 |
16 | script:
17 | - make test
18 |
19 | after_success:
20 | - bash <(curl -s https://codecov.io/bash)
21 |
22 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | All notable changes to this project will be documented in this file.
3 |
4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6 |
7 | ## Unreleased
8 |
9 | ## Fixed
10 |
11 | - Cleanup ordering
12 | - Cleanup with prototypes
13 | - Removed duplicate function `resolveParameterProvider()`
14 | - Refactor graph storage
15 | - Change internal di container interface.
16 | - Code style fixes
17 | - Documentation fixes
18 |
19 | ## v2.2.2
20 |
21 | Internal refactoring
22 |
23 | ### Added
24 |
25 | - Visualize parameter bag
26 |
27 | ### Fixed
28 |
29 | - Visualize type detection
30 |
31 | ## v2.2.1
32 |
33 | ### Fixed
34 |
35 | - Invoke: interface is nil, not error
36 |
37 | ## v2.2.0
38 |
39 | ### Added
40 |
41 | - `container.Invoke()` for invocations
42 |
43 | ## v2.1.1
44 |
45 | ### Fixed
46 |
47 | - Incorrect di.Parameter resolving
48 |
49 | ## v2.1.0
50 |
51 | ### Added
52 |
53 | - Visualization
54 |
55 | ## v2.0.1
56 |
57 | ### Added
58 |
59 | - Helper methods to ParameterBag
60 |
61 | ## v2.0.0
62 |
63 | Massive refactoring and rethinking features.
64 |
65 | ### Changed
66 |
67 | - Graph implementation
68 | - Simplify injection code
69 | - Documentation
70 |
71 | ### Added
72 |
73 | - Prototypes
74 | - Cleanup
75 | - Parameter bag
76 | - Optional parameters
77 | - Low-level container interface
78 |
79 | ### Removed
80 |
81 | - Replacing (investigating)
82 | - Non constructor providers
83 | - Combined providers
84 |
85 | ## v1.5.2
86 |
87 | ### Fixed
88 |
89 | - Checksum problem
90 |
91 | ## v1.5.1
92 |
93 | ### Added
94 |
95 | - Ability to extract `github.com/emicklei/*dot.Graph` from container
96 |
97 | ## v1.5.0
98 |
99 | ### Added
100 |
101 | - Error `inject.ErrTypeNotProvided`
102 |
103 | ## v1.4.4
104 |
105 | ### Changed
106 |
107 | - Internal refactoring of adding nodes
108 |
109 | ## v1.4.3
110 |
111 | ### Fixed
112 |
113 | - Improve test coverage
114 |
115 | ### Changed
116 |
117 | - Internal refactoring of groups creation
118 |
119 | ## v1.4.2
120 |
121 | ### Fixed
122 |
123 | - Replace: check that provider implement interface
124 |
125 | ## v1.4.1
126 |
127 | ### Fixed
128 |
129 | - Lint
130 |
131 | ## v1.4.0
132 |
133 | ### Change
134 |
135 | - `Container.WriteTo()` signature to `io.WriterTo`
136 |
137 | ## v1.3.1
138 |
139 | ### Added
140 |
141 | - Documentation
142 |
143 | ## v1.3.0
144 |
145 | ### Added
146 |
147 | - Graph visualization
148 |
149 | ## v1.2.1
150 |
151 | ### Fixed
152 |
153 | - inject.As() allows provide same interface without name
154 |
155 | ## v1.2.0
156 |
157 | ### Changed
158 |
159 | - Combined provider declaration
160 | - Some refactoring
161 |
162 | ## v1.1.1
163 |
164 | ### Changed
165 |
166 | - Refactor graph storage
167 |
168 | ## v1.1.0
169 |
170 | ### Added
171 |
172 | - Combined provider
173 |
174 | ### Changed
175 |
176 | - Provider refactoring
177 |
178 | ## v1.0.0
179 |
180 | - Initial release
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2019 Maxim Bovtunov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | ## Injector makefile
2 |
3 | .PHONY: test ## Run tests
4 | test:
5 | @.build/test.sh
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
[](https://twitter.com/intent/tweet?text=Dependency%20injection%20container%20for%20Golang&url=https://github.com/defval/inject&hashtags=golang,go,di,dependency-injection)
3 |
4 | [](https://godoc.org/github.com/defval/inject)
5 | 
6 | [](https://travis-ci.org/defval/inject)
7 | [](https://codecov.io/gh/defval/inject)
8 |
9 | ## This repository will be archived
10 |
11 | After using this library in production for a year, I made some
12 | conclusions about library API and really useful features. I don't want
13 | to make breaking changes and have a `v3` version. It's not that popular.
14 |
15 | I put old and new ideas in [goava/di](https://github.com/goava/di).
16 |
17 | ## How will dependency injection help me?
18 |
19 | Dependency injection is one form of the broader technique of inversion
20 | of control. It is used to increase modularity of the program and make it
21 | extensible.
22 |
23 | ## Contents
24 |
25 | - [Installing](#installing)
26 | - [Tutorial](#tutorial)
27 | - [Providing](#providing)
28 | - [Extraction](#extraction)
29 | - [Invocation](#invocation)
30 | - [Lazy-loading](#lazy-loading)
31 | - [Interfaces](#interfaces)
32 | - [Groups](#groups)
33 | - [Advanced features](#advanced-features)
34 | - [Named definitions](#named-definitions)
35 | - [Optional parameters](#optional-parameters)
36 | - [Parameter Bag](#parameter-bag)
37 | - [Prototypes](#prototypes)
38 | - [Cleanup](#cleanup)
39 | - [Visualization](#visualization)
40 | - [Contributing](#contributing)
41 |
42 | ## Installing
43 |
44 | ```shell
45 | go get -u github.com/defval/inject/v2
46 | ```
47 |
48 | This library follows [SemVer](http://semver.org/) strictly.
49 |
50 | ## Tutorial
51 |
52 | Let's learn to use Inject by example. We will code a simple application
53 | that processes HTTP requests.
54 |
55 | The full tutorial code is available [here](./_tutorial/main.go)
56 |
57 | ### Providing
58 |
59 | To start, we will need to create two fundamental types: `http.Server`
60 | and `http.ServeMux`. Let's create a simple constructors that initialize
61 | it:
62 |
63 | ```go
64 | // NewServer creates a http server with provided mux as handler.
65 | func NewServer(mux *http.ServeMux) *http.Server {
66 | return &http.Server{
67 | Handler: mux,
68 | }
69 | }
70 |
71 | // NewServeMux creates a new http serve mux.
72 | func NewServeMux() *http.ServeMux {
73 | return &http.ServeMux{}
74 | }
75 | ```
76 |
77 | > Supported constructor signature:
78 | >
79 | > ```go
80 | > func([dep1, dep2, depN]) (result, [cleanup, error])
81 | > ```
82 |
83 | Now let's teach a container to build these types.
84 |
85 | ```go
86 | container := inject.New(
87 | // provide http server
88 | inject.Provide(NewServer),
89 | // provide http serve mux
90 | inject.Provide(NewServeMux)
91 | )
92 | ```
93 |
94 | The function `inject.New()` parse our constructors, compile dependency
95 | graph and return `*inject.Container` type for interaction. Container
96 | panics if it could not compile.
97 |
98 | > I think that panic at the initialization of the application and not in
99 | > runtime is usual.
100 |
101 | ### Extraction
102 |
103 | We can extract the built server from the container. For this, define the
104 | variable of extracted type and pass variable pointer to `Extract`
105 | function.
106 |
107 | > If extracted type not found or the process of building instance cause
108 | > error, `Extract` return error.
109 |
110 | If no error occurred, we can use the variable as if we had built it
111 | yourself.
112 |
113 | ```go
114 | // declare type variable
115 | var server *http.Server
116 | // extracting
117 | err := container.Extract(&server)
118 | if err != nil {
119 | // check extraction error
120 | }
121 |
122 | server.ListenAndServe()
123 | ```
124 |
125 | > Note that by default, the container creates instances as a singleton.
126 | > But you can change this behaviour. See [Prototypes](#prototypes).
127 |
128 | ### Invocation
129 |
130 | As an alternative to extraction we can use `Invoke()` function. It
131 | resolves function dependencies and call the function. Invoke function
132 | may return optional error.
133 |
134 | ```go
135 | // StartServer starts the server.
136 | func StartServer(server *http.Server) error {
137 | return server.ListenAndServe()
138 | }
139 |
140 | container.Invoke(StartServer)
141 | ```
142 |
143 | ### Lazy-loading
144 |
145 | Result dependencies will be lazy-loaded. If no one requires a type from
146 | the container it will not be constructed.
147 |
148 | ### Interfaces
149 |
150 | Inject make possible to provide implementation as an interface.
151 |
152 | ```go
153 | // NewServer creates a http server with provided mux as handler.
154 | func NewServer(handler http.Handler) *http.Server {
155 | return &http.Server{
156 | Handler: handler,
157 | }
158 | }
159 | ```
160 |
161 | For a container to know that as an implementation of `http.Handler` is
162 | necessary to use, we use the option `inject.As()`. The arguments of this
163 | option must be a pointer(s) to an interface like `new(Endpoint)`.
164 |
165 | > This syntax may seem strange, but I have not found a better way to
166 | > specify the interface.
167 |
168 | Updated container initialization code:
169 |
170 | ```go
171 | container := inject.New(
172 | // provide http server
173 | inject.Provide(NewServer),
174 | // provide http serve mux as http.Handler interface
175 | inject.Provide(NewServeMux, inject.As(new(http.Handler)))
176 | )
177 | ```
178 |
179 | Now container uses provide `*http.ServeMux` as `http.Handler` in server
180 | constructor. Using interfaces contributes to writing more testable code.
181 |
182 | ### Groups
183 |
184 | Container automatically groups all implementations of interface to
185 | `[]` group. For example, provide with
186 | `inject.As(new(http.Handler)` automatically creates a group
187 | `[]http.Handler`.
188 |
189 | Let's add some http controllers using this feature. Controllers have
190 | typical behavior. It is registering routes. At first, will create an
191 | interface for it.
192 |
193 | ```go
194 | // Controller is an interface that can register its routes.
195 | type Controller interface {
196 | RegisterRoutes(mux *http.ServeMux)
197 | }
198 | ```
199 |
200 | Now we will write controllers and implement `Controller` interface.
201 |
202 | ##### OrderController
203 |
204 | ```go
205 | // OrderController is a http controller for orders.
206 | type OrderController struct {}
207 |
208 | // NewOrderController creates a auth http controller.
209 | func NewOrderController() *OrderController {
210 | return &OrderController{}
211 | }
212 |
213 | // RegisterRoutes is a Controller interface implementation.
214 | func (a *OrderController) RegisterRoutes(mux *http.ServeMux) {
215 | mux.HandleFunc("/orders", a.RetrieveOrders)
216 | }
217 |
218 | // Retrieve loads orders and writes it to the writer.
219 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, request *http.Request) {
220 | // implementation
221 | }
222 | ```
223 |
224 | ##### UserController
225 |
226 | ```go
227 | // UserController is a http endpoint for a user.
228 | type UserController struct {}
229 |
230 | // NewUserController creates a user http endpoint.
231 | func NewUserController() *UserController {
232 | return &UserController{}
233 | }
234 |
235 | // RegisterRoutes is a Controller interface implementation.
236 | func (e *UserController) RegisterRoutes(mux *http.ServeMux) {
237 | mux.HandleFunc("/users", e.RetrieveUsers)
238 | }
239 |
240 | // Retrieve loads users and writes it using the writer.
241 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, request *http.Request) {
242 | // implementation
243 | }
244 | ```
245 |
246 | Just like in the example with interfaces, we will use `inject.As()`
247 | provide option.
248 |
249 | ```go
250 | container := inject.New(
251 | inject.Provide(NewServer), // provide http server
252 | inject.Provide(NewServeMux), // provide http serve mux
253 | // endpoints
254 | inject.Provide(NewOrderController, inject.As(new(Controller))), // provide order controller
255 | inject.Provide(NewUserController, inject.As(new(Controller))), // provide user controller
256 | )
257 | ```
258 |
259 | Now, we can use `[]Controller` group in our mux. See updated code:
260 |
261 | ```go
262 | // NewServeMux creates a new http serve mux.
263 | func NewServeMux(controllers []Controller) *http.ServeMux {
264 | mux := &http.ServeMux{}
265 |
266 | for _, controller := range controllers {
267 | controller.RegisterRoutes(mux)
268 | }
269 |
270 | return mux
271 | }
272 | ```
273 |
274 | ## Advanced features
275 |
276 | ### Named definitions
277 |
278 | In some cases you have more than one instance of one type. For example
279 | two instances of database: master - for writing, slave - for reading.
280 |
281 | First way is a wrapping types:
282 |
283 | ```go
284 | // MasterDatabase provide write database access.
285 | type MasterDatabase struct {
286 | *Database
287 | }
288 |
289 | // SlaveDatabase provide read database access.
290 | type SlaveDatabase struct {
291 | *Database
292 | }
293 | ```
294 |
295 | Second way is a using named definitions with `inject.WithName()` provide
296 | option:
297 |
298 | ```go
299 | // provide master database
300 | inject.Provide(NewMasterDatabase, inject.WithName("master"))
301 | // provide slave database
302 | inject.Provide(NewSlaveDatabase, inject.WithName("slave"))
303 | ```
304 |
305 | If you need to extract it from container use `inject.Name()` extract
306 | option.
307 |
308 | ```go
309 | var db *Database
310 | container.Extract(&db, inject.Name("master"))
311 | ```
312 |
313 | If you need to provide named definition in other constructor use
314 | `di.Parameter` with embedding.
315 |
316 | ```go
317 | // ServiceParameters
318 | type ServiceParameters struct {
319 | di.Parameter
320 |
321 | // use `di` tag for the container to know that field need to be injected.
322 | MasterDatabase *Database `di:"master"`
323 | SlaveDatabase *Database `di:"slave"`
324 | }
325 |
326 | // NewService creates new service with provided parameters.
327 | func NewService(parameters ServiceParameters) *Service {
328 | return &Service{
329 | MasterDatabase: parameters.MasterDatabase,
330 | SlaveDatabase: parameters.SlaveDatabase,
331 | }
332 | }
333 | ```
334 |
335 | ### Optional parameters
336 |
337 | Also `di.Parameter` provide ability to skip dependency if it not exists
338 | in container.
339 |
340 | ```go
341 | // ServiceParameter
342 | type ServiceParameter struct {
343 | di.Parameter
344 |
345 | Logger *Logger `di:"optional"`
346 | }
347 | ```
348 |
349 | > Constructors that declare dependencies as optional must handle the
350 | > case of those dependencies being absent.
351 |
352 | You can use naming and optional together.
353 |
354 | ```go
355 | // ServiceParameter
356 | type ServiceParameter struct {
357 | di.Parameter
358 |
359 | StdOutLogger *Logger `di:"stdout"`
360 | FileLogger *Logger `di:"file,optional"`
361 | }
362 | ```
363 |
364 | ### Parameter Bag
365 |
366 | If you need to specify some parameters on definition level you can use
367 | `inject.ParameterBag` provide option. This is a `map[string]interface{}`
368 | that transforms to `di.ParameterBag` type.
369 |
370 | ```go
371 | // Provide server with parameter bag
372 | inject.Provide(NewServer, inject.ParameterBag{
373 | "addr": ":8080",
374 | })
375 |
376 | // NewServer create a server with provided parameter bag. Note: use di.ParameterBag type.
377 | // Not inject.ParameterBag.
378 | func NewServer(pb di.ParameterBag) *http.Server {
379 | return &http.Server{
380 | Addr: pb.RequireString("addr"),
381 | }
382 | }
383 | ```
384 |
385 | ### Prototypes
386 |
387 | If you want to create a new instance on each extraction use
388 | `inject.Prototype()` provide option.
389 |
390 | ```go
391 | inject.Provide(NewRequestContext, inject.Prototype())
392 | ```
393 |
394 | > todo: real use case
395 |
396 | ### Cleanup
397 |
398 | If a provider creates a value that needs to be cleaned up, then it can
399 | return a closure to clean up the resource.
400 |
401 | ```go
402 | func NewFile(log Logger, path Path) (*os.File, func(), error) {
403 | f, err := os.Open(string(path))
404 | if err != nil {
405 | return nil, nil, err
406 | }
407 | cleanup := func() {
408 | if err := f.Close(); err != nil {
409 | log.Log(err)
410 | }
411 | }
412 | return f, cleanup, nil
413 | }
414 | ```
415 |
416 | After `container.Cleanup()` call, it iterate over instances and call
417 | cleanup function if it exists.
418 |
419 | ```go
420 | container := inject.New(
421 | // ...
422 | inject.Provide(NewFile),
423 | )
424 |
425 | // do something
426 | container.Cleanup() // file was closed
427 | ```
428 |
429 | > Cleanup now work incorrectly with prototype providers.
430 |
431 | ## Visualization
432 |
433 | Dependency graph may be presented via
434 | ([Graphviz](https://www.graphviz.org/)). For it, load string
435 | representation:
436 |
437 | ```go
438 | var graph *di.Graph
439 | if err = container.Extract(&graph); err != nil {
440 | // handle err
441 | }
442 |
443 | dotGraph := graph.String() // use string representation
444 | ```
445 |
446 | And paste it to graphviz online tool:
448 |
449 |
450 |
451 | ## Contributing
452 |
453 | I will be glad if you contribute to this library. I don't know much
454 | English, so contributing to the documentation is very meaningful to me.
455 |
456 | [](https://sourcerer.io/fame/defval/defval/inject/links/0)[](https://sourcerer.io/fame/defval/defval/inject/links/1)[](https://sourcerer.io/fame/defval/defval/inject/links/2)[](https://sourcerer.io/fame/defval/defval/inject/links/3)[](https://sourcerer.io/fame/defval/defval/inject/links/4)[](https://sourcerer.io/fame/defval/defval/inject/links/5)[](https://sourcerer.io/fame/defval/defval/inject/links/6)[](https://sourcerer.io/fame/defval/defval/inject/links/7)
457 |
458 |
--------------------------------------------------------------------------------
/_tutorial/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/defval/inject/v2"
7 | )
8 |
9 | func main() {
10 | container := inject.New(
11 | inject.Provide(NewServer), // provide http server
12 | inject.Provide(NewServeMux), // provide http serve mux
13 | // endpoints
14 | inject.Provide(NewOrderController, inject.As(new(Controller))), // provide order controller
15 | inject.Provide(NewUserController, inject.As(new(Controller))), // provide user controller
16 | )
17 |
18 | var server *http.Server
19 | err := container.Extract(&server)
20 | if err != nil {
21 | panic(err)
22 | }
23 |
24 | server.ListenAndServe()
25 | }
26 |
27 | // NewServer creates a http server with provided mux as handler.
28 | func NewServer(mux *http.ServeMux) *http.Server {
29 | return &http.Server{
30 | Handler: mux,
31 | }
32 | }
33 |
34 | // NewServeMux creates a new http serve mux.
35 | func NewServeMux(controllers []Controller) *http.ServeMux {
36 | mux := &http.ServeMux{}
37 |
38 | for _, controller := range controllers {
39 | controller.RegisterRoutes(mux)
40 | }
41 |
42 | return mux
43 | }
44 |
45 | // Controller is an interface that can register its routes.
46 | type Controller interface {
47 | RegisterRoutes(mux *http.ServeMux)
48 | }
49 |
50 | // OrderController is a http controller for orders.
51 | type OrderController struct{}
52 |
53 | // NewOrderController creates a auth http controller.
54 | func NewOrderController() *OrderController {
55 | return &OrderController{}
56 | }
57 |
58 | // RegisterRoutes is a Controller interface implementation.
59 | func (a *OrderController) RegisterRoutes(mux *http.ServeMux) {
60 | mux.HandleFunc("/orders", a.RetrieveOrders)
61 | }
62 |
63 | // Retrieve loads orders and writes it to the writer.
64 | func (a *OrderController) RetrieveOrders(writer http.ResponseWriter, request *http.Request) {
65 | writer.WriteHeader(http.StatusOK)
66 | _, _ = writer.Write([]byte("Orders"))
67 | }
68 |
69 | // UserController is a http endpoint for a user.
70 | type UserController struct{}
71 |
72 | // NewUserController creates a user http endpoint.
73 | func NewUserController() *UserController {
74 | return &UserController{}
75 | }
76 |
77 | // RegisterRoutes is a Controller interface implementation.
78 | func (e *UserController) RegisterRoutes(mux *http.ServeMux) {
79 | mux.HandleFunc("/users", e.RetrieveUsers)
80 | }
81 |
82 | // Retrieve loads users and writes it using the writer.
83 | func (e *UserController) RetrieveUsers(writer http.ResponseWriter, request *http.Request) {
84 | writer.WriteHeader(http.StatusOK)
85 | _, _ = writer.Write([]byte("Users"))
86 | }
87 |
--------------------------------------------------------------------------------
/container.go:
--------------------------------------------------------------------------------
1 | package inject
2 |
3 | import (
4 | "github.com/defval/inject/v2/di"
5 | )
6 |
7 | // New creates a new container with provided options.
8 | func New(options ...Option) *Container {
9 | var c = &Container{
10 | container: di.New(),
11 | }
12 | // apply options.
13 | for _, opt := range options {
14 | opt.apply(c)
15 | }
16 | c.compile()
17 | return c
18 | }
19 |
20 | // Container is a dependency injection container.
21 | type Container struct {
22 | providers []provide
23 | container *di.Container
24 | }
25 |
26 | // Extract populates given target pointer with type instance provided in the container.
27 | //
28 | // var server *http.Server
29 | // if err = container.Extract(&server); err != nil {
30 | // // extract failed
31 | // }
32 | //
33 | // If the target type does not exist in a container or instance type building failed, Extract() returns an error.
34 | // Use ExtractOption for modifying the behavior of this function.
35 | func (c *Container) Extract(target interface{}, options ...ExtractOption) (err error) {
36 | var params = di.ExtractParams{}
37 | // apply extract options
38 | for _, opt := range options {
39 | opt.apply(¶ms)
40 | }
41 | return c.container.Extract(target, params)
42 | }
43 |
44 | // Invoke invokes custom function. Dependencies of function will be resolved via container.
45 | func (c *Container) Invoke(fn interface{}) error {
46 | return c.container.Invoke(fn)
47 | }
48 |
49 | // Cleanup cleanup container.
50 | func (c *Container) Cleanup() {
51 | c.container.Cleanup()
52 | }
53 |
54 | func (c *Container) compile() {
55 | for _, po := range c.providers {
56 | c.container.Provide(po.provider, po.params)
57 | }
58 | c.container.Compile()
59 | return
60 | }
61 |
62 | type provide struct {
63 | provider interface{}
64 | params di.ProvideParams
65 | }
66 |
--------------------------------------------------------------------------------
/container_test.go:
--------------------------------------------------------------------------------
1 | package inject_test
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "net/http"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 |
11 | "github.com/defval/inject/v2"
12 | )
13 |
14 | func TestContainer(t *testing.T) {
15 | var HTTPBundle = inject.Bundle(
16 | inject.Provide(ProvideAddr("0.0.0.0", "8080")),
17 | inject.Provide(NewMux, inject.As(new(http.Handler))),
18 | inject.Provide(NewHTTPServer, inject.Prototype(), inject.WithName("server")),
19 | )
20 |
21 | c := inject.New(HTTPBundle)
22 |
23 | var server1 *http.Server
24 | err := c.Extract(&server1, inject.Name("server"))
25 | require.NoError(t, err)
26 |
27 | var server2 *http.Server
28 | err = c.Extract(&server2, inject.Name("server"))
29 | require.NoError(t, err)
30 |
31 | err = c.Invoke(PrintAddr)
32 | require.NoError(t, err)
33 | }
34 |
35 | // Addr
36 | type Addr string
37 |
38 | // ProvideAddr
39 | func ProvideAddr(host string, port string) func() Addr {
40 | return func() Addr {
41 | return Addr(net.JoinHostPort(host, port))
42 | }
43 | }
44 |
45 | // NewHTTPServer
46 | func NewHTTPServer(addr Addr, handler http.Handler) *http.Server {
47 | return &http.Server{
48 | Addr: string(addr),
49 | Handler: handler,
50 | }
51 | }
52 |
53 | // NewMux
54 | func NewMux() *http.ServeMux {
55 | return &http.ServeMux{}
56 | }
57 |
58 | // PrintAddr
59 | func PrintAddr(addr Addr) {
60 | fmt.Println(addr)
61 | }
62 |
--------------------------------------------------------------------------------
/di/container.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 |
7 | "github.com/defval/inject/v2/di/internal/graphkv"
8 | "github.com/defval/inject/v2/di/internal/reflection"
9 | )
10 |
11 | // Interactor is a helper interface.
12 | type Interactor interface {
13 | Extract(target interface{}, options ...ExtractOption) error
14 | Invoke(fn interface{}, options ...InvokeOption) error
15 | }
16 |
17 | // Builder is helper interface.
18 | type Builder interface {
19 | Provide(provider interface{}, options ...ProvideOption)
20 | }
21 |
22 | // New create new container.
23 | func New() *Container {
24 | return &Container{
25 | graph: graphkv.New(),
26 | }
27 | }
28 |
29 | // Container is a dependency injection container.
30 | type Container struct {
31 | compiled bool
32 | graph *graphkv.Graph
33 | cleanups []func()
34 | }
35 |
36 | // Provide adds constructor into container with parameters.
37 | func (c *Container) Provide(constructor interface{}, options ...ProvideOption) {
38 | params := ProvideParams{}
39 | for _, opt := range options {
40 | opt.apply(¶ms)
41 | }
42 | provider := internalProvider(newProviderConstructor(params.Name, constructor))
43 | key := provider.Key()
44 | if c.graph.Exists(key) {
45 | panicf("The `%s` type already exists in container", provider.Key())
46 | }
47 | if !params.IsPrototype {
48 | provider = asSingleton(provider)
49 | }
50 | // add provider to graph
51 | c.graph.Add(key, provider)
52 | // parse embed parameters
53 | for _, param := range provider.ParameterList() {
54 | if param.embed {
55 | embed := newProviderEmbed(param)
56 | c.graph.Add(embed.Key(), embed)
57 | }
58 | }
59 | // provide parameter bag
60 | if len(params.Parameters) != 0 {
61 | parameterBugProvider := createParameterBugProvider(provider.Key(), params.Parameters)
62 | c.graph.Add(parameterBugProvider.Key(), parameterBugProvider)
63 | }
64 | // process interfaces
65 | for _, iface := range params.Interfaces {
66 | c.processProviderInterface(provider, iface)
67 | }
68 | }
69 |
70 | // Compile compiles the container. It iterates over all nodes
71 | // in graph and register their parameters.
72 | func (c *Container) Compile() {
73 | graphProvider := func() *Graph { return &Graph{graph: c.graph.DOTGraph()} }
74 | interactorProvider := func() Interactor { return c }
75 | c.Provide(graphProvider)
76 | c.Provide(interactorProvider)
77 | for _, node := range c.graph.Nodes() {
78 | c.registerProviderParameters(node.Value.(internalProvider))
79 | }
80 | if err := c.graph.CheckCycles(); err != nil {
81 | panic(err.Error())
82 | }
83 | c.compiled = true
84 | }
85 |
86 | // Extract builds instance of target type and fills target pointer.
87 | func (c *Container) Extract(target interface{}, options ...ExtractOption) error {
88 | params := ExtractParams{}
89 | for _, opt := range options {
90 | opt.apply(¶ms)
91 | }
92 | if !c.compiled {
93 | return fmt.Errorf("container not compiled")
94 | }
95 | if target == nil {
96 | return fmt.Errorf("extract target must be a pointer, got `nil`")
97 | }
98 | if !reflection.IsPtr(target) {
99 | return fmt.Errorf("extract target must be a pointer, got `%s`", reflect.TypeOf(target))
100 | }
101 | typ := reflect.TypeOf(target)
102 | param := parameter{
103 | name: params.Name,
104 | res: typ.Elem(),
105 | embed: isEmbedParameter(typ),
106 | }
107 | value, err := param.ResolveValue(c)
108 | if err != nil {
109 | return err
110 | }
111 | targetValue := reflect.ValueOf(target).Elem()
112 | targetValue.Set(value)
113 | return nil
114 | }
115 |
116 | // Invoke calls provided function.
117 | func (c *Container) Invoke(fn interface{}, options ...InvokeOption) error {
118 | params := InvokeParams{}
119 | for _, opt := range options {
120 | opt.apply(¶ms)
121 | }
122 | if !c.compiled {
123 | return fmt.Errorf("container not compiled")
124 | }
125 | invoker, err := newInvoker(fn)
126 | if err != nil {
127 | return err
128 | }
129 | return invoker.Invoke(c)
130 | }
131 |
132 | // Cleanup runs destructors in order that was been created.
133 | func (c *Container) Cleanup() {
134 | for _, cleanup := range c.cleanups {
135 | cleanup()
136 | }
137 | }
138 |
139 | // processProviderInterface represents instances as interfaces and groups.
140 | func (c *Container) processProviderInterface(provider internalProvider, as interface{}) {
141 | // create interface from provider
142 | iface := newProviderInterface(provider, as)
143 | key := iface.Key()
144 | if c.graph.Exists(key) {
145 | stub := newProviderStub(key, "have several implementations")
146 | c.graph.Replace(key, stub)
147 | } else {
148 | // add interface node
149 | c.graph.Add(key, iface)
150 | }
151 | // create group
152 | group := newProviderGroup(key)
153 | groupKey := group.Key()
154 | // check exists
155 | if c.graph.Exists(groupKey) {
156 | // if exists use existing group
157 | node := c.graph.Get(groupKey)
158 | group = node.Value.(*providerGroup)
159 | } else {
160 | // else add new group to graph
161 | c.graph.Add(groupKey, group)
162 | }
163 | // add provider reference into group
164 | providerKey := provider.Key()
165 | group.Add(providerKey)
166 | }
167 |
168 | // registerProviderParameters registers provider parameters in a dependency graph.
169 | func (c *Container) registerProviderParameters(p internalProvider) {
170 | for _, param := range p.ParameterList() {
171 | provider, exists := param.ResolveProvider(c)
172 | if exists {
173 | c.graph.Edge(provider.Key(), p.Key())
174 | continue
175 | }
176 | if !exists && !param.optional {
177 | panicf("%s: dependency %s not exists in container", p.Key(), param)
178 | }
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/di/container_test.go:
--------------------------------------------------------------------------------
1 | package di_test
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net"
7 | "net/http"
8 | "reflect"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/require"
12 |
13 | "github.com/defval/inject/v2/di"
14 | "github.com/defval/inject/v2/di/internal/ditest"
15 | )
16 |
17 | func TestContainerCompileErrors(t *testing.T) {
18 | t.Run("dependency cycle cause panic", func(t *testing.T) {
19 | c := NewTestContainer(t)
20 | c.MustProvide(ditest.NewCycleFooBar)
21 | c.MustProvide(ditest.NewBar)
22 | c.MustCompileError("the graph cannot be cyclic")
23 | })
24 |
25 | t.Run("not existing dependency cause compile error", func(t *testing.T) {
26 | c := NewTestContainer(t)
27 | c.MustProvide(ditest.NewBar)
28 | c.MustCompileError("*ditest.Bar: dependency *ditest.Foo not exists in container")
29 | })
30 |
31 | t.Run("not existing non pointer dependency cause compile error", func(t *testing.T) {
32 | c := NewTestContainer(t)
33 | type TestStruct struct {
34 | }
35 |
36 | c.MustProvide(func(s TestStruct) bool {
37 | return true
38 | })
39 |
40 | require.PanicsWithValue(t, "bool: dependency di_test.TestStruct not exists in container", func() {
41 | c.Compile()
42 | })
43 | })
44 | }
45 |
46 | func TestContainerProvideErrors(t *testing.T) {
47 | t.Run("provide string cause panic", func(t *testing.T) {
48 | c := NewTestContainer(t)
49 | c.MustProvideError("string", "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `string`")
50 | })
51 |
52 | t.Run("provide nil cause panic", func(t *testing.T) {
53 | c := NewTestContainer(t)
54 | c.MustProvideError(nil, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `nil`")
55 | })
56 |
57 | t.Run("provide struct pointer cause panic", func(t *testing.T) {
58 | c := NewTestContainer(t)
59 | c.MustProvideError(&ditest.Foo{}, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `*ditest.Foo`")
60 | })
61 |
62 | t.Run("provide constructor without result cause panic", func(t *testing.T) {
63 | c := NewTestContainer(t)
64 | c.MustProvideError(ditest.ConstructorWithoutResult, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `github.com/defval/inject/v2/di/internal/ditest.ConstructorWithoutResult`")
65 | })
66 |
67 | t.Run("provide constructor with many results cause panic", func(t *testing.T) {
68 | c := NewTestContainer(t)
69 | c.MustProvideError(ditest.ConstructorWithManyResults, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `github.com/defval/inject/v2/di/internal/ditest.ConstructorWithManyResults`")
70 | })
71 |
72 | t.Run("provide constructor with incorrect result error argument", func(t *testing.T) {
73 | c := NewTestContainer(t)
74 | c.MustProvideError(ditest.ConstructorWithIncorrectResultError, "The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `github.com/defval/inject/v2/di/internal/ditest.ConstructorWithIncorrectResultError`")
75 | })
76 |
77 | t.Run("provide duplicate", func(t *testing.T) {
78 | c := NewTestContainer(t)
79 | c.MustProvide(ditest.NewFoo)
80 | c.MustProvideError(ditest.NewFoo, "The `*ditest.Foo` type already exists in container")
81 | })
82 |
83 | t.Run("provide as not implemented interface cause error", func(t *testing.T) {
84 | c := NewTestContainer(t)
85 | c.MustProvide(ditest.NewFoo)
86 | c.MustProvideError(ditest.NewBar, "*ditest.Bar not implement ditest.Barer", new(ditest.Barer))
87 | })
88 |
89 | t.Run("provide as not interface cause error", func(t *testing.T) {
90 | c := NewTestContainer(t)
91 | c.MustProvide(ditest.NewFoo)
92 | c.MustProvideError(ditest.NewBar, "*ditest.Foo: not a pointer to interface", new(ditest.Foo))
93 | })
94 | }
95 |
96 | func TestContainerExtractErrors(t *testing.T) {
97 | t.Run("container panic on trying extract before compilation", func(t *testing.T) {
98 | c := NewTestContainer(t)
99 | foo := &ditest.Foo{}
100 | c.MustProvide(ditest.CreateFooConstructor(foo))
101 | var extracted *ditest.Foo
102 | c.MustExtractError(&extracted, "container not compiled")
103 | })
104 |
105 | t.Run("extract into string cause error", func(t *testing.T) {
106 | c := NewTestContainer(t)
107 | c.MustProvide(ditest.NewFoo)
108 | c.MustCompile()
109 | c.MustExtractError("string", "extract target must be a pointer, got `string`")
110 | })
111 |
112 | t.Run("extract into struct cause error", func(t *testing.T) {
113 | c := NewTestContainer(t)
114 | c.MustProvide(ditest.NewFoo)
115 | c.MustCompile()
116 | c.MustExtractError(struct{}{}, "extract target must be a pointer, got `struct {}`")
117 | })
118 |
119 | t.Run("extract into nil cause error", func(t *testing.T) {
120 | c := NewTestContainer(t)
121 | c.MustProvide(ditest.NewFoo)
122 | c.MustCompile()
123 | c.MustExtractError(nil, "extract target must be a pointer, got `nil`")
124 | })
125 |
126 | t.Run("container does not find type because its named", func(t *testing.T) {
127 | c := NewTestContainer(t)
128 | foo := &ditest.Foo{}
129 | c.MustProvideWithName("foo", ditest.CreateFooConstructor(foo))
130 | c.MustCompile()
131 |
132 | var extracted *ditest.Foo
133 | c.MustExtractError(&extracted, "*ditest.Foo: not exists in container")
134 | })
135 |
136 | t.Run("extract returns error because dependency constructing failed", func(t *testing.T) {
137 | c := NewTestContainer(t)
138 | c.MustProvide(ditest.CreateFooConstructorWithError(errors.New("internal error")))
139 | c.MustProvide(ditest.NewBar)
140 | c.MustCompile()
141 | var bar *ditest.Bar
142 | c.MustExtractError(&bar, "*ditest.Foo: internal error")
143 | })
144 |
145 | t.Run("extract interface with multiple implementations cause error", func(t *testing.T) {
146 | c := NewTestContainer(t)
147 | c.MustProvide(ditest.NewFoo)
148 | c.MustProvide(ditest.NewBar, new(ditest.Fooer))
149 | c.MustProvide(ditest.NewBaz, new(ditest.Fooer))
150 | c.MustCompile()
151 |
152 | var extracted ditest.Fooer
153 | c.MustExtractError(&extracted, "ditest.Fooer: have several implementations")
154 | })
155 | }
156 |
157 | func TestContainerInvokeErrors(t *testing.T) {
158 | t.Run("invoke function with incorrect signature cause error", func(t *testing.T) {
159 | c := NewTestContainer(t)
160 | c.MustCompile()
161 | c.MustInvokeError(func() *ditest.Foo {
162 | return nil
163 | }, "the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `func() *ditest.Foo`")
164 | })
165 |
166 | t.Run("invoke function with undefined dependency cause error", func(t *testing.T) {
167 | c := NewTestContainer(t)
168 | c.MustCompile()
169 | c.MustInvokeError(func(foo *ditest.Foo) {}, "could not resolve invoke parameters: *ditest.Foo: not exists in container")
170 | })
171 |
172 | t.Run("invoke before compile cause error", func(t *testing.T) {
173 | c := NewTestContainer(t)
174 | c.MustInvokeError(func() {}, "container not compiled")
175 | })
176 | }
177 |
178 | func TestContainerProvide(t *testing.T) {
179 | t.Run("container successfully accept simple constructor", func(t *testing.T) {
180 | c := NewTestContainer(t)
181 | c.MustProvide(ditest.NewFoo)
182 | })
183 |
184 | t.Run("container successfully accept constructor with error", func(t *testing.T) {
185 | c := NewTestContainer(t)
186 | c.MustProvide(ditest.CreateFooConstructorWithError(nil))
187 | })
188 |
189 | t.Run("container successfully accept constructor with cleanup function", func(t *testing.T) {
190 | c := NewTestContainer(t)
191 |
192 | cleanup := func() {}
193 | c.MustProvide(ditest.CreateFooConstructorWithCleanup(cleanup))
194 | })
195 |
196 | }
197 |
198 | func TestContainerExtract(t *testing.T) {
199 | t.Run("container extract correct pointer", func(t *testing.T) {
200 | c := NewTestContainer(t)
201 | foo := &ditest.Foo{}
202 | c.MustProvide(ditest.CreateFooConstructor(foo))
203 | c.MustCompile()
204 |
205 | var extracted *ditest.Foo
206 | c.MustExtractPtr(foo, &extracted)
207 | })
208 |
209 | t.Run("container extract same pointer on each extraction", func(t *testing.T) {
210 | c := NewTestContainer(t)
211 | foo := &ditest.Foo{}
212 | c.MustProvide(ditest.CreateFooConstructor(foo))
213 | c.MustCompile()
214 |
215 | var extracted1 *ditest.Foo
216 | c.MustExtractPtr(foo, &extracted1)
217 |
218 | var extracted2 *ditest.Foo
219 | c.MustExtractPtr(foo, &extracted2)
220 | })
221 |
222 | t.Run("container extract instance if error is nil", func(t *testing.T) {
223 | c := NewTestContainer(t)
224 | c.MustProvide(ditest.CreateFooConstructorWithError(nil))
225 | c.MustCompile()
226 |
227 | var extracted *ditest.Foo
228 | c.MustExtract(&extracted)
229 | })
230 |
231 | t.Run("container extract instance if cleanup and error is nil", func(t *testing.T) {
232 | c := NewTestContainer(t)
233 |
234 | c.MustProvide(ditest.CreateFooConstructorWithCleanupAndError(nil, nil))
235 | c.MustCompile()
236 |
237 | var extracted *ditest.Foo
238 | c.MustExtract(&extracted)
239 | })
240 |
241 | t.Run("container extract correct named pointer", func(t *testing.T) {
242 | c := NewTestContainer(t)
243 | foo := &ditest.Foo{}
244 | c.MustProvideWithName("foo", ditest.CreateFooConstructor(foo))
245 | c.MustCompile()
246 |
247 | var extracted *ditest.Foo
248 | c.MustExtractWithName("foo", &extracted)
249 | })
250 |
251 | t.Run("container extract correct interface implementation", func(t *testing.T) {
252 | c := NewTestContainer(t)
253 | bar := &ditest.Bar{}
254 | c.MustProvide(ditest.NewFoo)
255 | c.MustProvide(ditest.CreateBarConstructor(bar), new(ditest.Fooer))
256 | c.MustCompile()
257 |
258 | var extracted ditest.Fooer
259 | c.MustExtractPtr(bar, &extracted)
260 | })
261 |
262 | t.Run("container creates group from interface and extract it", func(t *testing.T) {
263 | c := NewTestContainer(t)
264 | c.MustProvide(ditest.NewFoo)
265 | c.MustProvide(ditest.NewBar, new(ditest.Fooer))
266 | c.MustProvide(ditest.NewBaz, new(ditest.Fooer))
267 | c.MustCompile()
268 |
269 | var group []ditest.Fooer
270 | c.MustExtract(&group)
271 | require.Len(t, group, 2)
272 | })
273 |
274 | t.Run("container extract new instance of prototype by each extraction", func(t *testing.T) {
275 | c := NewTestContainer(t)
276 | c.MustProvide(ditest.NewFoo)
277 | c.MustProvidePrototype(ditest.NewBar)
278 | c.MustCompile()
279 |
280 | var extracted1 *ditest.Bar
281 | c.MustExtract(&extracted1)
282 | var extracted2 *ditest.Bar
283 | c.MustExtract(&extracted2)
284 |
285 | c.MustNotEqualPointer(extracted1, extracted2)
286 | })
287 |
288 | t.Run("container resolve interactor", func(t *testing.T) {
289 | c := NewTestContainer(t)
290 | foo := ditest.NewFoo()
291 | c.MustProvide(ditest.CreateFooConstructor(foo))
292 | c.MustCompile()
293 | var interactor di.Interactor
294 | c.MustExtract(&interactor)
295 | var extractedFoo *ditest.Foo
296 | require.NoError(t, interactor.Extract(&extractedFoo))
297 | c.MustEqualPointer(foo, extractedFoo)
298 | })
299 | }
300 |
301 | func TestContainerResolve(t *testing.T) {
302 | t.Run("container resolve correct argument", func(t *testing.T) {
303 | c := NewTestContainer(t)
304 | foo := &ditest.Foo{}
305 | c.MustProvide(ditest.CreateFooConstructor(foo))
306 | c.MustProvide(ditest.NewBar)
307 | c.MustCompile()
308 |
309 | var bar *ditest.Bar
310 | c.MustExtract(&bar)
311 | c.MustEqualPointer(foo, bar.Foo())
312 | })
313 |
314 | t.Run("container resolve correct interface implementation", func(t *testing.T) {
315 | c := NewTestContainer(t)
316 |
317 | foo := ditest.NewFoo()
318 | bar := ditest.NewBar(foo)
319 |
320 | c.MustProvide(ditest.CreateFooConstructor(foo))
321 | c.MustProvide(ditest.CreateBarConstructor(bar), new(ditest.Fooer))
322 | c.MustProvide(ditest.NewQux)
323 | c.MustCompile()
324 |
325 | var qux *ditest.Qux
326 | c.MustExtract(&qux)
327 | c.MustEqualPointer(bar, qux.Fooer())
328 | })
329 |
330 | t.Run("container resolve correct group", func(t *testing.T) {
331 | c := NewTestContainer(t)
332 |
333 | c.MustProvide(ditest.NewFoo)
334 | c.MustProvide(ditest.NewBar, new(ditest.Fooer))
335 | c.MustProvide(ditest.NewBaz, new(ditest.Fooer))
336 | c.MustProvide(ditest.NewFooerGroup)
337 | c.MustCompile()
338 |
339 | var bar *ditest.Bar
340 | c.MustExtract(&bar)
341 |
342 | var baz *ditest.Baz
343 | c.MustExtract(&baz)
344 |
345 | var group *ditest.FooerGroup
346 | c.MustExtract(&group)
347 | require.Len(t, group.Fooers(), 2)
348 | c.MustEqualPointer(bar, group.Fooers()[0])
349 | c.MustEqualPointer(baz, group.Fooers()[1])
350 | })
351 | }
352 |
353 | func TestContainerResolveEmbedParameters(t *testing.T) {
354 | t.Run("container resolve embed parameters", func(t *testing.T) {
355 | c := NewTestContainer(t)
356 | foo := ditest.NewFoo()
357 | bar := ditest.NewBar(foo)
358 | c.MustProvide(ditest.CreateFooConstructor(foo))
359 | c.MustProvide(ditest.CreateBarConstructor(bar))
360 | c.MustProvide(ditest.NewBazFromParameters)
361 | c.MustCompile()
362 |
363 | var extracted *ditest.Baz
364 | c.MustExtract(&extracted)
365 | c.MustEqualPointer(foo, extracted.Foo())
366 | c.MustEqualPointer(bar, extracted.Bar())
367 | })
368 |
369 | t.Run("container skip optional parameter", func(t *testing.T) {
370 | c := NewTestContainer(t)
371 | foo := ditest.NewFoo()
372 | c.MustProvide(ditest.CreateFooConstructor(foo))
373 | c.MustProvide(ditest.NewBazFromParameters)
374 | c.MustCompile()
375 |
376 | var extracted *ditest.Baz
377 | c.MustExtract(&extracted)
378 | c.MustEqualPointer(foo, extracted.Foo())
379 | require.Nil(t, extracted.Bar())
380 | })
381 |
382 | t.Run("container resolve optional not existing group as nil", func(t *testing.T) {
383 | c := NewTestContainer(t)
384 | type Params struct {
385 | di.Parameter
386 | Handlers []http.Handler `di:"optional"`
387 | }
388 | c.MustProvide(func(params Params) bool {
389 | return params.Handlers == nil
390 | })
391 | c.MustCompile()
392 | var extracted bool
393 | c.MustExtract(&extracted)
394 | require.True(t, extracted)
395 | })
396 |
397 | t.Run("container skip private fields in parameter", func(t *testing.T) {
398 | c := NewTestContainer(t)
399 | type Param struct {
400 | di.Parameter
401 | private []http.Handler `di:"optional"`
402 | Addrs []net.Addr `di:"optional"`
403 | HaveNotTag string
404 | }
405 | c.MustProvide(func(param Param) bool {
406 | return param.Addrs == nil
407 | })
408 | c.MustCompile()
409 | var extracted bool
410 | c.MustExtract(&extracted)
411 | require.True(t, extracted)
412 | })
413 | }
414 |
415 | func TestContainerInvoke(t *testing.T) {
416 | t.Run("container call invoke function", func(t *testing.T) {
417 | c := NewTestContainer(t)
418 | c.MustCompile()
419 | var invokeCalled bool
420 | c.MustInvoke(func() {
421 | invokeCalled = true
422 | })
423 | require.True(t, invokeCalled)
424 | })
425 |
426 | t.Run("container resolve dependencies in invoke function", func(t *testing.T) {
427 | c := NewTestContainer(t)
428 | foo := ditest.NewFoo()
429 | c.MustProvide(ditest.CreateFooConstructor(foo))
430 | c.MustCompile()
431 | c.MustInvoke(func(invokeFoo *ditest.Foo) {
432 | c.MustEqualPointer(foo, invokeFoo)
433 | })
434 | })
435 |
436 | t.Run("container invoke return correct error", func(t *testing.T) {
437 | c := NewTestContainer(t)
438 | c.MustProvide(ditest.NewFoo)
439 | c.Compile()
440 | c.MustInvokeError(func(foo *ditest.Foo) error {
441 | return errors.New("invoke error")
442 | }, "invoke error")
443 | })
444 |
445 | t.Run("container invoke with nil error", func(t *testing.T) {
446 | c := NewTestContainer(t)
447 | c.MustProvide(ditest.NewFoo)
448 | c.Compile()
449 | c.MustInvoke(func(foo *ditest.Foo) error {
450 | return nil
451 | })
452 | })
453 | }
454 |
455 | func TestContainerResolveParameterBag(t *testing.T) {
456 | t.Run("container extract correct parameter bag for type", func(t *testing.T) {
457 | c := NewTestContainer(t)
458 |
459 | c.Provide(ditest.NewFooWithParameters, di.ProvideParams{
460 | Parameters: di.ParameterBag{
461 | "name": "test",
462 | },
463 | })
464 |
465 | c.MustCompile()
466 |
467 | var foo *ditest.Foo
468 | err := c.Extract(&foo)
469 |
470 | require.NoError(t, err)
471 | require.Equal(t, "test", foo.Name)
472 | })
473 |
474 | t.Run("container extract correct parameter bag for named type", func(t *testing.T) {
475 | c := NewTestContainer(t)
476 |
477 | c.Provide(ditest.NewFooWithParameters, di.ProvideParams{
478 | Name: "named",
479 | Parameters: di.ParameterBag{
480 | "name": "test",
481 | },
482 | })
483 |
484 | c.MustCompile()
485 |
486 | var foo *ditest.Foo
487 | err := c.Extract(&foo, di.ExtractParams{
488 | Name: "named",
489 | })
490 |
491 | require.NoError(t, err)
492 | require.Equal(t, "test", foo.Name)
493 | })
494 | }
495 |
496 | func TestContainerCleanup(t *testing.T) {
497 | t.Run("container run cleanup function after container close", func(t *testing.T) {
498 | c := NewTestContainer(t)
499 | var cleanupCalled bool
500 | c.MustProvide(ditest.CreateFooConstructorWithCleanup(func() { cleanupCalled = true }))
501 | c.MustCompile()
502 |
503 | var extracted *ditest.Foo
504 | c.MustExtract(&extracted)
505 | c.Cleanup()
506 |
507 | require.True(t, cleanupCalled)
508 | })
509 |
510 | t.Run("cleanup run in correct order", func(t *testing.T) {
511 | c := NewTestContainer(t)
512 | var cleanupCalls []string
513 | c.MustProvide(func(bar *ditest.Bar) (*ditest.Foo, func()) {
514 | return &ditest.Foo{}, func() { cleanupCalls = append(cleanupCalls, "foo") }
515 | })
516 | c.MustProvide(func() (*ditest.Bar, func()) {
517 | return &ditest.Bar{}, func() { cleanupCalls = append(cleanupCalls, "bar") }
518 | })
519 | c.MustCompile()
520 |
521 | var foo *ditest.Foo
522 | c.MustExtract(&foo)
523 | c.Cleanup()
524 | require.Equal(t, []string{"bar", "foo"}, cleanupCalls)
525 | })
526 |
527 | t.Run("cleanup for every prototyped instance", func(t *testing.T) {
528 | c := NewTestContainer(t)
529 | var cleanupCalls []string
530 | c.Provide(func() (*ditest.Foo, func()) {
531 | return &ditest.Foo{}, func() {
532 | cleanupCalls = append(cleanupCalls, fmt.Sprintf("foo_%d", len(cleanupCalls)))
533 | }
534 | }, di.ProvideParams{
535 | IsPrototype: true,
536 | })
537 | c.MustCompile()
538 | var foo1, foo2 *ditest.Foo
539 | c.MustExtract(&foo1)
540 | c.MustExtract(&foo2)
541 | c.Cleanup()
542 | require.Equal(t, []string{"foo_0", "foo_1"}, cleanupCalls)
543 | })
544 | }
545 |
546 | func TestContainer_GraphVisualizing(t *testing.T) {
547 | t.Run("graph", func(t *testing.T) {
548 | c := NewTestContainer(t)
549 |
550 | c.MustProvide(ditest.NewLogger)
551 | c.MustProvide(ditest.NewServer)
552 | c.MustProvide(ditest.NewRouter, new(http.Handler))
553 | c.MustProvide(ditest.NewAccountController, new(ditest.Controller))
554 | c.MustProvide(ditest.NewAuthController, new(ditest.Controller))
555 | c.MustCompile()
556 |
557 | var graph *di.Graph
558 | require.NoError(t, c.Extract(&graph))
559 |
560 | fmt.Println(graph.String())
561 |
562 | require.Equal(t, `digraph {
563 | subgraph cluster_s3 {
564 | ID = "cluster_s3";
565 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
566 | n9[color="#46494C",fontcolor="white",fontname="COURIER",label="*di.Graph",shape="box",style="filled"];
567 | n10[color="#46494C",fontcolor="white",fontname="COURIER",label="di.Interactor",shape="box",style="filled"];
568 |
569 | }subgraph cluster_s2 {
570 | ID = "cluster_s2";
571 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
572 | n6[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AccountController",shape="box",style="filled"];
573 | n8[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AuthController",shape="box",style="filled"];
574 | n7[color="#E54B4B",fontcolor="white",fontname="COURIER",label="[]ditest.Controller",shape="doubleoctagon",style="filled"];
575 | n4[color="#E5984B",fontcolor="white",fontname="COURIER",label="ditest.RouterParams",shape="box",style="filled"];
576 |
577 | }subgraph cluster_s0 {
578 | ID = "cluster_s0";
579 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
580 | n1[color="#46494C",fontcolor="white",fontname="COURIER",label="*log.Logger",shape="box",style="filled"];
581 |
582 | }subgraph cluster_s1 {
583 | ID = "cluster_s1";
584 | bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
585 | n3[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.ServeMux",shape="box",style="filled"];
586 | n2[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.Server",shape="box",style="filled"];
587 | n5[color="#2589BD",fontcolor="white",fontname="COURIER",label="http.Handler",style="filled"];
588 |
589 | }splines="ortho";
590 | n6->n7[color="#949494"];
591 | n8->n7[color="#949494"];
592 | n3->n5[color="#949494"];
593 | n1->n2[color="#949494"];
594 | n1->n3[color="#949494"];
595 | n1->n6[color="#949494"];
596 | n1->n8[color="#949494"];
597 | n7->n4[color="#949494"];
598 | n4->n3[color="#949494"];
599 | n5->n2[color="#949494"];
600 |
601 | }`, graph.String())
602 | })
603 | }
604 |
605 | // NewTestContainer
606 | func NewTestContainer(t *testing.T) *TestContainer {
607 | return &TestContainer{t, di.New()}
608 | }
609 |
610 | // TestContainer
611 | type TestContainer struct {
612 | t *testing.T
613 | *di.Container
614 | }
615 |
616 | func (c *TestContainer) MustProvide(provider interface{}, as ...interface{}) {
617 | require.NotPanics(c.t, func() {
618 | c.Provide(provider, di.ProvideParams{
619 | Interfaces: as,
620 | })
621 | }, "provide should not panic")
622 | }
623 |
624 | func (c *TestContainer) MustProvidePrototype(provider interface{}, as ...interface{}) {
625 | require.NotPanics(c.t, func() {
626 | c.Provide(provider, di.ProvideParams{
627 | Interfaces: as,
628 | IsPrototype: true,
629 | })
630 | })
631 | }
632 |
633 | func (c *TestContainer) MustProvideWithName(name string, provider interface{}, as ...interface{}) {
634 | require.NotPanics(c.t, func() {
635 | c.Provide(provider, di.ProvideParams{
636 | Name: name,
637 | Interfaces: as,
638 | })
639 | })
640 | }
641 |
642 | func (c *TestContainer) MustProvideError(provider interface{}, msg string, as ...interface{}) {
643 | require.PanicsWithValue(c.t, msg, func() {
644 | c.Provide(provider, di.ProvideParams{
645 | Interfaces: as,
646 | })
647 | })
648 | }
649 |
650 | func (c *TestContainer) MustCompile() {
651 | require.NotPanics(c.t, func() {
652 | c.Compile()
653 | })
654 | }
655 |
656 | func (c *TestContainer) MustCompileError(msg string) {
657 | require.PanicsWithValue(c.t, msg, func() {
658 | c.Compile()
659 | })
660 | }
661 |
662 | func (c *TestContainer) MustExtract(target interface{}) {
663 | require.NoError(c.t, c.Extract(target))
664 | }
665 |
666 | func (c *TestContainer) MustExtractWithName(name string, target interface{}) {
667 | require.NoError(c.t, c.Extract(target, di.ExtractParams{
668 | Name: name,
669 | }))
670 | }
671 |
672 | func (c *TestContainer) MustExtractError(target interface{}, msg string) {
673 | require.EqualError(c.t, c.Extract(target, di.ExtractParams{}), msg)
674 | }
675 |
676 | func (c *TestContainer) MustExtractWithNameError(name string, target interface{}, msg string) {
677 | require.EqualError(c.t, c.Extract(target, di.ExtractParams{
678 | Name: name,
679 | }), msg)
680 | }
681 |
682 | // MustExtractPtr extract value from container into target and check that target and expected pointers are equal.
683 | func (c *TestContainer) MustExtractPtr(expected, target interface{}) {
684 | c.MustExtract(target)
685 |
686 | // indirect
687 | actual := reflect.ValueOf(target).Elem().Interface()
688 | c.MustEqualPointer(expected, actual)
689 | }
690 |
691 | func (c *TestContainer) MustExtractPtrWithName(expected interface{}, name string, target interface{}) {
692 | c.MustExtractWithName(name, target)
693 |
694 | actual := reflect.ValueOf(target).Elem().Interface()
695 | c.MustEqualPointer(expected, actual)
696 | }
697 |
698 | func (c *TestContainer) MustInvoke(fn interface{}) {
699 | require.NoError(c.t, c.Invoke(fn))
700 | }
701 |
702 | func (c *TestContainer) MustInvokeError(fn interface{}, msg string) {
703 | require.EqualError(c.t, c.Invoke(fn), msg)
704 | }
705 |
706 | func (c *TestContainer) MustEqualPointer(expected interface{}, actual interface{}) {
707 | require.Equal(c.t,
708 | fmt.Sprintf("%p", actual),
709 | fmt.Sprintf("%p", expected),
710 | "actual and expected pointers should be equal",
711 | )
712 | }
713 |
714 | func (c *TestContainer) MustNotEqualPointer(expected interface{}, actual interface{}) {
715 | require.NotEqual(c.t,
716 | fmt.Sprintf("%p", actual),
717 | fmt.Sprintf("%p", expected),
718 | "actual and expected pointers should not be equal",
719 | )
720 | }
721 |
--------------------------------------------------------------------------------
/di/dot.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "io"
5 |
6 | "github.com/emicklei/dot"
7 | )
8 |
9 | // Graph
10 | type Graph struct {
11 | graph *dot.Graph
12 | }
13 |
14 | func (g *Graph) WriteTo(writer io.Writer) {
15 | g.graph.Write(writer)
16 | }
17 |
18 | func (g *Graph) String() string {
19 | return g.graph.String()
20 | }
21 |
--------------------------------------------------------------------------------
/di/errors.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import "fmt"
4 |
5 | // ErrParameterProvideFailed
6 | type ErrParameterProvideFailed struct {
7 | k key
8 | err error
9 | }
10 |
11 | func (e ErrParameterProvideFailed) Error() string {
12 | return fmt.Sprintf("%s: %s", e.k, e.err)
13 | }
14 |
15 | // ErrParameterProviderNotFound
16 | type ErrParameterProviderNotFound struct {
17 | param parameter
18 | }
19 |
20 | func (e ErrParameterProviderNotFound) Error() string {
21 | return fmt.Sprintf("%s: not exists in container", e.param)
22 | }
23 |
--------------------------------------------------------------------------------
/di/internal/ditest/bar.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | // Bar
4 | type Bar struct {
5 | foo *Foo
6 | }
7 |
8 | // NewBar
9 | func NewBar(foo *Foo) *Bar {
10 | return &Bar{
11 | foo: foo,
12 | }
13 | }
14 |
15 | // CreateBarConstructor
16 | func CreateBarConstructor(bar *Bar) func(foo *Foo) *Bar {
17 | return func(foo *Foo) *Bar {
18 | bar.foo = foo
19 | return bar
20 | }
21 | }
22 |
23 | func (b *Bar) Foo() *Foo { return b.foo }
24 |
--------------------------------------------------------------------------------
/di/internal/ditest/baz.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | import "github.com/defval/inject/v2/di"
4 |
5 | // Baz
6 | type Baz struct {
7 | foo *Foo
8 | bar *Bar
9 | }
10 |
11 | // NewBaz
12 | func NewBaz(foo *Foo, bar *Bar) *Baz {
13 | return &Baz{
14 | foo: foo,
15 | bar: bar,
16 | }
17 | }
18 |
19 | // BazParameters
20 | type BazParameters struct {
21 | di.Parameter
22 |
23 | Foo *Foo `di:""`
24 | Bar *Bar `di:"optional"`
25 | }
26 |
27 | // NewBazFromParameters
28 | func NewBazFromParameters(params BazParameters) *Baz {
29 | return &Baz{
30 | foo: params.Foo,
31 | bar: params.Bar,
32 | }
33 | }
34 |
35 | func (b *Baz) Foo() *Foo { return b.foo }
36 | func (b *Baz) Bar() *Bar { return b.bar }
37 |
--------------------------------------------------------------------------------
/di/internal/ditest/foo.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | import "github.com/defval/inject/v2/di"
4 |
5 | // Foo test struct
6 | type Foo struct {
7 | Name string
8 | }
9 |
10 | // NewFoo create new foo
11 | func NewFoo() *Foo {
12 | return &Foo{}
13 | }
14 |
15 | // NewFooWithParameters
16 | func NewFooWithParameters(parameters di.ParameterBag) *Foo {
17 | return &Foo{Name: parameters.RequireString("name")}
18 | }
19 |
20 | // NewCycleFooBar
21 | func NewCycleFooBar(bar *Bar) *Foo {
22 | return &Foo{}
23 | }
24 |
25 | // CreateFooConstructor
26 | func CreateFooConstructor(foo *Foo) func() *Foo {
27 | return func() *Foo {
28 | return foo
29 | }
30 | }
31 |
32 | // CreateFooConstructorWithError
33 | func CreateFooConstructorWithError(err error) func() (*Foo, error) {
34 | return func() (foo *Foo, e error) {
35 | return &Foo{}, err
36 | }
37 | }
38 |
39 | // CreateFooConstructorWithCleanup
40 | func CreateFooConstructorWithCleanup(cleanup func()) func() (*Foo, func()) {
41 | return func() (foo *Foo, i func()) {
42 | return &Foo{}, cleanup
43 | }
44 | }
45 |
46 | // CreateFooConstructorWithCleanupAndError
47 | func CreateFooConstructorWithCleanupAndError(cleanup func(), err error) func() (*Foo, func(), error) {
48 | return func() (foo *Foo, i func(), e error) {
49 | return &Foo{}, cleanup, err
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/di/internal/ditest/fooer_group.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | // FooerGroup
4 | type FooerGroup struct {
5 | fooers []Fooer
6 | }
7 |
8 | // NewFooerGroup
9 | func NewFooerGroup(fooers []Fooer) *FooerGroup {
10 | return &FooerGroup{fooers: fooers}
11 | }
12 |
13 | func (g *FooerGroup) Fooers() []Fooer {
14 | return g.fooers
15 | }
16 |
--------------------------------------------------------------------------------
/di/internal/ditest/full.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | import (
4 | "log"
5 | "net/http"
6 | "os"
7 |
8 | "github.com/defval/inject/v2/di"
9 | )
10 |
11 | // NewLogger
12 | func NewLogger() *log.Logger {
13 | logger := log.New(os.Stdout, "", 0)
14 | defer logger.Println("Logger loaded!")
15 |
16 | return logger
17 | }
18 |
19 | // NewServer
20 | func NewServer(logger *log.Logger, handler http.Handler) *http.Server {
21 | defer logger.Println("Server created!")
22 | return &http.Server{
23 | Handler: handler,
24 | }
25 | }
26 |
27 | // RouterParams
28 | type RouterParams struct {
29 | di.Parameter
30 | Controllers []Controller `di:"optional"`
31 | }
32 |
33 | // NewRouter
34 | func NewRouter(logger *log.Logger, params RouterParams) *http.ServeMux {
35 | logger.Println("Create router!")
36 | defer logger.Println("Router created!")
37 |
38 | mux := &http.ServeMux{}
39 |
40 | for _, ctrl := range params.Controllers {
41 | ctrl.RegisterRoutes(mux)
42 | }
43 |
44 | return mux
45 | }
46 |
47 | // Controller
48 | type Controller interface {
49 | RegisterRoutes(mux *http.ServeMux)
50 | }
51 |
52 | // AccountController
53 | type AccountController struct {
54 | Logger *log.Logger
55 | }
56 |
57 | // NewAccountController
58 | func NewAccountController(logger *log.Logger) *AccountController {
59 | return &AccountController{Logger: logger}
60 | }
61 |
62 | // RegisterRoutes
63 | func (c *AccountController) RegisterRoutes(mux *http.ServeMux) {
64 | c.Logger.Println("AccountController registered!")
65 |
66 | // register your routes
67 | }
68 |
69 | // AuthController
70 | type AuthController struct {
71 | Logger *log.Logger
72 | }
73 |
74 | // NewAuthController
75 | func NewAuthController(logger *log.Logger) *AuthController {
76 | return &AuthController{Logger: logger}
77 | }
78 |
79 | // RegisterRoutes
80 | func (c *AuthController) RegisterRoutes(mux *http.ServeMux) {
81 | c.Logger.Println("AuthController registered!")
82 |
83 | // register your routes
84 | }
85 |
--------------------------------------------------------------------------------
/di/internal/ditest/incorrect.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | // ConstructorWithoutResult
4 | func ConstructorWithoutResult() {
5 |
6 | }
7 |
8 | // ConstructorWithManyResults
9 | func ConstructorWithManyResults() (*Foo, *Bar, error) {
10 | return &Foo{}, &Bar{}, nil
11 | }
12 |
13 | // ConstructorWithIncorrectResultError
14 | func ConstructorWithIncorrectResultError() (*Foo, *Bar) {
15 | return &Foo{}, &Bar{}
16 | }
17 |
--------------------------------------------------------------------------------
/di/internal/ditest/interfaces.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | // Fooer
4 | type Fooer interface {
5 | Foo() *Foo
6 | }
7 |
8 | // Barer
9 | type Barer interface {
10 | Bar() *Bar
11 | }
12 |
13 | // Bazer
14 | type Bazer interface {
15 | Baz() *Baz
16 | }
17 |
--------------------------------------------------------------------------------
/di/internal/ditest/qux.go:
--------------------------------------------------------------------------------
1 | package ditest
2 |
3 | // Qux
4 | type Qux struct {
5 | fooer Fooer
6 | }
7 |
8 | // NewQux
9 | func NewQux(foo Fooer) *Qux {
10 | return &Qux{
11 | fooer: foo,
12 | }
13 | }
14 |
15 | func (q *Qux) Fooer() Fooer { return q.fooer }
16 |
--------------------------------------------------------------------------------
/di/internal/graphkv/directed_graph.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | // directedGraph is a graph supporting directed edges between nodes.
4 | type directedGraph struct {
5 | *graph
6 | edges *directedEdgeList
7 | }
8 |
9 | // newDirectedGraph creates a graph of nodes with directed edges.
10 | func newDirectedGraph() *directedGraph {
11 | return &directedGraph{
12 | graph: newGraph(),
13 | edges: newDirectedEdgeList(),
14 | }
15 | }
16 |
17 | // Copy returns a clone of the directed graph.
18 | func (g *directedGraph) Copy() *directedGraph {
19 | return &directedGraph{
20 | graph: g.graph.Copy(),
21 | edges: g.edges.Copy(),
22 | }
23 | }
24 |
25 | // EdgeCount returns the number of direced edges between nodes.
26 | func (g *directedGraph) EdgeCount() int {
27 | return g.edges.Count()
28 | }
29 |
30 | // AddEdge adds the edge to the graph.
31 | func (g *directedGraph) AddEdge(from Key, to Key) {
32 | // prevent adding an edge referring to missing nodes
33 | if !g.NodeExists(from) {
34 | g.AddNode(from)
35 | }
36 | if !g.NodeExists(to) {
37 | g.AddNode(to)
38 | }
39 |
40 | g.edges.Add(from, to)
41 | }
42 |
43 | // RemoveEdge removes the edge from the graph.
44 | func (g *directedGraph) RemoveEdge(from Key, to Key) {
45 | g.edges.Remove(from, to)
46 | }
47 |
48 | // HasEdges determines whether the graph contains any edges to or from the node.
49 | func (g *directedGraph) HasEdges(node Key) bool {
50 | if g.HasIncomingEdges(node) {
51 | return true
52 | }
53 | return g.HasOutgoingEdges(node)
54 | }
55 |
56 | // EdgeExists checks whether the edge exists within the graph.
57 | func (g *directedGraph) EdgeExists(from Key, to Key) bool {
58 | return g.edges.Exists(from, to)
59 | }
60 |
61 | // HasIncomingEdges checks whether the graph contains any directed
62 | // edges pointing to the node.
63 | func (g *directedGraph) HasIncomingEdges(node Key) bool {
64 | return g.edges.HasIncomingEdges(node)
65 | }
66 |
67 | // IncomingEdges returns the nodes belonging to directed edges pointing
68 | // towards the specified node.
69 | func (g *directedGraph) IncomingEdges(node Key) []Key {
70 | return g.edges.IncomingEdges(node)
71 | }
72 |
73 | // IncomingEdgeCount returns the number of edges pointing from the specified
74 | // node (indegree).
75 | func (g *directedGraph) IncomingEdgeCount(node Key) int {
76 | return g.edges.IncomingEdgeCount(node)
77 | }
78 |
79 | // HasOutgoingEdges checks whether the graph contains any directed
80 | // edges pointing from the node.
81 | func (g *directedGraph) HasOutgoingEdges(node Key) bool {
82 | return g.edges.HasOutgoingEdges(node)
83 | }
84 |
85 | // OutgoingEdges returns the nodes belonging to directed edges pointing
86 | // from the specified node.
87 | func (g *directedGraph) OutgoingEdges(node Key) []Key {
88 | return g.edges.OutgoingEdges(node)
89 | }
90 |
91 | // OutgoingEdgeCount returns the number of edges pointing from the specified
92 | // node (outdegree).
93 | func (g *directedGraph) OutgoingEdgeCount(node Key) int {
94 | return g.edges.OutgoingEdgeCount(node)
95 | }
96 |
97 | // RootNodes finds the entry-point nodes to the graph, i.e. those without
98 | // incoming edges.
99 | func (g *directedGraph) RootNodes() []Key {
100 | results := make([]Key, 0)
101 | for _, node := range g.Nodes() {
102 | if !g.HasIncomingEdges(node) {
103 | results = append(results, node)
104 | }
105 | }
106 | return results
107 | }
108 |
109 | // IsolatedNodes finds independent nodes in the graph, i.e. those without edges.
110 | func (g *directedGraph) IsolatedNodes() []Key {
111 | results := make([]Key, 0)
112 | for _, node := range g.Nodes() {
113 | if !g.HasEdges(node) {
114 | results = append(results, node)
115 | }
116 | }
117 | return results
118 | }
119 |
120 | // AdjacencyMatrix returns a matrix indicating whether pairs of nodes are
121 | // adjacent or not within the graph.
122 | func (g *directedGraph) AdjacencyMatrix() map[Key]map[Key]bool {
123 | matrix := make(map[Key]map[Key]bool, g.NodeCount())
124 | for _, a := range g.Nodes() {
125 | matrix[a] = make(map[Key]bool, g.NodeCount())
126 |
127 | for _, b := range g.Nodes() {
128 | matrix[a][b] = g.EdgeExists(a, b)
129 | }
130 | }
131 | return matrix
132 | }
133 |
134 | // RemoveTransitives removes any transitive edges so that as fewest possible
135 | // edges exist while matching the reachability of the original graph.
136 | func (g *directedGraph) RemoveTransitives() {
137 | for _, a := range g.Nodes() {
138 | for _, b := range g.Nodes() {
139 | if !g.EdgeExists(a, b) {
140 | continue
141 | }
142 | for _, c := range g.Nodes() {
143 | if g.EdgeExists(b, c) {
144 | g.RemoveEdge(a, c)
145 | }
146 | }
147 | }
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/di/internal/graphkv/directed_graph_test.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func newTestDirectedGraph() *directedGraph {
10 | graph := newDirectedGraph()
11 | graph.AddNodes("A", "B", "C", "D")
12 | return graph
13 | }
14 |
15 | func TestNewDirectedGraph(t *testing.T) {
16 | graph := newDirectedGraph()
17 | assert.NotNil(t, graph, "graph should not be nil")
18 | assert.Zero(t, graph.NodeCount(), "graph.NodeCount() should equal zero")
19 | assert.Empty(t, graph.Nodes(), "graph.Nodes() should equal empty")
20 | assert.Zero(t, graph.EdgeCount(), "graph.EdgeCount() should equal zero")
21 | }
22 |
23 | func TestDirectedGraphAddEdge(t *testing.T) {
24 | graph := newTestDirectedGraph()
25 | graph.AddEdge("A", "B")
26 | graph.AddEdge("B", "D")
27 | graph.AddEdge("C", "B")
28 |
29 | assert.Equal(t, 3, graph.EdgeCount(), "graph.EdgeCount() should equal 3")
30 | assert.True(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal true")
31 | assert.True(t, graph.EdgeExists("B", "D"), "graph.EdgeExists(B, D) should equal true")
32 | assert.True(t, graph.EdgeExists("C", "B"), "graph.EdgeExists(C, B) should equal true")
33 | }
34 |
35 | func TestDirectedGraphAddEdgeDuplicate(t *testing.T) {
36 | graph := newTestDirectedGraph()
37 | graph.AddEdge("A", "B")
38 | graph.AddEdge("B", "C")
39 | graph.AddEdge("B", "C")
40 |
41 | assert.Equal(t, 2, graph.EdgeCount(), "graph.EdgeCount() should equal 2")
42 | assert.True(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal true")
43 | assert.True(t, graph.EdgeExists("B", "C"), "graph.EdgeExists(B, C) should equal true")
44 | }
45 |
46 | func TestDirectedGraphAddEdgeMissingNodes(t *testing.T) {
47 | graph := newDirectedGraph()
48 | graph.AddEdge("A", "B")
49 | graph.AddEdge("B", "C")
50 |
51 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 2")
52 | assert.Equal(t, 2, graph.EdgeCount(), "graph.EdgeCount() should equal 2")
53 | assert.True(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal true")
54 | assert.True(t, graph.EdgeExists("B", "C"), "graph.EdgeExists(B, C) should equal true")
55 | }
56 |
57 | func TestDirectedGraphRemoveEdge(t *testing.T) {
58 | graph := newTestDirectedGraph()
59 | graph.AddEdge("A", "B")
60 | graph.AddEdge("B", "D")
61 | graph.AddEdge("C", "B")
62 | graph.RemoveEdge("A", "B")
63 | graph.RemoveEdge("C", "B")
64 |
65 | assert.Equal(t, 1, graph.EdgeCount(), "graph.EdgeCount() should equal 1")
66 | assert.False(t, graph.EdgeExists("A", "B"), "graph.EdgeExists(A, B) should equal false")
67 | assert.False(t, graph.EdgeExists("C", "B"), "graph.EdgeExists(C, B) should equal false")
68 | assert.True(t, graph.EdgeExists("B", "D"), "graph.EdgeExists(B, D) should equal true")
69 | }
70 |
71 | func TestDirectedGraphRemoveEdgeMissing(t *testing.T) {
72 | graph := newTestDirectedGraph()
73 | graph.AddEdge("C", "B")
74 | graph.RemoveEdge("D", "A")
75 | graph.RemoveEdge("C", "B")
76 |
77 | assert.Zero(t, graph.EdgeCount(), "graph.EdgeCount() should equal zero")
78 | }
79 |
80 | func TestDirectedGraphHasEdges(t *testing.T) {
81 | graph := newTestDirectedGraph()
82 | graph.AddEdge("A", "C")
83 |
84 | assert.True(t, graph.HasEdges("A"), "graph.HasEdges(A) should equal true")
85 | assert.False(t, graph.HasEdges("B"), "graph.HasEdges(B) should equal false")
86 | assert.True(t, graph.HasEdges("C"), "graph.HasEdges(C) should equal true")
87 | assert.False(t, graph.HasEdges("D"), "graph.HasEdges(D) should equal false")
88 | }
89 |
90 | func TestDirectedGraphIncomingEdgeCount(t *testing.T) {
91 | graph := newTestDirectedGraph()
92 | graph.AddEdge("A", "C")
93 | graph.AddEdge("B", "C")
94 |
95 | assert.Zero(t, graph.IncomingEdgeCount("A"), "graph.IncomingEdgeCount(A) should equal 0")
96 | assert.Zero(t, graph.IncomingEdgeCount("B"), "graph.IncomingEdgeCount(B) should equal 0")
97 | assert.Equal(t, 2, graph.IncomingEdgeCount("C"), "graph.IncomingEdgeCount(C) should equal 1")
98 | assert.Zero(t, graph.IncomingEdgeCount("D"), "graph.IncomingEdgeCount(D) should equal 0")
99 | }
100 |
101 | func TestDirectedGraphOutgoingEdgeCount(t *testing.T) {
102 | graph := newTestDirectedGraph()
103 | graph.AddEdge("A", "B")
104 | graph.AddEdge("A", "C")
105 |
106 | assert.Equal(t, 2, graph.OutgoingEdgeCount("A"), "graph.OutgoingEdgeCount(A) should equal 2")
107 | assert.Zero(t, graph.OutgoingEdgeCount("B"), "graph.OutgoingEdgeCount(B) should equal 0")
108 | assert.Zero(t, graph.OutgoingEdgeCount("C"), "graph.OutgoingEdgeCount(C) should equal 0")
109 | assert.Zero(t, graph.OutgoingEdgeCount("D"), "graph.OutgoingEdgeCount(D) should equal 0")
110 | }
111 |
112 | func TestDirectedGraphRootNodes(t *testing.T) {
113 | graph := newTestDirectedGraph()
114 | graph.AddEdge("A", "B")
115 | graph.AddEdge("B", "C")
116 | graph.AddEdge("D", "C")
117 | graph.AddEdge("E", "C")
118 | graph.AddEdge("F", "E")
119 |
120 | assert.Equal(t, []Key{"A", "D", "F"}, graph.RootNodes(), "graph.RootNodes() should equal [A, D, F]")
121 | }
122 |
123 | func TestDirectedGraphIsolatedNodes(t *testing.T) {
124 | graph := newTestDirectedGraph()
125 | graph.AddEdge("A", "C")
126 |
127 | assert.Equal(t, []Key{"B", "D"}, graph.IsolatedNodes(), "graph.IsolatedNodes() should equal [B, D]")
128 | }
129 |
130 | func TestDirectedGraphAdjacencyMatrix(t *testing.T) {
131 | graph := newTestDirectedGraph()
132 | graph.AddEdge("A", "C")
133 | graph.AddEdge("A", "B")
134 | graph.AddEdge("B", "D")
135 | graph.AddEdge("C", "A")
136 | graph.AddEdge("D", "D")
137 |
138 | expected := map[interface{}]map[interface{}]bool{
139 | "A": map[interface{}]bool{"A": false, "B": true, "C": true, "D": false},
140 | "B": map[interface{}]bool{"A": false, "B": false, "C": false, "D": true},
141 | "C": map[interface{}]bool{"A": true, "B": false, "C": false, "D": false},
142 | "D": map[interface{}]bool{"D": true, "A": false, "B": false, "C": false},
143 | }
144 |
145 | assert.Equal(t, expected, graph.AdjacencyMatrix(), "graph.AdjacencyMatrix() should equal [B, D]")
146 | }
147 |
--------------------------------------------------------------------------------
/di/internal/graphkv/edge.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | type directedEdgeList struct {
4 | outgoingEdges map[Key]*nodeList
5 | incomingEdges map[Key]*nodeList
6 | }
7 |
8 | func newDirectedEdgeList() *directedEdgeList {
9 | return &directedEdgeList{
10 | outgoingEdges: make(map[Key]*nodeList),
11 | incomingEdges: make(map[Key]*nodeList),
12 | }
13 | }
14 |
15 | func (l *directedEdgeList) Copy() *directedEdgeList {
16 | outgoingEdges := make(map[Key]*nodeList, len(l.outgoingEdges))
17 | for node, edges := range l.outgoingEdges {
18 | outgoingEdges[node] = edges.Copy()
19 | }
20 |
21 | incomingEdges := make(map[Key]*nodeList, len(l.incomingEdges))
22 | for node, edges := range l.incomingEdges {
23 | incomingEdges[node] = edges.Copy()
24 | }
25 |
26 | return &directedEdgeList{
27 | outgoingEdges: outgoingEdges,
28 | incomingEdges: incomingEdges,
29 | }
30 | }
31 |
32 | func (l *directedEdgeList) Count() int {
33 | return len(l.outgoingEdges)
34 | }
35 |
36 | func (l *directedEdgeList) HasOutgoingEdges(node Key) bool {
37 | _, ok := l.outgoingEdges[node]
38 | return ok
39 | }
40 |
41 | func (l *directedEdgeList) OutgoingEdgeCount(node Key) int {
42 | if list := l.outgoingNodeList(node, false); list != nil {
43 | return list.Count()
44 | }
45 | return 0
46 | }
47 |
48 | func (l *directedEdgeList) outgoingNodeList(node Key, create bool) *nodeList {
49 | if list, ok := l.outgoingEdges[node]; ok {
50 | return list
51 | }
52 | if create {
53 | list := newNodeList()
54 | l.outgoingEdges[node] = list
55 | return list
56 | }
57 | return nil
58 | }
59 |
60 | func (l *directedEdgeList) OutgoingEdges(node Key) []Key {
61 | if list := l.outgoingNodeList(node, false); list != nil {
62 | return list.Nodes()
63 | }
64 | return nil
65 | }
66 |
67 | func (l *directedEdgeList) HasIncomingEdges(node Key) bool {
68 | _, ok := l.incomingEdges[node]
69 | return ok
70 | }
71 |
72 | func (l *directedEdgeList) IncomingEdgeCount(node Key) int {
73 | if list := l.incomingNodeList(node, false); list != nil {
74 | return list.Count()
75 | }
76 | return 0
77 | }
78 |
79 | func (l *directedEdgeList) incomingNodeList(node Key, create bool) *nodeList {
80 | if list, ok := l.incomingEdges[node]; ok {
81 | return list
82 | }
83 | if create {
84 | list := newNodeList()
85 | l.incomingEdges[node] = list
86 | return list
87 | }
88 | return nil
89 | }
90 |
91 | func (l *directedEdgeList) IncomingEdges(node Key) []Key {
92 | if list := l.incomingNodeList(node, false); list != nil {
93 | return list.Nodes()
94 | }
95 | return nil
96 | }
97 |
98 | func (l *directedEdgeList) Add(from Key, to Key) {
99 | l.outgoingNodeList(from, true).Add(to)
100 | l.incomingNodeList(to, true).Add(from)
101 | }
102 |
103 | func (l *directedEdgeList) Remove(from Key, to Key) {
104 | if list := l.outgoingNodeList(from, false); list != nil {
105 | list.Remove(to)
106 |
107 | if list.Count() == 0 {
108 | delete(l.outgoingEdges, from)
109 | }
110 | }
111 | if list := l.incomingNodeList(to, false); list != nil {
112 | list.Remove(from)
113 |
114 | if list.Count() == 0 {
115 | delete(l.incomingEdges, to)
116 | }
117 | }
118 | }
119 |
120 | func (l *directedEdgeList) Exists(from Key, to Key) bool {
121 | if list := l.outgoingNodeList(from, false); list != nil {
122 | return list.Exists(to)
123 | }
124 | return false
125 | }
126 |
--------------------------------------------------------------------------------
/di/internal/graphkv/errors.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import "fmt"
4 |
5 | // ErrKeyAlreadyExists
6 | type ErrKeyAlreadyExists struct {
7 | Key Key
8 | }
9 |
10 | func (e ErrKeyAlreadyExists) Error() string {
11 | return fmt.Sprintf("%s already exists", e.Key)
12 | }
13 |
14 | // ErrNodeNotExists
15 | type ErrNodeNotExists struct {
16 | Key Key
17 | }
18 |
19 | // ErrNodeNotExists
20 | func (e ErrNodeNotExists) Error() string {
21 | return fmt.Sprintf("%s not exists", e.Key)
22 | }
23 |
--------------------------------------------------------------------------------
/di/internal/graphkv/graph.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | type graph struct {
4 | nodes *nodeList
5 | }
6 |
7 | func newGraph() *graph {
8 | return &graph{
9 | nodes: newNodeList(),
10 | }
11 | }
12 |
13 | // Copy returns a clone of the graph.
14 | func (g *graph) Copy() *graph {
15 | return &graph{
16 | nodes: g.nodes.Copy(),
17 | }
18 | }
19 |
20 | // Nodes returns the graph's nodes.
21 | // The slice is mutable for performance reasons but should not be mutated.
22 | func (g *graph) Nodes() []Key {
23 | return g.nodes.Nodes()
24 | }
25 |
26 | // NodeCount returns the number of nodes.
27 | func (g *graph) NodeCount() int {
28 | return g.nodes.Count()
29 | }
30 |
31 | // AddNode inserts the specified node into the graph.
32 | // A node can be any value, e.g. int, string, pointer to a struct, map etc.
33 | // Duplicate nodes are ignored.
34 | func (g *graph) AddNode(node Key) {
35 | g.AddNodes(node)
36 | }
37 |
38 | // AddNodes inserts the specified nodes into the graph.
39 | // A node can be any value, e.g. int, string, pointer to a struct, map etc.
40 | // Duplicate nodes are ignored.
41 | func (g *graph) AddNodes(nodes ...Key) {
42 | g.nodes.Add(nodes...)
43 | }
44 |
45 | // RemoveNode removes the specified nodes from the graph.
46 | // If the node does not exist within the graph the call will fail silently.
47 | func (g *graph) RemoveNode(node Key) {
48 | g.RemoveNodes(node)
49 | }
50 |
51 | // RemoveNodes removes the specified nodes from the graph.
52 | // If a node does not exist within the graph the call will fail silently.
53 | func (g *graph) RemoveNodes(nodes ...Key) {
54 | g.nodes.Remove(nodes...)
55 | }
56 |
57 | // NodeExists determines whether the specified node exists within the graph.
58 | func (g *graph) NodeExists(node Key) bool {
59 | return g.nodes.Exists(node)
60 | }
61 |
--------------------------------------------------------------------------------
/di/internal/graphkv/graph_test.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNewGraph(t *testing.T) {
12 | graph := newGraph()
13 | assert.NotNil(t, graph, "graph should not be nil")
14 | assert.Zero(t, graph.NodeCount(), "graph.NodeCount() should equal zero")
15 | assert.Empty(t, graph.Nodes(), "graph.Nodes() should equal empty")
16 | }
17 |
18 | func TestGraphAddNode(t *testing.T) {
19 | graph := newGraph()
20 | graph.AddNode("A")
21 | graph.AddNode("B")
22 | graph.AddNode("C")
23 |
24 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 3")
25 | assert.Equal(t, []interface{}{"A", "B", "C"}, graph.Nodes(), "graph.Nodes() should equal [A, B, C]")
26 | }
27 |
28 | func TestGraphAddNodeDuplicate(t *testing.T) {
29 | graph := newGraph()
30 | graph.AddNode("A")
31 | graph.AddNode("B")
32 | graph.AddNode("C")
33 | graph.AddNode("A")
34 |
35 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 3")
36 | assert.Equal(t, []interface{}{"A", "B", "C"}, graph.Nodes(), "graph.Nodes() should equal [A, B, C]")
37 | }
38 |
39 | func BenchmarkGraphAddNodes(b *testing.B) {
40 | for i := 12.0; i <= 20; i++ {
41 | count := int(math.Pow(2, i))
42 |
43 | b.Run(fmt.Sprintf("%d", count), func(b *testing.B) {
44 | graph := newGraph()
45 | for i := 0; i < count; i++ {
46 | graph.AddNode(i)
47 | }
48 | })
49 | }
50 | }
51 |
52 | func TestGraphAddNodes(t *testing.T) {
53 | graph := newGraph()
54 | graph.AddNodes("A", "B", "C")
55 |
56 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 3")
57 | assert.Equal(t, []interface{}{"A", "B", "C"}, graph.Nodes(), "graph.Nodes() should equal [A, B, C]")
58 | }
59 |
60 | func TestGraphRemoveNode(t *testing.T) {
61 | graph := newGraph()
62 | graph.AddNode("A")
63 | graph.AddNode("B")
64 | graph.AddNode("C")
65 | graph.AddNode("D")
66 | graph.RemoveNode("A")
67 | graph.RemoveNode("C")
68 |
69 | assert.Equal(t, 2, graph.NodeCount(), "graph.NodeCount() should equal 2")
70 | assert.Equal(t, []interface{}{"B", "D"}, graph.Nodes(), "graph.Nodes() should equal [B, D]")
71 | }
72 |
73 | func TestGraphRemoveNodeMissing(t *testing.T) {
74 | graph := newGraph()
75 | graph.AddNode("A")
76 | graph.AddNode("B")
77 | graph.AddNode("C")
78 | graph.AddNode("D")
79 | graph.RemoveNode("A")
80 | graph.RemoveNode("A")
81 | graph.RemoveNode("E")
82 |
83 | assert.Equal(t, 3, graph.NodeCount(), "graph.NodeCount() should equal 2")
84 | assert.Equal(t, []interface{}{"B", "C", "D"}, graph.Nodes(), "graph.Nodes() should equal [B, C, D]")
85 | }
86 |
87 | func TestGraphRemoveNodes(t *testing.T) {
88 | graph := newGraph()
89 | graph.AddNode("A")
90 | graph.AddNode("B")
91 | graph.AddNode("C")
92 | graph.AddNode("D")
93 | graph.RemoveNodes("A", "C")
94 |
95 | assert.Equal(t, 2, graph.NodeCount(), "graph.NodeCount() should equal 2")
96 | assert.Equal(t, []interface{}{"B", "D"}, graph.Nodes(), "graph.Nodes() should equal [B, D]")
97 | }
98 |
99 | func TestGraphNodeExists(t *testing.T) {
100 | graph := newGraph()
101 | assert.False(t, graph.NodeExists("A"), "graph.NodeExists(\"A\") should equal false")
102 | assert.False(t, graph.NodeExists("B"), "graph.NodeExists(\"B\") should equal false")
103 |
104 | graph.AddNode("A")
105 | assert.True(t, graph.NodeExists("A"), "graph.NodeExists(\"A\") should equal true")
106 | assert.False(t, graph.NodeExists("B"), "graph.NodeExists(\"B\") should equal false")
107 |
108 | graph.RemoveNode("A")
109 | assert.False(t, graph.NodeExists("A"), "graph.NodeExists(\"A\") should equal false")
110 | assert.False(t, graph.NodeExists("B"), "graph.NodeExists(\"B\") should equal false")
111 | }
112 |
--------------------------------------------------------------------------------
/di/internal/graphkv/graphkv.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import (
4 | "github.com/emicklei/dot"
5 | )
6 |
7 | // Node
8 | type Node struct {
9 | Key Key
10 | Value interface{}
11 | }
12 |
13 | // Graph
14 | type Graph struct {
15 | dag *directedGraph
16 | values map[Key]interface{}
17 | }
18 |
19 | // New
20 | // AddNode
21 | // NodeExists
22 | // Sort
23 | // AddEdge
24 | func New() *Graph {
25 | return &Graph{
26 | dag: newDirectedGraph(),
27 | values: map[Key]interface{}{},
28 | }
29 | }
30 |
31 | // Get
32 | func (g *Graph) Get(key Key) Node {
33 | return Node{Key: key, Value: g.values[key]}
34 | }
35 |
36 | // Replace
37 | func (g *Graph) Replace(key Key, value interface{}) {
38 | g.values[key] = value
39 | }
40 |
41 | // Add
42 | func (g *Graph) Add(key Key, value interface{}) {
43 | g.dag.AddNode(key)
44 | g.values[key] = value
45 | }
46 |
47 | // Edge
48 | func (g *Graph) Edge(from Key, to Key) {
49 | g.dag.AddEdge(from, to)
50 | }
51 |
52 | // Exists
53 | func (g *Graph) Exists(key Key) bool {
54 | return g.dag.NodeExists(key)
55 | }
56 |
57 | // Nodes
58 | func (g *Graph) Nodes() []Node {
59 | var nodes []Node
60 | for _, key := range g.dag.Nodes() {
61 | nodes = append(nodes, Node{key, g.values[key]})
62 | }
63 | return nodes
64 | }
65 |
66 | // CheckCycles
67 | func (g *Graph) CheckCycles() error {
68 | _, err := g.dag.DFSSort()
69 | return err // todo: errors
70 | }
71 |
72 | // DOTGraph
73 | func (g *Graph) DOTGraph() *dot.Graph {
74 | return g.dag.DOTGraph()
75 | }
76 |
--------------------------------------------------------------------------------
/di/internal/graphkv/graphkv_test.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | // todo: tests
4 |
--------------------------------------------------------------------------------
/di/internal/graphkv/key.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | // Key represents a graph node.
4 | type Key = interface{}
5 |
6 | type nodeList struct {
7 | nodes []Key
8 | set map[Key]bool
9 | }
10 |
11 | func newNodeList() *nodeList {
12 | return &nodeList{
13 | nodes: make([]Key, 0),
14 | set: make(map[Key]bool),
15 | }
16 | }
17 |
18 | func (l *nodeList) Copy() *nodeList {
19 | nodes := make([]Key, len(l.nodes))
20 | copy(nodes, l.nodes)
21 |
22 | set := make(map[Key]bool, len(nodes))
23 | for _, node := range nodes {
24 | set[node] = true
25 | }
26 |
27 | return &nodeList{
28 | nodes: nodes,
29 | set: set,
30 | }
31 | }
32 |
33 | func (l *nodeList) Nodes() []Key {
34 | return l.nodes
35 | }
36 |
37 | func (l *nodeList) Count() int {
38 | return len(l.nodes)
39 | }
40 |
41 | func (l *nodeList) Exists(node Key) bool {
42 | _, ok := l.set[node]
43 | return ok
44 | }
45 |
46 | func (l *nodeList) Add(nodes ...Key) {
47 | for _, node := range nodes {
48 | if l.Exists(node) {
49 | continue
50 | }
51 |
52 | l.nodes = append(l.nodes, node)
53 | l.set[node] = true
54 | }
55 | }
56 |
57 | func (l *nodeList) Remove(nodes ...Key) {
58 | for i := len(l.nodes) - 1; i >= 0; i-- {
59 | for j, node := range nodes {
60 | if l.nodes[i] == node {
61 | copy(l.nodes[i:], l.nodes[i+1:])
62 | l.nodes[len(l.nodes)-1] = nil
63 | l.nodes = l.nodes[:len(l.nodes)-1]
64 |
65 | delete(l.set, node)
66 |
67 | copy(nodes[j:], nodes[j+1:])
68 | nodes[len(nodes)-1] = nil
69 | nodes = nodes[:len(nodes)-1]
70 |
71 | break
72 | }
73 | }
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/di/internal/graphkv/output.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/emicklei/dot"
7 | )
8 |
9 | // NodeVisualizer
10 | type NodeVisualizer interface {
11 | Visualize(node *dot.Node)
12 | SubGraph() string
13 | IsAlwaysVisible() bool
14 | }
15 |
16 | // DOTGraph returns a textual representation of the graph in the DOT graph
17 | // description language.
18 | func (g *directedGraph) DOTGraph() *dot.Graph {
19 | root := dot.NewGraph(dot.Directed)
20 | root.Attr("splines", "ortho")
21 |
22 | subgraphs := make(map[string]*dot.Graph)
23 | itemsByNode := make(map[Key]dot.Node)
24 | for _, node := range g.Nodes() {
25 | nv := node.(NodeVisualizer)
26 |
27 | if !g.HasOutgoingEdges(node) && !nv.IsAlwaysVisible() {
28 | continue
29 | }
30 |
31 | name := fmt.Sprintf("%s", node)
32 | subgraph, ok := subgraphs[nv.SubGraph()]
33 | if !ok {
34 | subgraph = root.Subgraph(nv.SubGraph(), dot.ClusterOption{})
35 | subgraphs[nv.SubGraph()] = subgraph
36 | applySubGraphStyle(subgraph)
37 | }
38 | item := subgraph.Node(name)
39 | nv.Visualize(&item)
40 | itemsByNode[node] = item
41 |
42 | }
43 |
44 | for fromNode, fromItem := range itemsByNode {
45 | for _, toNode := range g.OutgoingEdges(fromNode) {
46 | if toItem, ok := itemsByNode[toNode]; ok {
47 | root.Edge(fromItem, toItem).Attr("color", "#949494")
48 | }
49 | }
50 | }
51 |
52 | return root
53 | }
54 |
55 | func applySubGraphStyle(graph *dot.Graph) {
56 | graph.Attr("label", "")
57 | graph.Attr("style", "rounded")
58 | graph.Attr("bgcolor", "#E8E8E8")
59 | graph.Attr("color", "lightgrey")
60 | graph.Attr("fontname", "COURIER")
61 | graph.Attr("fontcolor", "#46494C")
62 | }
63 |
--------------------------------------------------------------------------------
/di/internal/graphkv/output_test.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
--------------------------------------------------------------------------------
/di/internal/graphkv/sort.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import (
4 | "errors"
5 | )
6 |
7 | // Errors relating to the DFSSorter.
8 | var (
9 | ErrCyclicGraph = errors.New("the graph cannot be cyclic")
10 | )
11 |
12 | // DFSSorter topologically sorts a directed graph's nodes based on the
13 | // directed edges between them using the Depth-first search algorithm.
14 | type DFSSorter struct {
15 | graph *directedGraph
16 | sorted []Key
17 | visiting map[Key]bool
18 | discovered map[Key]bool
19 | }
20 |
21 | // NewDFSSorter returns a new DFS sorter.
22 | func NewDFSSorter(graph *directedGraph) *DFSSorter {
23 | return &DFSSorter{
24 | graph: graph,
25 | }
26 | }
27 |
28 | func (s *DFSSorter) init() {
29 | s.sorted = make([]Key, 0, s.graph.NodeCount())
30 | s.visiting = make(map[Key]bool)
31 | s.discovered = make(map[Key]bool, s.graph.NodeCount())
32 | }
33 |
34 | // Sort returns the sorted nodes.
35 | func (s *DFSSorter) Sort() ([]Key, error) {
36 | s.init()
37 |
38 | // > while there are unmarked nodes do
39 | for _, node := range s.graph.Nodes() {
40 | if err := s.visit(node); err != nil {
41 | return nil, err
42 | }
43 | }
44 |
45 | // as the nodes were appended to the slice for performance reasons,
46 | // rather than prepended as correctly stated by the algorithm,
47 | // we need to reverse the sorted slice
48 | for i, j := 0, len(s.sorted)-1; i < j; i, j = i+1, j-1 {
49 | s.sorted[i], s.sorted[j] = s.sorted[j], s.sorted[i]
50 | }
51 |
52 | return s.sorted, nil
53 | }
54 |
55 | // See https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
56 | func (s *DFSSorter) visit(node Key) error {
57 | // > if n has a permanent mark then return
58 | if discovered, ok := s.discovered[node]; ok && discovered {
59 | return nil
60 | }
61 | // > if n has a temporary mark then stop (not a DAG)
62 | if visiting, ok := s.visiting[node]; ok && visiting {
63 | return ErrCyclicGraph
64 | }
65 |
66 | // > mark n temporarily
67 | s.visiting[node] = true
68 |
69 | // > for each node m with an edge from n to m do
70 | for _, outgoing := range s.graph.OutgoingEdges(node) {
71 | if err := s.visit(outgoing); err != nil {
72 | return err
73 | }
74 | }
75 |
76 | s.discovered[node] = true
77 | delete(s.visiting, node)
78 |
79 | s.sorted = append(s.sorted, node)
80 | return nil
81 | }
82 |
83 | // DFSSort returns the graph's nodes in topological order based on the
84 | // directed edges between them using the Depth-first search algorithm.
85 | func (g *directedGraph) DFSSort() ([]Key, error) {
86 | sorter := NewDFSSorter(g)
87 | return sorter.Sort()
88 | }
89 |
90 | // Errors relating to the CoffmanGrahamSorter.
91 | var (
92 | ErrDependencyOrder = errors.New("the topological dependency order is incorrect")
93 | )
94 |
95 | // CoffmanGrahamSorter sorts a graph's nodes into a sequence of levels,
96 | // arranging so that a node which comes after another in the order is
97 | // assigned to a lower level, and that a level never exceeds the width.
98 | // See https://en.wikipedia.org/wiki/Coffman–Graham_algorithm
99 | type CoffmanGrahamSorter struct {
100 | graph *directedGraph
101 | width int
102 | }
103 |
104 | // NewCoffmanGrahamSorter returns a new Coffman-Graham sorter.
105 | func NewCoffmanGrahamSorter(graph *directedGraph, width int) *CoffmanGrahamSorter {
106 | return &CoffmanGrahamSorter{
107 | graph: graph,
108 | width: width,
109 | }
110 | }
111 |
112 | // Sort returns the sorted nodes.
113 | func (s *CoffmanGrahamSorter) Sort() ([][]Key, error) {
114 | // create a copy of the graph and remove transitive edges
115 | reduced := s.graph.Copy()
116 | reduced.RemoveTransitives()
117 |
118 | // topologically sort the graph nodes
119 | nodes, err := reduced.DFSSort()
120 | if err != nil {
121 | return nil, err
122 | }
123 |
124 | layers := make([][]Key, 0)
125 | levels := make(map[Key]int, len(nodes))
126 |
127 | for _, node := range nodes {
128 | dependantLevel := -1
129 | for _, dependant := range reduced.IncomingEdges(node) {
130 | level, ok := levels[dependant]
131 | if !ok {
132 | return nil, ErrDependencyOrder
133 | }
134 | if level > dependantLevel {
135 | dependantLevel = level
136 | }
137 | }
138 |
139 | level := -1
140 | // find the first unfilled layer outgoing the dependent layer
141 | // skip this if the dependent layer is the last
142 | if dependantLevel < len(layers)-1 {
143 | for i := dependantLevel + 1; i < len(layers); i++ {
144 | // ensure the layer doesn't exceed the desired width
145 | if len(layers[i]) < s.width {
146 | level = i
147 | break
148 | }
149 | }
150 | }
151 | // create a new layer new none was found
152 | if level == -1 {
153 | layers = append(layers, make([]Key, 0, 1))
154 | level = len(layers) - 1
155 | }
156 |
157 | layers[level] = append(layers[level], node)
158 | levels[node] = level
159 | }
160 |
161 | return layers, nil
162 | }
163 |
164 | // CoffmanGrahamSort sorts the graph's nodes into a sequence of levels,
165 | // arranging so that a node which comes after another in the order is
166 | // assigned to a lower level, and that a level never exceeds the specified width.
167 | func (g *directedGraph) CoffmanGrahamSort(width int) ([][]Key, error) {
168 | sorter := NewCoffmanGrahamSorter(g, width)
169 | return sorter.Sort()
170 | }
171 |
--------------------------------------------------------------------------------
/di/internal/graphkv/sort_test.go:
--------------------------------------------------------------------------------
1 | package graphkv
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestDFSSorter(t *testing.T) {
10 | graph := newDirectedGraph()
11 | graph.AddNodes(0, 1, 2, 3, 4, 5, 6, 7)
12 | graph.AddEdge(0, 2)
13 | graph.AddEdge(1, 2)
14 | graph.AddEdge(1, 5)
15 | graph.AddEdge(1, 6)
16 | graph.AddEdge(2, 5)
17 | graph.AddEdge(3, 5)
18 | graph.AddEdge(5, 6)
19 | graph.AddEdge(5, 7)
20 |
21 | sorted, err := graph.DFSSort()
22 |
23 | assert.NoError(t, err, "graph.DFSSort() error should be nil")
24 | assert.Equal(t, []Key{4, 3, 1, 0, 2, 5, 7, 6}, sorted, "graph.DFSSort() nodes should equal [4, 3, 1, 0, 2, 5, 7, 6]")
25 | }
26 |
27 | func TestDFSSorterCyclic(t *testing.T) {
28 | graph := newDirectedGraph()
29 | graph.AddNodes(0, 1)
30 | graph.AddEdge(0, 1)
31 | graph.AddEdge(1, 0)
32 |
33 | sorted, err := graph.DFSSort()
34 |
35 | assert.EqualError(t, err, ErrCyclicGraph.Error(), "graph.DFSSort() error should be ErrCyclicGraph")
36 | assert.Nil(t, sorted, "graph.DFSSort() nodes should be nil")
37 | }
38 |
39 | func TestCoffmanGrahamSorter(t *testing.T) {
40 | graph := newDirectedGraph()
41 |
42 | graph.AddNodes(0, 1, 2, 3, 4, 5, 6, 7, 8)
43 | graph.AddEdge(0, 2)
44 | graph.AddEdge(0, 5)
45 | graph.AddEdge(1, 2)
46 | graph.AddEdge(2, 3)
47 | graph.AddEdge(2, 4)
48 | graph.AddEdge(3, 6)
49 | graph.AddEdge(4, 6)
50 | graph.AddEdge(5, 7)
51 | graph.AddEdge(6, 7)
52 | graph.AddEdge(6, 8)
53 |
54 | sorted, err := graph.CoffmanGrahamSort(2)
55 |
56 | assert.NoError(t, err, "graph.CoffmanGrahamSort(2)0 error should be nil")
57 | assert.Equal(t, [][]Key{
58 | []Key{1, 0},
59 | []Key{5, 2},
60 | []Key{4, 3},
61 | []Key{6},
62 | []Key{8, 7},
63 | }, sorted, "graph.CoffmanGrahamSort(2) nodes should equal [[1, 0], [5, 2], [4, 3], [6], [8, 7]]")
64 | }
65 |
66 | func TestCoffmanGrahamSorterCyclic(t *testing.T) {
67 | graph := newDirectedGraph()
68 |
69 | graph.AddNodes(0, 1, 2, 3, 4, 5, 6, 7, 8)
70 | graph.AddEdge(0, 2)
71 | graph.AddEdge(0, 5)
72 | graph.AddEdge(1, 2)
73 | graph.AddEdge(2, 0) // cyclic edge
74 | graph.AddEdge(2, 3)
75 | graph.AddEdge(2, 4)
76 | graph.AddEdge(3, 6)
77 | graph.AddEdge(4, 6)
78 | graph.AddEdge(5, 7)
79 | graph.AddEdge(6, 7)
80 | graph.AddEdge(6, 8)
81 |
82 | sorted, err := graph.CoffmanGrahamSort(2)
83 |
84 | assert.EqualError(t, err, ErrCyclicGraph.Error(), "graph.CoffmanGrahamSort(2) error should be ErrCyclicGraph")
85 | assert.Nil(t, sorted, "graph.CoffmanGrahamSort(2) nodes should be nil")
86 | }
87 |
--------------------------------------------------------------------------------
/di/internal/reflection/func.go:
--------------------------------------------------------------------------------
1 | package reflection
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "runtime"
7 | )
8 |
9 | // IsFunc
10 | func IsFunc(value interface{}) bool {
11 | return reflect.ValueOf(value).Kind() == reflect.Func
12 | }
13 |
14 | // Func
15 | type Func struct {
16 | Name string
17 | reflect.Type
18 | reflect.Value
19 | }
20 |
21 | // InspectFunction
22 | func InspectFunction(fn interface{}) *Func {
23 | if !IsFunc(fn) {
24 | panic(fmt.Sprintf("%s: not a function", reflect.TypeOf(fn).Kind())) // todo: improve message
25 | }
26 |
27 | val := reflect.ValueOf(fn)
28 | fnpc := runtime.FuncForPC(val.Pointer())
29 |
30 | return &Func{
31 | Name: fnpc.Name(),
32 | Type: val.Type(),
33 | Value: val,
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/di/internal/reflection/iface.go:
--------------------------------------------------------------------------------
1 | package reflection
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | )
7 |
8 | // InspectInterfacePtr
9 | func InspectInterfacePtr(iface interface{}) *Interface {
10 | typ := reflect.TypeOf(iface)
11 | if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Interface {
12 | panic(fmt.Sprintf("%s: not a pointer to interface", typ)) // todo: improve message
13 | }
14 |
15 | return &Interface{
16 | Name: typ.Elem().Name(),
17 | Type: typ.Elem(),
18 | }
19 | }
20 |
21 | // Interface
22 | type Interface struct {
23 | Name string
24 | Type reflect.Type
25 | }
26 |
--------------------------------------------------------------------------------
/di/internal/reflection/reflection.go:
--------------------------------------------------------------------------------
1 | package reflection
2 |
3 | import "reflect"
4 |
5 | var errorInterface = reflect.TypeOf(new(error)).Elem()
6 |
7 | // IsError
8 | func IsError(typ reflect.Type) bool {
9 | return typ.Implements(errorInterface)
10 | }
11 |
12 | // IsCleanup
13 | func IsCleanup(typ reflect.Type) bool {
14 | return typ.Kind() == reflect.Func && typ.NumIn() == 0 && typ.NumOut() == 0
15 | }
16 |
17 | // IsPtr
18 | func IsPtr(value interface{}) bool {
19 | return reflect.ValueOf(value).Kind() == reflect.Ptr
20 | }
21 |
--------------------------------------------------------------------------------
/di/invoker.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 |
7 | "github.com/defval/inject/v2/di/internal/reflection"
8 | )
9 |
10 | type invokerType int
11 |
12 | const (
13 | invokerUnknown invokerType = iota
14 | invokerStd // func (deps) {}
15 | invokerError // func (deps) error {}
16 | )
17 |
18 | func determineInvokerType(fn *reflection.Func) (invokerType, error) {
19 | if fn.NumOut() == 0 {
20 | return invokerStd, nil
21 | }
22 | if fn.NumOut() == 1 && reflection.IsError(fn.Out(0)) {
23 | return invokerError, nil
24 | }
25 | return invokerUnknown, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", fn.Type)
26 | }
27 |
28 | type invoker struct {
29 | typ invokerType
30 | fn *reflection.Func
31 | }
32 |
33 | func newInvoker(fn interface{}) (*invoker, error) {
34 | if fn == nil {
35 | return nil, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", "nil")
36 | }
37 | if !reflection.IsFunc(fn) {
38 | return nil, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", reflect.ValueOf(fn).Type())
39 | }
40 | ifn := reflection.InspectFunction(fn)
41 | typ, err := determineInvokerType(ifn)
42 | if err != nil {
43 | return nil, err
44 | }
45 | return &invoker{
46 | typ: typ,
47 | fn: reflection.InspectFunction(fn),
48 | }, nil
49 | }
50 |
51 | func (i *invoker) Invoke(c *Container) error {
52 | plist := i.parameters()
53 | values, err := plist.Resolve(c)
54 | if err != nil {
55 | return fmt.Errorf("could not resolve invoke parameters: %s", err)
56 | }
57 | results := i.fn.Call(values)
58 | if len(results) == 0 {
59 | return nil
60 | }
61 | if results[0].Interface() == nil {
62 | return nil
63 | }
64 | return results[0].Interface().(error)
65 | }
66 |
67 | func (i *invoker) parameters() parameterList {
68 | var plist parameterList
69 | for j := 0; j < i.fn.NumIn(); j++ {
70 | ptype := i.fn.In(j)
71 | p := parameter{
72 | res: ptype,
73 | optional: false,
74 | embed: isEmbedParameter(ptype),
75 | }
76 | plist = append(plist, p)
77 | }
78 | return plist
79 | }
80 |
--------------------------------------------------------------------------------
/di/key.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 |
7 | "github.com/emicklei/dot"
8 | )
9 |
10 | // key is a id of provider in container
11 | type key struct {
12 | name string
13 | res reflect.Type
14 | typ providerType
15 | }
16 |
17 | // String represent resultKey as string.
18 | func (k key) String() string {
19 | if k.name == "" {
20 | return fmt.Sprintf("%s", k.res)
21 | }
22 | return fmt.Sprintf("%s[%s]", k.res, k.name)
23 | }
24 |
25 | // IsAlwaysVisible
26 | func (k key) IsAlwaysVisible() bool {
27 | return k.typ == ptConstructor
28 | }
29 |
30 | // Package
31 | func (k key) SubGraph() string {
32 | var pkg string
33 | switch k.res.Kind() {
34 | case reflect.Slice, reflect.Ptr:
35 | pkg = k.res.Elem().PkgPath()
36 | default:
37 | pkg = k.res.PkgPath()
38 | }
39 |
40 | return pkg
41 | }
42 |
43 | // Visualize
44 | func (k key) Visualize(node *dot.Node) {
45 | node.Label(k.String())
46 | node.Attr("fontname", "COURIER")
47 | node.Attr("style", "filled")
48 | node.Attr("fontcolor", "white")
49 | switch k.typ {
50 | case ptConstructor:
51 | node.Attr("shape", "box")
52 | node.Attr("color", "#46494C")
53 | case ptGroup:
54 | node.Attr("shape", "doubleoctagon")
55 | node.Attr("color", "#E54B4B")
56 | case ptInterface:
57 | node.Attr("color", "#2589BD")
58 | case ptEmbedParameter:
59 | node.Attr("shape", "box")
60 | node.Attr("color", "#E5984B")
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/di/options.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | // ExtractOption
4 | type ProvideOption interface {
5 | apply(params *ProvideParams)
6 | }
7 |
8 | type provideOption func(params *ProvideParams)
9 |
10 | func (o provideOption) apply(params *ProvideParams) {
11 | o(params)
12 | }
13 |
14 | // ProvideParams is a `Provide()` method options. Name is a unique identifier of type instance. Provider is a constructor
15 | // function. Interfaces is a interface that implements a provider result type.
16 | type ProvideParams struct {
17 | Name string
18 | Interfaces []interface{}
19 | Parameters ParameterBag
20 | IsPrototype bool
21 | }
22 |
23 | func (p ProvideParams) apply(params *ProvideParams) {
24 | *params = p
25 | }
26 |
27 | // As
28 | func As(interfaces ...interface{}) ProvideOption {
29 | return provideOption(func(params *ProvideParams) {
30 | params.Interfaces = append(params.Interfaces, interfaces...)
31 | })
32 | }
33 |
34 | // InvokeParams is a invoke parameters.
35 | type InvokeParams struct{}
36 |
37 | func (p InvokeParams) apply(params *InvokeParams) {
38 | *params = p
39 | }
40 |
41 | // InvokeOption
42 | type InvokeOption interface {
43 | apply(params *InvokeParams)
44 | }
45 |
46 | // ExtractParams
47 | type ExtractParams struct {
48 | Name string
49 | }
50 |
51 | func (p ExtractParams) apply(params *ExtractParams) {
52 | *params = p
53 | }
54 |
55 | // ExtractOption
56 | type ExtractOption interface {
57 | apply(params *ExtractParams)
58 | }
59 |
60 | type extractOption func(params *ExtractParams)
61 |
62 | func (o extractOption) apply(params *ExtractParams) {
63 | o(params)
64 | }
65 |
--------------------------------------------------------------------------------
/di/panic.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import "fmt"
4 |
5 | func panicf(format string, a ...interface{}) {
6 | panic(fmt.Sprintf(format, a...))
7 | }
8 |
--------------------------------------------------------------------------------
/di/parameter.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "reflect"
5 | )
6 |
7 | // Parameter
8 | type Parameter struct {
9 | internalParameter
10 | }
11 |
12 | // parameterRequired
13 | type parameter struct {
14 | name string
15 | res reflect.Type
16 | optional bool
17 | embed bool
18 | }
19 |
20 | func (p parameter) String() string {
21 | return key{name: p.name, res: p.res}.String()
22 | }
23 |
24 | // ResolveProvider resolves parameter provider
25 | func (p parameter) ResolveProvider(c *Container) (internalProvider, bool) {
26 | for _, pt := range providerLookupSequence {
27 | k := key{
28 | name: p.name,
29 | res: p.res,
30 | typ: pt,
31 | }
32 | if !c.graph.Exists(k) {
33 | continue
34 | }
35 | node := c.graph.Get(k)
36 | return node.Value.(internalProvider), true
37 | }
38 | return nil, false
39 | }
40 |
41 | func (p parameter) ResolveValue(c *Container) (reflect.Value, error) {
42 | provider, exists := p.ResolveProvider(c)
43 | if !exists && p.optional {
44 | return reflect.New(p.res).Elem(), nil
45 | }
46 | if !exists {
47 | return reflect.Value{}, ErrParameterProviderNotFound{param: p}
48 | }
49 | pl := provider.ParameterList()
50 | values, err := pl.Resolve(c)
51 | if err != nil {
52 | return reflect.Value{}, err
53 | }
54 | value, cleanup, err := provider.Provide(values...)
55 | if err != nil {
56 | return value, ErrParameterProvideFailed{k: provider.Key(), err: err}
57 | }
58 | if cleanup != nil {
59 | c.cleanups = append(c.cleanups, cleanup)
60 | }
61 | return value, nil
62 | }
63 |
64 | // isEmbedParameter
65 | func isEmbedParameter(typ reflect.Type) bool {
66 | return typ.Kind() == reflect.Struct && typ.Implements(parameterInterface)
67 | }
68 |
69 | // internalParameter
70 | type internalParameter interface {
71 | isDependencyInjectionParameter()
72 | }
73 |
74 | // parameterInterface
75 | var parameterInterface = reflect.TypeOf(new(internalParameter)).Elem()
76 |
--------------------------------------------------------------------------------
/di/parameter_bag.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | )
7 |
8 | // createParameterBugProvider
9 | func createParameterBugProvider(key key, parameters ParameterBag) internalProvider {
10 | return newProviderConstructor(key.String(), func() ParameterBag { return parameters })
11 | }
12 |
13 | // parameterBagType
14 | var parameterBagType = reflect.TypeOf(ParameterBag{})
15 |
16 | // ParameterBag
17 | type ParameterBag map[string]interface{}
18 |
19 | // Exists
20 | func (b ParameterBag) Exists(key string) bool {
21 | _, ok := b[key]
22 | return ok
23 | }
24 |
25 | // Get
26 | func (b ParameterBag) Get(key string) (interface{}, bool) {
27 | value, ok := b[key]
28 | return value, ok
29 | }
30 |
31 | // String
32 | func (b ParameterBag) String(key string) (string, bool) {
33 | value, ok := b[key].(string)
34 | return value, ok
35 | }
36 |
37 | // Int64
38 | func (b ParameterBag) Int64(key string) (int64, bool) {
39 | value, ok := b[key].(int64)
40 | return value, ok
41 | }
42 |
43 | // Int
44 | func (b ParameterBag) Int(key string) (int, bool) {
45 | value, ok := b[key].(int)
46 | return value, ok
47 | }
48 |
49 | // Float64
50 | func (b ParameterBag) Float64(key string) (float64, bool) {
51 | value, ok := b[key].(float64)
52 | return value, ok
53 | }
54 |
55 | // Require
56 | func (b ParameterBag) Require(key string) interface{} {
57 | value, ok := b[key]
58 | if !ok {
59 | panic(fmt.Sprintf("value for string key `%s` not found", key))
60 | }
61 | return value
62 | }
63 |
64 | // RequireString
65 | func (b ParameterBag) RequireString(key string) string {
66 | value, ok := b[key].(string)
67 | if !ok {
68 | panic(fmt.Sprintf("value for string key `%s` not found", key))
69 | }
70 | return value
71 | }
72 |
73 | // RequireInt64
74 | func (b ParameterBag) RequireInt64(key string) int64 {
75 | value, ok := b[key].(int64)
76 | if !ok {
77 | panic(fmt.Sprintf("value for string key `%s` not found", key))
78 | }
79 | return value
80 | }
81 |
82 | // RequireInt
83 | func (b ParameterBag) RequireInt(key string) int {
84 | value, ok := b[key].(int)
85 | if !ok {
86 | panic(fmt.Sprintf("value for string key `%s` not found", key))
87 | }
88 | return value
89 | }
90 |
91 | // RequireFloat64
92 | func (b ParameterBag) RequireFloat64(key string) float64 {
93 | value, ok := b[key].(float64)
94 | if !ok {
95 | panic(fmt.Sprintf("value for string key `%s` not found", key))
96 | }
97 | return value
98 | }
99 |
--------------------------------------------------------------------------------
/di/parameter_bag_test.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/require"
7 | )
8 |
9 | func TestParameterBag_Get(t *testing.T) {
10 | t.Run("key exists", func(t *testing.T) {
11 | pb := ParameterBag{
12 | "get": "get",
13 | }
14 |
15 | v, ok := pb.Get("get")
16 | require.Equal(t, "get", v)
17 | require.True(t, ok)
18 | })
19 |
20 | t.Run("key not exists", func(t *testing.T) {
21 | pb := ParameterBag{}
22 |
23 | v, ok := pb.Get("get")
24 | require.Equal(t, nil, v)
25 | require.False(t, ok)
26 | })
27 | }
28 |
29 | func TestParameterBag_GetType(t *testing.T) {
30 | t.Run("key exists", func(t *testing.T) {
31 | pb := ParameterBag{
32 | "string": "string",
33 | "int64": int64(64),
34 | "int": int(64),
35 | "float64": float64(64),
36 | }
37 |
38 | s, ok := pb.String("string")
39 | require.Equal(t, "string", s)
40 | require.True(t, ok)
41 |
42 | i64, ok := pb.Int64("int64")
43 | require.Equal(t, int64(64), i64)
44 | require.True(t, ok)
45 |
46 | i, ok := pb.Int("int")
47 | require.Equal(t, int(64), i)
48 | require.True(t, ok)
49 |
50 | f64, ok := pb.Float64("float64")
51 | require.Equal(t, float64(64), f64)
52 | require.True(t, ok)
53 | })
54 |
55 | t.Run("key not exists", func(t *testing.T) {
56 | pb := ParameterBag{}
57 |
58 | s, ok := pb.String("string")
59 | require.Equal(t, "", s)
60 | require.False(t, ok)
61 |
62 | i64, ok := pb.Int64("int64")
63 | require.Equal(t, int64(0), i64)
64 | require.False(t, ok)
65 |
66 | i, ok := pb.Int("int")
67 | require.Equal(t, int(0), i)
68 | require.False(t, ok)
69 |
70 | f64, ok := pb.Float64("float64")
71 | require.Equal(t, float64(0), f64)
72 | require.False(t, ok)
73 | })
74 | }
75 |
76 | func TestParameterBag_Exists(t *testing.T) {
77 | pb := ParameterBag{}
78 |
79 | require.False(t, pb.Exists("not existing key"))
80 | }
81 |
82 | func TestParameterBag_Require(t *testing.T) {
83 | t.Run("key exists", func(t *testing.T) {
84 | pb := ParameterBag{
85 | "require": "require",
86 | }
87 |
88 | value := pb.Require("require")
89 | require.Equal(t, "require", value)
90 | })
91 |
92 | t.Run("key not exists", func(t *testing.T) {
93 | pb := ParameterBag{}
94 |
95 | require.PanicsWithValue(t, "value for string key `not existing key` not found", func() {
96 | pb.Require("not existing key")
97 | })
98 | })
99 | }
100 |
101 | func TestParameterBag_RequireTypes(t *testing.T) {
102 | t.Run("key exists", func(t *testing.T) {
103 | pb := ParameterBag{
104 | "string": "string",
105 | "int64": int64(64),
106 | "int": int(64),
107 | "float64": float64(64),
108 | }
109 |
110 | s := pb.RequireString("string")
111 | require.Equal(t, "string", s)
112 |
113 | i64 := pb.RequireInt64("int64")
114 | require.Equal(t, int64(64), i64)
115 |
116 | i := pb.RequireInt("int")
117 | require.Equal(t, int(64), i)
118 |
119 | f64 := pb.RequireFloat64("float64")
120 | require.Equal(t, float64(64), f64)
121 | })
122 |
123 | t.Run("key not exists", func(t *testing.T) {
124 | pb := ParameterBag{}
125 |
126 | require.PanicsWithValue(t, "value for string key `string` not found", func() {
127 | pb.RequireString("string")
128 | })
129 |
130 | require.PanicsWithValue(t, "value for string key `int64` not found", func() {
131 | pb.RequireInt64("int64")
132 | })
133 |
134 | require.PanicsWithValue(t, "value for string key `int` not found", func() {
135 | pb.RequireInt("int")
136 | })
137 |
138 | require.PanicsWithValue(t, "value for string key `float64` not found", func() {
139 | pb.RequireFloat64("float64")
140 | })
141 | })
142 | }
143 |
--------------------------------------------------------------------------------
/di/parameter_list.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import "reflect"
4 |
5 | // parameterList
6 | type parameterList []parameter
7 |
8 | // ResolveValues loads all parameters presented in parameter list.
9 | func (pl parameterList) Resolve(c *Container) ([]reflect.Value, error) {
10 | var values []reflect.Value
11 | for _, p := range pl {
12 | value, err := p.ResolveValue(c)
13 | if err != nil {
14 | return nil, err
15 | }
16 | values = append(values, value)
17 | }
18 |
19 | return values, nil
20 | }
21 |
--------------------------------------------------------------------------------
/di/provider.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import "reflect"
4 |
5 | // provider lookup sequence
6 | var providerLookupSequence = []providerType{ptConstructor, ptInterface, ptGroup, ptEmbedParameter}
7 |
8 | // providerType
9 | type providerType int
10 |
11 | const (
12 | ptUnknown providerType = iota
13 | ptConstructor
14 | ptInterface
15 | ptGroup
16 | ptEmbedParameter
17 | )
18 |
19 | // provider
20 | type internalProvider interface {
21 | // The identity of result type.
22 | Key() key
23 | // ParameterList returns array of dependencies.
24 | ParameterList() parameterList
25 | // Provide provides value from provided parameters.
26 | Provide(values ...reflect.Value) (reflect.Value, func(), error)
27 | }
28 |
--------------------------------------------------------------------------------
/di/provider_ctor.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "reflect"
7 |
8 | "github.com/defval/inject/v2/di/internal/reflection"
9 | )
10 |
11 | type ctorType int
12 |
13 | const (
14 | ctorUnknown ctorType = iota // unknown ctor signature
15 | ctorStd // (deps) (result)
16 | ctorError // (deps) (result, error)
17 | ctorCleanup // (deps) (result, cleanup)
18 | ctorCleanupError // (deps) (result, cleanup, error)
19 | )
20 |
21 | // newProviderConstructor
22 | func newProviderConstructor(name string, ctor interface{}) *providerConstructor {
23 | if ctor == nil {
24 | panicf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", "nil")
25 | }
26 | if !reflection.IsFunc(ctor) {
27 | panicf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", reflect.ValueOf(ctor).Type())
28 | }
29 | fn := reflection.InspectFunction(ctor)
30 | ctorType := determineCtorType(fn)
31 | return &providerConstructor{
32 | name: name,
33 | ctor: fn,
34 | ctorType: ctorType,
35 | }
36 | }
37 |
38 | // providerConstructor
39 | type providerConstructor struct {
40 | name string
41 | ctor *reflection.Func
42 | ctorType ctorType
43 | clean *reflection.Func
44 | }
45 |
46 | func (c providerConstructor) Key() key {
47 | return key{
48 | name: c.name,
49 | res: c.ctor.Out(0),
50 | typ: ptConstructor,
51 | }
52 | }
53 |
54 | func (c providerConstructor) ParameterList() parameterList {
55 | var plist parameterList
56 | for i := 0; i < c.ctor.NumIn(); i++ {
57 | ptype := c.ctor.In(i)
58 | var name string
59 | if ptype == parameterBagType {
60 | name = c.Key().String()
61 | }
62 | p := parameter{
63 | name: name,
64 | res: ptype,
65 | optional: false,
66 | embed: isEmbedParameter(ptype),
67 | }
68 | plist = append(plist, p)
69 | }
70 | return plist
71 | }
72 |
73 | // Provide
74 | func (c *providerConstructor) Provide(values ...reflect.Value) (reflect.Value, func(), error) {
75 | out := callResult(c.ctor.Call(values))
76 | switch c.ctorType {
77 | case ctorStd:
78 | return out.instance(), nil, nil
79 | case ctorError:
80 | return out.instance(), nil, out.error(1)
81 | case ctorCleanup:
82 | return out.instance(), out.cleanup(), nil
83 | case ctorCleanupError:
84 | return out.instance(), out.cleanup(), out.error(2)
85 | }
86 | return reflect.Value{}, nil, errors.New("you found a bug, please create new issue for " +
87 | "this: https://github.com/defval/inject/issues/new")
88 | }
89 |
90 | // determineCtorType
91 | func determineCtorType(fn *reflection.Func) ctorType {
92 | if fn.NumOut() == 1 {
93 | return ctorStd
94 | }
95 | if fn.NumOut() == 2 {
96 | if reflection.IsError(fn.Out(1)) {
97 | return ctorError
98 | }
99 | if reflection.IsCleanup(fn.Out(1)) {
100 | return ctorCleanup
101 | }
102 | }
103 | if fn.NumOut() == 3 && reflection.IsCleanup(fn.Out(1)) && reflection.IsError(fn.Out(2)) {
104 | return ctorCleanupError
105 | }
106 | panic(fmt.Sprintf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", fn.Name))
107 | }
108 |
109 | // callResult
110 | type callResult []reflect.Value
111 |
112 | func (r callResult) instance() reflect.Value {
113 | return r[0]
114 | }
115 |
116 | func (r callResult) cleanup() func() {
117 | if r[1].IsNil() {
118 | return nil
119 | }
120 | return r[1].Interface().(func())
121 | }
122 |
123 | func (r callResult) error(position int) error {
124 | if r[position].IsNil() {
125 | return nil
126 | }
127 | return r[position].Interface().(error)
128 | }
129 |
--------------------------------------------------------------------------------
/di/provider_embed.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "reflect"
5 | "strings"
6 | )
7 |
8 | // createStructProvider
9 | func newProviderEmbed(p parameter) *providerEmbed {
10 | var embedType reflect.Type
11 | if p.res.Kind() == reflect.Ptr {
12 | embedType = p.res.Elem()
13 | } else {
14 | embedType = p.res
15 | }
16 |
17 | return &providerEmbed{
18 | key: key{
19 | name: p.name,
20 | res: p.res,
21 | typ: ptEmbedParameter,
22 | },
23 | embedType: embedType,
24 | embedValue: reflect.New(embedType).Elem(),
25 | }
26 | }
27 |
28 | type providerEmbed struct {
29 | key key
30 | embedType reflect.Type
31 | embedValue reflect.Value
32 | }
33 |
34 | func (p *providerEmbed) Key() key {
35 | return p.key
36 | }
37 |
38 | func (p *providerEmbed) ParameterList() parameterList {
39 | var plist parameterList
40 | for i := 0; i < p.embedType.NumField(); i++ {
41 | name, optional, isDependency := p.inspectFieldTag(i)
42 | if !isDependency {
43 | continue
44 | }
45 | field := p.embedType.Field(i)
46 | plist = append(plist, parameter{
47 | name: name,
48 | res: field.Type,
49 | optional: optional,
50 | embed: isEmbedParameter(field.Type),
51 | })
52 | }
53 | return plist
54 | }
55 |
56 | func (p *providerEmbed) Provide(values ...reflect.Value) (reflect.Value, func(), error) {
57 | for i, offset := 0, 0; i < p.embedType.NumField(); i++ {
58 | _, _, isDependency := p.inspectFieldTag(i)
59 | if !isDependency {
60 | offset++
61 | continue
62 | }
63 |
64 | p.embedValue.Field(i).Set(values[i-offset])
65 | }
66 |
67 | return p.embedValue, nil, nil
68 | }
69 |
70 | func (p *providerEmbed) inspectFieldTag(num int) (name string, optional bool, isDependency bool) {
71 | fieldType := p.embedType.Field(num)
72 | fieldValue := p.embedValue.Field(num)
73 | tag, tagExists := fieldType.Tag.Lookup("di")
74 | if !tagExists || !fieldValue.CanSet() {
75 | return "", false, false
76 | }
77 | name, optional = p.parseTag(tag)
78 | return name, optional, true
79 | }
80 |
81 | func (p *providerEmbed) parseTag(tag string) (name string, optional bool) {
82 | options := strings.Split(tag, ",")
83 | if len(options) == 0 {
84 | return "", false
85 | }
86 | if len(options) == 1 && options[0] == "optional" {
87 | return "", true
88 | }
89 | if len(options) == 1 {
90 | return options[0], false
91 | }
92 | if len(options) == 2 && options[1] == "optional" {
93 | return options[0], true
94 | }
95 | panic("incorrect di tag")
96 | }
97 |
--------------------------------------------------------------------------------
/di/provider_group.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "reflect"
5 | )
6 |
7 | // newProviderGroup creates new group from provided resultKey.
8 | func newProviderGroup(k key) *providerGroup {
9 | ifaceKey := key{
10 | res: reflect.SliceOf(k.res),
11 | typ: ptGroup,
12 | }
13 |
14 | return &providerGroup{
15 | result: ifaceKey,
16 | pl: parameterList{},
17 | }
18 | }
19 |
20 | // providerGroup
21 | type providerGroup struct {
22 | result key
23 | pl parameterList
24 | }
25 |
26 | // Add
27 | func (i *providerGroup) Add(k key) {
28 | i.pl = append(i.pl, parameter{
29 | name: k.name,
30 | res: k.res,
31 | optional: false,
32 | embed: false,
33 | })
34 | }
35 |
36 | // resultKey
37 | func (i providerGroup) Key() key {
38 | return i.result
39 | }
40 |
41 | // parameters
42 | func (i providerGroup) ParameterList() parameterList {
43 | return i.pl
44 | }
45 |
46 | // Provide
47 | func (i providerGroup) Provide(values ...reflect.Value) (reflect.Value, func(), error) {
48 | group := reflect.New(i.result.res).Elem()
49 | return reflect.Append(group, values...), nil, nil
50 | }
51 |
--------------------------------------------------------------------------------
/di/provider_iface.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "reflect"
5 |
6 | "github.com/defval/inject/v2/di/internal/reflection"
7 | )
8 |
9 | // newProviderInterface
10 | func newProviderInterface(provider internalProvider, as interface{}) *providerInterface {
11 | iface := reflection.InspectInterfacePtr(as)
12 | if !provider.Key().res.Implements(iface.Type) {
13 | panicf("%s not implement %s", provider.Key(), iface.Type)
14 | }
15 | return &providerInterface{
16 | res: key{
17 | name: provider.Key().name,
18 | res: iface.Type,
19 | typ: ptInterface,
20 | },
21 | provider: provider,
22 | }
23 | }
24 |
25 | // providerInterface
26 | type providerInterface struct {
27 | res key
28 | provider internalProvider
29 | }
30 |
31 | func (i *providerInterface) Key() key {
32 | return i.res
33 | }
34 |
35 | func (i *providerInterface) ParameterList() parameterList {
36 | var plist parameterList
37 | plist = append(plist, parameter{
38 | name: i.provider.Key().name,
39 | res: i.provider.Key().res,
40 | optional: false,
41 | embed: false,
42 | })
43 | return plist
44 | }
45 |
46 | func (i *providerInterface) Provide(values ...reflect.Value) (reflect.Value, func(), error) {
47 | return values[0], nil, nil
48 | }
49 |
--------------------------------------------------------------------------------
/di/provider_stub.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | )
7 |
8 | // providerStub
9 | type providerStub struct {
10 | msg string
11 | res key
12 | }
13 |
14 | // newProviderStub
15 | func newProviderStub(k key, msg string) *providerStub {
16 | return &providerStub{res: k, msg: msg}
17 | }
18 |
19 | func (m *providerStub) Key() key {
20 | return m.res
21 | }
22 |
23 | func (m *providerStub) ParameterList() parameterList {
24 | return parameterList{}
25 | }
26 |
27 | func (m *providerStub) Provide(values ...reflect.Value) (reflect.Value, func(), error) {
28 | return reflect.Value{}, nil, fmt.Errorf(m.msg)
29 | }
30 |
--------------------------------------------------------------------------------
/di/singleton.go:
--------------------------------------------------------------------------------
1 | package di
2 |
3 | import (
4 | "reflect"
5 | )
6 |
7 | // asSingleton creates a singleton wrapper.
8 | func asSingleton(provider internalProvider) *singletonWrapper {
9 | return &singletonWrapper{internalProvider: provider}
10 | }
11 |
12 | // singletonWrapper is a embedParamProvider wrapper. Stores provided value for prevent reinitialization.
13 | type singletonWrapper struct {
14 | internalProvider // source provider
15 | value reflect.Value // value cache
16 | }
17 |
18 | // Provide
19 | func (s *singletonWrapper) Provide(values ...reflect.Value) (reflect.Value, func(), error) {
20 | if s.value.IsValid() {
21 | return s.value, nil, nil
22 | }
23 | value, cleanup, err := s.internalProvider.Provide(values...)
24 | s.value = value
25 |
26 | return value, cleanup, err
27 | }
28 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | // The MIT License (MIT)
2 | //
3 | // Copyright (c) 2019 Maxim Bovtunov
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | /*
24 | Package inject make your dependency injection easy. Container allows you to inject dependencies
25 | into constructors or structures without the need to have specified
26 | each argument manually.
27 |
28 | Provide
29 |
30 | First of all, when creating a new container, you need to describe
31 | how to create each instance of a dependency. To do this, use the container
32 | option inject.Provide().
33 |
34 | container := New(
35 | Provide(NewDependency),
36 | Provide(NewAnotherDependency)
37 | )
38 |
39 | func NewDependency(dependency *pkg.AnotherDependency) *pkg.Dependency {
40 | return &pkg.Dependency{
41 | dependency: dependency,
42 | }
43 | }
44 |
45 | func NewAnotherDependency() (*pkg.AnotherDependency, error) {
46 | if dependency, err = initAnotherDependency(); err != nil {
47 | return nil, err
48 | }
49 |
50 | return dependency, nil
51 | }
52 |
53 | Now, container knows how to create *pkg.Dependency and *pkg.AnotherDependency.
54 | For advanced providing see inject.Provide() and inject.ProvideOption documentation.
55 |
56 | Extract
57 |
58 | After building a container, it is easy to get any previously provided type.
59 | To do this, use the container's Extract() method.
60 |
61 | var anotherDependency *pkg.AnotherDependency
62 | if err = container.Extract(&anotherDependency); err != nil {
63 | // handle error
64 | }
65 |
66 | The container collects a dependencies of *pkg.AnotherDependency, creates its instance and
67 | places it in a target pointer.
68 | For advanced extraction see Extract() and inject.ExtractOption documentation.
69 | */
70 | package inject
71 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/defval/inject/v2
2 |
3 | require (
4 | github.com/davecgh/go-spew v1.1.1 // indirect
5 | github.com/emicklei/dot v0.10.1
6 | github.com/kr/pretty v0.1.0 // indirect
7 | github.com/stretchr/testify v1.4.0
8 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
9 | gopkg.in/yaml.v2 v2.2.8 // indirect
10 | )
11 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5 | github.com/emicklei/dot v0.10.1 h1:bkzvwgIhhw/cuxxnJy5/5+ZL3GnhFxFfv0eolHtWE2w=
6 | github.com/emicklei/dot v0.10.1/go.mod h1:kZg82Ikwc4pqb31Ct2yb0B7RUqxh3JESIXw2uWSv/xY=
7 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
8 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
9 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
10 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
11 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
12 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
13 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
14 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
15 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
19 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
20 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
21 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
22 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
23 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
24 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
25 |
--------------------------------------------------------------------------------
/graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d3fvxl/inject/f8416f73ad0c0ce9487ef3ceae5035a86c465ee7/graph.png
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d3fvxl/inject/f8416f73ad0c0ce9487ef3ceae5035a86c465ee7/logo.png
--------------------------------------------------------------------------------
/options.go:
--------------------------------------------------------------------------------
1 | package inject
2 |
3 | import "github.com/defval/inject/v2/di"
4 |
5 | // OPTIONS
6 |
7 | // Option configures container. See inject.Provide(), inject.Bundle(), inject.Replace().
8 | type Option interface {
9 | apply(*Container)
10 | }
11 |
12 | // Provide returns container option that explains how to create an instance of a type inside a container.
13 | //
14 | // The first argument is the constructor function. A constructor is a function that creates an instance of the required
15 | // type. It can take an unlimited number of arguments needed to create an instance - the first returned value.
16 | //
17 | // func NewServer(mux *http.ServeMux) *http.Server {
18 | // return &http.Server{
19 | // Handle: mux,
20 | // }
21 | // }
22 | //
23 | // Optionally, you can return a cleanup function and initializing error.
24 | //
25 | // func NewServer(mux *http.ServeMux) (*http.Server, cleanup func(), err error) {
26 | // if time.Now().Day = 1 {
27 | // return nil, nil, errors.New("the server is down on the first day of a month")
28 | // }
29 | //
30 | // server := &http.Server{
31 | // Handler: mux,
32 | // }
33 | //
34 | // cleanup := func() {
35 | // _ = server.Close()
36 | // }
37 | //
38 | // return server, cleanup, nil
39 | // }
40 | //
41 | // Other function signatures will cause error.
42 | func Provide(provider interface{}, options ...ProvideOption) Option {
43 | return option(func(container *Container) {
44 | // todo: add provider
45 | var params = di.ProvideParams{
46 | Parameters: map[string]interface{}{},
47 | }
48 |
49 | for _, opt := range options {
50 | opt.apply(¶ms)
51 | }
52 | container.providers = append(container.providers, provide{
53 | provider: provider,
54 | params: params,
55 | })
56 | })
57 | }
58 |
59 | // Bundle group together container options.
60 | //
61 | // accountBundle := inject.Bundle(
62 | // inject.Provide(NewAccountController),
63 | // inject.Provide(NewAccountRepository),
64 | // )
65 | //
66 | // authBundle := inject.Bundle(
67 | // inject.Provide(NewAuthController),
68 | // inject.Provide(NewAuthRepository),
69 | // )
70 | //
71 | // container, _ := New(
72 | // accountBundle,
73 | // authBundle,
74 | // )
75 | func Bundle(options ...Option) Option {
76 | return option(func(container *Container) {
77 | for _, opt := range options {
78 | opt.apply(container)
79 | }
80 | })
81 | }
82 |
83 | // ProvideOption modifies default provide behavior. See inject.WithName(), inject.As(), inject.Prototype().
84 | type ProvideOption interface {
85 | apply(params *di.ProvideParams)
86 | }
87 |
88 | // WithName sets string identifier for provided value.
89 | //
90 | // inject.Provide(&http.Server{}, inject.WithName("first"))
91 | // inject.Provide(&http.Server{}, inject.WithName("second"))
92 | //
93 | // container.Extract(&server, inject.Name("second"))
94 | func WithName(name string) ProvideOption {
95 | return provideOption(func(provider *di.ProvideParams) {
96 | provider.Name = name
97 | })
98 | }
99 |
100 | // As specifies interfaces that implement provider instance. Provide with As() automatically checks that constructor
101 | // result implements interface and creates slice group with it.
102 | //
103 | // Provide(&http.ServerMux{}, inject.As(new(http.Handler)))
104 | //
105 | // var handler http.Handler
106 | // container.Extract(&handler) // extract as interface
107 | //
108 | // var handlers []http.Handler
109 | // container.Extract(&handlers) // extract group
110 | func As(ifaces ...interface{}) ProvideOption {
111 | return provideOption(func(provider *di.ProvideParams) {
112 | provider.Interfaces = append(provider.Interfaces, ifaces...)
113 |
114 | })
115 | }
116 |
117 | // Prototype modifies Provide() behavior. By default, each type resolves as a singleton. This option sets that
118 | // each type resolving creates a new instance of the type.
119 | //
120 | // Provide(&http.Server{], inject.Prototype())
121 | //
122 | // var server1 *http.Server
123 | // var server2 *http.Server
124 | // container.Extract(&server1, &server2)
125 | func Prototype() ProvideOption {
126 | return provideOption(func(provider *di.ProvideParams) {
127 | provider.IsPrototype = true
128 | })
129 | }
130 |
131 | // ParameterBag is a provider parameter bag. It stores a construction parameters. It is a alternative way to
132 | // configure type.
133 | //
134 | // inject.Provide(NewServer, inject.ParameterBag{
135 | // "addr": ":8080",
136 | // })
137 | //
138 | // NewServer(pb inject.ParameterBag) *http.Server {
139 | // return &http.Server{
140 | // Addr: pb.RequireString("addr"),
141 | // }
142 | // }
143 | type ParameterBag map[string]interface{}
144 |
145 | func (p ParameterBag) apply(provider *di.ProvideParams) {
146 | for k, v := range p {
147 | provider.Parameters[k] = v
148 | }
149 | }
150 |
151 | // ExtractOption modifies default extract behavior. See inject.Name().
152 | type ExtractOption interface {
153 | apply(params *di.ExtractParams)
154 | }
155 |
156 | // EXTRACT OPTIONS.
157 |
158 | // Name specify definition name.
159 | func Name(name string) ExtractOption {
160 | return extractOption(func(eo *di.ExtractParams) {
161 | eo.Name = name
162 | })
163 | }
164 |
165 | type option func(container *Container)
166 |
167 | func (o option) apply(container *Container) { o(container) }
168 |
169 | type provideOption func(provider *di.ProvideParams)
170 |
171 | func (o provideOption) apply(provider *di.ProvideParams) { o(provider) }
172 |
173 | type extractOption func(eo *di.ExtractParams)
174 |
175 | func (o extractOption) apply(eo *di.ExtractParams) { o(eo) }
176 |
177 | type extractOptions struct {
178 | name string
179 | target interface{}
180 | }
181 |
--------------------------------------------------------------------------------
/options_test.go:
--------------------------------------------------------------------------------
1 | package inject
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 |
9 | "github.com/defval/inject/v2/di"
10 | )
11 |
12 | // TestProvideOptions
13 | func TestProvideOptions(t *testing.T) {
14 | opts := &di.ProvideParams{
15 | Parameters: map[string]interface{}{},
16 | }
17 |
18 | for _, opt := range []ProvideOption{
19 | WithName("test"),
20 | As(new(http.Handler)),
21 | Prototype(),
22 | ParameterBag{
23 | "test": "test",
24 | },
25 | } {
26 | opt.apply(opts)
27 | }
28 |
29 | require.Equal(t, &di.ProvideParams{
30 | Name: "test",
31 | Interfaces: []interface{}{new(http.Handler)},
32 | IsPrototype: true,
33 | Parameters: map[string]interface{}{
34 | "test": "test",
35 | },
36 | }, opts)
37 | }
38 |
39 | func TestExtractOptions(t *testing.T) {
40 | opts := &di.ExtractParams{}
41 |
42 | for _, opt := range []ExtractOption{
43 | Name("test"),
44 | } {
45 | opt.apply(opts)
46 | }
47 |
48 | require.Equal(t, &di.ExtractParams{
49 | Name: "test",
50 | }, opts)
51 | }
52 |
--------------------------------------------------------------------------------