├── Makefile ├── README.md ├── awss3.go ├── awss3_test.go ├── databases └── schema.sql ├── docker-compose.yml ├── httpreq.go ├── httpreq_test.go ├── limiter.go ├── limiter_test.go ├── main.go ├── mysqlop.go ├── mysqlop_test.go └── testingo_test.go /Makefile: -------------------------------------------------------------------------------- 1 | MYSQLTEST_PORT = 33061 2 | MYSQLTEST_USER = root 3 | MYSQLTEST_PASS = root 4 | 5 | test: 6 | @docker-compose up -d --renew-anon-volumes mysqltest 7 | @docker-compose up waitformysqltest 8 | @docker-compose up testdatafiller 9 | env MYSQLTEST_PORT=$(MYSQLTEST_PORT) MYSQLTEST_USER=$(MYSQLTEST_USER) MYSQLTEST_PASS=$(MYSQLTEST_PASS) go test 10 | @docker-compose down -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Golang 单元测试实践 2 | -------------------- 3 | 4 | 每个严谨的项目都应该有单元测试,发现程序中的问题,保障程序现在和未来的正确性。我们新加入一个项目时,常被要求给现有代码加一些单元测试;自己的代码写到一定程度后,也希望加一些单元测试看看有没有问题。这时往往发现没法在不改动现有代码的情况下添加单元测试,这就引出一个很尴尬的问题~~ 不是所有代码都可以方便测试的~~ 5 | 6 | 比如这个例子: 7 | 8 | ``` 9 | func AddPerson(name string) error { 10 | db, _ := sqlx.Open("mysql", "...dsn...") 11 | _, err := db.Exec("INSERT INTO person (name) VALUES (?)", name) 12 | return err 13 | } 14 | ``` 15 | 16 | 在函数中写死了 MySQL 的连接方式,硬要写单元测试的话,会污染生产环境的数据库。 17 | 18 | 还有其它一些情况,比如从很多外部依赖获取数据并处理,输入和结果过于复杂。 19 | 20 | 一般来说,没法测试的代码都是不太好的代码,它们往往没有合理组织,不灵活,甚至错误百出。直接说明怎样的代码可方便测试有点难,但我们可以通过看看各种情况下怎样合理地测试,反推怎样写出方便测试的代码。 21 | 22 | 本文主要说明 Golang 单元测试用到的工具以及一些方法,包括: 23 | 24 | * 使用 Table Driven 的方式写测试代码 25 | * 使用 testify/assert 简化条件判断 26 | * 使用 testify/mock 隔离第三方依赖或者复杂调用 27 | * mock http request 28 | * stub redis 29 | * stub MySQL 30 | 31 | ### 使用 Table Driven 的方式写测试代码 32 | 33 | 测试一个 routine 分几个步骤:准备数据,调用 routine,判断返回。还要测试不同的情况。如果每种情况都手工写一次代码的话,会很繁琐,使用 Table Driven 的方式能让测试代码看起来简洁易懂不少。 34 | 35 | 比如要测试一个取模运算的 routine: 36 | 37 | ``` 38 | func Mod(a, b int) (r int, err error) { 39 | if b == 0 { 40 | return 0, fmt.Errorf("mod by zero") 41 | } 42 | return a%b, nil 43 | } 44 | ``` 45 | 46 | 可以这样测试: 47 | 48 | ``` 49 | func TestMod(t *testing.T) { 50 | tests := []struct { 51 | a int 52 | b int 53 | r int 54 | hasErr bool 55 | }{ 56 | {a: 42, b: 9, r: 6, hasErr: false}, 57 | {a: -1, b: 9, r: 8, hasErr: false}, 58 | {a: -1, b: -9, r: -1, hasErr: false}, 59 | {a: 42, b: 0, r: 0, hasErr: true}, 60 | } 61 | 62 | for row, test := range tests { 63 | r, err := Mod(test.a, test.b) 64 | if test.hasError { 65 | if err == nil { 66 | t.Errorf("should have error, row: %d", row) 67 | } 68 | continue 69 | } 70 | if err != nil { 71 | t.Errorf("should not have error, row: %d", row) 72 | } 73 | if r != test.r { 74 | t.Errorf("r is expected to be %d but now %d, row: %d", test.r, r, row) 75 | } 76 | } 77 | } 78 | ``` 79 | 80 | 以后有新的边缘情况,也可以很方便地添加到测试用例。 81 | 82 | ### 使用 testify/assert 简化条件判断 83 | 84 | 上面例子中很多 if xxx { t.Errorf(...) } 的代码,复杂,语义不清晰。使用 github.com/stretchr/testify 的 assert 可以简化这些代码。上面的 for 循环可以简化成下面这样: 85 | 86 | ``` 87 | import "github.com/stretchr/testify/assert" 88 | 89 | for row, test := range tests { 90 | r, err := Mod(test.a, test.b) 91 | if test.hasError { 92 | assert.Error(t, err, "row %d", row) 93 | continue 94 | } 95 | assert.NoError(t, err, "row %d", row) 96 | assert.Equal(t, test.r, r, "row %d", row) 97 | } 98 | ``` 99 | 100 | 除了 Equal Error NoError,assert 还提供其它很多意义明确的判断方法,如:NotNil, NotEmpty, HTTPSucess 等。 101 | 102 | ### 使用 testify/mock 隔离第三方依赖或者复杂调用 103 | 104 | 很多时候,测试环境不具备 routine 执行的必要条件。比如查询 consul 里的 KV,即使准备了测试consul,也要先往里面塞测试数据,十分麻烦。又比如查询 AWS S3 的文件列表,每个开发人员一个测试 bucket 太混乱,大家用同一个测试 bucket 更混乱。必须找个方式伪造 consul client 和 AWS S3 client。通过伪造 consul client 查询 KV 的方法,免去连接 consul, 直接返回预设的结果。 105 | 106 | 首先考虑一下怎样伪造 client。假设 client 被定义为 var client *SomeClient。当 SomeClient 是 type SomeClient struct{...} 时,我们永远没法在 test 环境修改 client 的行为。当是 type SomeClient interface{...} 时,我们可以在测试代码中实现一个符合 SomeClient interface 的 struct,用这个 struct 的实例替换原来的 client。 107 | 108 | 假设一个 IP 限流程序从 consul 获取阈值并更新: 109 | 110 | ``` 111 | type SettingGetter interface { 112 | Get(key string) ([]byte, error) 113 | } 114 | 115 | type ConsulKV struct { 116 | kv *consul.KV 117 | } 118 | 119 | func (ck *ConsulKV) Get(key string) (value []byte, err error) { 120 | pair, _, err := ck.kv.Get(key, nil) 121 | if err != nil { 122 | return nil, err 123 | } 124 | return pair.Value, nil 125 | } 126 | 127 | type IPLimit struct { 128 | Threshold int64 129 | SettingGetter SettingGetter 130 | } 131 | 132 | func (il *IPLimit) UpdateThreshold() error { 133 | value, err := il.SettingGetter.Get(KeyIPRateThreshold) 134 | if err != nil { 135 | return err 136 | } 137 | 138 | threshold, err := strconv.Atoi(string(value)) 139 | if err != nil { 140 | return err 141 | } 142 | 143 | il.Threshold = int64(threshold) 144 | return nil 145 | } 146 | ``` 147 | 148 | 因为 consul.KV 是个 struct,没法方便替换,而我们只用到它的 Get 功能,所以简单定义一个 SettingGetter,ConsulKV 实现了这个接口,IPLimit 通过 SettingGetter 获得值,转换并更新。 149 | 150 | 在测试的时候,我们不能使用 ConsulKV,需要伪造一个 SettingGetter,像下面这样: 151 | 152 | ``` 153 | type MockSettingGetter struct {} 154 | 155 | func (m *MockSettingGetter) Get(key string) ([]byte, error) { 156 | if key == "threshold" { 157 | return []byte("100"), nil 158 | } 159 | if key == "nothing" { 160 | return nil, fmt.Errorf("notfound") 161 | } 162 | ... 163 | } 164 | 165 | ipLimit := &IPLimit{SettingGetter: &MockSettingGetter{}} 166 | // ... test with ipLimit 167 | ``` 168 | 169 | 这样的确可以隔离 test 对 consul 的访问,但不方便 Table Driven。可以使用 testfiy/mock 改造一下,变成下面这样子: 170 | 171 | ``` 172 | import "github.com/stretchr/testify/mock" 173 | 174 | type MockSettingGetter struct { 175 | mock.Mock 176 | } 177 | 178 | func (m *MockSettingGetter) Get(key string) (value []byte, err error) { 179 | args := m.Called(key) 180 | return args.Get(0).([]byte), args.Error(1) 181 | } 182 | 183 | func TestUpdateThreshold(t *testing.T) { 184 | tests := []struct { 185 | v string 186 | err error 187 | rs int64 188 | hasErr bool 189 | }{ 190 | {v: "1000", err: nil, rs: 1000, hasErr: false}, 191 | {v: "a", err: nil, rs: 0, hasErr: true}, 192 | {v: "", err: fmt.Errorf("consul is down"), rs: 0, hasErr: true}, 193 | } 194 | 195 | for idx, test := range tests { 196 | mockSettingGetter := new(MockSettingGetter) 197 | mockSettingGetter.On("Get", mock.Anything).Return([]byte(test.v), test.err) 198 | 199 | limiter := &IPLimit{SettingGetter: mockSettingGetter} 200 | err := limiter.UpdateThreshold() 201 | if test.hasErr { 202 | assert.Error(t, err, "row %d", idx) 203 | } else { 204 | assert.NoError(t, err, "row %d", idx) 205 | } 206 | assert.Equal(t, test.rs, limiter.Threshold, "thredshold should equal, row %d", idx) 207 | } 208 | } 209 | ``` 210 | 211 | testfiy/mock 使得伪造对象的输入输出值可以在运行时决定。更多技巧可看 testify/mock 的文档。 212 | 213 | 再说到上面提到的 AWS S3,AWS 的 Go SDK 已经给我们定义好了 API 的 interface,每个服务下都有个 xxxiface 目录,比如 S3 的是 github.com/aws/aws-sdk-go/service/s3/s3iface,如果查看它的源码,会发现它的 API interface 列了一大堆方法,将这几十个方法都伪造一次而实际中只用到一两个显得很蠢。要想没那么蠢,一个方法是将 S3 的 API 像上面那样再封装一下,另一个方法可以像下面这样: 214 | 215 | ``` 216 | import ( 217 | "github.com/aws/aws-sdk-go/service/s3" 218 | "github.com/aws/aws-sdk-go/service/s3/s3iface" 219 | ) 220 | 221 | type MockS3API struct { 222 | s3iface.S3API 223 | mock.Mock 224 | } 225 | 226 | func (m *MockS3API) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { 227 | args := m.Called(input) 228 | return args.Get(0).(*s3.ListObjectsOutput), args.Error(1) 229 | } 230 | ``` 231 | 232 | struct 里内嵌一个匿名 interface,免去定义无关方法的苦恼。 233 | 234 | ### mock http request 235 | 236 | 单元测试中还有个难题是如何伪造 HTTP 请求的结果。如果像上面那样封装一下,可能会漏掉一些极端情况的测试,比如连接网络出错,失败的状态码。Golang 有个 httptest 库,可以在 test 时创建一个 server,让 client 连上 server。这样做会有点绕,事实上 Golang 的 http.Client 有个 Transport 成员,输入输出都通过它,通过篡改 Transport 就可以返回我们需要的数据。 237 | 238 | 以一段获得本机外网 IP 的代码为例: 239 | 240 | ``` 241 | type IPApi struct { 242 | Client *http.Client 243 | } 244 | 245 | // MyIP return public ip address of current machine 246 | func (ia *IPApi) MyIP() (ip string, err error) { 247 | resp, err := ia.Client.Get(MyIPUrl) 248 | if err != nil { 249 | return "", err 250 | } 251 | defer resp.Body.Close() 252 | 253 | body, err := ioutil.ReadAll(resp.Body) 254 | if err != nil { 255 | return "", err 256 | } 257 | 258 | if resp.StatusCode != 200 { 259 | return "", fmt.Errorf("status code: %d", resp.StatusCode) 260 | } 261 | 262 | infos := make(map[string]string) 263 | err = json.Unmarshal(body, &infos) 264 | if err != nil { 265 | return "", err 266 | } 267 | 268 | ip, ok := infos["ip"] 269 | if !ok { 270 | return "", ErrInvalidRespResult 271 | } 272 | return ip, nil 273 | } 274 | ``` 275 | 276 | 可以这样写单元测试: 277 | 278 | ``` 279 | // RoundTripFunc . 280 | type RoundTripFunc func(req *http.Request) *http.Response 281 | 282 | // RoundTrip . 283 | func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { 284 | return f(req), nil 285 | } 286 | 287 | // NewTestClient returns *http.Client with Transport replaced to avoid making real calls 288 | func NewTestClient(fn RoundTripFunc) *http.Client { 289 | return &http.Client{ 290 | Transport: RoundTripFunc(fn), 291 | } 292 | } 293 | 294 | func TestMyIP(t *testing.T) { 295 | tests := []struct { 296 | code int 297 | text string 298 | ip string 299 | hasError bool 300 | }{ 301 | {code: 200, text: "{\"ip\":\"1.2.3.4\"}", ip: "1.2.3.4", hasError: false}, 302 | {code: 403, text: "", ip: "", hasError: true}, 303 | {code: 200, text: "abcd", ip: "", hasError: true}, 304 | } 305 | 306 | for row, test := range tests { 307 | client := NewTestClient(func(req *http.Request) *http.Response { 308 | assert.Equal(t, req.URL.String(), MyIPUrl, "ip url should match, row %d", row) 309 | return &http.Response{ 310 | StatusCode: test.code, 311 | Body: ioutil.NopCloser(bytes.NewBufferString(test.text)), 312 | Header: make(http.Header), 313 | } 314 | }) 315 | api := &IPApi{Client: client} 316 | 317 | ip, err := api.MyIP() 318 | if test.hasError { 319 | assert.Error(t, err, "row %d", row) 320 | } else { 321 | assert.NoError(t, err, "row %d", row) 322 | } 323 | assert.Equal(t, test.ip, ip, "ip should equal, row %d", row) 324 | } 325 | } 326 | ``` 327 | 328 | ### stub redis 329 | 330 | 假如程序里用到 Redis,要伪造一个 Redis Client 用之前的办法也是可以的,但因为有 miniredis 的存在,我们有更好的办法。miniredis 是在 Golang 程序中运行的 Redis Server,它实现了大部分原装 Redis 的功能,测试的时候 miniredis.Run() 然后将 Redis Client 连向 miniredis 就可以了。 331 | 332 | 这种方式称为 stub,和 mock 有一些微妙的差别,可参考 [stackoverflow](https://stackoverflow.com/questions/3459287/whats-the-difference-between-a-mock-stub) 的讨论。 333 | 334 | miniredis 使用方式如下,主要需要考虑保障每个测试都有个干净的 redis 数据库。: 335 | 336 | ``` 337 | var testRdsSrv *miniredis.Miniredis 338 | 339 | func TestMain(m *testing.M) { 340 | s, err := miniredis.Run() 341 | if err != nil { 342 | panic(err) 343 | } 344 | defer s.Close() 345 | os.Exit(m.Run() 346 | } 347 | 348 | func TestSomeRedis(t *testing.T) { 349 | tests := []struct {...}{...} 350 | for row, test := range tests { 351 | testRdsSrv.FlushAll() 352 | rClient := redis.NewClient(&redis.Options{ 353 | Addr: testRdsSrv.Addr(), 354 | }) 355 | // do something with rClient 356 | } 357 | testRdsSrv.FlushAll() 358 | } 359 | ``` 360 | 361 | ### stub MySQL 362 | 363 | 要测试用到关系数据库的代码更加麻烦,因为很多时候看程序正确与否就看它写入到数据库里的数据对不对,关系数据库的操作不能简单 mock 一下,测试的时候需要一个真的数据库。 364 | 365 | MySQL 或者其它关系数据库没有类似 miniredis 的解决方案,我们在测试之前要搭好一个干净的 MySQL 测试 Server,里面的表也要建好。这些条件没法只靠写 Go 代码实现,需要使用一些工具,以及在代码工程里做一点约定。 366 | 367 | 我想到的一个方案是,工程里有个 sql 文件,里面有建库建表语句,编写一个 docker-compose 配置,用于创建 MySQL Server,执行建库建表语句,编写 Makefile 将「启动 MySQL」,「建表」,「go test」,「关闭 MySQL」 组织起来。 368 | 369 | 我试了一下,实现了整个流程后测试挺顺畅的,相关配置代码太多就不在这里贴了,有兴趣可看 [Github testingo](https://github.com/euclidr/testingo) 370 | 371 | 实现过程中主要遇到两个问题,一个是需要确认 MySQL 的 docker 真正正常运行后才能建库建表,一个是考虑修改默认 storage-engine 为 Memory 以加快测试速度。 372 | 373 | ## 参考资料 374 | 375 | 1. [以上所有测试的详细例子](https://github.com/euclidr/testingo) 376 | 2. [testing](https://golang.org/pkg/testing/) 377 | 3. [testify](https://github.com/stretchr/testify) 378 | 4. [Unit Testing http client in Go](http://hassansin.github.io/Unit-Testing-http-client-in-Go) 379 | 5. [Integration Test With Database in Golang](https://hackernoon.com/integration-test-with-database-in-golang-355dc123fdc9) 380 | 6. [miniredis](https://github.com/alicebob/miniredis) -------------------------------------------------------------------------------- /awss3.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // Reference 4 | // https://aws.amazon.com/blogs/developer/mocking-out-then-aws-sdk-for-go-for-unit-testing/ 5 | 6 | import ( 7 | "fmt" 8 | 9 | "github.com/aws/aws-sdk-go/aws" 10 | "github.com/aws/aws-sdk-go/aws/session" 11 | "github.com/aws/aws-sdk-go/service/s3" 12 | "github.com/aws/aws-sdk-go/service/s3/s3iface" 13 | ) 14 | 15 | // ListObjects(*s3.ListObjectsInput) (*s3.ListObjectsOutput, error) 16 | 17 | func ListFileNames(s3api s3iface.S3API, bucket string, prefix string) (names []string, err error) { 18 | input := &s3.ListObjectsInput{ 19 | Bucket: aws.String(bucket), 20 | MaxKeys: aws.Int64(10), 21 | Prefix: aws.String(prefix), 22 | } 23 | output, err := s3api.ListObjects(input) 24 | if err != nil { 25 | return nil, err 26 | } 27 | names = make([]string, len(output.Contents)) 28 | for idx, content := range output.Contents { 29 | names[idx] = *content.Key 30 | } 31 | return names, nil 32 | } 33 | 34 | func ListFileNamesExample() { 35 | sess := session.Must(session.NewSession()) 36 | s3api := s3.New(sess) 37 | names, _ := ListFileNames(s3api, "examplebucket", "/a/b") 38 | fmt.Println(names) 39 | } 40 | -------------------------------------------------------------------------------- /awss3_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/aws/aws-sdk-go/aws" 10 | "github.com/aws/aws-sdk-go/service/s3" 11 | "github.com/aws/aws-sdk-go/service/s3/s3iface" 12 | "github.com/stretchr/testify/mock" 13 | ) 14 | 15 | type MockS3API struct { 16 | s3iface.S3API 17 | mock.Mock 18 | } 19 | 20 | func (m *MockS3API) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { 21 | args := m.Called(input) 22 | return args.Get(0).(*s3.ListObjectsOutput), args.Error(1) 23 | } 24 | 25 | func TestListFileNames(t *testing.T) { 26 | tests := []struct { 27 | output *s3.ListObjectsOutput 28 | err error 29 | hasError bool 30 | count int 31 | first string 32 | }{ 33 | { 34 | output: &s3.ListObjectsOutput{ 35 | Contents: []*s3.Object{ 36 | {Key: aws.String("/a/b/1.txt")}, 37 | {Key: aws.String("/a/b/2.txt")}, 38 | }}, 39 | err: nil, 40 | hasError: false, 41 | count: 2, 42 | first: "/a/b/1.txt", 43 | }, 44 | { 45 | output: nil, 46 | err: fmt.Errorf("bad network"), 47 | hasError: true, 48 | }, 49 | } 50 | 51 | for row, test := range tests { 52 | s3api := new(MockS3API) 53 | s3api.On("ListObjects", mock.Anything).Return(test.output, test.err) 54 | 55 | names, err := ListFileNames(s3api, "anybucket", "anyprefix") 56 | if test.hasError { 57 | assert.Error(t, err, "row: %d", row) 58 | continue 59 | } 60 | assert.NoError(t, err, "row: %d", row) 61 | assert.Equal(t, test.count, len(names), "names count, row: %d", row) 62 | assert.Equal(t, test.first, names[0]) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /databases/schema.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE `testingo`; 2 | 3 | USE `testingo`; 4 | 5 | CREATE TABLE `animal` ( 6 | `id` int(10) unsigned NOT NULL AUTO_INCREMENT, 7 | `name` varchar(100) CHARACTER SET utf8 NOT NULL, 8 | `place` varchar(100) CHARACTER SET utf8 NOT NULL, 9 | PRIMARY KEY (`id`) 10 | ); 11 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.1' 2 | 3 | services: 4 | mysqltest: 5 | image: mysql:5.7 6 | command: 7 | - --default-authentication-plugin=mysql_native_password 8 | - --default-storage-engine=Memory 9 | environment: 10 | MYSQL_ROOT_PASSWORD: root 11 | ports: 12 | - 33061:3306 13 | 14 | waitformysqltest: 15 | image: mysql:5.7 16 | command: > 17 | /bin/bash -c "maxcounter=45; 18 | counter=1; 19 | while ! mysql --protocol TCP -hdbhost -uroot -proot -e 'show databases;' > /dev/null 2>&1; do 20 | sleep 1 21 | counter=`expr $${counter} + 1` 22 | if [ $${counter} -gt $${maxcounter} ]; then 23 | echo 'We have been waiting for MySQL too long already; failing.' 24 | exit 1 25 | fi; 26 | done" 27 | links: 28 | - mysqltest:dbhost 29 | 30 | testdatafiller: 31 | image: mysql:5.7 32 | depends_on: 33 | - mysqltest 34 | command: /bin/bash -c "mysql -hdbhost -uroot -proot