├── .gitignore ├── LICENSE ├── README.md └── src ├── demo ├── a.go └── b.go ├── di └── container.go └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | pkg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © 2018 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # golang实现DI容器 2 | 基于反射实现依赖注入容器 3 | 4 | ## Features 5 | 6 | + 注册/获取依赖 7 | + 基于tag自动注入依赖 8 | 9 | ## 依赖类型 10 | 11 | + 单例依赖 12 | + 实例依赖 13 | 14 | ## License 15 | 16 | MIT 17 | -------------------------------------------------------------------------------- /src/demo/a.go: -------------------------------------------------------------------------------- 1 | package demo 2 | 3 | import ( 4 | "database/sql" 5 | ) 6 | 7 | type A struct { 8 | Db *sql.DB `di:"db"` 9 | Db1 *sql.DB `di:"db"` 10 | B *B `di:"b,prototype"` 11 | B1 *B `di:"b,prototype"` 12 | } 13 | 14 | func NewA() *A { 15 | return &A{} 16 | } 17 | 18 | func (p *A) Version() (string, error) { 19 | rows, err := p.Db.Query("SELECT VERSION() as version") 20 | if err != nil { 21 | return "", err 22 | } 23 | defer rows.Close() 24 | 25 | var version string 26 | if rows.Next() { 27 | if err := rows.Scan(&version); err != nil { 28 | return "", err 29 | } 30 | } 31 | if err := rows.Err(); err != nil { 32 | return "", err 33 | } 34 | return version, nil 35 | } 36 | -------------------------------------------------------------------------------- /src/demo/b.go: -------------------------------------------------------------------------------- 1 | package demo 2 | 3 | import "time" 4 | 5 | type B struct { 6 | Name string 7 | } 8 | 9 | func NewB() *B { 10 | return &B{ 11 | Name: time.Now().String(), 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/di/container.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "sync" 5 | "reflect" 6 | "fmt" 7 | "strings" 8 | "errors" 9 | ) 10 | 11 | var ( 12 | ErrFactoryNotFound = errors.New("factory not found") 13 | ) 14 | 15 | type factory = func() (interface{}, error) 16 | // 容器 17 | type Container struct { 18 | sync.Mutex 19 | singletons map[string]interface{} 20 | factories map[string]factory 21 | } 22 | // 容器实例化 23 | func NewContainer() *Container { 24 | return &Container{ 25 | singletons: make(map[string]interface{}), 26 | factories: make(map[string]factory), 27 | } 28 | } 29 | 30 | // 注册单例对象 31 | func (p *Container) SetSingleton(name string, singleton interface{}) { 32 | p.Lock() 33 | p.singletons[name] = singleton 34 | p.Unlock() 35 | } 36 | 37 | // 获取单例对象 38 | func (p *Container) GetSingleton(name string) interface{} { 39 | return p.singletons[name] 40 | } 41 | 42 | // 获取实例对象 43 | func (p *Container) GetPrototype(name string) (interface{}, error) { 44 | factory, ok := p.factories[name] 45 | if !ok { 46 | return nil, ErrFactoryNotFound 47 | } 48 | return factory() 49 | } 50 | 51 | // 设置实例对象工厂 52 | func (p *Container) SetPrototype(name string, factory factory) { 53 | p.Lock() 54 | p.factories[name] = factory 55 | p.Unlock() 56 | } 57 | 58 | // 注入依赖 59 | func (p *Container) Ensure(instance interface{}) error { 60 | elemType := reflect.TypeOf(instance).Elem() 61 | ele := reflect.ValueOf(instance).Elem() 62 | for i := 0; i < elemType.NumField(); i++ { // 遍历字段 63 | fieldType := elemType.Field(i) 64 | tag := fieldType.Tag.Get("di") // 获取tag 65 | diName := p.injectName(tag) 66 | if diName == "" { 67 | continue 68 | } 69 | var ( 70 | diInstance interface{} 71 | err error 72 | ) 73 | if p.isSingleton(tag) { 74 | diInstance = p.GetSingleton(diName) 75 | } 76 | if p.isPrototype(tag) { 77 | diInstance, err = p.GetPrototype(diName) 78 | } 79 | if err != nil { 80 | return err 81 | } 82 | if diInstance == nil { 83 | return errors.New(diName + " dependency not found") 84 | } 85 | ele.Field(i).Set(reflect.ValueOf(diInstance)) 86 | } 87 | return nil 88 | } 89 | 90 | // 获取需要注入的依赖名称 91 | func (p *Container) injectName(tag string) string { 92 | tags := strings.Split(tag, ",") 93 | if len(tags) == 0 { 94 | return "" 95 | } 96 | return tags[0] 97 | } 98 | 99 | // 检测是否单例依赖 100 | func (p *Container) isSingleton(tag string) bool { 101 | tags := strings.Split(tag, ",") 102 | for _, name := range tags { 103 | if name == "prototype" { 104 | return false 105 | } 106 | } 107 | return true 108 | } 109 | 110 | // 检测是否实例依赖 111 | func (p *Container) isPrototype(tag string) bool { 112 | tags := strings.Split(tag, ",") 113 | for _, name := range tags { 114 | if name == "prototype" { 115 | return true 116 | } 117 | } 118 | return false 119 | } 120 | 121 | // 打印容器内部实例 122 | func (p *Container) String() string { 123 | lines := make([]string, 0, len(p.singletons)+len(p.factories)+2) 124 | lines = append(lines, "singletons:") 125 | for name, item := range p.singletons { 126 | line := fmt.Sprintf(" %s: %x %s", name, &item, reflect.TypeOf(item).String()) 127 | lines = append(lines, line) 128 | } 129 | lines = append(lines, "factories:") 130 | for name, item := range p.factories { 131 | line := fmt.Sprintf(" %s: %x %s", name, &item, reflect.TypeOf(item).String()) 132 | lines = append(lines, line) 133 | } 134 | return strings.Join(lines, "\n") 135 | } 136 | -------------------------------------------------------------------------------- /src/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "di" 5 | "database/sql" 6 | "fmt" 7 | "os" 8 | _ "github.com/go-sql-driver/mysql" 9 | "demo" 10 | ) 11 | 12 | func main() { 13 | container := di.NewContainer() 14 | db, err := sql.Open("mysql", "root:root@tcp(localhost)/sampledb") 15 | if err != nil { 16 | fmt.Printf("error: %s\n", err.Error()) 17 | os.Exit(1) 18 | } 19 | container.SetSingleton("db", db) 20 | container.SetPrototype("b", func() (interface{}, error) { 21 | return demo.NewB(), nil 22 | }) 23 | 24 | a := demo.NewA() 25 | if err := container.Ensure(a); err != nil { 26 | fmt.Println(err) 27 | return 28 | } 29 | // 打印指针,确保单例和实例的指针地址 30 | fmt.Printf("db: %p\ndb1: %p\nb: %p\nb1: %p\n", a.Db, a.Db1, &a.B, &a.B1) 31 | } 32 | --------------------------------------------------------------------------------