├── LICENSE.txt
├── example_test.go
├── README.md
├── vars.go
├── flags.go
└── flags_test.go
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2016 Uber Technologies, Inc.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
21 |
--------------------------------------------------------------------------------
/example_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2016 Uber Technologies, Inc.
2 | //
3 | // Permission is hereby granted, free of charge, to any person obtaining a copy
4 | // of this software and associated documentation files (the "Software"), to deal
5 | // in the Software without restriction, including without limitation the rights
6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | // copies of the Software, and to permit persons to whom the Software is
8 | // furnished to do so, subject to the following conditions:
9 | //
10 | // The above copyright notice and this permission notice shall be included in
11 | // all copies or substantial portions of the Software.
12 | //
13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | // THE SOFTWARE.
20 |
21 | package flags
22 |
23 | import (
24 | "fmt"
25 | "time"
26 | )
27 |
28 | func ExampleParseArgs() {
29 | type Logging struct {
30 | Interval int
31 | Path string
32 | }
33 | type Socket struct {
34 | ReadTimeout time.Duration `yaml:"read_timeout"`
35 | WriteTimeout time.Duration
36 | }
37 |
38 | type TCP struct {
39 | ReadTimeout time.Duration
40 | Socket
41 | }
42 |
43 | type Network struct {
44 | ReadTimeout time.Duration
45 | WriteTimeout time.Duration
46 | TCP
47 | }
48 |
49 | type Cfg struct {
50 | Logging
51 | Network
52 | }
53 |
54 | // this is just an example, normally one would use packages like yaml to
55 | // populate the struct rather than manually create it.
56 | c := &Cfg{
57 | Logging: Logging{Interval: 3, Path: "/tmp"},
58 | Network: Network{
59 | TCP: TCP{
60 | ReadTimeout: time.Duration(10) * time.Millisecond,
61 | Socket: Socket{
62 | ReadTimeout: time.Duration(10) * time.Millisecond,
63 | },
64 | },
65 | },
66 | }
67 |
68 | fmt.Printf("loaded config 'network.tcp.socket.read_timeout' is %s\n", c.Network.TCP.Socket.ReadTimeout)
69 | args := []string{"--logging.interval", "2", "--network.tcp.socket.read_timeout", "50ms"}
70 | ParseArgs(c, args)
71 | fmt.Printf("after override 'network.tcp.socket.read_timeout' is %s\n", c.Network.TCP.Socket.ReadTimeout)
72 | // Output:
73 | // loaded config 'network.tcp.socket.read_timeout' is 10ms
74 | // after override 'network.tcp.socket.read_timeout' is 50ms
75 | }
76 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
2 | An automatic way of creating command line options to override fields from a struct.
3 |
4 | ## Installation
5 | `go get -u github.com/uber-go/flagoverride`
6 |
7 |
8 | ## Overview
9 | Typically, if one wants to load from a config file (e.g. yaml), one has to
10 | define a proper struct, then load values into it (e.g. yaml.Unmarshal()).
11 | However, there are situations where we want to load most of the configs from
12 | the file and to override some of the configs.
13 |
14 | Let's say we use a yaml to config our Db connections and upon start of the
15 | application we load from the yaml file to get the necessary parameters to
16 | create the connection. Our base.yaml looks like this:
17 |
18 | ```yaml
19 | base.yaml
20 | ---
21 | mysql:
22 | user: 'foo'
23 | password: 'xxxxxx'
24 | mysql_defaults_file: ./mysql_defaults.ini
25 | mysql_socket_path: /var/run/mysqld/mysqld.sock
26 | ... more config options ...
27 | ```
28 |
29 | we want to load all the configs from it but we want to provide some
30 | flexibility for the program to connect via a different db user. We could
31 | define a --user command flag then after loading the yaml file, we override
32 | the user field with what we get from --user flag.
33 |
34 | If there are many overriding like this, manual define these flags is
35 | tedious. This package provides an automatic way to define this override,
36 | which is, given a struct, it'll create all the flags which are name using
37 | the field names of the struct. If one of these flags are set via command
38 | line, the struct will be modified in-place to reflect the value from command
39 | line, therefore the values of the fields in the struct are overridden.
40 |
41 | YAML is just used as an example here. In practice, one can use any struct tdefine flags.
42 |
43 | Let's say we have our configuration object as the following:
44 |
45 | ```go
46 | type logging struct {
47 | Interval int
48 | Path string
49 | }
50 |
51 | type socket struct {
52 | ReadTimeout time.Duration
53 | WriteTimeout time.Duration
54 | }
55 |
56 | type tcp struct {
57 | ReadTimeout time.Duration
58 | socket
59 | }
60 |
61 | type network struct {
62 | ReadTimeout time.Duration
63 | WriteTimeout time.Duration
64 | tcp
65 | }
66 |
67 | type Cfg struct {
68 | logging
69 | network
70 | }
71 | ```
72 |
73 | The following code:
74 |
75 | ```go
76 | func main() {
77 | c := &Cfg{}
78 | flags.ParseArgs(c, os.Args[1:])
79 | }
80 | ```
81 |
82 | will create the following flags:
83 |
84 | ```
85 | -logging.interval int
86 | logging.interval
87 | -logging.path string
88 | logging.path
89 | -network.readtimeout duration
90 | network.readtimeout
91 | -network.tcp.readtimeout duration
92 | network.tcp.readtimeout
93 | -network.tcp.socket.readtimeout duration
94 | network.tcp.socket.readtimeout
95 | -network.tcp.socket.writetimeout duration
96 | network.tcp.socket.writetimeout
97 | -network.writetimeout duration
98 | network.writetimeout
99 | ```
100 |
101 | flags to subcommands are naturally supported.
102 |
103 | ```go
104 | func main() {
105 | cmd := os.Args[1]
106 | switch cmd {
107 | case "new"
108 | c1 := &Cfg1{}
109 | ParseArgs(c1, os.Args[2:])
110 | case "update":
111 | c2 := &Cfg2{}
112 | ParseArgs(c2, os.Args[2:])
113 |
114 | ... more sub commands ...
115 | }
116 | }
117 | ```
118 |
119 | One can set Flatten to true when calling `NewFlagMakerAdv`, in which case,
120 | flags are created without namespacing. For example,
121 |
122 | ```go
123 | type auth struct {
124 | Token string
125 | Tag float64
126 | }
127 |
128 | type credentials struct {
129 | User string
130 | Password string
131 | auth
132 | }
133 |
134 | type database struct {
135 | DBName string
136 | TableName string
137 | credentials
138 | }
139 |
140 | type Cfg struct {
141 | logging
142 | database
143 | }
144 |
145 | func main() {
146 | c := &Cfg{}
147 | flags.ParseArgs(c, os.Args[1:])
148 | }
149 | ```
150 |
151 | will create the following flags:
152 |
153 | ```
154 | -dbname string
155 | dbname
156 | -interval int
157 | interval
158 | -password string
159 | password
160 | -path string
161 | path
162 | -tablename string
163 | tablename
164 | -tag float
165 | tag
166 | -token string
167 | token
168 | -user string
169 | user
170 | ```
171 |
172 | Please be aware that usual GoLang flag creation rules apply, i.e., if there are
173 | duplication in flag names (in the flattened case it's more likely to happen
174 | unless the caller make due diligence to create the struct properly), it panics.
175 |
176 | Note that not all types can have command line flags created for.
177 |
178 | `map`, `channel` and function type will not define a flag corresponding to the field.
179 |
180 | Pointer types are properly handled and slice type will create multi-value command line flags.
181 |
182 | That is, e.g. if a field foo's type is `[]int`, one can use
183 | --foo 10 --foo 15 --foo 20 to override this field value to be
184 | `[]int{10, 15, 20}`. For now, only `[]int`, `[]string` and `[]float64` are supported in this fashion.
185 |
186 |
187 | Released under the [MIT License](LICENSE.txt).
188 |
--------------------------------------------------------------------------------
/vars.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2016 Uber Technologies, Inc.
2 | //
3 | // Permission is hereby granted, free of charge, to any person obtaining a copy
4 | // of this software and associated documentation files (the "Software"), to deal
5 | // in the Software without restriction, including without limitation the rights
6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | // copies of the Software, and to permit persons to whom the Software is
8 | // furnished to do so, subject to the following conditions:
9 | //
10 | // The above copyright notice and this permission notice shall be included in
11 | // all copies or substantial portions of the Software.
12 | //
13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | // THE SOFTWARE.
20 |
21 | package flags
22 |
23 | import (
24 | "fmt"
25 | "strconv"
26 | )
27 |
28 | // additional types
29 | type int8Value int8
30 | type int16Value int16
31 | type int32Value int32
32 | type f32Value float32
33 | type uint8Value uint8
34 | type uint32Value uint32
35 | type uint16Value uint16
36 |
37 | // Var handlers for each of the types
38 | func newInt8Value(p *int8) *int8Value {
39 | return (*int8Value)(p)
40 | }
41 |
42 | func newInt16Value(p *int16) *int16Value {
43 | return (*int16Value)(p)
44 | }
45 |
46 | func newInt32Value(p *int32) *int32Value {
47 | return (*int32Value)(p)
48 | }
49 |
50 | func newFloat32Value(p *float32) *f32Value {
51 | return (*f32Value)(p)
52 | }
53 |
54 | func newUint8Value(p *uint8) *uint8Value {
55 | return (*uint8Value)(p)
56 | }
57 |
58 | func newUint16Value(p *uint16) *uint16Value {
59 | return (*uint16Value)(p)
60 | }
61 |
62 | func newUint32Value(p *uint32) *uint32Value {
63 | return (*uint32Value)(p)
64 | }
65 |
66 | // Setters for each of the types
67 | func (f *int8Value) Set(s string) error {
68 | v, err := strconv.ParseInt(s, 10, 8)
69 | if err != nil {
70 | return err
71 | }
72 | *f = int8Value(v)
73 | return nil
74 | }
75 |
76 | func (f *int16Value) Set(s string) error {
77 | v, err := strconv.ParseInt(s, 10, 16)
78 | if err != nil {
79 | return err
80 | }
81 | *f = int16Value(v)
82 | return nil
83 | }
84 |
85 | func (f *int32Value) Set(s string) error {
86 | v, err := strconv.ParseInt(s, 10, 32)
87 | if err != nil {
88 | return err
89 | }
90 | *f = int32Value(v)
91 | return nil
92 | }
93 |
94 | func (f *f32Value) Set(s string) error {
95 | v, err := strconv.ParseFloat(s, 32)
96 | if err != nil {
97 | return err
98 | }
99 | *f = f32Value(v)
100 | return nil
101 | }
102 |
103 | func (f *uint8Value) Set(s string) error {
104 | v, err := strconv.ParseUint(s, 10, 8)
105 | if err != nil {
106 | return err
107 | }
108 | *f = uint8Value(v)
109 | return nil
110 | }
111 |
112 | func (f *uint16Value) Set(s string) error {
113 | v, err := strconv.ParseUint(s, 10, 16)
114 | if err != nil {
115 | return err
116 | }
117 | *f = uint16Value(v)
118 | return nil
119 | }
120 |
121 | func (f *uint32Value) Set(s string) error {
122 | v, err := strconv.ParseUint(s, 10, 32)
123 | if err != nil {
124 | return err
125 | }
126 | *f = uint32Value(v)
127 | return nil
128 | }
129 |
130 | // Getters for each of the types
131 | func (f *int8Value) Get() interface{} { return int8(*f) }
132 | func (f *int16Value) Get() interface{} { return int16(*f) }
133 | func (f *int32Value) Get() interface{} { return int32(*f) }
134 | func (f *f32Value) Get() interface{} { return float32(*f) }
135 | func (f *uint8Value) Get() interface{} { return uint8(*f) }
136 | func (f *uint16Value) Get() interface{} { return uint16(*f) }
137 | func (f *uint32Value) Get() interface{} { return uint32(*f) }
138 |
139 | // Stringers for each of the types
140 | func (f *int8Value) String() string { return fmt.Sprintf("%v", *f) }
141 | func (f *int16Value) String() string { return fmt.Sprintf("%v", *f) }
142 | func (f *int32Value) String() string { return fmt.Sprintf("%v", *f) }
143 | func (f *f32Value) String() string { return fmt.Sprintf("%v", *f) }
144 | func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f) }
145 | func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f) }
146 | func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f) }
147 |
148 | // string slice
149 |
150 | type strSlice struct {
151 | s *[]string
152 | set bool // if there a flag defined via command line, the slice will be cleared first.
153 | }
154 |
155 | func newStringSlice(p *[]string) *strSlice {
156 | return &strSlice{
157 | s: p,
158 | set: false,
159 | }
160 | }
161 |
162 | func (s *strSlice) Set(str string) error {
163 | if !s.set {
164 | *s.s = (*s.s)[:0]
165 | s.set = true
166 | }
167 | *s.s = append(*s.s, str)
168 | return nil
169 | }
170 |
171 | func (s *strSlice) Get() interface{} {
172 | return []string(*s.s)
173 | }
174 |
175 | func (s *strSlice) String() string {
176 | return fmt.Sprintf("%v", *s.s)
177 | }
178 |
179 | // int slice
180 | type intSlice struct {
181 | s *[]int
182 | set bool
183 | }
184 |
185 | func newIntSlice(p *[]int) *intSlice {
186 | return &intSlice{
187 | s: p,
188 | set: false,
189 | }
190 | }
191 |
192 | func (is *intSlice) Set(str string) error {
193 | i, err := strconv.Atoi(str)
194 | if err != nil {
195 | return err
196 | }
197 | if !is.set {
198 | *is.s = (*is.s)[:0]
199 | is.set = true
200 | }
201 | *is.s = append(*is.s, i)
202 | return nil
203 | }
204 |
205 | func (is *intSlice) Get() interface{} {
206 | return []int(*is.s)
207 | }
208 |
209 | func (is *intSlice) String() string {
210 | return fmt.Sprintf("%v", *is.s)
211 | }
212 |
213 | // float64 slice
214 | type float64Slice struct {
215 | s *[]float64
216 | set bool
217 | }
218 |
219 | func newFloat64Slice(p *[]float64) *float64Slice {
220 | return &float64Slice{
221 | s: p,
222 | set: false,
223 | }
224 | }
225 |
226 | func (is *float64Slice) Set(str string) error {
227 | i, err := strconv.ParseFloat(str, 64)
228 | if err != nil {
229 | return err
230 | }
231 | if !is.set {
232 | *is.s = (*is.s)[:0]
233 | is.set = true
234 | }
235 | *is.s = append(*is.s, i)
236 | return nil
237 | }
238 |
239 | func (is *float64Slice) Get() interface{} {
240 | return []float64(*is.s)
241 | }
242 |
243 | func (is *float64Slice) String() string {
244 | return fmt.Sprintf("%v", *is.s)
245 | }
246 |
--------------------------------------------------------------------------------
/flags.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2016 Uber Technologies, Inc.
2 | //
3 | // Permission is hereby granted, free of charge, to any person obtaining a copy
4 | // of this software and associated documentation files (the "Software"), to deal
5 | // in the Software without restriction, including without limitation the rights
6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | // copies of the Software, and to permit persons to whom the Software is
8 | // furnished to do so, subject to the following conditions:
9 | //
10 | // The above copyright notice and this permission notice shall be included in
11 | // all copies or substantial portions of the Software.
12 | //
13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | // THE SOFTWARE.
20 |
21 | // Package flags provides an interface for automatically creating command line
22 | // options from a struct.
23 | //
24 | // Typically, if one wants to load from a yaml, one has to define a proper
25 | // struct, then yaml.Unmarshal(), this is all good. However, there are
26 | // situations where we want to load most of the configs from the file but
27 | // overriding some configs.
28 | //
29 | // Let's say we use a yaml to config our Db connections and upon start of the
30 | // application we load from the yaml file to get the necessary parameters to
31 | // create the connection. Our base.yaml looks like this
32 | //
33 | // base.yaml
34 | // ---
35 | // mysql:
36 | // user: 'foo'
37 | // password: 'xxxxxx'
38 | // mysql_defaults_file: ./mysql_defaults.ini
39 | // mysql_socket_path: /var/run/mysqld/mysqld.sock
40 | // ... more config options ...
41 | //
42 | // we want to load all the configs from it but we want to provide some
43 | // flexibility for the program to connect via a different db user. We could
44 | // define a --user command flag then after loading the yaml file, we override
45 | // the user field with what we get from --user flag.
46 | //
47 | // If there are many overriding like this, manual define these flags is
48 | // tedious. This package provides an automatic way to define this override,
49 | // which is, given a struct, it'll create all the flags which are name using
50 | // the field names of the struct. If one of these flags are set via command
51 | // line, the struct will be modified in-place to reflect the value from command
52 | // line, therefore the values of the fields in the struct are overridden
53 | //
54 | // YAML is just used as an example here. In practice, one can use any struct
55 | // to define flags.
56 | //
57 | // Let's say we have our configration object as the following.
58 | //
59 | // type logging struct {
60 | // Interval int
61 | // Path string
62 | // }
63 | //
64 | // type socket struct {
65 | // ReadTimeout time.Duration
66 | // WriteTimeout time.Duration
67 | // }
68 | //
69 | // type tcp struct {
70 | // ReadTimeout time.Duration
71 | // socket
72 | // }
73 | //
74 | // type network struct {
75 | // ReadTimeout time.Duration
76 | // WriteTimeout time.Duration
77 | // tcp
78 | // }
79 | //
80 | // type Cfg struct {
81 | // logging
82 | // network
83 | // }
84 | //
85 | // The following code
86 | //
87 | // func main() {
88 | // c := &Cfg{}
89 | // flags.ParseArgs(c, os.Args[1:])
90 | // }
91 | //
92 | // will create the following flags
93 | //
94 | // -logging.interval int
95 | // logging.interval
96 | // -logging.path string
97 | // logging.path
98 | // -network.readtimeout duration
99 | // network.readtimeout
100 | // -network.tcp.readtimeout duration
101 | // network.tcp.readtimeout
102 | // -network.tcp.socket.readtimeout duration
103 | // network.tcp.socket.readtimeout
104 | // -network.tcp.socket.writetimeout duration
105 | // network.tcp.socket.writetimeout
106 | // -network.writetimeout duration
107 | // network.writetimeout
108 | //
109 | // flags to subcommands are naturally suported.
110 | //
111 | // func main() {
112 | // cmd := os.Args[1]
113 | // switch cmd {
114 | // case "new"
115 | // c1 := &Cfg1{}
116 | // ParseArgs(c1, os.Args[2:])
117 | // case "update":
118 | // c2 := &Cfg2{}
119 | // ParseArgs(c2, os.Args[2:])
120 | //
121 | // ... more sub commands ...
122 | // }
123 | // }
124 | //
125 | // One can set Flatten to true when calling NewFlagMakerAdv, in which case,
126 | // flags are created without namespacing. For example,
127 | //
128 | // type auth struct {
129 | // Token string
130 | // Tag float64
131 | // }
132 | //
133 | // type credentials struct {
134 | // User string
135 | // Password string
136 | // auth
137 | // }
138 | //
139 | // type database struct {
140 | // DBName string
141 | // TableName string
142 | // credentials
143 | // }
144 | //
145 | // type Cfg struct {
146 | // logging
147 | // database
148 | // }
149 | //
150 | // func main() {
151 | // c := &Cfg{}
152 | // flags.ParseArgs(c, os.Args[1:])
153 | // }
154 | //
155 | // will create the following flags
156 | // -dbname string
157 | // dbname
158 | // -interval int
159 | // interval
160 | // -password string
161 | // password
162 | // -path string
163 | // path
164 | // -tablename string
165 | // tablename
166 | // -tag float
167 | // tag
168 | // -token string
169 | // token
170 | // -user string
171 | // user
172 | //
173 | // Please be aware that usual GoLang flag creation rules apply, i.e., if there are
174 | // duplication in flag names (in the flattened case it's more likely to happen
175 | // unless the caller make due dilligence to create the struct properly), it panics.
176 | //
177 | //
178 | // Note that not all types can have command line flags created for. map, channel
179 | // and function type will not defien a flag corresponding to the field. Pointer
180 | // types are properly handled and slice type will create multi-value command
181 | // line flags. That is, e.g. if a field foo's type is []int, one can use
182 | // --foo 10 --foo 15 --foo 20 to override this field value to be
183 | // []int{10, 15, 20}. For now, only []int, []string and []float64 are supported
184 | // in this fashion.
185 | package flags
186 |
187 | import (
188 | "flag"
189 | "fmt"
190 | "reflect"
191 | "strings"
192 | "time"
193 | )
194 |
195 | // FlagMakingOptions control the way FlagMaker's behavior when defining flags.
196 | type FlagMakingOptions struct {
197 | // Use lower case flag names rather than the field name/tag name directly.
198 | UseLowerCase bool
199 | // Create flags in namespaced fashion
200 | Flatten bool
201 | // If there is a struct tag named 'TagName', use its value as the flag name.
202 | // The purpose is that, for yaml/json parsing we often have something like
203 | // Foobar string `yaml:"host_name"`, in which case the flag will be named
204 | // 'host_name' rather than 'foobar'.
205 | TagName string
206 | }
207 |
208 | // FlagMaker enumerate all the exported fields of a struct recursively
209 | // and create corresponding command line flags. For anonymous fields,
210 | // they are only enumerated if they are pointers to structs.
211 | // Usual GoLang flag rules apply, e.g. duplicated flag names leads to
212 | // panic.
213 | type FlagMaker struct {
214 | opts *FlagMakingOptions
215 | // We don't consume os.Args directly unless told to.
216 | fs *flag.FlagSet
217 | }
218 |
219 | // NewFlagMaker creates a default FlagMaker which creates namespaced flags
220 | func NewFlagMaker() *FlagMaker {
221 | return NewFlagMakerAdv(&FlagMakingOptions{
222 | UseLowerCase: true,
223 | Flatten: false,
224 | TagName: "yaml"})
225 | }
226 |
227 | // NewFlagMakerAdv gives full control to create flags.
228 | func NewFlagMakerAdv(options *FlagMakingOptions) *FlagMaker {
229 | return &FlagMaker{
230 | opts: options,
231 | fs: flag.NewFlagSet("xFlags", flag.ContinueOnError),
232 | }
233 | }
234 |
235 | // ParseArgs parses the string arguments which should not contain the program name.
236 | //
237 | // obj is the struct to populate. args are the command line arguments,
238 | // typically obtained from os.Args.
239 | func ParseArgs(obj interface{}, args []string) ([]string, error) {
240 | fm := NewFlagMaker()
241 | return fm.ParseArgs(obj, args)
242 | }
243 |
244 | // PrintDefaults prints the default value and type of defined flags.
245 | // It just calls the standard 'flag' package's PrintDefaults.
246 | func (fm *FlagMaker) PrintDefaults() {
247 | fm.fs.PrintDefaults()
248 | }
249 |
250 | // ParseArgs parses the arguments based on the FlagMaker's setting.
251 | func (fm *FlagMaker) ParseArgs(obj interface{}, args []string) ([]string, error) {
252 | v := reflect.ValueOf(obj)
253 | if v.Kind() != reflect.Ptr {
254 | return args, fmt.Errorf("top level object must be a pointer. %v is passed", v.Type())
255 | }
256 | if v.IsNil() {
257 | return args, fmt.Errorf("top level object cannot be nil")
258 | }
259 |
260 | switch e := v.Elem(); e.Kind() {
261 | case reflect.Struct:
262 | fm.enumerateAndCreate("", e)
263 | case reflect.Interface:
264 | if e.Elem().Kind() == reflect.Ptr {
265 | fm.enumerateAndCreate("", e)
266 | } else {
267 | return args, fmt.Errorf("interface must have pointer underlying type. %v is passed", v.Type())
268 | }
269 | default:
270 | return args, fmt.Errorf("object must be a pointer to struct or interface. %v is passed", v.Type())
271 | }
272 |
273 | err := fm.fs.Parse(args)
274 | return fm.fs.Args(), err
275 | }
276 |
277 | func (fm *FlagMaker) enumerateAndCreate(prefix string, value reflect.Value) {
278 | switch value.Kind() {
279 | case
280 | // do no create flag for these types
281 | reflect.Map,
282 | reflect.Uintptr,
283 | reflect.UnsafePointer,
284 | reflect.Array,
285 | reflect.Chan,
286 | reflect.Func:
287 | return
288 | case reflect.Slice:
289 | // only support slice of strings, ints and float64s
290 | switch value.Type().Elem().Kind() {
291 | case reflect.String:
292 | fm.defineStringSlice(prefix, value)
293 | case reflect.Int:
294 | fm.defineIntSlice(prefix, value)
295 | case reflect.Float64:
296 | fm.defineFloat64Slice(prefix, value)
297 | }
298 | return
299 | case
300 | // Basic value types
301 | reflect.String,
302 | reflect.Bool,
303 | reflect.Float32, reflect.Float64,
304 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
305 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
306 | fm.defineFlag(prefix, value)
307 | return
308 | case reflect.Interface:
309 | if !value.IsNil() {
310 | fm.enumerateAndCreate(prefix, value.Elem())
311 | }
312 | return
313 | case reflect.Ptr:
314 | if value.IsNil() {
315 | value.Set(reflect.New(value.Type().Elem()))
316 | }
317 | fm.enumerateAndCreate(prefix, value.Elem())
318 | return
319 | case reflect.Struct:
320 | // keep going
321 | default:
322 | panic(fmt.Sprintf("unknown reflected kind %v", value.Kind()))
323 | }
324 |
325 | numFields := value.NumField()
326 | tt := value.Type()
327 |
328 | for i := 0; i < numFields; i++ {
329 | stField := tt.Field(i)
330 | // Skip unexported fields, as only exported fields can be set. This is similar to how json and yaml work.
331 | if stField.PkgPath != "" && !stField.Anonymous {
332 | continue
333 | }
334 | if stField.Anonymous && fm.getUnderlyingType(stField.Type).Kind() != reflect.Struct {
335 | continue
336 | }
337 | field := value.Field(i)
338 | optName := fm.getName(stField)
339 | if len(prefix) > 0 && !fm.opts.Flatten {
340 | optName = prefix + "." + optName
341 | }
342 | fm.enumerateAndCreate(optName, field)
343 | }
344 | }
345 |
346 | func (fm *FlagMaker) getName(field reflect.StructField) string {
347 | name := field.Tag.Get(fm.opts.TagName)
348 | if len(name) == 0 {
349 | if field.Anonymous {
350 | name = fm.getUnderlyingType(field.Type).Name()
351 | } else {
352 | name = field.Name
353 | }
354 | }
355 | if fm.opts.UseLowerCase {
356 | return strings.ToLower(name)
357 | }
358 | return name
359 | }
360 |
361 | func (fm *FlagMaker) getUnderlyingType(ttype reflect.Type) reflect.Type {
362 | // this only deals with *T unnamed type, other unnamed types, e.g. []int, struct{}
363 | // will return empty string.
364 | if ttype.Kind() == reflect.Ptr {
365 | return fm.getUnderlyingType(ttype.Elem())
366 | }
367 | return ttype
368 | }
369 |
370 | // Each object has its type (which prescribes the possible operations/methods
371 | // could be invoked; it also has an underlying 'kind', int, float, struct etc.
372 | // Since user can freely define types, one 'kind' of object may correpond to
373 | // many types. We cannot do type assertion because types of same kind are still
374 | // different types. Instead, we convert to the primitive types that corresponds
375 | // to the kinds and create flag vars. One thing to know is that, the whole point
376 | // of defineFlag() method is to define flag.Vars that points to certain field
377 | // of the struct so that command line values can modify the struct. We cannot
378 | // define a flag var pointing to arbitrary 'free' varible.
379 |
380 | // I wish GoLang had macro...
381 | var (
382 | stringPtrType = reflect.TypeOf((*string)(nil))
383 | boolPtrType = reflect.TypeOf((*bool)(nil))
384 | float32PtrType = reflect.TypeOf((*float32)(nil))
385 | float64PtrType = reflect.TypeOf((*float64)(nil))
386 | intPtrType = reflect.TypeOf((*int)(nil))
387 | int8PtrType = reflect.TypeOf((*int8)(nil))
388 | int16PtrType = reflect.TypeOf((*int16)(nil))
389 | int32PtrType = reflect.TypeOf((*int32)(nil))
390 | int64PtrType = reflect.TypeOf((*int64)(nil))
391 | uintPtrType = reflect.TypeOf((*uint)(nil))
392 | uint8PtrType = reflect.TypeOf((*uint8)(nil))
393 | uint16PtrType = reflect.TypeOf((*uint16)(nil))
394 | uint32PtrType = reflect.TypeOf((*uint32)(nil))
395 | uint64PtrType = reflect.TypeOf((*uint64)(nil))
396 | )
397 |
398 | func (fm *FlagMaker) defineFlag(name string, value reflect.Value) {
399 | // v must be scalar, otherwise panic
400 | ptrValue := value.Addr()
401 | switch value.Kind() {
402 | case reflect.String:
403 | v := ptrValue.Convert(stringPtrType).Interface().(*string)
404 | fm.fs.StringVar(v, name, value.String(), name)
405 | case reflect.Bool:
406 | v := ptrValue.Convert(boolPtrType).Interface().(*bool)
407 | fm.fs.BoolVar(v, name, value.Bool(), name)
408 | case reflect.Int:
409 | v := ptrValue.Convert(intPtrType).Interface().(*int)
410 | fm.fs.IntVar(v, name, int(value.Int()), name)
411 | case reflect.Int8:
412 | v := ptrValue.Convert(int8PtrType).Interface().(*int8)
413 | fm.fs.Var(newInt8Value(v), name, name)
414 | case reflect.Int16:
415 | v := ptrValue.Convert(int16PtrType).Interface().(*int16)
416 | fm.fs.Var(newInt16Value(v), name, name)
417 | case reflect.Int32:
418 | v := ptrValue.Convert(int32PtrType).Interface().(*int32)
419 | fm.fs.Var(newInt32Value(v), name, name)
420 | case reflect.Int64:
421 | switch v := ptrValue.Interface().(type) {
422 | case *int64:
423 | fm.fs.Int64Var(v, name, value.Int(), name)
424 | case *time.Duration:
425 | fm.fs.DurationVar(v, name, value.Interface().(time.Duration), name)
426 | default:
427 | // (TODO) if one type defines time.Duration, we'll create a int64 flag for it.
428 | // Find some acceptible way to deal with it.
429 | vv := ptrValue.Convert(int64PtrType).Interface().(*int64)
430 | fm.fs.Int64Var(vv, name, value.Int(), name)
431 | }
432 | case reflect.Float32:
433 | v := ptrValue.Convert(float32PtrType).Interface().(*float32)
434 | fm.fs.Var(newFloat32Value(v), name, name)
435 | case reflect.Float64:
436 | v := ptrValue.Convert(float64PtrType).Interface().(*float64)
437 | fm.fs.Float64Var(v, name, value.Float(), name)
438 | case reflect.Uint:
439 | v := ptrValue.Convert(uintPtrType).Interface().(*uint)
440 | fm.fs.UintVar(v, name, uint(value.Uint()), name)
441 | case reflect.Uint8:
442 | v := ptrValue.Convert(uint8PtrType).Interface().(*uint8)
443 | fm.fs.Var(newUint8Value(v), name, name)
444 | case reflect.Uint16:
445 | v := ptrValue.Convert(uint16PtrType).Interface().(*uint16)
446 | fm.fs.Var(newUint16Value(v), name, name)
447 | case reflect.Uint32:
448 | v := ptrValue.Convert(uint32PtrType).Interface().(*uint32)
449 | fm.fs.Var(newUint32Value(v), name, name)
450 | case reflect.Uint64:
451 | v := ptrValue.Convert(uint64PtrType).Interface().(*uint64)
452 | fm.fs.Uint64Var(v, name, value.Uint(), name)
453 | }
454 | }
455 |
456 | func (fm *FlagMaker) defineStringSlice(name string, value reflect.Value) {
457 | ptrValue := value.Addr().Interface().(*[]string)
458 | fm.fs.Var(newStringSlice(ptrValue), name, name)
459 | }
460 |
461 | func (fm *FlagMaker) defineIntSlice(name string, value reflect.Value) {
462 | ptrValue := value.Addr().Interface().(*[]int)
463 | fm.fs.Var(newIntSlice(ptrValue), name, name)
464 | }
465 |
466 | func (fm *FlagMaker) defineFloat64Slice(name string, value reflect.Value) {
467 | ptrValue := value.Addr().Interface().(*[]float64)
468 | fm.fs.Var(newFloat64Slice(ptrValue), name, name)
469 | }
470 |
--------------------------------------------------------------------------------
/flags_test.go:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2016 Uber Technologies, Inc.
2 | //
3 | // Permission is hereby granted, free of charge, to any person obtaining a copy
4 | // of this software and associated documentation files (the "Software"), to deal
5 | // in the Software without restriction, including without limitation the rights
6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | // copies of the Software, and to permit persons to whom the Software is
8 | // furnished to do so, subject to the following conditions:
9 | //
10 | // The above copyright notice and this permission notice shall be included in
11 | // all copies or substantial portions of the Software.
12 | //
13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | // THE SOFTWARE.
20 |
21 | package flags
22 |
23 | import (
24 | "flag"
25 | "testing"
26 | "time"
27 |
28 | "github.com/stretchr/testify/assert"
29 | )
30 |
31 | type logging struct {
32 | Interval int
33 | Path string
34 | }
35 |
36 | type socket struct {
37 | ReadTimeout time.Duration
38 | WriteTimeout time.Duration
39 | }
40 |
41 | type tcp struct {
42 | ReadTimeout time.Duration
43 | socket
44 | }
45 |
46 | type network struct {
47 | ReadTimeout time.Duration
48 | WriteTimeout time.Duration
49 | tcp
50 | }
51 |
52 | type Cfg1 struct {
53 | logging
54 | network
55 | }
56 |
57 | func TestFlagMakerExample(t *testing.T) {
58 | cfg := Cfg1{}
59 |
60 | args := []string{
61 | "--network.tcp.socket.readtimeout", "5ms",
62 | "--network.tcp.readtimeout", "3ms",
63 | "-logging.path", "/var/log",
64 | }
65 | args, err := ParseArgs(cfg, args)
66 | assert.False(t, err == nil)
67 | args, err = ParseArgs(&cfg, args)
68 | assert.True(t, err == nil)
69 | assert.Equal(t, 0, len(args))
70 |
71 | expected := Cfg1{
72 | network: network{
73 | tcp: tcp{
74 | ReadTimeout: time.Duration(3) * time.Millisecond,
75 | socket: socket{
76 | ReadTimeout: time.Duration(5) * time.Millisecond,
77 | },
78 | },
79 | },
80 | logging: logging{
81 | Path: "/var/log",
82 | },
83 | }
84 | assert.Equal(t, expected, cfg)
85 | }
86 |
87 | type auth struct {
88 | Token string
89 | Tag float64
90 | }
91 |
92 | type credentials struct {
93 | User string
94 | Password string
95 | auth
96 | }
97 |
98 | type database struct {
99 | DBName string
100 | TableName string
101 | credentials
102 | }
103 |
104 | type Cfg2 struct {
105 | logging
106 | database
107 | *string
108 | }
109 |
110 | func TestFlagMakerExampleFlattened(t *testing.T) {
111 | cfg := Cfg2{}
112 |
113 | args := []string{
114 | "--dbname", "db1",
115 | "--token", "abcd",
116 | "-tag=3.14",
117 | "-path", "/var/log",
118 | }
119 |
120 | fm := NewFlagMakerAdv(&FlagMakingOptions{true, true, "not-care"})
121 | args, err := fm.ParseArgs(&cfg, args)
122 |
123 | assert.True(t, err == nil)
124 | assert.Equal(t, 0, len(args))
125 |
126 | expected := Cfg2{}
127 | expected.Tag = 3.14
128 | expected.DBName = "db1"
129 | expected.Token = "abcd"
130 | expected.Path = "/var/log"
131 |
132 | assert.Equal(t, expected, cfg)
133 | }
134 |
135 | type C4 struct {
136 | TableName string
137 | }
138 |
139 | type C3 struct {
140 | DBName string
141 | C4
142 | }
143 |
144 | type C2 struct {
145 | User string
146 | Password int64
147 | Tag int8
148 | C3
149 | }
150 |
151 | type C1 struct {
152 | Name string `yaml:"label"`
153 | Value int
154 | Float float64
155 | Timeout time.Duration
156 | Hosts []string
157 | Ports []int
158 | Weights []float64
159 | Credential C2
160 |
161 | // some unexported fields
162 | opentimeout time.Duration
163 | localhost string
164 | }
165 |
166 | func TestFlagMakerBasic(t *testing.T) {
167 | c := &C1{
168 | Name: "basic",
169 | Value: 10,
170 | Float: 7.4,
171 | Timeout: time.Duration(10) * time.Millisecond,
172 | Hosts: []string{"host1", "host2"},
173 | Ports: []int{89, 90},
174 | Credential: C2{
175 | User: "user",
176 | Password: 1234,
177 | Tag: 20,
178 | C3: C3{
179 | DBName: "db1",
180 | C4: C4{
181 | TableName: "t1",
182 | },
183 | },
184 | },
185 | opentimeout: time.Duration(3) * time.Microsecond,
186 | localhost: "weird.host",
187 | }
188 | args := []string{
189 | "--label", "advanced", "-float", "5.1", "-timeout", "5ms", "--ports", "22", "--ports", "43",
190 | "--credential.user", "uber", "--credential.tag", "80", "--credential.c3.dbname", "db2",
191 | "--credential.c3.c4.tablename", "t2"}
192 | args, err := ParseArgs(c, args)
193 | assert.Equal(t, nil, err, "should be no error")
194 | assert.Equal(t, 0, len(args), "should be no arg left")
195 |
196 | expected := *c
197 | // only these fields should be modified by parsing the arguments.
198 | expected.Credential.User = "uber"
199 | expected.Credential.Password = int64(1234)
200 | expected.Credential.Tag = int8(80)
201 | expected.Name = "advanced"
202 | expected.Credential.C3.DBName = "db2"
203 | expected.Credential.C3.TableName = "t2"
204 | expected.Ports = []int{22, 43}
205 |
206 | assert.Equal(t, &expected, c)
207 | }
208 |
209 | type CTypes struct {
210 | Strval string
211 | Bval bool
212 | F32val float32
213 | F64val float64
214 | Ival int
215 | I8val int8
216 | I16val int16
217 | I32val int32
218 | I64val int64
219 | UIval uint
220 | UI8val uint8
221 | UI16val uint16
222 | UI32val uint32
223 | UI64val uint64
224 | }
225 |
226 | func TestFlagMakerTypes(t *testing.T) {
227 | /* Check all of the types */
228 | refCtypes := &CTypes{
229 | Strval: "string value",
230 | Bval: true,
231 | F32val: 3.1415927, // <- Max PI for 32 bit float
232 | F64val: 3.141592653589793, // <- Max PI for 64 bit float
233 | /* The rest of these use the highest value for the type */
234 | Ival: int(0x7fffffffffffffff),
235 | I8val: int8(0x7f),
236 | I16val: int16(0x7fff),
237 | I32val: int32(0x7fffffff),
238 | I64val: int64(0x7fffffffffffffff),
239 | UIval: uint(0xffffffffffffffff),
240 | UI8val: uint8(0xff),
241 | UI16val: uint16(0xffff),
242 | UI32val: uint32(0xffffffff),
243 | UI64val: uint64(0xffffffffffffffff),
244 | }
245 | parseCtypes := &CTypes{}
246 | args := []string{
247 | "-strval", "string value", "--bval",
248 | "-f32val", "3.1415927", "--f64val", "3.141592653589793",
249 | "--ival", "9223372036854775807", "--i8val", "127", "--i16val", "32767",
250 | "-i32val", "2147483647", "--i64val", "9223372036854775807",
251 | "--uival", "18446744073709551615", "--ui8val", "255", "--ui16val", "65535",
252 | "-ui32val", "4294967295", "--ui64val", "18446744073709551615"}
253 | args, err := ParseArgs(parseCtypes, args)
254 | assert.Equal(t, nil, err, "should be no error")
255 | assert.Equal(t, parseCtypes, refCtypes)
256 | }
257 |
258 | type D1 struct {
259 | F1 ****string
260 | F2 ***[]int
261 | }
262 |
263 | type D2 struct {
264 | F1 **[]float64
265 | F2 *****bool
266 | D1
267 | }
268 |
269 | type D3 struct {
270 | D2
271 | F3 uint
272 | F4 int64
273 | Hosts []string
274 | }
275 |
276 | type DD struct {
277 | D1
278 | D2
279 | D3
280 | }
281 |
282 | func TestFlagMakerComplex(t *testing.T) {
283 | d := DD{}
284 | args := []string{"-d2.f1", "1.2", "-d3.d2.d1.f2", "45", "-d2.f2", "-d2.f1", "4.2", "-d3.d2.d1.f2", "56", "-d2.f1", "7.4", "-d3.d2.d1.f2", "78"}
285 | args, err := ParseArgs(&d, args)
286 | assert.Equal(t, nil, err, "unexpected error")
287 | assert.Equal(t, 0, len(args))
288 | assert.Equal(t, true, *****d.D2.F2)
289 | assert.Equal(t, []int{45, 56, 78}, ***d.D3.D2.D1.F2)
290 | assert.Equal(t, []float64{1.2, 4.2, 7.4}, **d.D2.F1)
291 | }
292 |
293 | type I1 interface {
294 | Method1() string
295 | }
296 |
297 | type S1 struct {
298 | Host string
299 | ignore int
300 | Weights []float64
301 | F int8
302 | }
303 |
304 | func (s *S1) Method1() string {
305 | return s.Host
306 | }
307 |
308 | type S2 struct {
309 | Open bool
310 | Volume float64
311 | }
312 |
313 | func (s S2) Method1() string { return "haha" }
314 |
315 | func TestFlagMakerInterface(t *testing.T) {
316 | var s I1 = &S1{ignore: 12}
317 | args := []string{
318 | "-host", "test.local", "-f", "16",
319 | }
320 | args, err := ParseArgs(s, args)
321 | assert.Equal(t, nil, err)
322 | expected := S1{
323 | Host: "test.local",
324 | F: int8(16),
325 | ignore: 12,
326 | }
327 | assert.Equal(t, "test.local", s.Method1())
328 | assert.Equal(t, &expected, s)
329 | }
330 |
331 | func TestFlagMakerPtrToIntf(t *testing.T) {
332 | s := &S1{}
333 | var i2 I1 = s
334 | args := []string{"--weights", "9.3", "--host", "www", "--weights", "10.0"}
335 | out, err := ParseArgs(&i2, args)
336 | assert.Nil(t, err)
337 | assert.Equal(t, 0, len(out))
338 | assert.Equal(t, "www", s.Host)
339 | assert.Equal(t, []float64{9.3, 10.0}, s.Weights)
340 |
341 | s2 := S2{}
342 | var i3 I1 = s2
343 | args = []string{"--open", "--volume", "9.3"}
344 | out, err = ParseArgs(&i3, args)
345 | assert.Error(t, err)
346 | assert.Contains(t, err.Error(), "interface must have pointer underlying type.")
347 | assert.Equal(t, 3, len(out))
348 | }
349 |
350 | type Cfg3 struct {
351 | D3
352 | *D2
353 | }
354 |
355 | func TestFlagMakerNested(t *testing.T) {
356 | cfg := Cfg3{}
357 | args := []string{"-d3.hosts", "h1.com", "-d2.f1", "1.2", "-d3.hosts", "h2.com", "-d2.f1", "4.2", "-d3.hosts", "h3.com", "-d2.f1", "7.4"}
358 | args, err := ParseArgs(&cfg, args)
359 | assert.Equal(t, nil, err)
360 | assert.Equal(t, 0, len(args))
361 |
362 | assert.Equal(t, []string{"h1.com", "h2.com", "h3.com"}, cfg.Hosts)
363 | assert.Equal(t, []float64{1.2, 4.2, 7.4}, **cfg.F1)
364 | }
365 |
366 | type Cfg4 struct {
367 | Name *string
368 | *string
369 | int
370 | }
371 |
372 | func TestFlagMakerUnnamedFields(t *testing.T) {
373 | c := Cfg4{int: 4}
374 | args := []string{"--name=haha"}
375 | args, err := ParseArgs(&c, args)
376 | assert.Equal(t, nil, err)
377 | assert.Equal(t, 0, len(args))
378 | ss := "haha"
379 | expected := Cfg4{Name: &ss, int: 4}
380 | assert.Equal(t, expected, c)
381 | }
382 |
383 | // The following test ensures that we can properly create flags for user
384 | // defined non-struct types. The kind of an object and the type of an
385 | // object is different. See the comments of defineFlag().
386 | type String string
387 | type Int int
388 | type Int8 int8
389 | type Int32 int32
390 | type Int16 int16
391 | type Int64 int64
392 | type Float float64
393 | type Float32 float32
394 | type Uint uint
395 | type Uint8 uint8
396 | type Uint16 uint16
397 | type Uint32 uint32
398 | type Uint64 uint64
399 | type Bool bool
400 | type PString *String
401 | type PInt *Int
402 | type PInt8 *Int8
403 | type PInt16 *Int16
404 | type PInt32 *Int32
405 | type PInt64 *Int64
406 | type PFloat *Float
407 | type PUint *Uint
408 | type PUint8 *Uint8
409 | type PUint16 *Uint16
410 | type PUint32 *Uint32
411 | type PUint64 *Uint64
412 | type PBool *bool
413 |
414 | type Cfg5 struct {
415 | S String
416 | PS PString
417 | I Int
418 | PI PInt
419 | I8 Int8
420 | PI8 PInt8
421 | I16 Int16
422 | PI16 PInt16
423 | I32 Int32
424 | PI32 PInt32
425 | I64 Int64
426 | PI64 PInt64
427 | F Float
428 | PF PFloat
429 | U Uint
430 | U8 Uint8
431 | PU8 PUint8
432 | U16 Uint16
433 | PU16 PUint16
434 | U32 Uint32
435 | PU32 PUint32
436 | PU PUint
437 | U64 Uint64
438 | PU64 PUint64
439 | B Bool
440 | PB PBool
441 | F32 Float32
442 | PF32 *Float32
443 | }
444 |
445 | func TestFlagMakerTypeDef(t *testing.T) {
446 | cfg := &Cfg5{}
447 | args := []string{"--s", "hehe", "--ps", "good",
448 | "-i", "33", "--pi", "44",
449 | "--i64", "55", "--pi64", "66",
450 | "--f", "5.7", "--pf", "6.7",
451 | "--u", "10", "--pu", "20",
452 | "--u64", "30", "--pu64", "40",
453 | "--i8", "20", "--pi8", "10",
454 | "--i16", "400", "-pi16", "500",
455 | "--i32", "600", "-pi32", "700",
456 | "--u8", "20", "--pu8", "10",
457 | "--u16", "400", "-pu16", "500",
458 | "--u32", "600", "-pu32", "700",
459 | "--f32", "10.1", "-pf32", "11.1",
460 | "--b=true", "--pb"}
461 | args, err := ParseArgs(cfg, args)
462 |
463 | assert.Equal(t, nil, err)
464 | assert.Equal(t, 0, len(args))
465 |
466 | assert.Equal(t, "hehe", string(cfg.S))
467 | assert.Equal(t, "good", string(*cfg.PS))
468 |
469 | assert.Equal(t, 33, int(cfg.I))
470 | assert.Equal(t, 44, int(*cfg.PI))
471 |
472 | assert.Equal(t, int8(20), int8(cfg.I8))
473 | assert.Equal(t, int8(10), int8(*cfg.PI8))
474 |
475 | assert.Equal(t, int16(400), int16(cfg.I16))
476 | assert.Equal(t, int16(500), int16(*cfg.PI16))
477 |
478 | assert.Equal(t, int32(600), int32(cfg.I32))
479 | assert.Equal(t, int32(700), int32(*cfg.PI32))
480 |
481 | assert.Equal(t, int64(55), int64(cfg.I64))
482 | assert.Equal(t, int64(66), int64(*cfg.PI64))
483 |
484 | assert.Equal(t, 5.7, float64(cfg.F))
485 | assert.Equal(t, 6.7, float64(*cfg.PF))
486 |
487 | assert.Equal(t, float32(10.1), float32(cfg.F32))
488 | assert.Equal(t, float32(11.1), float32(*cfg.PF32))
489 |
490 | assert.Equal(t, uint(10), uint(cfg.U))
491 | assert.Equal(t, uint(20), uint(*cfg.PU))
492 |
493 | assert.Equal(t, uint8(20), uint8(cfg.U8))
494 | assert.Equal(t, uint8(10), uint8(*cfg.PU8))
495 |
496 | assert.Equal(t, uint16(400), uint16(cfg.U16))
497 | assert.Equal(t, uint16(500), uint16(*cfg.PU16))
498 |
499 | assert.Equal(t, uint32(600), uint32(cfg.U32))
500 | assert.Equal(t, uint32(700), uint32(*cfg.PU32))
501 |
502 | assert.Equal(t, uint64(30), uint64(cfg.U64))
503 | assert.Equal(t, uint64(40), uint64(*cfg.PU64))
504 |
505 | assert.Equal(t, true, bool(cfg.B))
506 | assert.Equal(t, true, bool(*cfg.PB))
507 | }
508 |
509 | func TestFlagMakerInvalidInput(t *testing.T) {
510 | var cfg *Cfg5
511 | args := []string{"--s", "hehe", "--ps", "good",
512 | "-i", "33", "--pi", "44",
513 | "--i64", "55", "--pi64", "66",
514 | }
515 | out, err := ParseArgs(cfg, args)
516 | assert.Error(t, err)
517 | assert.Equal(t, "top level object cannot be nil", err.Error())
518 | assert.Equal(t, len(args), len(out))
519 |
520 | var cfg2 Cfg4
521 | out, err = ParseArgs(cfg2, args)
522 | assert.Error(t, err)
523 | assert.Contains(t, err.Error(), "top level object must be a pointer")
524 | assert.Equal(t, len(args), len(out))
525 | }
526 |
527 | func TestFlagMakerUnsupportedTypes(t *testing.T) {
528 | cases := []struct {
529 | cfg interface{}
530 | args []string
531 | }{
532 | {&struct {
533 | Env map[string]string
534 | Level int
535 | }{}, []string{"--level", "10", "--env", "hh,fgg,10"}},
536 | {&struct {
537 | Env chan int
538 | Level int
539 | }{}, []string{"--level", "10", "--env", "hh,fgg,10"}},
540 | {&struct {
541 | Env func(int) string
542 | Level int
543 | }{}, []string{"--level", "10", "--env", "hh,fgg,10"}},
544 | }
545 |
546 | for _, c := range cases {
547 | out, err := ParseArgs(c.cfg, c.args)
548 | assert.Error(t, err)
549 | assert.Contains(t, err.Error(), "flag provided but not defined")
550 | assert.Equal(t, 1, len(out))
551 | }
552 | }
553 |
554 | func TestFlagMakerInvalidValue(t *testing.T) {
555 | cases := []struct {
556 | cfg interface{}
557 | args []string
558 | }{
559 | {&struct{ Level int }{}, []string{"--level", "haha"}},
560 | {&struct{ Level int8 }{}, []string{"--level", "haha"}},
561 | {&struct{ Level int16 }{}, []string{"--level", "haha"}},
562 | {&struct{ Level int32 }{}, []string{"--level", "haha"}},
563 | {&struct{ Level int64 }{}, []string{"--level", "haha"}},
564 | {&struct{ Level uint8 }{}, []string{"--level", "haha"}},
565 | {&struct{ Level uint16 }{}, []string{"--level", "haha"}},
566 | {&struct{ Level uint32 }{}, []string{"--level", "haha"}},
567 | {&struct{ Level uint64 }{}, []string{"--level", "haha"}},
568 | {&struct{ Level float32 }{}, []string{"--level", "haha"}},
569 | {&struct{ Level float64 }{}, []string{"--level", "haha"}},
570 | }
571 |
572 | for _, c := range cases {
573 | out, err := ParseArgs(c.cfg, c.args)
574 | assert.Error(t, err)
575 | assert.Contains(t, err.Error(), "invalid value")
576 | // args are consumed even thought the value is invalid
577 | assert.Equal(t, 0, len(out))
578 | }
579 | }
580 |
581 | // slice
582 |
583 | func TestFlagMakerStringSlice(t *testing.T) {
584 | type C struct {
585 | Hosts []string
586 | }
587 | cases := []struct {
588 | cfg *C
589 | args, expected []string
590 | }{
591 | {&C{}, []string{"--hosts", "h1", "--hosts", "h2", "--hosts", "h3"}, []string{"h1", "h2", "h3"}},
592 | {&C{[]string{}}, []string{"--hosts", "h1", "--hosts", "h2", "--hosts", "h3"}, []string{"h1", "h2", "h3"}},
593 | {&C{}, []string{}, nil},
594 | {&C{[]string{}}, []string{}, []string{}},
595 | {&C{[]string{"l1", "l2"}}, []string{}, []string{"l1", "l2"}},
596 | {&C{[]string{"l1", "l2"}}, []string{"--hosts", "ok"}, []string{"ok"}},
597 | }
598 | for _, c := range cases {
599 | args, err := ParseArgs(c.cfg, c.args)
600 | assert.Nil(t, err)
601 | assert.Equal(t, 0, len(args))
602 | assert.Equal(t, c.expected, c.cfg.Hosts)
603 | }
604 | }
605 |
606 | func TestFlagMakerIntSlice(t *testing.T) {
607 | type C struct {
608 | Levels []int
609 | }
610 |
611 | cases := []struct {
612 | cfg *C
613 | args []string
614 | expected []int
615 | }{
616 | {&C{}, []string{"--levels", "8", "--levels", "9", "--levels", "10"}, []int{8, 9, 10}},
617 | {&C{[]int{}}, []string{"--levels", "8", "--levels", "9", "--levels", "10"}, []int{8, 9, 10}},
618 | {&C{}, []string{}, nil},
619 | {&C{[]int{}}, []string{}, []int{}},
620 | {&C{[]int{11, 12}}, []string{}, []int{11, 12}},
621 | {&C{[]int{11, 12}}, []string{"--levels", "5"}, []int{5}},
622 | }
623 | for _, c := range cases {
624 | args, err := ParseArgs(c.cfg, c.args)
625 | assert.Nil(t, err)
626 | assert.Equal(t, 0, len(args))
627 | assert.Equal(t, c.expected, c.cfg.Levels)
628 | }
629 | }
630 |
631 | func TestFlagMakerFloatSlice(t *testing.T) {
632 | type C struct {
633 | Levels []float64
634 | }
635 | cases := []struct {
636 | cfg *C
637 | args []string
638 | expected []float64
639 | }{
640 | {&C{}, []string{"--levels", "8.9", "--levels", "9.9", "--levels", "10.9"}, []float64{8.9, 9.9, 10.9}},
641 | {&C{[]float64{}}, []string{"--levels", "8.9", "--levels", "9.9", "--levels", "10.9"}, []float64{8.9, 9.9, 10.9}},
642 | {&C{}, []string{}, nil},
643 | {&C{[]float64{}}, []string{}, []float64{}},
644 | {&C{[]float64{11.3, 12.3}}, []string{}, []float64{11.3, 12.3}},
645 | {&C{[]float64{11.3, 12.3}}, []string{"--levels", "5.1"}, []float64{5.1}},
646 | }
647 | for _, c := range cases {
648 | args, err := ParseArgs(c.cfg, c.args)
649 | assert.Nil(t, err)
650 | assert.Equal(t, 0, len(args))
651 | assert.Equal(t, c.expected, c.cfg.Levels)
652 | }
653 | }
654 |
655 | func TestFlagMakerInvalidSlice(t *testing.T) {
656 | type C struct {
657 | Levels []int
658 | Weights []float64
659 | }
660 | cases := []struct {
661 | cfg *C
662 | args []string
663 | l []int
664 | w []float64
665 | }{
666 | {
667 | // invalid flag values won't modify the struct
668 | &C{Levels: []int{2, 3}, Weights: []float64{2.4, 5.6}},
669 | []string{"--levels", "7ax", "--levels", "10", "--weights", "u8.2"},
670 | []int{2, 3},
671 | []float64{2.4, 5.6},
672 | },
673 | {
674 | // however, valid values before invalid flag values WILL clear slice
675 | &C{Levels: []int{2, 3}, Weights: []float64{2.4, 5.6}},
676 | []string{"--weights", "1.1", "--levels", "10", "--weights", "u8.2", "--levels", "abc"},
677 | []int{10},
678 | []float64{1.1},
679 | },
680 | }
681 |
682 | for _, c := range cases {
683 | _, err := ParseArgs(c.cfg, c.args)
684 | assert.Error(t, err)
685 | assert.Equal(t, c.l, c.cfg.Levels)
686 | assert.Equal(t, c.w, c.cfg.Weights)
687 | }
688 | }
689 |
690 | func TestFlagMakerVarGet(t *testing.T) {
691 | var i8 int8 = 3
692 | var i16 int16 = 4
693 | var i32 int32 = 5
694 | var f32 float32 = 10.9
695 | var u8 uint8 = 22
696 | var u16 uint16 = 30
697 | var u32 uint32 = 55
698 | is := []int{1, 40, 30}
699 | ss := []string{"haha", "xx"}
700 | fs := []float64{242.66, 7565.23, 234.67}
701 | cases := []struct {
702 | getter flag.Getter
703 | expected interface{}
704 | }{
705 | {newInt8Value(&i8), i8},
706 | {newInt16Value(&i16), i16},
707 | {newInt32Value(&i32), i32},
708 | {newFloat32Value(&f32), f32},
709 | {newUint8Value(&u8), u8},
710 | {newUint16Value(&u16), u16},
711 | {newUint32Value(&u32), u32},
712 | {newStringSlice(&ss), ss},
713 | {newIntSlice(&is), is},
714 | {newFloat64Slice(&fs), fs},
715 | }
716 |
717 | for _, c := range cases {
718 | assert.Equal(t, c.expected, c.getter.Get())
719 | }
720 | }
721 |
--------------------------------------------------------------------------------