├── LICENSE.txt ├── example_test.go ├── README.md ├── vars.go ├── flags.go └── flags_test.go /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Uber Technologies, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Uber Technologies, Inc. 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy 4 | // of this software and associated documentation files (the "Software"), to deal 5 | // in the Software without restriction, including without limitation the rights 6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software is 8 | // furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included in 11 | // all copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | // THE SOFTWARE. 20 | 21 | package flags 22 | 23 | import ( 24 | "fmt" 25 | "time" 26 | ) 27 | 28 | func ExampleParseArgs() { 29 | type Logging struct { 30 | Interval int 31 | Path string 32 | } 33 | type Socket struct { 34 | ReadTimeout time.Duration `yaml:"read_timeout"` 35 | WriteTimeout time.Duration 36 | } 37 | 38 | type TCP struct { 39 | ReadTimeout time.Duration 40 | Socket 41 | } 42 | 43 | type Network struct { 44 | ReadTimeout time.Duration 45 | WriteTimeout time.Duration 46 | TCP 47 | } 48 | 49 | type Cfg struct { 50 | Logging 51 | Network 52 | } 53 | 54 | // this is just an example, normally one would use packages like yaml to 55 | // populate the struct rather than manually create it. 56 | c := &Cfg{ 57 | Logging: Logging{Interval: 3, Path: "/tmp"}, 58 | Network: Network{ 59 | TCP: TCP{ 60 | ReadTimeout: time.Duration(10) * time.Millisecond, 61 | Socket: Socket{ 62 | ReadTimeout: time.Duration(10) * time.Millisecond, 63 | }, 64 | }, 65 | }, 66 | } 67 | 68 | fmt.Printf("loaded config 'network.tcp.socket.read_timeout' is %s\n", c.Network.TCP.Socket.ReadTimeout) 69 | args := []string{"--logging.interval", "2", "--network.tcp.socket.read_timeout", "50ms"} 70 | ParseArgs(c, args) 71 | fmt.Printf("after override 'network.tcp.socket.read_timeout' is %s\n", c.Network.TCP.Socket.ReadTimeout) 72 | // Output: 73 | // loaded config 'network.tcp.socket.read_timeout' is 10ms 74 | // after override 'network.tcp.socket.read_timeout' is 50ms 75 | } 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2 | An automatic way of creating command line options to override fields from a struct. 3 | 4 | ## Installation 5 | `go get -u github.com/uber-go/flagoverride` 6 | 7 | 8 | ## Overview 9 | Typically, if one wants to load from a config file (e.g. yaml), one has to 10 | define a proper struct, then load values into it (e.g. yaml.Unmarshal()). 11 | However, there are situations where we want to load most of the configs from 12 | the file and to override some of the configs. 13 | 14 | Let's say we use a yaml to config our Db connections and upon start of the 15 | application we load from the yaml file to get the necessary parameters to 16 | create the connection. Our base.yaml looks like this: 17 | 18 | ```yaml 19 | base.yaml 20 | --- 21 | mysql: 22 | user: 'foo' 23 | password: 'xxxxxx' 24 | mysql_defaults_file: ./mysql_defaults.ini 25 | mysql_socket_path: /var/run/mysqld/mysqld.sock 26 | ... more config options ... 27 | ``` 28 | 29 | we want to load all the configs from it but we want to provide some 30 | flexibility for the program to connect via a different db user. We could 31 | define a --user command flag then after loading the yaml file, we override 32 | the user field with what we get from --user flag. 33 | 34 | If there are many overriding like this, manual define these flags is 35 | tedious. This package provides an automatic way to define this override, 36 | which is, given a struct, it'll create all the flags which are name using 37 | the field names of the struct. If one of these flags are set via command 38 | line, the struct will be modified in-place to reflect the value from command 39 | line, therefore the values of the fields in the struct are overridden. 40 | 41 | YAML is just used as an example here. In practice, one can use any struct tdefine flags. 42 | 43 | Let's say we have our configuration object as the following: 44 | 45 | ```go 46 | type logging struct { 47 | Interval int 48 | Path string 49 | } 50 | 51 | type socket struct { 52 | ReadTimeout time.Duration 53 | WriteTimeout time.Duration 54 | } 55 | 56 | type tcp struct { 57 | ReadTimeout time.Duration 58 | socket 59 | } 60 | 61 | type network struct { 62 | ReadTimeout time.Duration 63 | WriteTimeout time.Duration 64 | tcp 65 | } 66 | 67 | type Cfg struct { 68 | logging 69 | network 70 | } 71 | ``` 72 | 73 | The following code: 74 | 75 | ```go 76 | func main() { 77 | c := &Cfg{} 78 | flags.ParseArgs(c, os.Args[1:]) 79 | } 80 | ``` 81 | 82 | will create the following flags: 83 | 84 | ``` 85 | -logging.interval int 86 | logging.interval 87 | -logging.path string 88 | logging.path 89 | -network.readtimeout duration 90 | network.readtimeout 91 | -network.tcp.readtimeout duration 92 | network.tcp.readtimeout 93 | -network.tcp.socket.readtimeout duration 94 | network.tcp.socket.readtimeout 95 | -network.tcp.socket.writetimeout duration 96 | network.tcp.socket.writetimeout 97 | -network.writetimeout duration 98 | network.writetimeout 99 | ``` 100 | 101 | flags to subcommands are naturally supported. 102 | 103 | ```go 104 | func main() { 105 | cmd := os.Args[1] 106 | switch cmd { 107 | case "new" 108 | c1 := &Cfg1{} 109 | ParseArgs(c1, os.Args[2:]) 110 | case "update": 111 | c2 := &Cfg2{} 112 | ParseArgs(c2, os.Args[2:]) 113 | 114 | ... more sub commands ... 115 | } 116 | } 117 | ``` 118 | 119 | One can set Flatten to true when calling `NewFlagMakerAdv`, in which case, 120 | flags are created without namespacing. For example, 121 | 122 | ```go 123 | type auth struct { 124 | Token string 125 | Tag float64 126 | } 127 | 128 | type credentials struct { 129 | User string 130 | Password string 131 | auth 132 | } 133 | 134 | type database struct { 135 | DBName string 136 | TableName string 137 | credentials 138 | } 139 | 140 | type Cfg struct { 141 | logging 142 | database 143 | } 144 | 145 | func main() { 146 | c := &Cfg{} 147 | flags.ParseArgs(c, os.Args[1:]) 148 | } 149 | ``` 150 | 151 | will create the following flags: 152 | 153 | ``` 154 | -dbname string 155 | dbname 156 | -interval int 157 | interval 158 | -password string 159 | password 160 | -path string 161 | path 162 | -tablename string 163 | tablename 164 | -tag float 165 | tag 166 | -token string 167 | token 168 | -user string 169 | user 170 | ``` 171 | 172 | Please be aware that usual GoLang flag creation rules apply, i.e., if there are 173 | duplication in flag names (in the flattened case it's more likely to happen 174 | unless the caller make due diligence to create the struct properly), it panics. 175 | 176 | Note that not all types can have command line flags created for. 177 | 178 | `map`, `channel` and function type will not define a flag corresponding to the field. 179 | 180 | Pointer types are properly handled and slice type will create multi-value command line flags. 181 | 182 | That is, e.g. if a field foo's type is `[]int`, one can use 183 | --foo 10 --foo 15 --foo 20 to override this field value to be 184 | `[]int{10, 15, 20}`. For now, only `[]int`, `[]string` and `[]float64` are supported in this fashion. 185 | 186 |
187 | Released under the [MIT License](LICENSE.txt). 188 | -------------------------------------------------------------------------------- /vars.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Uber Technologies, Inc. 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy 4 | // of this software and associated documentation files (the "Software"), to deal 5 | // in the Software without restriction, including without limitation the rights 6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software is 8 | // furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included in 11 | // all copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | // THE SOFTWARE. 20 | 21 | package flags 22 | 23 | import ( 24 | "fmt" 25 | "strconv" 26 | ) 27 | 28 | // additional types 29 | type int8Value int8 30 | type int16Value int16 31 | type int32Value int32 32 | type f32Value float32 33 | type uint8Value uint8 34 | type uint32Value uint32 35 | type uint16Value uint16 36 | 37 | // Var handlers for each of the types 38 | func newInt8Value(p *int8) *int8Value { 39 | return (*int8Value)(p) 40 | } 41 | 42 | func newInt16Value(p *int16) *int16Value { 43 | return (*int16Value)(p) 44 | } 45 | 46 | func newInt32Value(p *int32) *int32Value { 47 | return (*int32Value)(p) 48 | } 49 | 50 | func newFloat32Value(p *float32) *f32Value { 51 | return (*f32Value)(p) 52 | } 53 | 54 | func newUint8Value(p *uint8) *uint8Value { 55 | return (*uint8Value)(p) 56 | } 57 | 58 | func newUint16Value(p *uint16) *uint16Value { 59 | return (*uint16Value)(p) 60 | } 61 | 62 | func newUint32Value(p *uint32) *uint32Value { 63 | return (*uint32Value)(p) 64 | } 65 | 66 | // Setters for each of the types 67 | func (f *int8Value) Set(s string) error { 68 | v, err := strconv.ParseInt(s, 10, 8) 69 | if err != nil { 70 | return err 71 | } 72 | *f = int8Value(v) 73 | return nil 74 | } 75 | 76 | func (f *int16Value) Set(s string) error { 77 | v, err := strconv.ParseInt(s, 10, 16) 78 | if err != nil { 79 | return err 80 | } 81 | *f = int16Value(v) 82 | return nil 83 | } 84 | 85 | func (f *int32Value) Set(s string) error { 86 | v, err := strconv.ParseInt(s, 10, 32) 87 | if err != nil { 88 | return err 89 | } 90 | *f = int32Value(v) 91 | return nil 92 | } 93 | 94 | func (f *f32Value) Set(s string) error { 95 | v, err := strconv.ParseFloat(s, 32) 96 | if err != nil { 97 | return err 98 | } 99 | *f = f32Value(v) 100 | return nil 101 | } 102 | 103 | func (f *uint8Value) Set(s string) error { 104 | v, err := strconv.ParseUint(s, 10, 8) 105 | if err != nil { 106 | return err 107 | } 108 | *f = uint8Value(v) 109 | return nil 110 | } 111 | 112 | func (f *uint16Value) Set(s string) error { 113 | v, err := strconv.ParseUint(s, 10, 16) 114 | if err != nil { 115 | return err 116 | } 117 | *f = uint16Value(v) 118 | return nil 119 | } 120 | 121 | func (f *uint32Value) Set(s string) error { 122 | v, err := strconv.ParseUint(s, 10, 32) 123 | if err != nil { 124 | return err 125 | } 126 | *f = uint32Value(v) 127 | return nil 128 | } 129 | 130 | // Getters for each of the types 131 | func (f *int8Value) Get() interface{} { return int8(*f) } 132 | func (f *int16Value) Get() interface{} { return int16(*f) } 133 | func (f *int32Value) Get() interface{} { return int32(*f) } 134 | func (f *f32Value) Get() interface{} { return float32(*f) } 135 | func (f *uint8Value) Get() interface{} { return uint8(*f) } 136 | func (f *uint16Value) Get() interface{} { return uint16(*f) } 137 | func (f *uint32Value) Get() interface{} { return uint32(*f) } 138 | 139 | // Stringers for each of the types 140 | func (f *int8Value) String() string { return fmt.Sprintf("%v", *f) } 141 | func (f *int16Value) String() string { return fmt.Sprintf("%v", *f) } 142 | func (f *int32Value) String() string { return fmt.Sprintf("%v", *f) } 143 | func (f *f32Value) String() string { return fmt.Sprintf("%v", *f) } 144 | func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f) } 145 | func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f) } 146 | func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f) } 147 | 148 | // string slice 149 | 150 | type strSlice struct { 151 | s *[]string 152 | set bool // if there a flag defined via command line, the slice will be cleared first. 153 | } 154 | 155 | func newStringSlice(p *[]string) *strSlice { 156 | return &strSlice{ 157 | s: p, 158 | set: false, 159 | } 160 | } 161 | 162 | func (s *strSlice) Set(str string) error { 163 | if !s.set { 164 | *s.s = (*s.s)[:0] 165 | s.set = true 166 | } 167 | *s.s = append(*s.s, str) 168 | return nil 169 | } 170 | 171 | func (s *strSlice) Get() interface{} { 172 | return []string(*s.s) 173 | } 174 | 175 | func (s *strSlice) String() string { 176 | return fmt.Sprintf("%v", *s.s) 177 | } 178 | 179 | // int slice 180 | type intSlice struct { 181 | s *[]int 182 | set bool 183 | } 184 | 185 | func newIntSlice(p *[]int) *intSlice { 186 | return &intSlice{ 187 | s: p, 188 | set: false, 189 | } 190 | } 191 | 192 | func (is *intSlice) Set(str string) error { 193 | i, err := strconv.Atoi(str) 194 | if err != nil { 195 | return err 196 | } 197 | if !is.set { 198 | *is.s = (*is.s)[:0] 199 | is.set = true 200 | } 201 | *is.s = append(*is.s, i) 202 | return nil 203 | } 204 | 205 | func (is *intSlice) Get() interface{} { 206 | return []int(*is.s) 207 | } 208 | 209 | func (is *intSlice) String() string { 210 | return fmt.Sprintf("%v", *is.s) 211 | } 212 | 213 | // float64 slice 214 | type float64Slice struct { 215 | s *[]float64 216 | set bool 217 | } 218 | 219 | func newFloat64Slice(p *[]float64) *float64Slice { 220 | return &float64Slice{ 221 | s: p, 222 | set: false, 223 | } 224 | } 225 | 226 | func (is *float64Slice) Set(str string) error { 227 | i, err := strconv.ParseFloat(str, 64) 228 | if err != nil { 229 | return err 230 | } 231 | if !is.set { 232 | *is.s = (*is.s)[:0] 233 | is.set = true 234 | } 235 | *is.s = append(*is.s, i) 236 | return nil 237 | } 238 | 239 | func (is *float64Slice) Get() interface{} { 240 | return []float64(*is.s) 241 | } 242 | 243 | func (is *float64Slice) String() string { 244 | return fmt.Sprintf("%v", *is.s) 245 | } 246 | -------------------------------------------------------------------------------- /flags.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Uber Technologies, Inc. 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy 4 | // of this software and associated documentation files (the "Software"), to deal 5 | // in the Software without restriction, including without limitation the rights 6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software is 8 | // furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included in 11 | // all copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | // THE SOFTWARE. 20 | 21 | // Package flags provides an interface for automatically creating command line 22 | // options from a struct. 23 | // 24 | // Typically, if one wants to load from a yaml, one has to define a proper 25 | // struct, then yaml.Unmarshal(), this is all good. However, there are 26 | // situations where we want to load most of the configs from the file but 27 | // overriding some configs. 28 | // 29 | // Let's say we use a yaml to config our Db connections and upon start of the 30 | // application we load from the yaml file to get the necessary parameters to 31 | // create the connection. Our base.yaml looks like this 32 | // 33 | // base.yaml 34 | // --- 35 | // mysql: 36 | // user: 'foo' 37 | // password: 'xxxxxx' 38 | // mysql_defaults_file: ./mysql_defaults.ini 39 | // mysql_socket_path: /var/run/mysqld/mysqld.sock 40 | // ... more config options ... 41 | // 42 | // we want to load all the configs from it but we want to provide some 43 | // flexibility for the program to connect via a different db user. We could 44 | // define a --user command flag then after loading the yaml file, we override 45 | // the user field with what we get from --user flag. 46 | // 47 | // If there are many overriding like this, manual define these flags is 48 | // tedious. This package provides an automatic way to define this override, 49 | // which is, given a struct, it'll create all the flags which are name using 50 | // the field names of the struct. If one of these flags are set via command 51 | // line, the struct will be modified in-place to reflect the value from command 52 | // line, therefore the values of the fields in the struct are overridden 53 | // 54 | // YAML is just used as an example here. In practice, one can use any struct 55 | // to define flags. 56 | // 57 | // Let's say we have our configration object as the following. 58 | // 59 | // type logging struct { 60 | // Interval int 61 | // Path string 62 | // } 63 | // 64 | // type socket struct { 65 | // ReadTimeout time.Duration 66 | // WriteTimeout time.Duration 67 | // } 68 | // 69 | // type tcp struct { 70 | // ReadTimeout time.Duration 71 | // socket 72 | // } 73 | // 74 | // type network struct { 75 | // ReadTimeout time.Duration 76 | // WriteTimeout time.Duration 77 | // tcp 78 | // } 79 | // 80 | // type Cfg struct { 81 | // logging 82 | // network 83 | // } 84 | // 85 | // The following code 86 | // 87 | // func main() { 88 | // c := &Cfg{} 89 | // flags.ParseArgs(c, os.Args[1:]) 90 | // } 91 | // 92 | // will create the following flags 93 | // 94 | // -logging.interval int 95 | // logging.interval 96 | // -logging.path string 97 | // logging.path 98 | // -network.readtimeout duration 99 | // network.readtimeout 100 | // -network.tcp.readtimeout duration 101 | // network.tcp.readtimeout 102 | // -network.tcp.socket.readtimeout duration 103 | // network.tcp.socket.readtimeout 104 | // -network.tcp.socket.writetimeout duration 105 | // network.tcp.socket.writetimeout 106 | // -network.writetimeout duration 107 | // network.writetimeout 108 | // 109 | // flags to subcommands are naturally suported. 110 | // 111 | // func main() { 112 | // cmd := os.Args[1] 113 | // switch cmd { 114 | // case "new" 115 | // c1 := &Cfg1{} 116 | // ParseArgs(c1, os.Args[2:]) 117 | // case "update": 118 | // c2 := &Cfg2{} 119 | // ParseArgs(c2, os.Args[2:]) 120 | // 121 | // ... more sub commands ... 122 | // } 123 | // } 124 | // 125 | // One can set Flatten to true when calling NewFlagMakerAdv, in which case, 126 | // flags are created without namespacing. For example, 127 | // 128 | // type auth struct { 129 | // Token string 130 | // Tag float64 131 | // } 132 | // 133 | // type credentials struct { 134 | // User string 135 | // Password string 136 | // auth 137 | // } 138 | // 139 | // type database struct { 140 | // DBName string 141 | // TableName string 142 | // credentials 143 | // } 144 | // 145 | // type Cfg struct { 146 | // logging 147 | // database 148 | // } 149 | // 150 | // func main() { 151 | // c := &Cfg{} 152 | // flags.ParseArgs(c, os.Args[1:]) 153 | // } 154 | // 155 | // will create the following flags 156 | // -dbname string 157 | // dbname 158 | // -interval int 159 | // interval 160 | // -password string 161 | // password 162 | // -path string 163 | // path 164 | // -tablename string 165 | // tablename 166 | // -tag float 167 | // tag 168 | // -token string 169 | // token 170 | // -user string 171 | // user 172 | // 173 | // Please be aware that usual GoLang flag creation rules apply, i.e., if there are 174 | // duplication in flag names (in the flattened case it's more likely to happen 175 | // unless the caller make due dilligence to create the struct properly), it panics. 176 | // 177 | // 178 | // Note that not all types can have command line flags created for. map, channel 179 | // and function type will not defien a flag corresponding to the field. Pointer 180 | // types are properly handled and slice type will create multi-value command 181 | // line flags. That is, e.g. if a field foo's type is []int, one can use 182 | // --foo 10 --foo 15 --foo 20 to override this field value to be 183 | // []int{10, 15, 20}. For now, only []int, []string and []float64 are supported 184 | // in this fashion. 185 | package flags 186 | 187 | import ( 188 | "flag" 189 | "fmt" 190 | "reflect" 191 | "strings" 192 | "time" 193 | ) 194 | 195 | // FlagMakingOptions control the way FlagMaker's behavior when defining flags. 196 | type FlagMakingOptions struct { 197 | // Use lower case flag names rather than the field name/tag name directly. 198 | UseLowerCase bool 199 | // Create flags in namespaced fashion 200 | Flatten bool 201 | // If there is a struct tag named 'TagName', use its value as the flag name. 202 | // The purpose is that, for yaml/json parsing we often have something like 203 | // Foobar string `yaml:"host_name"`, in which case the flag will be named 204 | // 'host_name' rather than 'foobar'. 205 | TagName string 206 | } 207 | 208 | // FlagMaker enumerate all the exported fields of a struct recursively 209 | // and create corresponding command line flags. For anonymous fields, 210 | // they are only enumerated if they are pointers to structs. 211 | // Usual GoLang flag rules apply, e.g. duplicated flag names leads to 212 | // panic. 213 | type FlagMaker struct { 214 | opts *FlagMakingOptions 215 | // We don't consume os.Args directly unless told to. 216 | fs *flag.FlagSet 217 | } 218 | 219 | // NewFlagMaker creates a default FlagMaker which creates namespaced flags 220 | func NewFlagMaker() *FlagMaker { 221 | return NewFlagMakerAdv(&FlagMakingOptions{ 222 | UseLowerCase: true, 223 | Flatten: false, 224 | TagName: "yaml"}) 225 | } 226 | 227 | // NewFlagMakerAdv gives full control to create flags. 228 | func NewFlagMakerAdv(options *FlagMakingOptions) *FlagMaker { 229 | return &FlagMaker{ 230 | opts: options, 231 | fs: flag.NewFlagSet("xFlags", flag.ContinueOnError), 232 | } 233 | } 234 | 235 | // ParseArgs parses the string arguments which should not contain the program name. 236 | // 237 | // obj is the struct to populate. args are the command line arguments, 238 | // typically obtained from os.Args. 239 | func ParseArgs(obj interface{}, args []string) ([]string, error) { 240 | fm := NewFlagMaker() 241 | return fm.ParseArgs(obj, args) 242 | } 243 | 244 | // PrintDefaults prints the default value and type of defined flags. 245 | // It just calls the standard 'flag' package's PrintDefaults. 246 | func (fm *FlagMaker) PrintDefaults() { 247 | fm.fs.PrintDefaults() 248 | } 249 | 250 | // ParseArgs parses the arguments based on the FlagMaker's setting. 251 | func (fm *FlagMaker) ParseArgs(obj interface{}, args []string) ([]string, error) { 252 | v := reflect.ValueOf(obj) 253 | if v.Kind() != reflect.Ptr { 254 | return args, fmt.Errorf("top level object must be a pointer. %v is passed", v.Type()) 255 | } 256 | if v.IsNil() { 257 | return args, fmt.Errorf("top level object cannot be nil") 258 | } 259 | 260 | switch e := v.Elem(); e.Kind() { 261 | case reflect.Struct: 262 | fm.enumerateAndCreate("", e) 263 | case reflect.Interface: 264 | if e.Elem().Kind() == reflect.Ptr { 265 | fm.enumerateAndCreate("", e) 266 | } else { 267 | return args, fmt.Errorf("interface must have pointer underlying type. %v is passed", v.Type()) 268 | } 269 | default: 270 | return args, fmt.Errorf("object must be a pointer to struct or interface. %v is passed", v.Type()) 271 | } 272 | 273 | err := fm.fs.Parse(args) 274 | return fm.fs.Args(), err 275 | } 276 | 277 | func (fm *FlagMaker) enumerateAndCreate(prefix string, value reflect.Value) { 278 | switch value.Kind() { 279 | case 280 | // do no create flag for these types 281 | reflect.Map, 282 | reflect.Uintptr, 283 | reflect.UnsafePointer, 284 | reflect.Array, 285 | reflect.Chan, 286 | reflect.Func: 287 | return 288 | case reflect.Slice: 289 | // only support slice of strings, ints and float64s 290 | switch value.Type().Elem().Kind() { 291 | case reflect.String: 292 | fm.defineStringSlice(prefix, value) 293 | case reflect.Int: 294 | fm.defineIntSlice(prefix, value) 295 | case reflect.Float64: 296 | fm.defineFloat64Slice(prefix, value) 297 | } 298 | return 299 | case 300 | // Basic value types 301 | reflect.String, 302 | reflect.Bool, 303 | reflect.Float32, reflect.Float64, 304 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 305 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 306 | fm.defineFlag(prefix, value) 307 | return 308 | case reflect.Interface: 309 | if !value.IsNil() { 310 | fm.enumerateAndCreate(prefix, value.Elem()) 311 | } 312 | return 313 | case reflect.Ptr: 314 | if value.IsNil() { 315 | value.Set(reflect.New(value.Type().Elem())) 316 | } 317 | fm.enumerateAndCreate(prefix, value.Elem()) 318 | return 319 | case reflect.Struct: 320 | // keep going 321 | default: 322 | panic(fmt.Sprintf("unknown reflected kind %v", value.Kind())) 323 | } 324 | 325 | numFields := value.NumField() 326 | tt := value.Type() 327 | 328 | for i := 0; i < numFields; i++ { 329 | stField := tt.Field(i) 330 | // Skip unexported fields, as only exported fields can be set. This is similar to how json and yaml work. 331 | if stField.PkgPath != "" && !stField.Anonymous { 332 | continue 333 | } 334 | if stField.Anonymous && fm.getUnderlyingType(stField.Type).Kind() != reflect.Struct { 335 | continue 336 | } 337 | field := value.Field(i) 338 | optName := fm.getName(stField) 339 | if len(prefix) > 0 && !fm.opts.Flatten { 340 | optName = prefix + "." + optName 341 | } 342 | fm.enumerateAndCreate(optName, field) 343 | } 344 | } 345 | 346 | func (fm *FlagMaker) getName(field reflect.StructField) string { 347 | name := field.Tag.Get(fm.opts.TagName) 348 | if len(name) == 0 { 349 | if field.Anonymous { 350 | name = fm.getUnderlyingType(field.Type).Name() 351 | } else { 352 | name = field.Name 353 | } 354 | } 355 | if fm.opts.UseLowerCase { 356 | return strings.ToLower(name) 357 | } 358 | return name 359 | } 360 | 361 | func (fm *FlagMaker) getUnderlyingType(ttype reflect.Type) reflect.Type { 362 | // this only deals with *T unnamed type, other unnamed types, e.g. []int, struct{} 363 | // will return empty string. 364 | if ttype.Kind() == reflect.Ptr { 365 | return fm.getUnderlyingType(ttype.Elem()) 366 | } 367 | return ttype 368 | } 369 | 370 | // Each object has its type (which prescribes the possible operations/methods 371 | // could be invoked; it also has an underlying 'kind', int, float, struct etc. 372 | // Since user can freely define types, one 'kind' of object may correpond to 373 | // many types. We cannot do type assertion because types of same kind are still 374 | // different types. Instead, we convert to the primitive types that corresponds 375 | // to the kinds and create flag vars. One thing to know is that, the whole point 376 | // of defineFlag() method is to define flag.Vars that points to certain field 377 | // of the struct so that command line values can modify the struct. We cannot 378 | // define a flag var pointing to arbitrary 'free' varible. 379 | 380 | // I wish GoLang had macro... 381 | var ( 382 | stringPtrType = reflect.TypeOf((*string)(nil)) 383 | boolPtrType = reflect.TypeOf((*bool)(nil)) 384 | float32PtrType = reflect.TypeOf((*float32)(nil)) 385 | float64PtrType = reflect.TypeOf((*float64)(nil)) 386 | intPtrType = reflect.TypeOf((*int)(nil)) 387 | int8PtrType = reflect.TypeOf((*int8)(nil)) 388 | int16PtrType = reflect.TypeOf((*int16)(nil)) 389 | int32PtrType = reflect.TypeOf((*int32)(nil)) 390 | int64PtrType = reflect.TypeOf((*int64)(nil)) 391 | uintPtrType = reflect.TypeOf((*uint)(nil)) 392 | uint8PtrType = reflect.TypeOf((*uint8)(nil)) 393 | uint16PtrType = reflect.TypeOf((*uint16)(nil)) 394 | uint32PtrType = reflect.TypeOf((*uint32)(nil)) 395 | uint64PtrType = reflect.TypeOf((*uint64)(nil)) 396 | ) 397 | 398 | func (fm *FlagMaker) defineFlag(name string, value reflect.Value) { 399 | // v must be scalar, otherwise panic 400 | ptrValue := value.Addr() 401 | switch value.Kind() { 402 | case reflect.String: 403 | v := ptrValue.Convert(stringPtrType).Interface().(*string) 404 | fm.fs.StringVar(v, name, value.String(), name) 405 | case reflect.Bool: 406 | v := ptrValue.Convert(boolPtrType).Interface().(*bool) 407 | fm.fs.BoolVar(v, name, value.Bool(), name) 408 | case reflect.Int: 409 | v := ptrValue.Convert(intPtrType).Interface().(*int) 410 | fm.fs.IntVar(v, name, int(value.Int()), name) 411 | case reflect.Int8: 412 | v := ptrValue.Convert(int8PtrType).Interface().(*int8) 413 | fm.fs.Var(newInt8Value(v), name, name) 414 | case reflect.Int16: 415 | v := ptrValue.Convert(int16PtrType).Interface().(*int16) 416 | fm.fs.Var(newInt16Value(v), name, name) 417 | case reflect.Int32: 418 | v := ptrValue.Convert(int32PtrType).Interface().(*int32) 419 | fm.fs.Var(newInt32Value(v), name, name) 420 | case reflect.Int64: 421 | switch v := ptrValue.Interface().(type) { 422 | case *int64: 423 | fm.fs.Int64Var(v, name, value.Int(), name) 424 | case *time.Duration: 425 | fm.fs.DurationVar(v, name, value.Interface().(time.Duration), name) 426 | default: 427 | // (TODO) if one type defines time.Duration, we'll create a int64 flag for it. 428 | // Find some acceptible way to deal with it. 429 | vv := ptrValue.Convert(int64PtrType).Interface().(*int64) 430 | fm.fs.Int64Var(vv, name, value.Int(), name) 431 | } 432 | case reflect.Float32: 433 | v := ptrValue.Convert(float32PtrType).Interface().(*float32) 434 | fm.fs.Var(newFloat32Value(v), name, name) 435 | case reflect.Float64: 436 | v := ptrValue.Convert(float64PtrType).Interface().(*float64) 437 | fm.fs.Float64Var(v, name, value.Float(), name) 438 | case reflect.Uint: 439 | v := ptrValue.Convert(uintPtrType).Interface().(*uint) 440 | fm.fs.UintVar(v, name, uint(value.Uint()), name) 441 | case reflect.Uint8: 442 | v := ptrValue.Convert(uint8PtrType).Interface().(*uint8) 443 | fm.fs.Var(newUint8Value(v), name, name) 444 | case reflect.Uint16: 445 | v := ptrValue.Convert(uint16PtrType).Interface().(*uint16) 446 | fm.fs.Var(newUint16Value(v), name, name) 447 | case reflect.Uint32: 448 | v := ptrValue.Convert(uint32PtrType).Interface().(*uint32) 449 | fm.fs.Var(newUint32Value(v), name, name) 450 | case reflect.Uint64: 451 | v := ptrValue.Convert(uint64PtrType).Interface().(*uint64) 452 | fm.fs.Uint64Var(v, name, value.Uint(), name) 453 | } 454 | } 455 | 456 | func (fm *FlagMaker) defineStringSlice(name string, value reflect.Value) { 457 | ptrValue := value.Addr().Interface().(*[]string) 458 | fm.fs.Var(newStringSlice(ptrValue), name, name) 459 | } 460 | 461 | func (fm *FlagMaker) defineIntSlice(name string, value reflect.Value) { 462 | ptrValue := value.Addr().Interface().(*[]int) 463 | fm.fs.Var(newIntSlice(ptrValue), name, name) 464 | } 465 | 466 | func (fm *FlagMaker) defineFloat64Slice(name string, value reflect.Value) { 467 | ptrValue := value.Addr().Interface().(*[]float64) 468 | fm.fs.Var(newFloat64Slice(ptrValue), name, name) 469 | } 470 | -------------------------------------------------------------------------------- /flags_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Uber Technologies, Inc. 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy 4 | // of this software and associated documentation files (the "Software"), to deal 5 | // in the Software without restriction, including without limitation the rights 6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software is 8 | // furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included in 11 | // all copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | // THE SOFTWARE. 20 | 21 | package flags 22 | 23 | import ( 24 | "flag" 25 | "testing" 26 | "time" 27 | 28 | "github.com/stretchr/testify/assert" 29 | ) 30 | 31 | type logging struct { 32 | Interval int 33 | Path string 34 | } 35 | 36 | type socket struct { 37 | ReadTimeout time.Duration 38 | WriteTimeout time.Duration 39 | } 40 | 41 | type tcp struct { 42 | ReadTimeout time.Duration 43 | socket 44 | } 45 | 46 | type network struct { 47 | ReadTimeout time.Duration 48 | WriteTimeout time.Duration 49 | tcp 50 | } 51 | 52 | type Cfg1 struct { 53 | logging 54 | network 55 | } 56 | 57 | func TestFlagMakerExample(t *testing.T) { 58 | cfg := Cfg1{} 59 | 60 | args := []string{ 61 | "--network.tcp.socket.readtimeout", "5ms", 62 | "--network.tcp.readtimeout", "3ms", 63 | "-logging.path", "/var/log", 64 | } 65 | args, err := ParseArgs(cfg, args) 66 | assert.False(t, err == nil) 67 | args, err = ParseArgs(&cfg, args) 68 | assert.True(t, err == nil) 69 | assert.Equal(t, 0, len(args)) 70 | 71 | expected := Cfg1{ 72 | network: network{ 73 | tcp: tcp{ 74 | ReadTimeout: time.Duration(3) * time.Millisecond, 75 | socket: socket{ 76 | ReadTimeout: time.Duration(5) * time.Millisecond, 77 | }, 78 | }, 79 | }, 80 | logging: logging{ 81 | Path: "/var/log", 82 | }, 83 | } 84 | assert.Equal(t, expected, cfg) 85 | } 86 | 87 | type auth struct { 88 | Token string 89 | Tag float64 90 | } 91 | 92 | type credentials struct { 93 | User string 94 | Password string 95 | auth 96 | } 97 | 98 | type database struct { 99 | DBName string 100 | TableName string 101 | credentials 102 | } 103 | 104 | type Cfg2 struct { 105 | logging 106 | database 107 | *string 108 | } 109 | 110 | func TestFlagMakerExampleFlattened(t *testing.T) { 111 | cfg := Cfg2{} 112 | 113 | args := []string{ 114 | "--dbname", "db1", 115 | "--token", "abcd", 116 | "-tag=3.14", 117 | "-path", "/var/log", 118 | } 119 | 120 | fm := NewFlagMakerAdv(&FlagMakingOptions{true, true, "not-care"}) 121 | args, err := fm.ParseArgs(&cfg, args) 122 | 123 | assert.True(t, err == nil) 124 | assert.Equal(t, 0, len(args)) 125 | 126 | expected := Cfg2{} 127 | expected.Tag = 3.14 128 | expected.DBName = "db1" 129 | expected.Token = "abcd" 130 | expected.Path = "/var/log" 131 | 132 | assert.Equal(t, expected, cfg) 133 | } 134 | 135 | type C4 struct { 136 | TableName string 137 | } 138 | 139 | type C3 struct { 140 | DBName string 141 | C4 142 | } 143 | 144 | type C2 struct { 145 | User string 146 | Password int64 147 | Tag int8 148 | C3 149 | } 150 | 151 | type C1 struct { 152 | Name string `yaml:"label"` 153 | Value int 154 | Float float64 155 | Timeout time.Duration 156 | Hosts []string 157 | Ports []int 158 | Weights []float64 159 | Credential C2 160 | 161 | // some unexported fields 162 | opentimeout time.Duration 163 | localhost string 164 | } 165 | 166 | func TestFlagMakerBasic(t *testing.T) { 167 | c := &C1{ 168 | Name: "basic", 169 | Value: 10, 170 | Float: 7.4, 171 | Timeout: time.Duration(10) * time.Millisecond, 172 | Hosts: []string{"host1", "host2"}, 173 | Ports: []int{89, 90}, 174 | Credential: C2{ 175 | User: "user", 176 | Password: 1234, 177 | Tag: 20, 178 | C3: C3{ 179 | DBName: "db1", 180 | C4: C4{ 181 | TableName: "t1", 182 | }, 183 | }, 184 | }, 185 | opentimeout: time.Duration(3) * time.Microsecond, 186 | localhost: "weird.host", 187 | } 188 | args := []string{ 189 | "--label", "advanced", "-float", "5.1", "-timeout", "5ms", "--ports", "22", "--ports", "43", 190 | "--credential.user", "uber", "--credential.tag", "80", "--credential.c3.dbname", "db2", 191 | "--credential.c3.c4.tablename", "t2"} 192 | args, err := ParseArgs(c, args) 193 | assert.Equal(t, nil, err, "should be no error") 194 | assert.Equal(t, 0, len(args), "should be no arg left") 195 | 196 | expected := *c 197 | // only these fields should be modified by parsing the arguments. 198 | expected.Credential.User = "uber" 199 | expected.Credential.Password = int64(1234) 200 | expected.Credential.Tag = int8(80) 201 | expected.Name = "advanced" 202 | expected.Credential.C3.DBName = "db2" 203 | expected.Credential.C3.TableName = "t2" 204 | expected.Ports = []int{22, 43} 205 | 206 | assert.Equal(t, &expected, c) 207 | } 208 | 209 | type CTypes struct { 210 | Strval string 211 | Bval bool 212 | F32val float32 213 | F64val float64 214 | Ival int 215 | I8val int8 216 | I16val int16 217 | I32val int32 218 | I64val int64 219 | UIval uint 220 | UI8val uint8 221 | UI16val uint16 222 | UI32val uint32 223 | UI64val uint64 224 | } 225 | 226 | func TestFlagMakerTypes(t *testing.T) { 227 | /* Check all of the types */ 228 | refCtypes := &CTypes{ 229 | Strval: "string value", 230 | Bval: true, 231 | F32val: 3.1415927, // <- Max PI for 32 bit float 232 | F64val: 3.141592653589793, // <- Max PI for 64 bit float 233 | /* The rest of these use the highest value for the type */ 234 | Ival: int(0x7fffffffffffffff), 235 | I8val: int8(0x7f), 236 | I16val: int16(0x7fff), 237 | I32val: int32(0x7fffffff), 238 | I64val: int64(0x7fffffffffffffff), 239 | UIval: uint(0xffffffffffffffff), 240 | UI8val: uint8(0xff), 241 | UI16val: uint16(0xffff), 242 | UI32val: uint32(0xffffffff), 243 | UI64val: uint64(0xffffffffffffffff), 244 | } 245 | parseCtypes := &CTypes{} 246 | args := []string{ 247 | "-strval", "string value", "--bval", 248 | "-f32val", "3.1415927", "--f64val", "3.141592653589793", 249 | "--ival", "9223372036854775807", "--i8val", "127", "--i16val", "32767", 250 | "-i32val", "2147483647", "--i64val", "9223372036854775807", 251 | "--uival", "18446744073709551615", "--ui8val", "255", "--ui16val", "65535", 252 | "-ui32val", "4294967295", "--ui64val", "18446744073709551615"} 253 | args, err := ParseArgs(parseCtypes, args) 254 | assert.Equal(t, nil, err, "should be no error") 255 | assert.Equal(t, parseCtypes, refCtypes) 256 | } 257 | 258 | type D1 struct { 259 | F1 ****string 260 | F2 ***[]int 261 | } 262 | 263 | type D2 struct { 264 | F1 **[]float64 265 | F2 *****bool 266 | D1 267 | } 268 | 269 | type D3 struct { 270 | D2 271 | F3 uint 272 | F4 int64 273 | Hosts []string 274 | } 275 | 276 | type DD struct { 277 | D1 278 | D2 279 | D3 280 | } 281 | 282 | func TestFlagMakerComplex(t *testing.T) { 283 | d := DD{} 284 | args := []string{"-d2.f1", "1.2", "-d3.d2.d1.f2", "45", "-d2.f2", "-d2.f1", "4.2", "-d3.d2.d1.f2", "56", "-d2.f1", "7.4", "-d3.d2.d1.f2", "78"} 285 | args, err := ParseArgs(&d, args) 286 | assert.Equal(t, nil, err, "unexpected error") 287 | assert.Equal(t, 0, len(args)) 288 | assert.Equal(t, true, *****d.D2.F2) 289 | assert.Equal(t, []int{45, 56, 78}, ***d.D3.D2.D1.F2) 290 | assert.Equal(t, []float64{1.2, 4.2, 7.4}, **d.D2.F1) 291 | } 292 | 293 | type I1 interface { 294 | Method1() string 295 | } 296 | 297 | type S1 struct { 298 | Host string 299 | ignore int 300 | Weights []float64 301 | F int8 302 | } 303 | 304 | func (s *S1) Method1() string { 305 | return s.Host 306 | } 307 | 308 | type S2 struct { 309 | Open bool 310 | Volume float64 311 | } 312 | 313 | func (s S2) Method1() string { return "haha" } 314 | 315 | func TestFlagMakerInterface(t *testing.T) { 316 | var s I1 = &S1{ignore: 12} 317 | args := []string{ 318 | "-host", "test.local", "-f", "16", 319 | } 320 | args, err := ParseArgs(s, args) 321 | assert.Equal(t, nil, err) 322 | expected := S1{ 323 | Host: "test.local", 324 | F: int8(16), 325 | ignore: 12, 326 | } 327 | assert.Equal(t, "test.local", s.Method1()) 328 | assert.Equal(t, &expected, s) 329 | } 330 | 331 | func TestFlagMakerPtrToIntf(t *testing.T) { 332 | s := &S1{} 333 | var i2 I1 = s 334 | args := []string{"--weights", "9.3", "--host", "www", "--weights", "10.0"} 335 | out, err := ParseArgs(&i2, args) 336 | assert.Nil(t, err) 337 | assert.Equal(t, 0, len(out)) 338 | assert.Equal(t, "www", s.Host) 339 | assert.Equal(t, []float64{9.3, 10.0}, s.Weights) 340 | 341 | s2 := S2{} 342 | var i3 I1 = s2 343 | args = []string{"--open", "--volume", "9.3"} 344 | out, err = ParseArgs(&i3, args) 345 | assert.Error(t, err) 346 | assert.Contains(t, err.Error(), "interface must have pointer underlying type.") 347 | assert.Equal(t, 3, len(out)) 348 | } 349 | 350 | type Cfg3 struct { 351 | D3 352 | *D2 353 | } 354 | 355 | func TestFlagMakerNested(t *testing.T) { 356 | cfg := Cfg3{} 357 | args := []string{"-d3.hosts", "h1.com", "-d2.f1", "1.2", "-d3.hosts", "h2.com", "-d2.f1", "4.2", "-d3.hosts", "h3.com", "-d2.f1", "7.4"} 358 | args, err := ParseArgs(&cfg, args) 359 | assert.Equal(t, nil, err) 360 | assert.Equal(t, 0, len(args)) 361 | 362 | assert.Equal(t, []string{"h1.com", "h2.com", "h3.com"}, cfg.Hosts) 363 | assert.Equal(t, []float64{1.2, 4.2, 7.4}, **cfg.F1) 364 | } 365 | 366 | type Cfg4 struct { 367 | Name *string 368 | *string 369 | int 370 | } 371 | 372 | func TestFlagMakerUnnamedFields(t *testing.T) { 373 | c := Cfg4{int: 4} 374 | args := []string{"--name=haha"} 375 | args, err := ParseArgs(&c, args) 376 | assert.Equal(t, nil, err) 377 | assert.Equal(t, 0, len(args)) 378 | ss := "haha" 379 | expected := Cfg4{Name: &ss, int: 4} 380 | assert.Equal(t, expected, c) 381 | } 382 | 383 | // The following test ensures that we can properly create flags for user 384 | // defined non-struct types. The kind of an object and the type of an 385 | // object is different. See the comments of defineFlag(). 386 | type String string 387 | type Int int 388 | type Int8 int8 389 | type Int32 int32 390 | type Int16 int16 391 | type Int64 int64 392 | type Float float64 393 | type Float32 float32 394 | type Uint uint 395 | type Uint8 uint8 396 | type Uint16 uint16 397 | type Uint32 uint32 398 | type Uint64 uint64 399 | type Bool bool 400 | type PString *String 401 | type PInt *Int 402 | type PInt8 *Int8 403 | type PInt16 *Int16 404 | type PInt32 *Int32 405 | type PInt64 *Int64 406 | type PFloat *Float 407 | type PUint *Uint 408 | type PUint8 *Uint8 409 | type PUint16 *Uint16 410 | type PUint32 *Uint32 411 | type PUint64 *Uint64 412 | type PBool *bool 413 | 414 | type Cfg5 struct { 415 | S String 416 | PS PString 417 | I Int 418 | PI PInt 419 | I8 Int8 420 | PI8 PInt8 421 | I16 Int16 422 | PI16 PInt16 423 | I32 Int32 424 | PI32 PInt32 425 | I64 Int64 426 | PI64 PInt64 427 | F Float 428 | PF PFloat 429 | U Uint 430 | U8 Uint8 431 | PU8 PUint8 432 | U16 Uint16 433 | PU16 PUint16 434 | U32 Uint32 435 | PU32 PUint32 436 | PU PUint 437 | U64 Uint64 438 | PU64 PUint64 439 | B Bool 440 | PB PBool 441 | F32 Float32 442 | PF32 *Float32 443 | } 444 | 445 | func TestFlagMakerTypeDef(t *testing.T) { 446 | cfg := &Cfg5{} 447 | args := []string{"--s", "hehe", "--ps", "good", 448 | "-i", "33", "--pi", "44", 449 | "--i64", "55", "--pi64", "66", 450 | "--f", "5.7", "--pf", "6.7", 451 | "--u", "10", "--pu", "20", 452 | "--u64", "30", "--pu64", "40", 453 | "--i8", "20", "--pi8", "10", 454 | "--i16", "400", "-pi16", "500", 455 | "--i32", "600", "-pi32", "700", 456 | "--u8", "20", "--pu8", "10", 457 | "--u16", "400", "-pu16", "500", 458 | "--u32", "600", "-pu32", "700", 459 | "--f32", "10.1", "-pf32", "11.1", 460 | "--b=true", "--pb"} 461 | args, err := ParseArgs(cfg, args) 462 | 463 | assert.Equal(t, nil, err) 464 | assert.Equal(t, 0, len(args)) 465 | 466 | assert.Equal(t, "hehe", string(cfg.S)) 467 | assert.Equal(t, "good", string(*cfg.PS)) 468 | 469 | assert.Equal(t, 33, int(cfg.I)) 470 | assert.Equal(t, 44, int(*cfg.PI)) 471 | 472 | assert.Equal(t, int8(20), int8(cfg.I8)) 473 | assert.Equal(t, int8(10), int8(*cfg.PI8)) 474 | 475 | assert.Equal(t, int16(400), int16(cfg.I16)) 476 | assert.Equal(t, int16(500), int16(*cfg.PI16)) 477 | 478 | assert.Equal(t, int32(600), int32(cfg.I32)) 479 | assert.Equal(t, int32(700), int32(*cfg.PI32)) 480 | 481 | assert.Equal(t, int64(55), int64(cfg.I64)) 482 | assert.Equal(t, int64(66), int64(*cfg.PI64)) 483 | 484 | assert.Equal(t, 5.7, float64(cfg.F)) 485 | assert.Equal(t, 6.7, float64(*cfg.PF)) 486 | 487 | assert.Equal(t, float32(10.1), float32(cfg.F32)) 488 | assert.Equal(t, float32(11.1), float32(*cfg.PF32)) 489 | 490 | assert.Equal(t, uint(10), uint(cfg.U)) 491 | assert.Equal(t, uint(20), uint(*cfg.PU)) 492 | 493 | assert.Equal(t, uint8(20), uint8(cfg.U8)) 494 | assert.Equal(t, uint8(10), uint8(*cfg.PU8)) 495 | 496 | assert.Equal(t, uint16(400), uint16(cfg.U16)) 497 | assert.Equal(t, uint16(500), uint16(*cfg.PU16)) 498 | 499 | assert.Equal(t, uint32(600), uint32(cfg.U32)) 500 | assert.Equal(t, uint32(700), uint32(*cfg.PU32)) 501 | 502 | assert.Equal(t, uint64(30), uint64(cfg.U64)) 503 | assert.Equal(t, uint64(40), uint64(*cfg.PU64)) 504 | 505 | assert.Equal(t, true, bool(cfg.B)) 506 | assert.Equal(t, true, bool(*cfg.PB)) 507 | } 508 | 509 | func TestFlagMakerInvalidInput(t *testing.T) { 510 | var cfg *Cfg5 511 | args := []string{"--s", "hehe", "--ps", "good", 512 | "-i", "33", "--pi", "44", 513 | "--i64", "55", "--pi64", "66", 514 | } 515 | out, err := ParseArgs(cfg, args) 516 | assert.Error(t, err) 517 | assert.Equal(t, "top level object cannot be nil", err.Error()) 518 | assert.Equal(t, len(args), len(out)) 519 | 520 | var cfg2 Cfg4 521 | out, err = ParseArgs(cfg2, args) 522 | assert.Error(t, err) 523 | assert.Contains(t, err.Error(), "top level object must be a pointer") 524 | assert.Equal(t, len(args), len(out)) 525 | } 526 | 527 | func TestFlagMakerUnsupportedTypes(t *testing.T) { 528 | cases := []struct { 529 | cfg interface{} 530 | args []string 531 | }{ 532 | {&struct { 533 | Env map[string]string 534 | Level int 535 | }{}, []string{"--level", "10", "--env", "hh,fgg,10"}}, 536 | {&struct { 537 | Env chan int 538 | Level int 539 | }{}, []string{"--level", "10", "--env", "hh,fgg,10"}}, 540 | {&struct { 541 | Env func(int) string 542 | Level int 543 | }{}, []string{"--level", "10", "--env", "hh,fgg,10"}}, 544 | } 545 | 546 | for _, c := range cases { 547 | out, err := ParseArgs(c.cfg, c.args) 548 | assert.Error(t, err) 549 | assert.Contains(t, err.Error(), "flag provided but not defined") 550 | assert.Equal(t, 1, len(out)) 551 | } 552 | } 553 | 554 | func TestFlagMakerInvalidValue(t *testing.T) { 555 | cases := []struct { 556 | cfg interface{} 557 | args []string 558 | }{ 559 | {&struct{ Level int }{}, []string{"--level", "haha"}}, 560 | {&struct{ Level int8 }{}, []string{"--level", "haha"}}, 561 | {&struct{ Level int16 }{}, []string{"--level", "haha"}}, 562 | {&struct{ Level int32 }{}, []string{"--level", "haha"}}, 563 | {&struct{ Level int64 }{}, []string{"--level", "haha"}}, 564 | {&struct{ Level uint8 }{}, []string{"--level", "haha"}}, 565 | {&struct{ Level uint16 }{}, []string{"--level", "haha"}}, 566 | {&struct{ Level uint32 }{}, []string{"--level", "haha"}}, 567 | {&struct{ Level uint64 }{}, []string{"--level", "haha"}}, 568 | {&struct{ Level float32 }{}, []string{"--level", "haha"}}, 569 | {&struct{ Level float64 }{}, []string{"--level", "haha"}}, 570 | } 571 | 572 | for _, c := range cases { 573 | out, err := ParseArgs(c.cfg, c.args) 574 | assert.Error(t, err) 575 | assert.Contains(t, err.Error(), "invalid value") 576 | // args are consumed even thought the value is invalid 577 | assert.Equal(t, 0, len(out)) 578 | } 579 | } 580 | 581 | // slice 582 | 583 | func TestFlagMakerStringSlice(t *testing.T) { 584 | type C struct { 585 | Hosts []string 586 | } 587 | cases := []struct { 588 | cfg *C 589 | args, expected []string 590 | }{ 591 | {&C{}, []string{"--hosts", "h1", "--hosts", "h2", "--hosts", "h3"}, []string{"h1", "h2", "h3"}}, 592 | {&C{[]string{}}, []string{"--hosts", "h1", "--hosts", "h2", "--hosts", "h3"}, []string{"h1", "h2", "h3"}}, 593 | {&C{}, []string{}, nil}, 594 | {&C{[]string{}}, []string{}, []string{}}, 595 | {&C{[]string{"l1", "l2"}}, []string{}, []string{"l1", "l2"}}, 596 | {&C{[]string{"l1", "l2"}}, []string{"--hosts", "ok"}, []string{"ok"}}, 597 | } 598 | for _, c := range cases { 599 | args, err := ParseArgs(c.cfg, c.args) 600 | assert.Nil(t, err) 601 | assert.Equal(t, 0, len(args)) 602 | assert.Equal(t, c.expected, c.cfg.Hosts) 603 | } 604 | } 605 | 606 | func TestFlagMakerIntSlice(t *testing.T) { 607 | type C struct { 608 | Levels []int 609 | } 610 | 611 | cases := []struct { 612 | cfg *C 613 | args []string 614 | expected []int 615 | }{ 616 | {&C{}, []string{"--levels", "8", "--levels", "9", "--levels", "10"}, []int{8, 9, 10}}, 617 | {&C{[]int{}}, []string{"--levels", "8", "--levels", "9", "--levels", "10"}, []int{8, 9, 10}}, 618 | {&C{}, []string{}, nil}, 619 | {&C{[]int{}}, []string{}, []int{}}, 620 | {&C{[]int{11, 12}}, []string{}, []int{11, 12}}, 621 | {&C{[]int{11, 12}}, []string{"--levels", "5"}, []int{5}}, 622 | } 623 | for _, c := range cases { 624 | args, err := ParseArgs(c.cfg, c.args) 625 | assert.Nil(t, err) 626 | assert.Equal(t, 0, len(args)) 627 | assert.Equal(t, c.expected, c.cfg.Levels) 628 | } 629 | } 630 | 631 | func TestFlagMakerFloatSlice(t *testing.T) { 632 | type C struct { 633 | Levels []float64 634 | } 635 | cases := []struct { 636 | cfg *C 637 | args []string 638 | expected []float64 639 | }{ 640 | {&C{}, []string{"--levels", "8.9", "--levels", "9.9", "--levels", "10.9"}, []float64{8.9, 9.9, 10.9}}, 641 | {&C{[]float64{}}, []string{"--levels", "8.9", "--levels", "9.9", "--levels", "10.9"}, []float64{8.9, 9.9, 10.9}}, 642 | {&C{}, []string{}, nil}, 643 | {&C{[]float64{}}, []string{}, []float64{}}, 644 | {&C{[]float64{11.3, 12.3}}, []string{}, []float64{11.3, 12.3}}, 645 | {&C{[]float64{11.3, 12.3}}, []string{"--levels", "5.1"}, []float64{5.1}}, 646 | } 647 | for _, c := range cases { 648 | args, err := ParseArgs(c.cfg, c.args) 649 | assert.Nil(t, err) 650 | assert.Equal(t, 0, len(args)) 651 | assert.Equal(t, c.expected, c.cfg.Levels) 652 | } 653 | } 654 | 655 | func TestFlagMakerInvalidSlice(t *testing.T) { 656 | type C struct { 657 | Levels []int 658 | Weights []float64 659 | } 660 | cases := []struct { 661 | cfg *C 662 | args []string 663 | l []int 664 | w []float64 665 | }{ 666 | { 667 | // invalid flag values won't modify the struct 668 | &C{Levels: []int{2, 3}, Weights: []float64{2.4, 5.6}}, 669 | []string{"--levels", "7ax", "--levels", "10", "--weights", "u8.2"}, 670 | []int{2, 3}, 671 | []float64{2.4, 5.6}, 672 | }, 673 | { 674 | // however, valid values before invalid flag values WILL clear slice 675 | &C{Levels: []int{2, 3}, Weights: []float64{2.4, 5.6}}, 676 | []string{"--weights", "1.1", "--levels", "10", "--weights", "u8.2", "--levels", "abc"}, 677 | []int{10}, 678 | []float64{1.1}, 679 | }, 680 | } 681 | 682 | for _, c := range cases { 683 | _, err := ParseArgs(c.cfg, c.args) 684 | assert.Error(t, err) 685 | assert.Equal(t, c.l, c.cfg.Levels) 686 | assert.Equal(t, c.w, c.cfg.Weights) 687 | } 688 | } 689 | 690 | func TestFlagMakerVarGet(t *testing.T) { 691 | var i8 int8 = 3 692 | var i16 int16 = 4 693 | var i32 int32 = 5 694 | var f32 float32 = 10.9 695 | var u8 uint8 = 22 696 | var u16 uint16 = 30 697 | var u32 uint32 = 55 698 | is := []int{1, 40, 30} 699 | ss := []string{"haha", "xx"} 700 | fs := []float64{242.66, 7565.23, 234.67} 701 | cases := []struct { 702 | getter flag.Getter 703 | expected interface{} 704 | }{ 705 | {newInt8Value(&i8), i8}, 706 | {newInt16Value(&i16), i16}, 707 | {newInt32Value(&i32), i32}, 708 | {newFloat32Value(&f32), f32}, 709 | {newUint8Value(&u8), u8}, 710 | {newUint16Value(&u16), u16}, 711 | {newUint32Value(&u32), u32}, 712 | {newStringSlice(&ss), ss}, 713 | {newIntSlice(&is), is}, 714 | {newFloat64Slice(&fs), fs}, 715 | } 716 | 717 | for _, c := range cases { 718 | assert.Equal(t, c.expected, c.getter.Get()) 719 | } 720 | } 721 | --------------------------------------------------------------------------------