├── LICENSE ├── README.md ├── bean_definition.go ├── bean_lifecycle.go ├── bean_lifecycle_interface.go ├── container_interface.go ├── di.go ├── example └── main.go ├── global.go ├── go.mod ├── go.sum ├── logger.go ├── util.go ├── value_store.go └── van ├── cast.go ├── store.go └── van.go /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Cheivin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # di 2 | 3 | `di`是一个简易版本的Go依赖注入实现 4 | 5 | [文档地址](https://cheivin.gitbook.io/di/) 6 | 7 | ## 特性 8 | 9 | * 支持手动注册bean实例 10 | * 支持注册bean类型原型,由DI容器自动实例化并托管bean实例 11 | * 支持根据名称、类型获取DI容器托管的bean实例 12 | * 支持根据类型手动生成新的bean实例并返回 13 | * 支持配置项注入并转换成对应的基本类型 14 | * 支持匿名字段的bean注入 15 | 16 | ## 特别鸣谢 17 | 18 | [![JetBrains](https://raw.githubusercontent.com/kainonly/ngx-bit/main/resource/jetbrains.svg)](https://www.jetbrains.com/?from=cheivin) 19 | 20 | 感谢 [JetBrains](https://www.jetbrains.com/?from=cheivin) 提供的开源开发许可证。 21 | -------------------------------------------------------------------------------- /bean_definition.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | type ( 10 | // bean定义 11 | definition struct { 12 | Name string 13 | Type reflect.Type 14 | awareMap map[string]aware // fieldName:aware 15 | valueMap map[string]aware // fieldName:aware 16 | } 17 | 18 | // 需要注入的信息 19 | aware struct { 20 | Name string 21 | Type reflect.Type 22 | IsPtr bool // 是否为结构指针 23 | IsInterface bool // 是否为接口 24 | Anonymous bool // 是否为匿名字段 25 | Omitempty bool // 不存在依赖时则忽略注入 26 | } 27 | ) 28 | 29 | func (container *di) newDefinition(beanName string, prototype reflect.Type) definition { 30 | def := definition{Name: beanName, Type: prototype} 31 | awareMap := map[string]aware{} 32 | valueMap := map[string]aware{} 33 | for i := 0; i < prototype.NumField(); i++ { 34 | field := prototype.Field(i) 35 | switch field.Type.Kind() { 36 | case reflect.Ptr, reflect.Interface, reflect.Struct: 37 | if awareName, ok := field.Tag.Lookup("aware"); ok { 38 | omitempty := false 39 | switch { 40 | case strings.EqualFold(awareName, "omitempty"): 41 | omitempty = true 42 | awareName = "" 43 | case strings.HasSuffix(awareName, ",omitempty"): 44 | omitempty = true 45 | awareName = strings.TrimSuffix(awareName, ",omitempty") 46 | } 47 | 48 | switch field.Type.Kind() { 49 | case reflect.Ptr: 50 | if reflect.Interface == field.Type.Elem().Kind() { 51 | panic(fmt.Errorf("%w: aware bean not accept interface pointer for %s.%s", ErrDefinition, prototype.String(), field.Name)) 52 | } 53 | tmpBean := reflect.New(field.Type.Elem()).Interface() 54 | if awareName == "" { 55 | switch tmpBean.(type) { 56 | case BeanName: // 取接口返回值为注入的beanName 57 | if name := tmpBean.(BeanName).BeanName(); name != "" { 58 | awareName = name 59 | } 60 | } 61 | } 62 | if awareName == "" { 63 | // 取类型名称为注入的beanName 64 | awareName = GetBeanName(field.Type) 65 | } 66 | // 检查匿名类 67 | if field.Anonymous { 68 | errInterface := checkAnonymousFieldBean(tmpBean) 69 | if errInterface != "" { 70 | container.log.Fatal(fmt.Sprintf("%s: %s(%s) as anonymous field in %s(%s.%s) can not implements %s", 71 | ErrBean, awareName, field.Type.String(), 72 | def.Name, def.Type.String(), field.Name, 73 | errInterface, 74 | )) 75 | } 76 | } 77 | 78 | // 注册aware信息 79 | awareMap[field.Name] = aware{ 80 | Name: awareName, 81 | Type: field.Type, 82 | IsPtr: true, 83 | Anonymous: field.Anonymous, 84 | Omitempty: omitempty, 85 | } 86 | case reflect.Interface: 87 | // 取类型名称为注入的beanName 88 | if awareName == "" { 89 | awareName = GetBeanName(field.Type) 90 | } 91 | // 注册aware信息 92 | awareMap[field.Name] = aware{ 93 | Name: awareName, 94 | Type: field.Type, 95 | IsPtr: false, 96 | IsInterface: true, 97 | Anonymous: field.Anonymous, 98 | Omitempty: omitempty, 99 | } 100 | case reflect.Struct: 101 | panic(fmt.Errorf("%w: aware bean not accept struct for %s.%s", ErrDefinition, prototype.String(), field.Name)) 102 | } 103 | } 104 | case reflect.String, reflect.Bool, 105 | reflect.Float64, reflect.Float32, 106 | reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, 107 | reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: 108 | if property, ok := field.Tag.Lookup("value"); ok { 109 | if property != "" { 110 | valueMap[field.Name] = aware{ 111 | Name: property, 112 | Type: field.Type, 113 | } 114 | } 115 | } 116 | default: 117 | // ignore其他类型 118 | } 119 | } 120 | def.awareMap = awareMap 121 | def.valueMap = valueMap 122 | return def 123 | } 124 | 125 | func (container *di) getValueDefinition(prototype reflect.Type) definition { 126 | def := definition{Name: prototype.Name(), Type: prototype} 127 | valueMap := map[string]aware{} 128 | for i := 0; i < prototype.NumField(); i++ { 129 | field := prototype.Field(i) 130 | switch field.Type.Kind() { 131 | case reflect.String, reflect.Bool, 132 | reflect.Float64, reflect.Float32, 133 | reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, 134 | reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: 135 | if property, ok := field.Tag.Lookup("value"); ok { 136 | if property != "" { 137 | valueMap[field.Name] = aware{ 138 | Name: property, 139 | Type: field.Type, 140 | } 141 | } 142 | } 143 | default: 144 | // ignore其他类型 145 | } 146 | } 147 | def.valueMap = valueMap 148 | return def 149 | } 150 | 151 | // checkAnonymousFieldBean 检查匿名字段不能实现的接口 152 | func checkAnonymousFieldBean(awareBean interface{}) string { 153 | // 匿名字段不能实现BeanConstruct/PreInitialize/AfterPropertiesSet/Initialized/Disposable等生命周期接口 154 | switch awareBean.(type) { 155 | case BeanConstruct: 156 | return "BeanConstruct" 157 | case BeanConstructWithContainer: 158 | return "BeanConstructWithContainer" 159 | case PreInitialize: 160 | return "PreInitialize" 161 | case PreInitializeWithContainer: 162 | return "PreInitializeWithContainer" 163 | case AfterPropertiesSet: 164 | return "AfterPropertiesSet" 165 | case AfterPropertiesSetWithContainer: 166 | return "AfterPropertiesSetWithContainer" 167 | case Initialized: 168 | return "Initialized" 169 | case InitializedWithContainer: 170 | return "InitializedWithContainer" 171 | case Disposable: 172 | return "Disposable" 173 | case DisposableWithContainer: 174 | return "DisposableWithContainer" 175 | default: 176 | return "" 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /bean_lifecycle.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "fmt" 5 | "github.com/cheivin/di/van" 6 | "reflect" 7 | "unsafe" 8 | ) 9 | 10 | // wireValue 注入配置项 11 | func (container *di) wireValue(bean reflect.Value, def definition, prefix string) { 12 | if len(def.valueMap) > 0 { 13 | container.log.Info(fmt.Sprintf("wire value for bean %s(%s)", def.Name, def.Type.String())) 14 | } 15 | for filedName, valueInfo := range def.valueMap { 16 | valueName := prefix + valueInfo.Name 17 | value := container.valueStore.Get(valueName) 18 | if value == nil { 19 | continue 20 | } 21 | castValue, err := van.Cast(value, valueInfo.Type) 22 | if err != nil { 23 | container.log.Fatal(fmt.Sprintf("%s: %s(%s) wire value failed for %s(%s.%s), %s", 24 | ErrBean, valueName, valueInfo.Type.String(), 25 | def.Name, def.Type.String(), filedName, 26 | err.Error(), 27 | )) 28 | return 29 | } 30 | val := reflect.ValueOf(castValue) 31 | // 设置值 32 | if container.unsafe { 33 | container.log.Debug(fmt.Sprintf("wire value for %s(%s.%s) in unsafe mode", 34 | def.Name, def.Type.String(), filedName, 35 | )) 36 | field := bean.FieldByName(filedName) 37 | field = reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() 38 | field.Set(val) 39 | } else { 40 | container.log.Debug(fmt.Sprintf("wire value for %s(%s.%s)", 41 | def.Name, def.Type.String(), filedName, 42 | )) 43 | bean.FieldByName(filedName).Set(val) 44 | } 45 | } 46 | } 47 | 48 | // instanceBean 创建bean指针对象 并注入value 49 | func (container *di) instanceBean(def definition) interface{} { 50 | container.log.Debug(fmt.Sprintf("reflect instance for %s(%s)", def.Name, def.Type.String())) 51 | prototype := reflect.New(def.Type).Interface() 52 | // 注入值 53 | container.wireValue(reflect.ValueOf(prototype).Elem(), def, "") 54 | return prototype 55 | } 56 | 57 | // constructBean 触发bean构造方法 58 | func (container *di) constructBean(beanName string, prototype interface{}) { 59 | switch prototype.(type) { 60 | case BeanConstructWithContainer: 61 | container.log.Debug(fmt.Sprintf("call lifecycle interface BeanConstructWithContainer for %s(%T)", beanName, prototype)) 62 | prototype.(BeanConstructWithContainer).BeanConstruct(container) 63 | case BeanConstruct: 64 | container.log.Debug(fmt.Sprintf("call lifecycle interface BeanConstruct for %s(%T)", beanName, prototype)) 65 | prototype.(BeanConstruct).BeanConstruct() 66 | } 67 | } 68 | 69 | // processBean 处理bean依赖注入 70 | func (container *di) processBean(prototype interface{}, def definition) interface{} { 71 | // 注入前方法 72 | switch prototype.(type) { 73 | case PreInitializeWithContainer: 74 | container.log.Debug(fmt.Sprintf("call lifecycle interface PreInitializeWithContainer for %s(%s)", def.Name, def.Type.String())) 75 | prototype.(PreInitializeWithContainer).PreInitialize(container) 76 | case PreInitialize: 77 | container.log.Debug(fmt.Sprintf("call lifecycle interface PreInitialize for %s(%s)", def.Name, def.Type.String())) 78 | prototype.(PreInitialize).PreInitialize() 79 | } 80 | 81 | bean := reflect.ValueOf(prototype).Elem() 82 | container.wireBean(bean, def) 83 | 84 | // 注入后方法 85 | switch prototype.(type) { 86 | case AfterPropertiesSetWithContainer: 87 | container.log.Debug(fmt.Sprintf("call lifecycle interface AfterPropertiesSetWithContainer for %s(%s)", def.Name, def.Type.String())) 88 | prototype.(AfterPropertiesSetWithContainer).AfterPropertiesSet(container) 89 | case AfterPropertiesSet: 90 | container.log.Debug(fmt.Sprintf("call lifecycle interface AfterPropertiesSet for %s(%s)", def.Name, def.Type.String())) 91 | prototype.(AfterPropertiesSet).AfterPropertiesSet() 92 | } 93 | return prototype 94 | } 95 | 96 | // findBeanByName 根据名称查找bean 97 | func (container *di) findBeanByName(beanName string) (awareBean interface{}, ok bool) { 98 | // 从注册的bean中查找 99 | if awareBean, ok = container.beanMap[beanName]; !ok { 100 | // 从原型定义中查找 101 | awareBean, ok = container.prototypeMap[beanName] 102 | } 103 | return 104 | } 105 | 106 | type BeanWithName struct { 107 | Name string 108 | Bean interface{} 109 | } 110 | 111 | func (container *di) findBeanByType(beanType reflect.Type) []BeanWithName { 112 | var beans []BeanWithName 113 | // 根据排序遍历beanName查找 114 | for e := container.beanSort.Front(); e != nil; e = e.Next() { 115 | findBeanName := e.Value.(string) 116 | 117 | if prototype, ok := container.findBeanByName(findBeanName); ok { 118 | if reflect.TypeOf(prototype).AssignableTo(beanType) { 119 | container.log.Info(fmt.Sprintf("find interface %s implemented by %s(%T)", 120 | beanType.String(), findBeanName, prototype, 121 | )) 122 | beans = append(beans, BeanWithName{Name: findBeanName, Bean: prototype}) 123 | } 124 | } 125 | } 126 | return beans 127 | } 128 | 129 | // wireBean 注入单个依赖 130 | func (container *di) wireBean(bean reflect.Value, def definition) { 131 | if len(def.awareMap) > 0 { 132 | container.log.Info(fmt.Sprintf("wire field for bean %s(%s)", def.Name, def.Type.String())) 133 | } 134 | for filedName, awareInfo := range def.awareMap { 135 | var awareBean interface{} 136 | var ok bool 137 | 138 | // 根据名称查找bean 139 | awareBean, ok = container.findBeanByName(awareInfo.Name) 140 | // 如果是接口类型 141 | if awareInfo.IsInterface && !ok { 142 | awareBeans := container.findBeanByType(awareInfo.Type) 143 | if len(awareBeans) > 0 { 144 | selectBean := awareBeans[len(awareBeans)-1] 145 | awareBean = selectBean.Bean 146 | ok = true 147 | container.log.Info(fmt.Sprintf("%s(%T) will be set to %s(%s.%s)", 148 | selectBean.Name, awareBean, 149 | def.Name, def.Type.String(), filedName, 150 | )) 151 | } 152 | } 153 | 154 | injectInfo := &InjectInfo{ 155 | Bean: awareBean, 156 | BeanName: awareInfo.Name, 157 | Type: awareInfo.Type, 158 | IsPtr: awareInfo.IsPtr, 159 | Anonymous: awareInfo.Anonymous, 160 | Omitempty: awareInfo.Omitempty, 161 | } 162 | switch bean.Interface().(type) { 163 | case Injector: 164 | bean.Interface().(Injector).BeanInject(container, injectInfo) 165 | if !ok { 166 | ok = injectInfo.Bean != nil 167 | } 168 | awareBean = injectInfo.Bean 169 | } 170 | 171 | if !ok { 172 | if awareInfo.Omitempty { 173 | container.log.Warn(fmt.Sprintf("Omitempty: dependent bean %s not found for %s(%s.%s)", 174 | awareInfo.Name, 175 | def.Name, 176 | def.Type.String(), 177 | filedName)) 178 | continue 179 | } 180 | container.log.Fatal(fmt.Sprintf("%s: %s notfound for %s(%s.%s)", 181 | ErrBean, 182 | awareInfo.Name, 183 | def.Name, 184 | def.Type.String(), 185 | filedName)) 186 | } 187 | value := reflect.ValueOf(awareBean) 188 | 189 | // 类型检查 190 | if awareInfo.IsPtr { // 指针类型 191 | if !value.Type().AssignableTo(awareInfo.Type) { 192 | container.log.Fatal(fmt.Sprintf("%s: %s(%s) not match for %s(%s.%s) need type %s", 193 | ErrBean, 194 | awareInfo.Name, value.Type().String(), 195 | def.Name, 196 | def.Type.String(), 197 | filedName, 198 | awareInfo.Type.String(), 199 | )) 200 | return 201 | } 202 | } else { // 接口类型 203 | if !value.Type().Implements(awareInfo.Type) { 204 | container.log.Fatal(fmt.Sprintf("%s: %s(%s) not implements interface %s for %s(%s.%s)", 205 | ErrBean, 206 | awareInfo.Name, value.Type().String(), 207 | awareInfo.Type.String(), 208 | def.Name, 209 | def.Type.String(), 210 | filedName, 211 | )) 212 | return 213 | } 214 | } 215 | 216 | // 设置值 217 | if container.unsafe { 218 | if awareInfo.Anonymous { 219 | container.log.Debug(fmt.Sprintf("wire anonymous field for %s(%s.%s) in unsafe mode", 220 | def.Name, def.Type.String(), filedName, 221 | )) 222 | } else { 223 | container.log.Debug(fmt.Sprintf("wire field for %s(%s.%s) in unsafe mode", 224 | def.Name, def.Type.String(), filedName, 225 | )) 226 | } 227 | 228 | field := bean.FieldByName(filedName) 229 | field = reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() 230 | field.Set(value) 231 | } else { 232 | if awareInfo.Anonymous { 233 | container.log.Debug(fmt.Sprintf("wire anonymous field for %s(%s.%s)", 234 | def.Name, def.Type.String(), filedName, 235 | )) 236 | } else { 237 | container.log.Debug(fmt.Sprintf("wire field for %s(%s.%s)", 238 | def.Name, def.Type.String(), filedName, 239 | )) 240 | } 241 | 242 | bean.FieldByName(filedName).Set(value) 243 | } 244 | } 245 | } 246 | 247 | // processInitialized bean初始化完成 248 | func (container *di) initializedBean(beanName string, bean interface{}) { 249 | switch bean.(type) { 250 | case InitializedWithContainer: 251 | container.log.Debug(fmt.Sprintf("call lifecycle interface InitializedWithContainer for %s(%T)", beanName, bean)) 252 | bean.(InitializedWithContainer).Initialized(container) 253 | case Initialized: 254 | container.log.Debug(fmt.Sprintf("call lifecycle interface Initialized for %s(%T)", beanName, bean)) 255 | bean.(Initialized).Initialized() 256 | } 257 | } 258 | 259 | // destroyBean 销毁bean 260 | func (container *di) destroyBean(beanName string, bean interface{}) { 261 | switch bean.(type) { 262 | case DisposableWithContainer: 263 | container.log.Debug(fmt.Sprintf("call lifecycle interface DisposableWithContainer for %s(%T)", beanName, bean)) 264 | bean.(DisposableWithContainer).Destroy(container) 265 | case Disposable: 266 | container.log.Debug(fmt.Sprintf("call lifecycle interface Disposable for %s(%T)", beanName, bean)) 267 | bean.(Disposable).Destroy() 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /bean_lifecycle_interface.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "reflect" 4 | 5 | type ( 6 | // BeanName 返回beanName 7 | BeanName interface { 8 | BeanName() string 9 | } 10 | 11 | // BeanConstruct Bean实例创建时 12 | BeanConstruct interface { 13 | BeanConstruct() 14 | } 15 | 16 | // BeanConstructWithContainer Bean实例创建时 17 | BeanConstructWithContainer interface { 18 | BeanConstruct(DI) 19 | } 20 | 21 | // PreInitialize Bean实例依赖注入前 22 | PreInitialize interface { 23 | PreInitialize() 24 | } 25 | 26 | // PreInitializeWithContainer Bean实例依赖注入前 27 | PreInitializeWithContainer interface { 28 | PreInitialize(DI) 29 | } 30 | 31 | InjectInfo struct { 32 | Bean interface{} 33 | BeanName string 34 | Type reflect.Type 35 | IsPtr bool // 是否为结构指针 36 | IsInterface bool // 是否为接口 37 | Anonymous bool // 是否为匿名字段 38 | Omitempty bool // 不存在依赖时则忽略注入 39 | } 40 | // Injector bean实例注入器 41 | Injector interface { 42 | BeanInject(di DI, info *InjectInfo) 43 | } 44 | 45 | // AfterPropertiesSet Bean实例注入完成 46 | AfterPropertiesSet interface { 47 | AfterPropertiesSet() 48 | } 49 | 50 | // AfterPropertiesSetWithContainer Bean实例注入完成 51 | AfterPropertiesSetWithContainer interface { 52 | AfterPropertiesSet(DI) 53 | } 54 | 55 | // Initialized 在Bean依赖注入完成后执行,可以理解为DI加载完成的通知事件。 56 | Initialized interface { 57 | Initialized() 58 | } 59 | 60 | // InitializedWithContainer 在Bean依赖注入完成后执行,可以理解为DI加载完成的通知事件。 61 | InitializedWithContainer interface { 62 | Initialized(DI) 63 | } 64 | 65 | // Disposable 在Bean注销时调用 66 | Disposable interface { 67 | Destroy() 68 | } 69 | // DisposableWithContainer 在Bean注销时调用 70 | DisposableWithContainer interface { 71 | Destroy(DI) 72 | } 73 | ) 74 | -------------------------------------------------------------------------------- /container_interface.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "context" 4 | 5 | type DI interface { 6 | DebugMode(bool) DI 7 | 8 | Log(log Log) DI 9 | 10 | RegisterBean(bean interface{}) DI 11 | 12 | RegisterNamedBean(name string, bean interface{}) DI 13 | 14 | Provide(prototype interface{}) DI 15 | 16 | ProvideNamedBean(beanName string, prototype interface{}) DI 17 | 18 | GetBean(beanName string) (bean interface{}, ok bool) 19 | 20 | GetByType(beanType interface{}) (bean interface{}, ok bool) 21 | 22 | GetByTypeAll(beanType interface{}) (beans []BeanWithName) 23 | 24 | NewBean(beanType interface{}) (bean interface{}) 25 | 26 | NewBeanByName(beanName string) (bean interface{}) 27 | 28 | UseValueStore(v ValueStore) DI 29 | 30 | Property() ValueStore 31 | 32 | SetDefaultProperty(key string, value interface{}) DI 33 | 34 | SetDefaultPropertyMap(properties map[string]interface{}) DI 35 | 36 | SetProperty(key string, value interface{}) DI 37 | 38 | SetPropertyMap(properties map[string]interface{}) DI 39 | 40 | GetProperty(key string) interface{} 41 | 42 | LoadProperties(prefix string, propertyType interface{}) interface{} 43 | 44 | Load() 45 | 46 | Serve(ctx context.Context) 47 | 48 | Context() context.Context 49 | } 50 | -------------------------------------------------------------------------------- /di.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "container/list" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "github.com/cheivin/di/van" 9 | "reflect" 10 | "runtime" 11 | ) 12 | 13 | type ( 14 | di struct { 15 | log Log 16 | beanDefinitionMap map[string]definition // Name:bean定义 17 | prototypeMap map[string]interface{} // Name:初始化的bean 18 | beanMap map[string]interface{} // Name:bean实例 19 | loaded bool 20 | unsafe bool 21 | valueStore ValueStore 22 | beanSort *list.List 23 | ctx context.Context 24 | } 25 | ) 26 | 27 | func (container *di) Context() context.Context { 28 | return container.ctx 29 | } 30 | 31 | var ( 32 | ErrBean = errors.New("error bean") 33 | ErrDefinition = errors.New("error definition") 34 | ErrLoaded = errors.New("di loaded") 35 | ) 36 | 37 | func New() *di { 38 | return &di{ 39 | log: stdLogger(), 40 | beanDefinitionMap: map[string]definition{}, 41 | prototypeMap: map[string]interface{}{}, 42 | beanMap: map[string]interface{}{}, 43 | valueStore: van.New(), 44 | beanSort: list.New(), 45 | ctx: context.Background(), 46 | } 47 | } 48 | 49 | func (container *di) UnsafeMode(open bool) DI { 50 | container.unsafe = open 51 | container.log.Warn("Unsafe mode enabled!") 52 | return container 53 | } 54 | 55 | func (container *di) parseBeanType(beanType interface{}) (prototype reflect.Type, beanName string) { 56 | prototype = reflect.Indirect(reflect.ValueOf(beanType)).Type() 57 | // 生成beanName 58 | tmpBeanName := reflect.New(prototype).Interface() 59 | switch tmpBeanName.(type) { 60 | case BeanName: 61 | if name := tmpBeanName.(BeanName).BeanName(); name != "" { 62 | container.log.Debug(fmt.Sprintf("beanName generate by interface BeanName for type %T, beanName: %s", beanType, name)) 63 | beanName = name 64 | } 65 | } 66 | if beanName == "" { 67 | beanName = GetBeanName(beanType) 68 | container.log.Debug(fmt.Sprintf("beanName generate by default for type %T, beanName: %s", beanType, beanName)) 69 | } 70 | return 71 | } 72 | 73 | func (container *di) DebugMode(enable bool) DI { 74 | container.log.DebugMode(enable) 75 | return container 76 | } 77 | 78 | func (container *di) Log(log Log) DI { 79 | container.log = log 80 | return container 81 | } 82 | 83 | // RegisterBean 注册一个已生成的bean,根据bean类型生成beanName 84 | func (container *di) RegisterBean(bean interface{}) DI { 85 | return container.RegisterNamedBean("", bean) 86 | } 87 | 88 | // RegisterNamedBean 以指定名称注册一个bean 89 | func (container *di) RegisterNamedBean(beanName string, bean interface{}) DI { 90 | if !IsPtr(bean) { 91 | container.log.Fatal(fmt.Sprintf("%s: bean must be a pointer", ErrBean)) 92 | return container 93 | } 94 | if beanName == "" { 95 | _, beanName = container.parseBeanType(bean) 96 | } 97 | if _, exist := container.beanMap[beanName]; exist { 98 | container.log.Fatal(fmt.Sprintf("%s: bean %s already exists", ErrBean, beanName)) 99 | return container 100 | } 101 | container.beanMap[beanName] = bean 102 | // 加入队列 103 | container.beanSort.PushBack(beanName) 104 | container.log.Info(fmt.Sprintf("register bean with name: %s", beanName)) 105 | return container 106 | } 107 | 108 | func (container *di) Provide(prototype interface{}) DI { 109 | container.ProvideNamedBean("", prototype) 110 | return container 111 | } 112 | 113 | func (container *di) ProvideNamedBean(beanName string, beanType interface{}) DI { 114 | if container.loaded { 115 | container.log.Fatal(ErrLoaded.Error()) 116 | return container 117 | } 118 | var prototype reflect.Type 119 | if beanName == "" { 120 | prototype, beanName = container.parseBeanType(beanType) 121 | } else { 122 | prototype, _ = container.parseBeanType(beanType) 123 | } 124 | 125 | // 检查bean重复 126 | if _, exist := container.beanMap[beanName]; exist { 127 | container.log.Fatal(fmt.Sprintf("%s: bean %s already exists", ErrBean, beanName)) 128 | return container 129 | } 130 | // 检查beanDefinition重复 131 | if existDefinition, exist := container.beanDefinitionMap[beanName]; exist { 132 | container.log.Fatal(fmt.Sprintf("%s: bean %s already defined by %s", ErrDefinition, beanName, existDefinition.Type.String())) 133 | return container 134 | } else { 135 | container.beanDefinitionMap[beanName] = container.newDefinition(beanName, prototype) 136 | // 加入队列 137 | container.beanSort.PushBack(beanName) 138 | } 139 | container.log.Info(fmt.Sprintf("provide bean with name: %s", beanName)) 140 | return container 141 | } 142 | 143 | func (container *di) GetBean(beanName string) (interface{}, bool) { 144 | bean, ok := container.beanMap[beanName] 145 | return bean, ok 146 | } 147 | 148 | func (container *di) getAllByType(beanType interface{}, limitOne bool) (beans []BeanWithName) { 149 | var typeValue reflect.Type 150 | if IsPtr(beanType) { 151 | typeValue = reflect.ValueOf(beanType).Elem().Type() 152 | if typeValue.Kind() == reflect.Struct { 153 | typeValue = reflect.PtrTo(typeValue) 154 | } 155 | } else { 156 | typeValue = reflect.PtrTo(reflect.TypeOf(beanType)) 157 | } 158 | for name, bean := range container.beanMap { 159 | if reflect.TypeOf(bean).AssignableTo(typeValue) { 160 | beans = append(beans, BeanWithName{ 161 | Name: name, 162 | Bean: bean, 163 | }) 164 | if limitOne { 165 | return 166 | } 167 | } 168 | } 169 | return 170 | } 171 | 172 | func (container *di) GetByType(beanType interface{}) (interface{}, bool) { 173 | beans := container.getAllByType(beanType, true) 174 | if len(beans) == 0 { 175 | return nil, false 176 | } else { 177 | return beans[0].Bean, true 178 | } 179 | } 180 | 181 | func (container *di) GetByTypeAll(beanType interface{}) (beans []BeanWithName) { 182 | return container.getAllByType(beanType, false) 183 | } 184 | 185 | func (container *di) NewBean(beanType interface{}) (bean interface{}) { 186 | prototype, beanName := container.parseBeanType(beanType) 187 | // 检查beanDefinition是否存在 188 | if _, exist := container.beanDefinitionMap[beanName]; !exist { 189 | return container.newBean(container.newDefinition(beanName, prototype)) 190 | } else { 191 | return container.NewBeanByName(beanName) 192 | } 193 | } 194 | 195 | func (container *di) NewBeanByName(beanName string) (bean interface{}) { 196 | def, ok := container.beanDefinitionMap[beanName] 197 | if !ok { 198 | panic(fmt.Errorf("%w: %s notfound", ErrDefinition, beanName)) 199 | } 200 | return container.newBean(def) 201 | } 202 | 203 | func (container *di) newBean(def definition) (bean interface{}) { 204 | container.log.Info(fmt.Sprintf("new bean instance %s", def.Name)) 205 | // 反射实例并注入值 206 | prototype := container.instanceBean(def) 207 | // 触发构造方法 208 | container.constructBean(def.Name, prototype) 209 | // 触发注入 bean 210 | bean = container.processBean(prototype, def) 211 | // 初始化完成 212 | container.initializedBean(def.Name, bean) 213 | // 使用析构函数来完成 bean 的 destroy 214 | runtime.SetFinalizer(bean, func(bean interface{}) { 215 | container.destroyBean(def.Name, bean) 216 | }) 217 | return 218 | } 219 | 220 | func (container *di) Load() { 221 | if container.loaded { 222 | panic(ErrLoaded) 223 | } 224 | 225 | container.loaded = true 226 | container.initializeBeans() 227 | container.processBeans() 228 | container.initialized() 229 | 230 | } 231 | 232 | func (container *di) Serve(ctx context.Context) { 233 | if !container.loaded { 234 | panic(ErrLoaded) 235 | } 236 | var cancel context.CancelFunc 237 | container.ctx, cancel = context.WithCancel(ctx) 238 | <-ctx.Done() 239 | defer cancel() 240 | container.destroyBeans() 241 | } 242 | 243 | // initializeBeans 初始化bean对象 244 | func (container *di) initializeBeans() { 245 | // 创建类型的指针对象 246 | for beanName, def := range container.beanDefinitionMap { 247 | container.prototypeMap[beanName] = container.instanceBean(def) 248 | } 249 | // 根据排序遍历触发BeanConstruct方法 250 | for e := container.beanSort.Front(); e != nil; e = e.Next() { 251 | beanName := e.Value.(string) 252 | if prototype, ok := container.prototypeMap[beanName]; ok { 253 | container.constructBean(beanName, prototype) 254 | } 255 | } 256 | } 257 | 258 | // processBeans 注入依赖 259 | func (container *di) processBeans() { 260 | for e := container.beanSort.Front(); e != nil; e = e.Next() { 261 | beanName := e.Value.(string) 262 | if prototype, ok := container.prototypeMap[beanName]; ok { 263 | def := container.beanDefinitionMap[beanName] 264 | // 加载为bean 265 | container.log.Info(fmt.Sprintf("initialize bean %s(%T)", def.Name, prototype)) 266 | // 加载完成的bean放入beanMap中 267 | container.beanMap[beanName] = container.processBean(prototype, def) 268 | } 269 | } 270 | } 271 | 272 | // initialized 容器初始化完成 273 | func (container *di) initialized() { 274 | for e := container.beanSort.Front(); e != nil; e = e.Next() { 275 | beanName := e.Value.(string) 276 | bean := container.beanMap[beanName] 277 | container.initializedBean(beanName, bean) 278 | } 279 | } 280 | 281 | func (container *di) destroyBeans() { 282 | // 倒序销毁bean 283 | for e := container.beanSort.Back(); e != nil; e = e.Prev() { 284 | beanName := e.Value.(string) 285 | if bean, ok := container.beanMap[beanName]; ok { 286 | container.destroyBean(beanName, bean) 287 | delete(container.beanMap, beanName) 288 | } 289 | } 290 | } 291 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/cheivin/di" 7 | "log" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | ) 12 | 13 | type ( 14 | DB struct { 15 | Prefix string 16 | } 17 | 18 | DB2 struct { 19 | Prefix string 20 | } 21 | 22 | DB3 struct { 23 | } 24 | 25 | UserDao struct { 26 | Db *DB `aware:"db"` 27 | Db2 *DB2 `aware:""` 28 | Db3 *DB3 `aware:"db3,omitempty"` 29 | TableName string 30 | DefaultAge int `value:"base.user.age"` 31 | DefaultName string `value:"base.user.name"` 32 | DefaultType uint8 `value:"base.user.type"` 33 | DefaultCacheTime time.Duration `value:"base.user.cache"` 34 | DefaultExpire time.Duration `value:"base.user.expire"` 35 | } 36 | 37 | WalletDao struct { 38 | Db *DB `aware:"db"` 39 | TableName string 40 | } 41 | 42 | OrderRepository interface { 43 | TableName() string 44 | } 45 | 46 | OrderDao struct { 47 | Db *DB `aware:"db"` 48 | } 49 | 50 | UserService struct { 51 | UserDao *UserDao `aware:""` 52 | Wallet *WalletDao `aware:""` 53 | OrderDao OrderRepository `aware:""` 54 | } 55 | ) 56 | 57 | func (DB2) BeanName() string { 58 | fmt.Println("获取DB2名称:db2") 59 | return "db2" 60 | } 61 | 62 | func (o *OrderDao) TableName() string { 63 | return o.Db.Prefix + "order" 64 | } 65 | 66 | func (u UserService) PreInitialize(container di.DI) { 67 | fmt.Println("依赖注入", "UserService", container.Property()) 68 | } 69 | 70 | func (u UserService) BeanInject(di di.DI, info *di.InjectInfo) { 71 | switch info.BeanName { 72 | case "orderRepository": 73 | info.Bean = &OrderDao{&DB{Prefix: "BeanInject"}} 74 | } 75 | fmt.Println("BeanInject:", info.BeanName) 76 | 77 | } 78 | 79 | func (u *UserDao) BeanName() string { 80 | return "user" 81 | } 82 | 83 | func (u *UserDao) AfterPropertiesSet() { 84 | fmt.Println("装载完成", "UserDao") 85 | fmt.Println("userDao.DB2", u.Db2) 86 | fmt.Println("userDao.DB3", u.Db3) 87 | u.TableName = "user" 88 | } 89 | 90 | func (w *WalletDao) Initialized() { 91 | fmt.Println("加载完成", "WalletDao") 92 | w.TableName = "wallet" 93 | } 94 | 95 | func (o *OrderDao) BeanConstruct() { 96 | fmt.Println("构造实例", "OrderDao") 97 | } 98 | 99 | func (u *OrderDao) BeanName() string { 100 | return "order" 101 | } 102 | 103 | func (u *UserService) GetUserTable() string { 104 | return u.UserDao.Db.Prefix + u.UserDao.TableName 105 | } 106 | 107 | func (u *UserService) GetUserDefault() map[string]interface{} { 108 | return map[string]interface{}{ 109 | "age": u.UserDao.DefaultAge, 110 | "name": u.UserDao.DefaultName, 111 | "type": u.UserDao.DefaultType, 112 | "cache": u.UserDao.DefaultCacheTime, 113 | "expire": u.UserDao.DefaultExpire, 114 | } 115 | } 116 | 117 | func (u *UserService) GetWalletTable() string { 118 | return u.Wallet.Db.Prefix + u.Wallet.TableName 119 | } 120 | 121 | func (u *UserService) GetOrderTable() string { 122 | return u.OrderDao.TableName() 123 | } 124 | 125 | func (u *UserService) Destroy() { 126 | fmt.Println("注销实例", "UserService") 127 | } 128 | 129 | func (d *DB) Destroy() { 130 | fmt.Println("注销实例", "DB") 131 | } 132 | 133 | func main() { 134 | di.RegisterNamedBean("db", &DB{Prefix: "test_"}). 135 | RegisterBean(&DB2{Prefix: "xxx_"}). 136 | ProvideNamedBean("user", UserDao{}). 137 | Provide(WalletDao{}). 138 | Provide(OrderDao{}). 139 | ProvideNamedBean("multiOne", OrderDao{}). 140 | Provide(UserService{}). 141 | SetDefaultProperty("base.user.name", "新用户"). 142 | SetProperty("base.user.age", 25). 143 | SetProperty("base.user.name", "新注册用户"). 144 | SetProperty("base.user.type", "8"). 145 | SetProperty("base.user.cache", "30000"). 146 | SetProperty("base.user.expire", "1h"). 147 | Load() 148 | 149 | bean, ok := di.GetBean("userService") 150 | if ok { 151 | log.Println(bean.(*UserService).GetUserTable()) 152 | log.Println(bean.(*UserService).GetWalletTable()) 153 | log.Println(bean.(*UserService).GetOrderTable()) 154 | log.Println(bean.(*UserService).GetUserDefault()) 155 | } 156 | // 退出信号 157 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 158 | defer stop() 159 | di.Serve(ctx) 160 | fmt.Println("容器退出") 161 | } 162 | -------------------------------------------------------------------------------- /global.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | var g DI 10 | 11 | func init() { 12 | g = New() 13 | } 14 | 15 | func RegisterBean(bean interface{}) DI { 16 | return g.RegisterBean(bean) 17 | } 18 | 19 | func RegisterNamedBean(name string, bean interface{}) DI { 20 | return g.RegisterNamedBean(name, bean) 21 | } 22 | 23 | func Provide(prototype interface{}) DI { 24 | return g.Provide(prototype) 25 | } 26 | 27 | func ProvideNamedBean(beanName string, prototype interface{}) DI { 28 | return g.ProvideNamedBean(beanName, prototype) 29 | } 30 | 31 | func GetBean(beanName string) (bean interface{}, ok bool) { 32 | return g.GetBean(beanName) 33 | } 34 | 35 | func GetByType(beanType interface{}) (bean interface{}, ok bool) { 36 | return g.GetByType(beanType) 37 | } 38 | 39 | func GetByTypeAll(beanType interface{}) (beans []BeanWithName) { 40 | return g.GetByTypeAll(beanType) 41 | } 42 | 43 | func NewBean(beanType interface{}) (bean interface{}) { 44 | return g.NewBean(beanType) 45 | } 46 | 47 | func NewBeanByName(beanName string) (bean interface{}) { 48 | return g.NewBeanByName(beanName) 49 | } 50 | 51 | func UseValueStore(v ValueStore) DI { 52 | g.UseValueStore(v) 53 | return g 54 | } 55 | 56 | func Property() ValueStore { 57 | return g.Property() 58 | } 59 | 60 | func SetDefaultProperty(key string, value interface{}) DI { 61 | return g.SetDefaultProperty(key, value) 62 | } 63 | 64 | func SetDefaultPropertyMap(properties map[string]interface{}) DI { 65 | return g.SetDefaultPropertyMap(properties) 66 | } 67 | 68 | func SetProperty(key string, value interface{}) DI { 69 | return g.SetProperty(key, value) 70 | } 71 | 72 | func SetPropertyMap(properties map[string]interface{}) DI { 73 | return g.SetPropertyMap(properties) 74 | } 75 | 76 | func GetProperty(key string) interface{} { 77 | return g.GetProperty(key) 78 | } 79 | 80 | func LoadProperties(prefix string, propertyType interface{}) interface{} { 81 | return g.LoadProperties(prefix, propertyType) 82 | } 83 | 84 | func AutoMigrateEnv() { 85 | envMap := LoadEnvironment(strings.NewReplacer("_", "."), false) 86 | SetPropertyMap(envMap) 87 | } 88 | 89 | func LoadEnvironment(replacer *strings.Replacer, trimPrefix bool, prefix ...string) map[string]interface{} { 90 | environ := os.Environ() 91 | envMap := make(map[string]interface{}, len(environ)) 92 | for _, env := range environ { 93 | kv := strings.SplitN(env, "=", 2) 94 | if ok, pfx := hasPrefix(kv[0], prefix); !ok { 95 | continue 96 | } else if trimPrefix { 97 | kv[0] = strings.TrimPrefix(kv[0], pfx) 98 | } 99 | var property string 100 | if replacer != nil { 101 | property = replacer.Replace(kv[0]) 102 | } else { 103 | property = kv[0] 104 | } 105 | envMap[property] = kv[1] 106 | } 107 | return envMap 108 | } 109 | 110 | func Load() { 111 | g.Load() 112 | } 113 | 114 | func Serve(ctx context.Context) { 115 | g.Serve(ctx) 116 | } 117 | 118 | func LoadAndServ(ctx context.Context) { 119 | g.Load() 120 | g.Serve(ctx) 121 | } 122 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cheivin/di 2 | 3 | go 1.18 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheivin/di/3b4c7c8d44a39675f74a363541be0994e8d88980/go.sum -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "io" 5 | "os" 6 | ) 7 | 8 | type Log interface { 9 | DebugMode(bool) 10 | Debug(string) 11 | Info(string) 12 | Warn(string) 13 | Fatal(string) 14 | } 15 | 16 | type logger struct { 17 | debugMode bool 18 | writer io.Writer 19 | errWriter io.Writer 20 | } 21 | 22 | func stdLogger() Log { 23 | return logger{ 24 | debugMode: false, 25 | writer: os.Stdout, 26 | errWriter: os.Stderr, 27 | } 28 | } 29 | 30 | func (l logger) DebugMode(b bool) { 31 | l.debugMode = b 32 | } 33 | 34 | func (l logger) Debug(s string) { 35 | if !l.debugMode { 36 | return 37 | } 38 | _, _ = l.writer.Write([]byte("[DI-DEBUG] : " + s + "\n")) 39 | } 40 | 41 | func (l logger) Info(s string) { 42 | _, _ = l.writer.Write([]byte("[DI-INFO] : " + s + "\n")) 43 | } 44 | 45 | func (l logger) Warn(s string) { 46 | _, _ = l.errWriter.Write([]byte("[DI-WARN] : " + s + "\n")) 47 | } 48 | 49 | func (l logger) Fatal(s string) { 50 | _, _ = l.errWriter.Write([]byte("[DI-FATAL] : " + s + "\n")) 51 | os.Exit(1) 52 | } 53 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import ( 4 | "reflect" 5 | "sort" 6 | "strings" 7 | ) 8 | 9 | func IsPtr(o interface{}) bool { 10 | return reflect.TypeOf(o).Kind() == reflect.Ptr 11 | } 12 | 13 | func GetBeanName(o interface{}) (name string) { 14 | if t, ok := o.(reflect.Type); ok { 15 | if t.Kind() == reflect.Ptr { 16 | t = t.Elem() 17 | } 18 | name = t.Name() 19 | } else { 20 | name = reflect.Indirect(reflect.ValueOf(o)).Type().Name() 21 | } 22 | // 简单粗暴将首字母小写 23 | name = strings.ToLower(name[:1]) + name[1:] 24 | return 25 | } 26 | 27 | func in(target string, array []string) bool { 28 | sort.Strings(array) 29 | index := sort.SearchStrings(array, target) 30 | if index < len(array) && array[index] == target { 31 | return true 32 | } 33 | return false 34 | } 35 | 36 | func hasPrefix(prefix string, array []string) (bool, string) { 37 | if len(array) == 0 { 38 | return true, "" 39 | } 40 | for i := range array { 41 | if strings.HasPrefix(prefix, array[i]) { 42 | return true, array[i] 43 | } 44 | } 45 | return false, "" 46 | } 47 | -------------------------------------------------------------------------------- /value_store.go: -------------------------------------------------------------------------------- 1 | package di 2 | 3 | import "reflect" 4 | 5 | type ValueStore interface { 6 | SetDefault(key string, value interface{}) 7 | 8 | Set(key string, value interface{}) 9 | 10 | Get(key string) (val interface{}) 11 | 12 | GetAll() map[string]interface{} 13 | } 14 | 15 | func (container *di) UseValueStore(v ValueStore) DI { 16 | container.valueStore = v 17 | return container 18 | } 19 | 20 | func (container *di) Property() ValueStore { 21 | return container.valueStore 22 | } 23 | 24 | func (container *di) SetDefaultProperty(key string, value interface{}) DI { 25 | container.valueStore.SetDefault(key, value) 26 | return container 27 | } 28 | 29 | func (container *di) SetDefaultPropertyMap(properties map[string]interface{}) DI { 30 | for key, value := range properties { 31 | container.valueStore.SetDefault(key, value) 32 | } 33 | return container 34 | } 35 | 36 | func (container *di) SetProperty(key string, value interface{}) DI { 37 | container.valueStore.Set(key, value) 38 | return container 39 | } 40 | 41 | func (container *di) SetPropertyMap(properties map[string]interface{}) DI { 42 | for key, value := range properties { 43 | container.valueStore.Set(key, value) 44 | } 45 | return container 46 | } 47 | 48 | func (container *di) GetProperty(key string) interface{} { 49 | return container.valueStore.Get(key) 50 | } 51 | 52 | func (container *di) LoadProperties(prefix string, propertyType interface{}) interface{} { 53 | prototype := reflect.Indirect(reflect.ValueOf(propertyType)).Type() 54 | def := container.getValueDefinition(prototype) 55 | bean := reflect.New(def.Type) 56 | container.wireValue(bean.Elem(), def, prefix) 57 | return bean.Elem().Interface() 58 | } 59 | -------------------------------------------------------------------------------- /van/cast.go: -------------------------------------------------------------------------------- 1 | package van 2 | 3 | import ( 4 | "reflect" 5 | "strconv" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | func indirect(v interface{}) interface{} { 11 | value := reflect.Indirect(reflect.ValueOf(v)) 12 | if val, ok := value.Interface().(reflect.Value); ok { 13 | return val.Interface() 14 | } else { 15 | return value.Interface() 16 | } 17 | } 18 | 19 | func isMap(v interface{}) bool { 20 | v = indirect(v) 21 | return reflect.ValueOf(v).Kind() == reflect.Map 22 | } 23 | 24 | func toString(v interface{}) string { 25 | v = indirect(v) 26 | switch s := v.(type) { 27 | case string: 28 | return s 29 | case bool: 30 | return strconv.FormatBool(s) 31 | case float64: 32 | return strconv.FormatFloat(s, 'f', -1, 64) 33 | case float32: 34 | return strconv.FormatFloat(float64(s), 'f', -1, 32) 35 | case int: 36 | return strconv.Itoa(s) 37 | case int64: 38 | return strconv.FormatInt(s, 10) 39 | case int32: 40 | return strconv.Itoa(int(s)) 41 | case int16: 42 | return strconv.FormatInt(int64(s), 10) 43 | case int8: 44 | return strconv.FormatInt(int64(s), 10) 45 | case uint: 46 | return strconv.FormatUint(uint64(s), 10) 47 | case uint64: 48 | return strconv.FormatUint(s, 10) 49 | case uint32: 50 | return strconv.FormatUint(uint64(s), 10) 51 | case uint16: 52 | return strconv.FormatUint(uint64(s), 10) 53 | case uint8: 54 | return strconv.FormatUint(uint64(s), 10) 55 | case time.Duration: 56 | return s.String() 57 | default: 58 | return "" 59 | } 60 | } 61 | 62 | var typeDuration = reflect.TypeOf(time.Nanosecond) 63 | 64 | func Cast(v interface{}, typ reflect.Type) (to interface{}, err error) { 65 | v = indirect(v) 66 | if typ.Kind() == reflect.String { 67 | return toString(v), nil 68 | } 69 | value := reflect.ValueOf(v) 70 | if value.Type().ConvertibleTo(typ) && typ != typeDuration { 71 | return value.Convert(typ).Interface(), nil 72 | } 73 | s := toString(v) 74 | switch typ.Kind() { 75 | case reflect.Bool: 76 | to, err = strconv.ParseBool(s) 77 | if err != nil { 78 | return nil, err 79 | } 80 | case reflect.Float64: 81 | to, err = strconv.ParseFloat(s, 64) 82 | if err != nil { 83 | return nil, err 84 | } 85 | case reflect.Float32: 86 | to, err = strconv.ParseFloat(s, 32) 87 | if err != nil { 88 | return nil, err 89 | } 90 | to = float32(to.(float64)) 91 | case reflect.Int: 92 | to, err = strconv.ParseInt(s, 10, 0) 93 | if err != nil { 94 | return nil, err 95 | } 96 | to = int(to.(int64)) 97 | case reflect.Int64: 98 | if typ == typeDuration { 99 | to, err = time.ParseDuration(s) 100 | if err != nil && strings.HasPrefix(err.Error(), "time: missing unit in duration") { 101 | to, err = strconv.ParseInt(s, 10, 64) 102 | if err != nil { 103 | return nil, err 104 | } 105 | to = time.Duration(to.(int64)) * time.Millisecond 106 | } 107 | } else { 108 | to, err = strconv.ParseInt(s, 10, 64) 109 | if err != nil { 110 | return nil, err 111 | } 112 | } 113 | case reflect.Int32: 114 | to, err = strconv.ParseInt(s, 10, 32) 115 | if err != nil { 116 | return nil, err 117 | } 118 | to = int32(to.(int64)) 119 | case reflect.Int16: 120 | to, err = strconv.ParseInt(s, 10, 16) 121 | if err != nil { 122 | return nil, err 123 | } 124 | to = int16(to.(int64)) 125 | case reflect.Int8: 126 | to, err = strconv.ParseInt(s, 10, 8) 127 | if err != nil { 128 | return nil, err 129 | } 130 | to = int8(to.(int64)) 131 | case reflect.Uint: 132 | to, err = strconv.ParseUint(s, 10, 0) 133 | if err != nil { 134 | return nil, err 135 | } 136 | to = uint(to.(uint64)) 137 | case reflect.Uint64: 138 | to, err = strconv.ParseUint(s, 10, 64) 139 | if err != nil { 140 | return nil, err 141 | } 142 | case reflect.Uint32: 143 | to, err = strconv.ParseUint(s, 10, 32) 144 | if err != nil { 145 | return nil, err 146 | } 147 | to = uint32(to.(uint64)) 148 | case reflect.Uint16: 149 | to, err = strconv.ParseUint(s, 10, 16) 150 | if err != nil { 151 | return nil, err 152 | } 153 | to = uint16(to.(uint64)) 154 | case reflect.Uint8: 155 | to, err = strconv.ParseUint(s, 10, 8) 156 | if err != nil { 157 | return nil, err 158 | } 159 | to = uint8(to.(uint64)) 160 | default: 161 | } 162 | return 163 | } 164 | -------------------------------------------------------------------------------- /van/store.go: -------------------------------------------------------------------------------- 1 | package van 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | ) 7 | 8 | type store struct { 9 | separator string 10 | tree map[string]interface{} 11 | } 12 | 13 | func newStore(separator string) *store { 14 | return &store{separator: separator, tree: make(map[string]interface{})} 15 | } 16 | 17 | func toCaseInsensitiveMap(value interface{}, separator string) map[string]interface{} { 18 | m := make(map[string]interface{}) 19 | 20 | iter := reflect.ValueOf(value).MapRange() 21 | for iter.Next() { 22 | key := strings.ToLower(toString(iter.Key())) 23 | val := iter.Value() 24 | keyPath := strings.Split(key, separator) 25 | if len(keyPath) > 1 { 26 | tmpV := deepSearchIfAbsent(m, keyPath[0:len(keyPath)-1]) 27 | lastKey := keyPath[len(keyPath)-1] 28 | if isMap(val) { 29 | tmpV[lastKey] = toCaseInsensitiveMap(val.Interface(), separator) 30 | } else { 31 | tmpV[lastKey] = val.Interface() 32 | } 33 | } else { 34 | if isMap(val) { 35 | m[key] = toCaseInsensitiveMap(val.Interface(), separator) 36 | } else { 37 | m[key] = val.Interface() 38 | } 39 | } 40 | } 41 | 42 | return m 43 | } 44 | 45 | func copyStringMap(origin map[string]interface{}) map[string]interface{} { 46 | m := make(map[string]interface{}, len(origin)) 47 | iter := reflect.ValueOf(origin).MapRange() 48 | for iter.Next() { 49 | key := iter.Key().String() 50 | if isMap(iter.Value()) { 51 | m[key] = copyStringMap(iter.Value().Interface().(map[string]interface{})) 52 | } else { 53 | m[key] = iter.Value().Interface() 54 | } 55 | } 56 | return m 57 | } 58 | 59 | func mergeStringMap(source map[string]interface{}, target map[string]interface{}) { 60 | for sk, sv := range source { 61 | tv, ok := target[sk] 62 | if !ok { 63 | target[sk] = sv 64 | } else { 65 | tvm := isMap(tv) 66 | svm := isMap(sv) 67 | if tvm && svm { 68 | mergeStringMap(sv.(map[string]interface{}), tv.(map[string]interface{})) 69 | } else if !tvm && !svm { 70 | target[sk] = sv 71 | } 72 | } 73 | } 74 | } 75 | 76 | func deepSearchIfAbsent(tree map[string]interface{}, path []string) map[string]interface{} { 77 | if len(path) == 0 { 78 | return tree 79 | } 80 | key := path[0] 81 | subPath := path[1:] 82 | if sub, ok := tree[key]; !ok { 83 | // map不存在则创建新map 84 | emptyTree := make(map[string]interface{}) 85 | tree[key] = emptyTree 86 | return deepSearchIfAbsent(emptyTree, subPath) 87 | } else { 88 | subTree, ok := sub.(map[string]interface{}) 89 | if !ok { 90 | // 强转失败则用新map代替 91 | subTree = make(map[string]interface{}) 92 | tree[key] = subTree 93 | } 94 | return deepSearchIfAbsent(subTree, subPath) 95 | } 96 | } 97 | 98 | func deepSearch(v interface{}, path []string) interface{} { 99 | if v == nil || len(path) == 0 { 100 | return v 101 | } 102 | if tree, ok := v.(map[string]interface{}); !ok { 103 | if len(path) == 1 { 104 | return v 105 | } 106 | } else { 107 | key := path[0] 108 | subPath := path[1:] 109 | return deepSearch(tree[key], subPath) 110 | } 111 | return nil 112 | } 113 | 114 | func (s *store) Set(key string, value interface{}) { 115 | key = strings.ToLower(key) 116 | if isMap(value) { 117 | value = toCaseInsensitiveMap(value, s.separator) 118 | } 119 | keyPath := strings.Split(key, s.separator) 120 | lastKey := keyPath[len(keyPath)-1] 121 | tree := deepSearchIfAbsent(s.tree, keyPath[0:len(keyPath)-1]) 122 | 123 | if sub, ok := tree[lastKey]; !ok { 124 | tree[lastKey] = value 125 | } else { 126 | if isMap(sub) && isMap(value) { 127 | mergeStringMap(value.(map[string]interface{}), sub.(map[string]interface{})) 128 | } else { 129 | tree[lastKey] = value 130 | } 131 | } 132 | } 133 | 134 | func (s *store) Get(key string) interface{} { 135 | key = strings.ToLower(key) 136 | keyPath := strings.Split(key, s.separator) 137 | return deepSearch(s.tree, keyPath) 138 | } 139 | 140 | func (s *store) GetAll() map[string]interface{} { 141 | return s.tree 142 | } 143 | -------------------------------------------------------------------------------- /van/van.go: -------------------------------------------------------------------------------- 1 | package van 2 | 3 | type Van struct { 4 | defaults *store 5 | override *store 6 | } 7 | 8 | func New() *Van { 9 | separator := "." 10 | return &Van{defaults: newStore(separator), override: newStore(separator)} 11 | } 12 | 13 | func (v *Van) SetDefault(key string, value interface{}) { 14 | v.defaults.Set(key, value) 15 | } 16 | 17 | func (v *Van) Set(key string, value interface{}) { 18 | v.override.Set(key, value) 19 | } 20 | 21 | func (v *Van) Get(key string) (val interface{}) { 22 | val = v.override.Get(key) 23 | if val == nil { 24 | val = v.defaults.Get(key) 25 | } 26 | return val 27 | } 28 | 29 | func (v *Van) GetAll() map[string]interface{} { 30 | mergeMap := copyStringMap(v.override.GetAll()) 31 | mergeStringMap(v.defaults.GetAll(), mergeMap) 32 | return mergeMap 33 | } 34 | --------------------------------------------------------------------------------