├── go.mod ├── .editorconfig ├── .gitignore ├── rest ├── server_test.go ├── request_test.go ├── response_test.go ├── server.go ├── rest.go ├── rest_test.go ├── response.go └── request.go ├── .github └── workflows │ └── go.yml ├── file_test.go ├── LICENSE ├── number_test.go ├── assert_test.go ├── errors_test.go ├── README.md ├── assert.go ├── file.go ├── assertion_test.go ├── errors.go ├── number.go ├── util_test.go ├── assertion.go └── util.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/issue9/assert/v4 2 | 3 | go 1.17 4 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: http://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | charset = utf-8 11 | 12 | # html 13 | [*.{htm,html,js,css}] 14 | indent_style = space 15 | indent_size = 4 16 | 17 | # 配置文件 18 | [*.{yml,yaml,json}] 19 | indent_style = space 20 | indent_size = 2 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | # osx 27 | .DS_Store 28 | 29 | # ide 30 | .vscode 31 | .idea 32 | .vs 33 | *.swp 34 | -------------------------------------------------------------------------------- /rest/server_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "net/http" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | func TestNew(t *testing.T) { 15 | a := assert.New(t, false) 16 | 17 | srv := NewTLSServer(a, nil, nil) 18 | a.NotNil(srv) 19 | a.Equal(srv.client, &http.Client{}) 20 | a.True(len(srv.server.URL) > 0) 21 | 22 | client := &http.Client{} 23 | srv = NewServer(a, nil, client) 24 | a.NotNil(srv) 25 | a.Equal(srv.client, client) 26 | 27 | srv.Close() 28 | a.True(srv.closed) 29 | srv.Close() 30 | } 31 | -------------------------------------------------------------------------------- /rest/request_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "net/http" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | func TestRequest_buildPath(t *testing.T) { 15 | srv := NewServer(assert.New(t, false), h, nil) 16 | a := srv.Assertion() 17 | a.NotNil(srv) 18 | 19 | req := srv.NewRequest(http.MethodGet, "/get") 20 | a.NotNil(req) 21 | a.Equal(req.buildPath(), srv.URL()+"/get") 22 | 23 | req.Param("id", "1").Query("page", "5") 24 | a.Equal(req.buildPath(), srv.URL()+"/get?page=5") 25 | 26 | req = srv.NewRequest(http.MethodGet, "/users/{id}/orders/{oid}") 27 | a.NotNil(req) 28 | a.Equal(req.buildPath(), srv.URL()+"/users/{id}/orders/{oid}") 29 | req.Param("id", "1").Param("oid", "2").Query("page", "5") 30 | a.Equal(req.buildPath(), srv.URL()+"/users/1/orders/2?page=5") 31 | } 32 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | 6 | test: 7 | name: Test 8 | runs-on: ${{ matrix.os }} 9 | 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, macOS-latest, windows-latest] 13 | go: ['1.17.x', '1.25.x'] 14 | 15 | steps: 16 | 17 | - name: Check out code into the Go module directory 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Go ${{ matrix.go }} 21 | uses: actions/setup-go@v5 22 | with: 23 | go-version: ${{ matrix.go }} 24 | id: go 25 | 26 | - name: Vet 27 | run: go vet -v ./... 28 | 29 | - name: Test 30 | run: go test -v -coverprofile='coverage.txt' -covermode=atomic ./... 31 | 32 | - name: Upload Coverage report 33 | uses: codecov/codecov-action@v5 34 | with: 35 | token: ${{secrets.CODECOV_TOKEN}} 36 | file: ./coverage.txt 37 | -------------------------------------------------------------------------------- /file_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "os" 9 | "testing" 10 | ) 11 | 12 | func TestAssertion_FileExists_FileNotExists(t *testing.T) { 13 | a := New(t, false) 14 | 15 | a.FileExists("./assert.go", "a.FileExists(./assert.go) failed"). 16 | FileNotExists("c:/win", "a.FileNotExists(c:/win) failed") 17 | 18 | fsys := os.DirFS("./") 19 | a.FileExistsFS(fsys, "assert.go", "a.FileExistsFS(./assert) failed"). 20 | FileNotExistsFS(fsys, "win", "a.FileNotExistsFS(c:/win) failed") 21 | } 22 | 23 | func TestAssertion_IsDir_IsNotDir(t *testing.T) { 24 | a := New(t, false) 25 | 26 | a.IsDir("./rest", "a.IsDir(./rest) failed"). 27 | IsNotDir("./assert.go", "a.IsNotDir(./assert.go) failed") 28 | 29 | fsys := os.DirFS("./") 30 | a.IsDirFS(fsys, "rest", "a.IsDirFS(./rest) failed"). 31 | IsNotDirFS(fsys, "./assert.go", "a.IsNotDirFS(./assert.go) failed") 32 | } 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 caixw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /number_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import "testing" 8 | 9 | func TestAssertion_Length_NotLength(t *testing.T) { 10 | a := New(t, false) 11 | 12 | a.Length(nil, 0) 13 | a.Length([]int{1, 2}, 2) 14 | a.Length([3]int{1, 2, 3}, 3) 15 | a.NotLength([3]int{1, 2, 3}, 2) 16 | a.Length(map[string]string{"1": "1", "2": "2"}, 2) 17 | a.NotLength(map[string]string{"1": "1", "2": "2"}, 3) 18 | slices := []rune{'a', 'b', 'c'} 19 | ps := &slices 20 | pps := &ps 21 | a.Length(pps, 3) 22 | a.NotLength(pps, 2) 23 | a.Length("string", 6) 24 | a.NotLength("string", 4) 25 | } 26 | 27 | func TestAssertion_Greater_Less(t *testing.T) { 28 | a := New(t, false) 29 | 30 | a.Greater(uint16(5), 3).Less(uint8(5), 6).GreaterEqual(uint64(5), 5).LessEqual(uint(5), 5) 31 | } 32 | 33 | func TestAssertion_Positive_Negative(t *testing.T) { 34 | a := New(t, false) 35 | 36 | a.Positive(float32(5)).Negative(float64(-5)) 37 | } 38 | 39 | func TestAssertion_Between(t *testing.T) { 40 | a := New(t, false) 41 | 42 | a.Between(int8(5), 1, 6). 43 | BetweenEqual(int16(5), 5, 6). 44 | BetweenEqual(int32(6), 5, 6). 45 | BetweenEqualMin(int64(5), 5, 6). 46 | BetweenEqualMax(uint32(5), 4, 5) 47 | } 48 | -------------------------------------------------------------------------------- /assert_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import "testing" 8 | 9 | func TestDefaultFailureSprint(t *testing.T) { 10 | f := NewFailure("A", nil, nil) 11 | if f.Action != "A" || f.User() != "" || len(f.Values) != 0 { 12 | t.Error("err1") 13 | } 14 | if s := DefaultFailureSprint(f); s != "A 断言失败!" { 15 | t.Error("err2") 16 | } 17 | 18 | // 带 user 19 | f = NewFailure("AB", []interface{}{1, 2}, nil) 20 | if f.Action != "AB" || f.User() != "1 2" || len(f.Values) != 0 { 21 | t.Error("err3") 22 | } 23 | if s := DefaultFailureSprint(f); s != "AB 断言失败!用户反馈信息:1 2" { 24 | t.Error("err4", s) 25 | } 26 | 27 | // 带 values 28 | f = NewFailure("AB", nil, map[string]interface{}{"k1": "v1", "k2": 2}) 29 | if f.Action != "AB" || f.User() != "" || len(f.Values) != 2 { 30 | t.Error("err5") 31 | } 32 | if s := DefaultFailureSprint(f); s != "AB 断言失败!反馈以下参数:\nk1=v1\nk2=2\n" { 33 | t.Error("err6", s) 34 | } 35 | 36 | // 带 user,values 37 | f = NewFailure("AB", []interface{}{1, 2}, map[string]interface{}{"k1": "v1", "k2": 2}) 38 | if f.Action != "AB" || f.User() == "" || len(f.Values) != 2 { 39 | t.Error("err7") 40 | } 41 | if s := DefaultFailureSprint(f); s != "AB 断言失败!反馈以下参数:\nk1=v1\nk2=2\n用户反馈信息:1 2" { 42 | t.Error("err8", s) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /rest/response_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "net/http" 9 | "testing" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | func TestRequest_Do(t *testing.T) { 15 | a := assert.New(t, false) 16 | srv := NewServer(a, h, nil) 17 | 18 | srv.Get("/get"). 19 | Do(nil). 20 | Success(). 21 | Status(201) 22 | 23 | srv.NewRequest(http.MethodGet, "/not-exists"). 24 | Do(nil). 25 | Fail() 26 | 27 | srv.NewRequest(http.MethodGet, "/get"). 28 | Do(BuildHandler(a, 202, "", nil)). 29 | Status(202) 30 | 31 | r := Get(a, "/get") 32 | r.Do(BuildHandler(a, 202, "", nil)).Status(202) 33 | r.Do(BuildHandler(a, 203, "", nil)).Status(203) 34 | a.Panic(func() { 35 | r.Do(nil) 36 | }) 37 | } 38 | 39 | func TestResponse(t *testing.T) { 40 | srv := NewServer(assert.New(t, false), h, nil) 41 | 42 | srv.NewRequest(http.MethodGet, "/body"). 43 | Header("content-type", "application/json"). 44 | Query("page", "5"). 45 | StringBody(`{"id":5}`). 46 | Do(nil). 47 | Status(http.StatusCreated). 48 | NotStatus(http.StatusNotFound). 49 | Header("content-type", "application/json;charset=utf-8"). 50 | NotHeader("content-type", "invalid value"). 51 | Body([]byte(`{"id":6}`)). 52 | StringBody(`{"id":6}`). 53 | BodyNotEmpty() 54 | 55 | srv.NewRequest(http.MethodGet, "/get"). 56 | Query("page", "5"). 57 | Do(nil). 58 | Status(http.StatusCreated). 59 | NotHeader("content-type", "invalid value"). 60 | BodyEmpty() 61 | } 62 | -------------------------------------------------------------------------------- /rest/server.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | 11 | "github.com/issue9/assert/v4" 12 | ) 13 | 14 | // Server 测试服务 15 | type Server struct { 16 | a *assert.Assertion 17 | server *httptest.Server 18 | client *http.Client 19 | closed bool 20 | } 21 | 22 | // NewServer 声明新的测试服务 23 | // 24 | // 如果 client 为 nil,则会采用 &http.Client{} 作为默认值 25 | func NewServer(a *assert.Assertion, h http.Handler, client *http.Client) *Server { 26 | return newServer(a, httptest.NewServer(h), client) 27 | } 28 | 29 | // NewTLSServer 声明新的测试服务 30 | // 31 | // 如果 client 为 nil,则会采用 &http.Client{} 作为默认值 32 | func NewTLSServer(a *assert.Assertion, h http.Handler, client *http.Client) *Server { 33 | return newServer(a, httptest.NewTLSServer(h), client) 34 | } 35 | 36 | func newServer(a *assert.Assertion, srv *httptest.Server, client *http.Client) *Server { 37 | if client == nil { 38 | client = &http.Client{} 39 | } 40 | 41 | s := &Server{ 42 | a: a, 43 | server: srv, 44 | client: client, 45 | } 46 | 47 | a.TB().Cleanup(func() { 48 | s.Close() 49 | }) 50 | 51 | return s 52 | } 53 | 54 | func (srv *Server) URL() string { return srv.server.URL } 55 | 56 | func (srv *Server) Assertion() *assert.Assertion { return srv.a } 57 | 58 | // Close 关闭服务 59 | // 60 | // 如果未手动调用,则在 testing.TB.Cleanup 中自动调用。 61 | func (srv *Server) Close() { 62 | if srv.closed { 63 | return 64 | } 65 | 66 | srv.server.Close() 67 | srv.closed = true 68 | } 69 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "testing" 11 | ) 12 | 13 | func TestAssertion_Error(t *testing.T) { 14 | a := New(t, false) 15 | 16 | err := errors.New("test") 17 | a.Error(err, "a.Error(err) failed") 18 | a.ErrorString(err, "test", "ErrorString(err) failed") 19 | 20 | err2 := &errorImpl{msg: "msg"} 21 | a.Error(err2, "ErrorString(errorImpl) failed") 22 | a.ErrorString(err2, "msg", "ErrorString(errorImpl) failed") 23 | 24 | var err3 error 25 | a.NotError(err3, "var err1 error failed") 26 | 27 | err4 := errors.New("err4") 28 | err5 := fmt.Errorf("err5 with %w", err4) 29 | a.ErrorIs(err5, err4) 30 | } 31 | 32 | func TestAssertion_Panic(t *testing.T) { 33 | a := New(t, false) 34 | 35 | f1 := func() { 36 | panic("panic message") 37 | } 38 | 39 | a.Panic(f1) 40 | a.PanicString(f1, "message") 41 | a.PanicType(f1, "abc") 42 | a.PanicValue(f1, "panic message") 43 | 44 | f1 = func() { 45 | panic(errors.New("panic")) 46 | } 47 | a.PanicType(f1, errors.New("abc")) 48 | 49 | f1 = func() { 50 | panic(&errorImpl{msg: "panic"}) 51 | } 52 | a.PanicType(f1, &errorImpl{msg: "abc"}) 53 | 54 | f1 = func() {} 55 | a.NotPanic(f1) 56 | } 57 | 58 | func TestHasPanic(t *testing.T) { 59 | f1 := func() { 60 | panic("panic") 61 | } 62 | 63 | if has, _ := hasPanic(f1); !has { 64 | t.Error("f1未发生panic") 65 | } 66 | 67 | f2 := func() { 68 | f1() 69 | } 70 | 71 | if has, msg := hasPanic(f2); !has { 72 | t.Error("f2未发生panic") 73 | } else if msg != "panic" { 74 | t.Errorf("f2发生了panic,但返回信息不正确,应为[panic],但其实返回了%v", msg) 75 | } 76 | 77 | f3 := func() { 78 | defer func() { 79 | if msg := recover(); msg != nil { 80 | t.Logf("TestHasPanic.f3 recover msg:[%v]", msg) 81 | } 82 | }() 83 | 84 | f1() 85 | } 86 | 87 | if has, msg := hasPanic(f3); has { 88 | t.Errorf("f3发生了panic,其信息为:[%v]", msg) 89 | } 90 | 91 | f4 := func() { 92 | //todo 93 | } 94 | 95 | if has, msg := hasPanic(f4); has { 96 | t.Errorf("f4发生panic,其信息为[%v]", msg) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | assert 2 | ====== 3 | 4 | [![Go](https://github.com/issue9/assert/workflows/Go/badge.svg)](https://github.com/issue9/assert/actions?query=workflow%3AGo) 5 | [![codecov](https://codecov.io/gh/issue9/assert/branch/master/graph/badge.svg)](https://codecov.io/gh/issue9/assert) 6 | [![license](https://img.shields.io/badge/license-MIT-brightgreen.svg?style=flat)](https://opensource.org/licenses/MIT) 7 | [![PkgGoDev](https://pkg.go.dev/badge/github.com/issue9/assert)](https://pkg.go.dev/github.com/issue9/assert/v4) 8 | [![Go version](https://img.shields.io/github/go-mod/go-version/issue9/assert)](https://golang.org) 9 | 10 | assert 包是对 testing 的一个简单扩展,提供的一系列的断言函数, 11 | 方便在测试函数中使用: 12 | 13 | ```go 14 | func TestA(t *testing.T) { 15 | v := true 16 | a := assert.New(t, false) 17 | a.True(v) 18 | } 19 | 20 | // 也可以对 testing.B 使用 21 | func Benchmark1(b *testing.B) { 22 | a := assert.New(b, false) 23 | v := false 24 | a.True(v) 25 | for(i:=0; i 0 { 65 | keys := make([]string, 0, len(f.Values)) 66 | for k := range f.Values { 67 | keys = append(keys, k) 68 | } 69 | sort.Strings(keys) // TODO(go1.21): slices.Sort 70 | 71 | s.WriteString("反馈以下参数:\n") 72 | for _, k := range keys { 73 | s.WriteString(k) 74 | s.WriteByte('=') 75 | s.WriteString(fmt.Sprint(f.Values[k])) 76 | s.WriteByte('\n') 77 | } 78 | } 79 | 80 | if u := f.User(); u != "" { 81 | s.WriteString("用户反馈信息:") 82 | s.WriteString(u) 83 | } 84 | 85 | return s.String() 86 | } 87 | 88 | // NewFailure 声明 [Failure] 对象 89 | // 90 | // user 表示用户提交的反馈,其第一个元素如果是 string,那么将调用 fmt.Sprintf(user[0], user[1:]...) 91 | // 对数据进行格式化,否则采用 fmt.Sprint(user...) 格式化数据; 92 | // kv 表示当前错误返回的数据; 93 | func NewFailure(action string, user []interface{}, kv map[string]interface{}) *Failure { 94 | f := failurePool.Get().(*Failure) 95 | f.Action = action 96 | f.user = user 97 | f.Values = kv 98 | return f 99 | } 100 | 101 | // User 返回用户提交的返馈信息 102 | func (f *Failure) User() string { 103 | // NOTE: 通过函数的方式返回字符串,而不是直接在 [NewFailure] 直接处理完,可以确保在未使用的情况下无需初始化。 104 | 105 | if len(f.user) == 0 { 106 | return "" 107 | } 108 | 109 | switch v := f.user[0].(type) { 110 | case string: 111 | return fmt.Sprintf(v, f.user[1:]...) 112 | default: 113 | return fmt.Sprint(f.user...) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /file.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "errors" 9 | "io/fs" 10 | "os" 11 | ) 12 | 13 | func (a *Assertion) FileExists(path string, msg ...interface{}) *Assertion { 14 | a.TB().Helper() 15 | 16 | if _, err := os.Stat(path); err != nil && !errors.Is(err, fs.ErrExist) { 17 | return a.Assert(false, NewFailure("FileExists", msg, map[string]interface{}{"err": err})) 18 | } 19 | return a 20 | } 21 | 22 | func (a *Assertion) FileNotExists(path string, msg ...interface{}) *Assertion { 23 | a.TB().Helper() 24 | 25 | _, err := os.Stat(path) 26 | if err == nil { 27 | return a.Assert(false, NewFailure("FileNotExists", msg, nil)) 28 | } 29 | if errors.Is(err, fs.ErrExist) { 30 | return a.Assert(false, NewFailure("FileNotExists", msg, map[string]interface{}{"err": err})) 31 | } 32 | 33 | return a 34 | } 35 | 36 | func (a *Assertion) FileExistsFS(fsys fs.FS, path string, msg ...interface{}) *Assertion { 37 | a.TB().Helper() 38 | 39 | if _, err := fs.Stat(fsys, path); err != nil && !errors.Is(err, fs.ErrExist) { 40 | return a.Assert(false, NewFailure("FileExistsFS", msg, map[string]interface{}{"err": err})) 41 | } 42 | 43 | return a 44 | } 45 | 46 | func (a *Assertion) FileNotExistsFS(fsys fs.FS, path string, msg ...interface{}) *Assertion { 47 | a.TB().Helper() 48 | 49 | _, err := fs.Stat(fsys, path) 50 | if err == nil { 51 | return a.Assert(false, NewFailure("FileNotExistsFS", msg, nil)) 52 | } 53 | if errors.Is(err, fs.ErrExist) { 54 | return a.Assert(false, NewFailure("FileNotExistsFS", msg, map[string]interface{}{"err": err})) 55 | } 56 | 57 | return a 58 | } 59 | 60 | // IsDir 断言 path 是个目录 61 | func (a *Assertion) IsDir(path string, msg ...interface{}) *Assertion { 62 | a.TB().Helper() 63 | 64 | s, err := os.Stat(path) 65 | if err != nil { 66 | return a.Assert(false, NewFailure("IsDir", msg, map[string]interface{}{"err": err})) 67 | } 68 | return a.Assert(s.IsDir(), NewFailure("IsDir", msg, nil)) 69 | } 70 | 71 | func (a *Assertion) IsDirFS(fsys fs.FS, path string, msg ...interface{}) *Assertion { 72 | a.TB().Helper() 73 | 74 | s, err := fs.Stat(fsys, path) 75 | if err != nil { 76 | return a.Assert(false, NewFailure("IsDirFS", msg, map[string]interface{}{"err": err})) 77 | } 78 | return a.Assert(s.IsDir(), NewFailure("IsDirFS", msg, nil)) 79 | } 80 | 81 | // IsNotDir 断言 path 不存在或是非目录 82 | func (a *Assertion) IsNotDir(path string, msg ...interface{}) *Assertion { 83 | a.TB().Helper() 84 | 85 | s, err := os.Stat(path) 86 | if err != nil { 87 | return a.Assert(false, NewFailure("IsNotDir", msg, map[string]interface{}{"err": err})) 88 | } 89 | return a.Assert(!s.IsDir(), NewFailure("IsNotDir", msg, nil)) 90 | } 91 | 92 | func (a *Assertion) IsNotDirFS(fsys fs.FS, path string, msg ...interface{}) *Assertion { 93 | a.TB().Helper() 94 | 95 | s, err := os.Stat(path) 96 | if err != nil { 97 | return a.Assert(false, NewFailure("IsNotDirFS", msg, map[string]interface{}{"err": err})) 98 | } 99 | return a.Assert(!s.IsDir(), NewFailure("IsNotDirFS", msg, nil)) 100 | } 101 | -------------------------------------------------------------------------------- /assertion_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "database/sql" 9 | "regexp" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | type errorImpl struct { 15 | msg string 16 | } 17 | 18 | func (err *errorImpl) Error() string { 19 | return err.msg 20 | } 21 | 22 | func TestAssertion_True_False(t *testing.T) { 23 | a := New(t, true) 24 | 25 | if t != a.TB() { 26 | t.Error("a.T与t不相等") 27 | } 28 | 29 | a.True(true) 30 | a.True(true, "a.True(5==5 failed") 31 | 32 | a.False(false, "a.False(false) failed") 33 | a.False(false, "a.False(4==5) failed") 34 | } 35 | 36 | func TestAssertion_Equal_NotEqual_Nil_NotNil(t *testing.T) { 37 | a := New(t, false) 38 | 39 | v1 := 4 40 | v2 := 4 41 | v3 := 5 42 | v4 := "5" 43 | 44 | a.Equal(4, 4, "a.Equal(4,4) failed") 45 | a.Equal(v1, v2, "a.Equal(v1,v2) failed") 46 | 47 | a.NotEqual(4, 5, "a.NotEqual(4,5) failed"). 48 | NotEqual(v1, v3, "a.NotEqual(v1,v3) failed"). 49 | NotEqual(v3, v4, "a.NotEqual(v3,v4) failed") 50 | 51 | var v5 interface{} 52 | v6 := 0 53 | v7 := []int{} 54 | 55 | a.Empty(v5, "a.Empty failed"). 56 | Empty(v6, "a.Empty(0) failed"). 57 | Empty(v7, "a.Empty(v7) failed") 58 | 59 | a.NotEmpty(1, "a.NotEmpty(1) failed") 60 | 61 | a.Nil(v5) 62 | 63 | a.NotNil(v7, "a.Nil(v7) failed"). 64 | NotNil(v6, "a.NotNil(v6) failed") 65 | } 66 | 67 | func TestAssertion_Zero_NotZero(t *testing.T) { 68 | a := New(t, false) 69 | 70 | var v interface{} 71 | a.Zero(0) 72 | a.Zero(nil) 73 | a.Zero(time.Time{}) 74 | a.Zero(v) 75 | a.Zero([2]int{0, 0}) 76 | a.Zero([0]int{}) 77 | a.Zero(&time.Time{}) 78 | a.Zero(sql.NullTime{}) 79 | 80 | a.NotZero([]int{0, 0}) 81 | a.NotZero([]int{}) 82 | } 83 | 84 | func TestAssertion_Contains(t *testing.T) { 85 | a := New(t, false) 86 | 87 | a.Contains([]int{1, 2, 3}, []int8{1, 2}). 88 | NotContains([]int{1, 2, 3}, []int8{1, 3}) 89 | } 90 | 91 | func TestAssertion_TypeEqual(t *testing.T) { 92 | a := New(t, true) 93 | 94 | a.TypeEqual(false, 1, 2) 95 | a.TypeEqual(false, 1, 1) 96 | a.TypeEqual(false, 1.0, 2.0) 97 | 98 | v1 := 5 99 | pv1 := &v1 100 | a.TypeEqual(false, 1, v1) 101 | a.TypeEqual(true, 1, &pv1) 102 | 103 | v2 := &errorImpl{} 104 | v3 := errorImpl{} 105 | a.TypeEqual(false, v2, v2) 106 | a.TypeEqual(true, v2, v3) 107 | a.TypeEqual(true, v2, &v3) 108 | a.TypeEqual(true, &v2, &v3) 109 | } 110 | 111 | func TestAssertion_Same(t *testing.T) { 112 | a := New(t, false) 113 | 114 | a.NotSame(5, 5). 115 | NotSame(struct{}{}, struct{}{}). 116 | NotSame(func() {}, func() {}) 117 | 118 | i := 5 119 | a.NotSame(i, i) 120 | 121 | empty := struct{}{} 122 | empty2 := empty 123 | a.NotSame(empty, empty) 124 | a.NotSame(empty, empty2) 125 | a.Same(&empty, &empty) 126 | a.Same(&empty, &empty2) 127 | 128 | f := func() {} 129 | f2 := f 130 | a.Same(f, f) 131 | a.Same(f, f2) 132 | 133 | a.NotSame(5, 5) 134 | a.NotSame(f, 5) 135 | } 136 | 137 | func TestAssertion_Match(t *testing.T) { 138 | a := New(t, false) 139 | 140 | a.Match(regexp.MustCompile("^[1-9]*$"), "123") 141 | a.NotMatch(regexp.MustCompile("^[1-9]*$"), "x123") 142 | 143 | a.Match(regexp.MustCompile("^[1-9]*$"), []byte("123")) 144 | a.NotMatch(regexp.MustCompile("^[1-9]*$"), []byte("x123")) 145 | } 146 | 147 | func TestAssert_When(t *testing.T) { 148 | a := New(t, false) 149 | 150 | a.When(true, func(a *Assertion) { 151 | a.True(true) 152 | }) 153 | } 154 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "strings" 11 | ) 12 | 13 | // Error 断言有错误发生 14 | // 15 | // 传递未初始化的 error 值(var err error = nil),将断言失败 16 | // 17 | // [Assertion.NotNil] 的特化版本,限定了类型为 error。 18 | func (a *Assertion) Error(expr error, msg ...interface{}) *Assertion { 19 | a.TB().Helper() 20 | return a.Assert(!isNil(expr), NewFailure("Error", msg, map[string]interface{}{"v": expr})) 21 | } 22 | 23 | // ErrorString 断言有错误发生且错误信息中包含指定的字符串 str 24 | // 25 | // 传递未初始化的 error 值(var err error = nil),将断言失败 26 | func (a *Assertion) ErrorString(expr error, str string, msg ...interface{}) *Assertion { 27 | a.TB().Helper() 28 | 29 | if isNil(expr) { // 空值,必定没有错误 30 | return a.Assert(false, NewFailure("ErrorString", msg, map[string]interface{}{"v": expr})) 31 | } 32 | return a.Assert(strings.Contains(expr.Error(), str), NewFailure("ErrorString", msg, map[string]interface{}{"v": expr})) 33 | } 34 | 35 | // ErrorIs 断言 expr 为 target 类型 36 | // 37 | // 相当于 a.True(errors.Is(expr, target)) 38 | func (a *Assertion) ErrorIs(expr, target error, msg ...interface{}) *Assertion { 39 | a.TB().Helper() 40 | return a.Assert(errors.Is(expr, target), NewFailure("ErrorIs", msg, map[string]interface{}{"err": expr})) 41 | } 42 | 43 | // NotError 断言没有错误 44 | // 45 | // [Assertion.Nil] 的特化版本,限定了类型为 error。 46 | func (a *Assertion) NotError(expr error, msg ...interface{}) *Assertion { 47 | a.TB().Helper() 48 | return a.Assert(isNil(expr), NewFailure("NotError", msg, map[string]interface{}{"v": expr})) 49 | } 50 | 51 | // Panic 断言函数会发生 panic 52 | func (a *Assertion) Panic(fn func(), msg ...interface{}) *Assertion { 53 | a.TB().Helper() 54 | has, _ := hasPanic(fn) 55 | return a.Assert(has, NewFailure("Panic", msg, nil)) 56 | } 57 | 58 | // PanicString 断言函数会发生 panic 且 panic 信息中包含指定的字符串内容 59 | func (a *Assertion) PanicString(fn func(), str string, msg ...interface{}) *Assertion { 60 | a.TB().Helper() 61 | 62 | if has, m := hasPanic(fn); has { 63 | return a.Assert(strings.Contains(fmt.Sprint(m), str), NewFailure("PanicString", msg, map[string]interface{}{"msg": m})) 64 | } 65 | return a.Assert(false, NewFailure("PanicString", msg, nil)) 66 | } 67 | 68 | // PanicType 断言函数会发生 panic 且抛出指定的类型 69 | func (a *Assertion) PanicType(fn func(), typ interface{}, msg ...interface{}) *Assertion { 70 | a.TB().Helper() 71 | 72 | if has, m := hasPanic(fn); has { 73 | t1, t2 := getType(true, m, typ) 74 | return a.Assert(t1 == t2, NewFailure("PanicType", msg, map[string]interface{}{"v1": t1, "v2": t2})) 75 | } 76 | return a.Assert(false, NewFailure("PanicType", msg, nil)) 77 | } 78 | 79 | // PanicValue 断言函数会抛出与 v 相同的信息 80 | func (a *Assertion) PanicValue(fn func(), v interface{}, msg ...interface{}) *Assertion { 81 | a.TB().Helper() 82 | 83 | if has, m := hasPanic(fn); has { 84 | return a.Assert(isEqual(m, v), NewFailure("PanicValue", msg, map[string]interface{}{"v": m})) 85 | } 86 | return a.Assert(false, NewFailure("PanicType", msg, nil)) 87 | } 88 | 89 | // NotPanic 断言 fn 不会 panic 90 | func (a *Assertion) NotPanic(fn func(), msg ...interface{}) *Assertion { 91 | a.TB().Helper() 92 | has, m := hasPanic(fn) 93 | return a.Assert(!has, NewFailure("NotPanic", msg, map[string]interface{}{"err": m})) 94 | } 95 | 96 | // hasPanic 判断 fn 函数是否会发生 panic 97 | // 若发生了 panic,将把 msg 一起返回。 98 | func hasPanic(fn func()) (has bool, msg interface{}) { 99 | defer func() { 100 | if msg = recover(); msg != nil { 101 | has = true 102 | } 103 | }() 104 | fn() 105 | 106 | return 107 | } 108 | -------------------------------------------------------------------------------- /rest/rest.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | // Package rest 简单的 API 测试库 6 | package rest 7 | 8 | import ( 9 | "bufio" 10 | "bytes" 11 | "io" 12 | "net/http" 13 | "net/http/httptest" 14 | 15 | "github.com/issue9/assert/v4" 16 | ) 17 | 18 | // BuildHandler 生成用于测试的 [http.Handler] 对象 19 | // 20 | // 仅是简单地按以下步骤输出内容: 21 | // - 输出状态码 code; 22 | // - 输出报头 headers,以 Add 方式,而不是 set,不会覆盖原来的数据; 23 | // - 输出 body,如果为空字符串,则不会输出; 24 | func BuildHandler(a *assert.Assertion, code int, body string, headers map[string]string) http.Handler { 25 | return http.HandlerFunc(BuildHandlerFunc(a, code, body, headers)) 26 | } 27 | 28 | func BuildHandlerFunc(a *assert.Assertion, code int, body string, headers map[string]string) func(http.ResponseWriter, *http.Request) { 29 | return func(w http.ResponseWriter, _ *http.Request) { 30 | a.TB().Helper() 31 | 32 | for k, v := range headers { 33 | w.Header().Add(k, v) 34 | } 35 | w.WriteHeader(code) 36 | 37 | if body != "" { 38 | _, err := w.Write([]byte(body)) 39 | a.NotError(err) 40 | } 41 | } 42 | } 43 | 44 | func (srv *Server) RawHTTP(req, resp string) *Server { 45 | srv.Assertion().TB().Helper() 46 | RawHTTP(srv.Assertion(), srv.client, req, resp) 47 | return srv 48 | } 49 | 50 | // RawHTTP 通过原始数据进行比较请求和返回数据是符合要求 51 | // 52 | // reqRaw 表示原始的请求数据。其格式如下: 53 | // 54 | // POST https://example.com/path HTTP/1.1 55 | // 56 | // text 57 | // 58 | // 会忽略 HOST 报头,而是应该将主机部分直接写在请求地址中。 59 | // 60 | // respRaw 表示返回之后的原始数据; 61 | // 62 | // NOTE: 仅判断状态码、报头和实际内容是否相同,而不是直接比较两个 http.Response 的值。 63 | func RawHTTP(a *assert.Assertion, client *http.Client, reqRaw, respRaw string) { 64 | if client == nil { 65 | client = &http.Client{} 66 | } 67 | a.TB().Helper() 68 | 69 | r, resp := readRaw(a, reqRaw, respRaw) 70 | if r == nil { 71 | return 72 | } 73 | 74 | ret, err := client.Do(r) 75 | a.NotError(err).NotNil(ret) 76 | 77 | compare(a, resp, ret.StatusCode, ret.Header, ret.Body) 78 | a.NotError(ret.Body.Close()) 79 | } 80 | 81 | // RawHandler 通过原始数据进行比较请求和返回数据是符合要求 82 | // 83 | // 功能上与 RawHTTP 相似,处理方式从 http.Client 变成了 http.Handler。 84 | func RawHandler(a *assert.Assertion, h http.Handler, reqRaw, respRaw string) { 85 | if h == nil { 86 | panic("h 不能为空") 87 | } 88 | a.TB().Helper() 89 | 90 | r, resp := readRaw(a, reqRaw, respRaw) 91 | if r == nil { 92 | return 93 | } 94 | 95 | ret := httptest.NewRecorder() 96 | h.ServeHTTP(ret, r) 97 | 98 | compare(a, resp, ret.Code, ret.Header(), ret.Body) 99 | } 100 | 101 | func readRaw(a *assert.Assertion, reqRaw, respRaw string) (*http.Request, *http.Response) { 102 | a.TB().Helper() 103 | 104 | resp, err := http.ReadResponse(bufio.NewReader(bytes.NewBufferString(respRaw)), nil) 105 | a.NotError(err).NotNil(resp) 106 | 107 | r, err := http.ReadRequest(bufio.NewReader(bytes.NewBufferString(reqRaw))) 108 | a.NotError(err).NotNil(r) 109 | r.RequestURI = "" // 作为 client 不需要此值 110 | 111 | return r, resp 112 | } 113 | 114 | func compare(a *assert.Assertion, resp *http.Response, status int, header http.Header, body io.Reader) { 115 | a.Equal(resp.StatusCode, status, "compare 断言失败,状态码的期望值 %d 与实际值 %d 不同", resp.StatusCode, status) 116 | 117 | for k := range resp.Header { 118 | respV := resp.Header.Get(k) 119 | retV := header.Get(k) 120 | a.Equal(respV, retV, "compare 断言失败,报头 %s 的期望值 %s 与实际值 %s 不相同", k, respV, retV) 121 | } 122 | 123 | retB, err := io.ReadAll(body) 124 | a.NotError(err).NotNil(retB) 125 | respB, err := io.ReadAll(resp.Body) 126 | a.NotError(err).NotNil(respB) 127 | retB = bytes.TrimSpace(retB) 128 | respB = bytes.TrimSpace(respB) 129 | a.Equal(respB, retB, "compare 断言失败,内容的期望值与实际值不相同\n%s\n\n%s\n", respB, retB) 130 | } 131 | -------------------------------------------------------------------------------- /rest/rest_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "encoding/json" 9 | "encoding/xml" 10 | "fmt" 11 | "io" 12 | "net/http" 13 | "net/http/httptest" 14 | "strings" 15 | "testing" 16 | 17 | "github.com/issue9/assert/v4" 18 | ) 19 | 20 | type bodyTest struct { 21 | XMLName struct{} `json:"-" xml:"root"` 22 | ID int `json:"id" xml:"id"` 23 | } 24 | 25 | var h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | if r.URL.Path == "/get" { 27 | w.WriteHeader(http.StatusCreated) 28 | return 29 | } 30 | 31 | if r.URL.Path == "/body" { 32 | if r.Header.Get("content-type") == "application/json" { 33 | b := &bodyTest{} 34 | bs, err := io.ReadAll(r.Body) 35 | if err != nil { 36 | fmt.Println(err) 37 | w.WriteHeader(http.StatusInternalServerError) 38 | return 39 | } 40 | 41 | if err := json.Unmarshal(bs, b); err != nil { 42 | fmt.Println(err) 43 | w.WriteHeader(http.StatusInternalServerError) 44 | return 45 | } 46 | 47 | b.ID++ 48 | bs, err = json.Marshal(b) 49 | if err != nil { 50 | fmt.Println(err) 51 | w.WriteHeader(http.StatusInternalServerError) 52 | return 53 | } 54 | w.Header().Add("content-Type", "application/json;charset=utf-8") 55 | w.WriteHeader(http.StatusCreated) 56 | w.Write(bs) 57 | return 58 | } 59 | 60 | if r.Header.Get("content-type") == "application/xml" { 61 | b := &bodyTest{} 62 | bs, err := io.ReadAll(r.Body) 63 | if err != nil { 64 | fmt.Println(err) 65 | w.WriteHeader(http.StatusInternalServerError) 66 | return 67 | } 68 | 69 | if err := xml.Unmarshal(bs, b); err != nil { 70 | fmt.Println(err) 71 | w.WriteHeader(http.StatusInternalServerError) 72 | return 73 | } 74 | 75 | b.ID++ 76 | bs, err = xml.Marshal(b) 77 | if err != nil { 78 | fmt.Println(err) 79 | w.WriteHeader(http.StatusInternalServerError) 80 | return 81 | } 82 | w.Header().Add("content-Type", "application/xml;charset=utf-8") 83 | w.WriteHeader(http.StatusCreated) 84 | w.Write(bs) 85 | return 86 | } 87 | 88 | w.WriteHeader(http.StatusUnsupportedMediaType) 89 | return 90 | } 91 | 92 | w.WriteHeader(http.StatusNotFound) 93 | }) 94 | 95 | func TestBuildHandler(t *testing.T) { 96 | a := assert.New(t, false) 97 | 98 | h := BuildHandler(a, 201, "body", map[string]string{"k1": "v1"}) 99 | w := httptest.NewRecorder() 100 | r, err := http.NewRequest(http.MethodGet, "/", nil) 101 | a.NotError(err).NotNil(r) 102 | h.ServeHTTP(w, r) 103 | a.Equal(w.Code, 201). 104 | Equal(w.Header().Get("k1"), "v1") 105 | } 106 | 107 | var raw = []*struct { 108 | req, resp string 109 | }{ 110 | { 111 | req: `GET {host}/get HTTP/1.1 112 | 113 | `, 114 | resp: `HTTP/1.1 201 115 | 116 | `, 117 | }, 118 | { 119 | req: `POST {host}/body HTTP/1.1 120 | Host: 这行将被忽略 121 | Content-Type: application/json 122 | Content-Length: 8 123 | 124 | {"id":5} 125 | 126 | `, 127 | resp: `HTTP/1.1 201 128 | Content-Type: application/json;charset=utf-8 129 | 130 | {"id":6} 131 | 132 | `, 133 | }, 134 | { 135 | req: `DELETE {host}/body?page=5 HTTP/1.0 136 | Content-Type: application/xml 137 | Content-Length: 23 138 | 139 | 6 140 | 141 | `, 142 | resp: `HTTP/1.0 201 143 | Content-Type: application/xml;charset=utf-8 144 | 145 | 7 146 | 147 | `, 148 | }, 149 | } 150 | 151 | func TestServer_RawHTTP(t *testing.T) { 152 | a := assert.New(t, true) 153 | s := NewServer(a, h, nil) 154 | 155 | for _, item := range raw { 156 | req := strings.Replace(item.req, "{host}", s.URL(), 1) 157 | s.RawHTTP(req, item.resp) 158 | } 159 | } 160 | 161 | func TestRawHandler(t *testing.T) { 162 | a := assert.New(t, true) 163 | host := "http://localhost:88" 164 | 165 | for _, item := range raw { 166 | req := strings.Replace(item.req, "{host}", host, 1) 167 | RawHandler(a, h, req, item.resp) 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /rest/response.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "bytes" 9 | "io" 10 | "net/http" 11 | "net/http/httptest" 12 | 13 | "github.com/issue9/assert/v4" 14 | ) 15 | 16 | // Response 测试请求的返回结构 17 | type Response struct { 18 | resp *http.Response 19 | a *assert.Assertion 20 | body []byte 21 | } 22 | 23 | // Do 执行请求操作 24 | // 25 | // h 默认为空,如果不为空,则表示当前请求忽略 [http.Client],而是访问 h.ServeHTTP 的内容。 26 | func (req *Request) Do(h http.Handler) *Response { 27 | if req.client == nil && h == nil { 28 | panic("h 不能为空") 29 | } 30 | 31 | req.a.TB().Helper() 32 | 33 | r := req.Request() 34 | var err error 35 | var resp *http.Response 36 | if h != nil { 37 | w := httptest.NewRecorder() 38 | h.ServeHTTP(w, r) 39 | resp = w.Result() 40 | } else { 41 | resp, err = req.client.Do(r) 42 | req.a.NotError(err).NotNil(resp) 43 | } 44 | 45 | var bs []byte 46 | if resp.Body != nil { 47 | bs, err = io.ReadAll(resp.Body) 48 | if err != io.EOF { 49 | req.a.NotError(err) 50 | } 51 | req.a.NotError(resp.Body.Close()) 52 | } 53 | 54 | return &Response{ 55 | a: req.a, 56 | resp: resp, 57 | body: bs, 58 | } 59 | } 60 | 61 | // Resp 返回 [http.Response] 实例 62 | // 63 | // NOTE: [http.Response.Body] 内容已经被读取且关闭。 64 | func (resp *Response) Resp() *http.Response { return resp.resp } 65 | 66 | func (resp *Response) assert(expr bool, f *assert.Failure) *Response { 67 | resp.a.TB().Helper() 68 | resp.a.Assert(expr, f) 69 | return resp 70 | } 71 | 72 | // Success 状态码是否在 100-399 之间 73 | func (resp *Response) Success(msg ...interface{}) *Response { 74 | resp.a.TB().Helper() 75 | succ := resp.resp.StatusCode >= 100 && resp.resp.StatusCode < 400 76 | return resp.assert(succ, assert.NewFailure("Success", msg, map[string]interface{}{"status": resp.resp.StatusCode})) 77 | } 78 | 79 | // Fail 状态码是否大于 399 80 | func (resp *Response) Fail(msg ...interface{}) *Response { 81 | resp.a.TB().Helper() 82 | fail := resp.resp.StatusCode >= 400 83 | return resp.assert(fail, assert.NewFailure("Fail", msg, map[string]interface{}{"status": resp.resp.StatusCode})) 84 | } 85 | 86 | // Status 判断状态码是否与 status 相等 87 | func (resp *Response) Status(status int, msg ...interface{}) *Response { 88 | resp.a.TB().Helper() 89 | eq := resp.resp.StatusCode == status 90 | return resp.assert(eq, assert.NewFailure("Status", msg, map[string]interface{}{"status": resp.resp.StatusCode, "val": status})) 91 | } 92 | 93 | // NotStatus 判断状态码是否与 status 不相等 94 | func (resp *Response) NotStatus(status int, msg ...interface{}) *Response { 95 | resp.a.TB().Helper() 96 | neq := resp.resp.StatusCode != status 97 | return resp.assert(neq, assert.NewFailure("NotStatus", msg, map[string]interface{}{"status": resp.resp.StatusCode})) 98 | } 99 | 100 | // Header 判断指定的报头是否与 val 相同 101 | // 102 | // msg 可以为空,会返回一个默认的错误提示信息 103 | func (resp *Response) Header(key string, val string, msg ...interface{}) *Response { 104 | resp.a.TB().Helper() 105 | h := resp.resp.Header.Get(key) 106 | return resp.assert(h == val, assert.NewFailure("Header", msg, map[string]interface{}{"header": key, "v1": h, "v2": val})) 107 | } 108 | 109 | // NotHeader 指定的报头必定不与 val 相同。 110 | func (resp *Response) NotHeader(key string, val string, msg ...interface{}) *Response { 111 | resp.a.TB().Helper() 112 | h := resp.resp.Header.Get(key) 113 | return resp.assert(h != val, assert.NewFailure("NotHeader", msg, map[string]interface{}{"header": key, "v": h})) 114 | } 115 | 116 | // Body 断言内容与 val 相同 117 | func (resp *Response) Body(val []byte, msg ...interface{}) *Response { 118 | resp.a.TB().Helper() 119 | return resp.assert(bytes.Equal(resp.body, val), assert.NewFailure("Body", msg, map[string]interface{}{"body": string(resp.body), "val": string(val)})) 120 | } 121 | 122 | // StringBody 断言内容与 val 相同 123 | func (resp *Response) StringBody(val string, msg ...interface{}) *Response { 124 | resp.a.TB().Helper() 125 | b := string(resp.body) 126 | return resp.assert(b == val, assert.NewFailure("StringBody", msg, map[string]interface{}{"body": b, "val": val})) 127 | } 128 | 129 | // BodyNotEmpty 报文内容是否不为空 130 | func (resp *Response) BodyNotEmpty(msg ...interface{}) *Response { 131 | resp.a.TB().Helper() 132 | return resp.assert(len(resp.body) > 0, assert.NewFailure("BodyNotEmpty", msg, nil)) 133 | } 134 | 135 | // BodyEmpty 报文内容是否为空 136 | func (resp *Response) BodyEmpty(msg ...interface{}) *Response { 137 | resp.a.TB().Helper() 138 | return resp.assert(len(resp.body) == 0, assert.NewFailure("BodyEmpty", msg, map[string]interface{}{"body": resp.body})) 139 | } 140 | 141 | // BodyFunc 指定对 body 内容的断言方式 142 | func (resp *Response) BodyFunc(f func(a *assert.Assertion, body []byte)) *Response { 143 | resp.a.TB().Helper() 144 | 145 | b := make([]byte, len(resp.body)) 146 | copy(b, resp.body) 147 | f(resp.a, b) 148 | 149 | return resp 150 | } 151 | -------------------------------------------------------------------------------- /rest/request.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rest 6 | 7 | import ( 8 | "bytes" 9 | "io" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | 14 | "github.com/issue9/assert/v4" 15 | ) 16 | 17 | // Request 请求的参数封装 18 | type Request struct { 19 | path string 20 | method string 21 | body io.Reader 22 | queries url.Values 23 | cookies []*http.Cookie 24 | params map[string]string 25 | headers map[string]string 26 | a *assert.Assertion 27 | client *http.Client 28 | } 29 | 30 | // NewRequest 获取一条请求的结果 31 | // 32 | // method 表示请求方法 33 | // path 表示请求的路径,域名部分无须填定。可以通过 {} 指定参数,比如: 34 | // 35 | // r := NewRequest(http.MethodGet, "/users/{id}") 36 | // 37 | // 之后就可以使用 Params 指定 id 的具体值,达到复用的目的: 38 | // 39 | // resp1 := r.Param("id", "1").Do() 40 | // resp2 := r.Param("id", "2").Do() 41 | func (srv *Server) NewRequest(method, path string) *Request { 42 | return NewRequest(srv.a, method, srv.URL()+path).Client(srv.client) 43 | } 44 | 45 | func (srv *Server) Get(path string) *Request { 46 | return srv.NewRequest(http.MethodGet, path) 47 | } 48 | 49 | func (srv *Server) Put(path string, body []byte) *Request { 50 | return srv.NewRequest(http.MethodPut, path).Body(body) 51 | } 52 | 53 | func (srv *Server) Post(path string, body []byte) *Request { 54 | return srv.NewRequest(http.MethodPost, path).Body(body) 55 | } 56 | 57 | func (srv *Server) Patch(path string, body []byte) *Request { 58 | return srv.NewRequest(http.MethodPatch, path).Body(body) 59 | } 60 | 61 | func (srv *Server) Delete(path string) *Request { 62 | return srv.NewRequest(http.MethodDelete, path) 63 | } 64 | 65 | // NewRequest 以调用链的方式构建一个访问请求对象 66 | func NewRequest(a *assert.Assertion, method, path string) *Request { 67 | return &Request{ 68 | a: a, 69 | method: method, 70 | path: path, 71 | } 72 | } 73 | 74 | func Get(a *assert.Assertion, path string) *Request { 75 | return NewRequest(a, http.MethodGet, path) 76 | } 77 | 78 | func Delete(a *assert.Assertion, path string) *Request { 79 | return NewRequest(a, http.MethodDelete, path) 80 | } 81 | 82 | func Post(a *assert.Assertion, path string, body []byte) *Request { 83 | return NewRequest(a, http.MethodPost, path).Body(body) 84 | } 85 | 86 | func Put(a *assert.Assertion, path string, body []byte) *Request { 87 | return NewRequest(a, http.MethodPut, path).Body(body) 88 | } 89 | 90 | func Patch(a *assert.Assertion, path string, body []byte) *Request { 91 | return NewRequest(a, http.MethodPatch, path).Body(body) 92 | } 93 | 94 | // Client 指定采用的客户端实例 95 | // 96 | // 可以为空,如果为空,那么在 Do 函数中的参数必不能为空。 97 | func (req *Request) Client(c *http.Client) *Request { 98 | req.client = c 99 | return req 100 | } 101 | 102 | // Query 添加一个请求参数 103 | func (req *Request) Query(key, val string) *Request { 104 | if req.queries == nil { 105 | req.queries = url.Values{} 106 | } 107 | 108 | req.queries.Add(key, val) 109 | 110 | return req 111 | } 112 | 113 | // Cookie 添加 Cookie 114 | func (req *Request) Cookie(c *http.Cookie) *Request { 115 | req.cookies = append(req.cookies, c) 116 | return req 117 | } 118 | 119 | // Header 指定请求时的报头 120 | func (req *Request) Header(key, val string) *Request { 121 | if req.headers == nil { 122 | req.headers = make(map[string]string, 5) 123 | } 124 | 125 | req.headers[key] = val 126 | 127 | return req 128 | } 129 | 130 | // Param 替换参数 131 | func (req *Request) Param(key, val string) *Request { 132 | if req.params == nil { 133 | req.params = make(map[string]string, 5) 134 | } 135 | 136 | req.params[key] = val 137 | 138 | return req 139 | } 140 | 141 | // Body 指定提交的内容 142 | func (req *Request) Body(body []byte) *Request { 143 | req.body = bytes.NewReader(body) 144 | return req 145 | } 146 | 147 | func (req *Request) StringBody(body string) *Request { 148 | req.body = bytes.NewBufferString(body) 149 | return req 150 | } 151 | 152 | // BodyFunc 指定一个未编码的对象 153 | // 154 | // marshal 对 obj 的编码函数,比如 [json.Marshal] 等。 155 | func (req *Request) BodyFunc(obj interface{}, marshal func(interface{}) ([]byte, error)) *Request { 156 | req.a.TB().Helper() 157 | 158 | data, err := marshal(obj) 159 | req.a.NotError(err).NotNil(data) 160 | return req.Body(data) 161 | } 162 | 163 | func (req *Request) buildPath() string { 164 | path := req.path 165 | 166 | for key, val := range req.params { 167 | key = "{" + key + "}" 168 | path = strings.ReplaceAll(path, key, val) 169 | } 170 | 171 | if len(req.queries) > 0 { 172 | path += ("?" + req.queries.Encode()) 173 | } 174 | 175 | return path 176 | } 177 | 178 | // Request 生成 [http.Request] 实例 179 | func (req *Request) Request() *http.Request { 180 | req.a.TB().Helper() 181 | 182 | r, err := http.NewRequest(req.method, req.buildPath(), req.body) 183 | req.a.NotError(err).NotNil(r) 184 | r.Close = true 185 | 186 | for k, v := range req.headers { 187 | r.Header.Add(k, v) 188 | } 189 | 190 | for _, c := range req.cookies { 191 | r.AddCookie(c) 192 | } 193 | 194 | return r 195 | } 196 | -------------------------------------------------------------------------------- /number.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "fmt" 9 | "reflect" 10 | ) 11 | 12 | // Length 断言长度是否为指定的值 13 | // 14 | // v 可以是以下类型: 15 | // - map 16 | // - string 17 | // - slice 18 | // - array 19 | func (a *Assertion) Length(v interface{}, l int, msg ...interface{}) *Assertion { 20 | a.TB().Helper() 21 | 22 | rl, err := getLen(v) 23 | if err != "" { 24 | a.Assert(false, NewFailure("Length", msg, map[string]interface{}{"err": err})) 25 | } 26 | return a.Assert(rl == l, NewFailure("Length", msg, map[string]interface{}{"l1": rl, "l2": l})) 27 | } 28 | 29 | // NotLength 断言长度不是指定的值 30 | // 31 | // v 可以是以下类型: 32 | // - map 33 | // - string 34 | // - slice 35 | // - array 36 | func (a *Assertion) NotLength(v interface{}, l int, msg ...interface{}) *Assertion { 37 | a.TB().Helper() 38 | 39 | rl, err := getLen(v) 40 | if err != "" { 41 | a.Assert(false, NewFailure("NotLength", msg, map[string]interface{}{"err": err})) 42 | } 43 | return a.Assert(rl != l, NewFailure("NotLength", msg, map[string]interface{}{"l": rl})) 44 | } 45 | 46 | func (a *Assertion) Greater(v interface{}, val float64, msg ...interface{}) *Assertion { 47 | vv, ok := getNumber(v) 48 | if !ok { 49 | return a.Assert(false, NewFailure("Greater", msg, nil)) 50 | } 51 | return a.Assert(vv > val, NewFailure("Greater", msg, nil)) 52 | } 53 | 54 | func (a *Assertion) Less(v interface{}, val float64, msg ...interface{}) *Assertion { 55 | vv, ok := getNumber(v) 56 | if !ok { 57 | return a.Assert(false, NewFailure("Less", msg, nil)) 58 | } 59 | return a.Assert(vv < val, NewFailure("Less", msg, nil)) 60 | } 61 | 62 | func (a *Assertion) GreaterEqual(v interface{}, val float64, msg ...interface{}) *Assertion { 63 | vv, ok := getNumber(v) 64 | if !ok { 65 | return a.Assert(false, NewFailure("GreaterEqual", msg, nil)) 66 | } 67 | return a.Assert(vv >= val, NewFailure("GreaterEqual", msg, nil)) 68 | } 69 | 70 | func (a *Assertion) LessEqual(v interface{}, val float64, msg ...interface{}) *Assertion { 71 | vv, ok := getNumber(v) 72 | if !ok { 73 | return a.Assert(false, NewFailure("LessEqual", msg, nil)) 74 | } 75 | return a.Assert(vv <= val, NewFailure("LessEqual", msg, nil)) 76 | } 77 | 78 | // Positive 断言 v 为正数 79 | // 80 | // NOTE: 不包含 0 81 | func (a *Assertion) Positive(v interface{}, msg ...interface{}) *Assertion { 82 | vv, ok := getNumber(v) 83 | if !ok { 84 | return a.Assert(false, NewFailure("Positive", msg, nil)) 85 | } 86 | return a.Assert(vv > 0, NewFailure("Positive", msg, nil)) 87 | } 88 | 89 | // Negative 断言 v 为负数 90 | // 91 | // NOTE: 不包含 0 92 | func (a *Assertion) Negative(v interface{}, msg ...interface{}) *Assertion { 93 | vv, ok := getNumber(v) 94 | if !ok { 95 | return a.Assert(false, NewFailure("Negative", msg, nil)) 96 | } 97 | return a.Assert(vv < 0, NewFailure("Negative", msg, nil)) 98 | } 99 | 100 | // Between 断言 v 是否存在于 (min,max) 之间 101 | func (a *Assertion) Between(v interface{}, min, max float64, msg ...interface{}) *Assertion { 102 | vv, ok := getNumber(v) 103 | if !ok { 104 | return a.Assert(false, NewFailure("Between", msg, nil)) 105 | } 106 | 107 | return a.Assert(vv > min && vv < max, NewFailure("Between", msg, nil)) 108 | } 109 | 110 | // BetweenEqual 断言 v 是否存在于 [min,max] 之间 111 | func (a *Assertion) BetweenEqual(v interface{}, min, max float64, msg ...interface{}) *Assertion { 112 | vv, ok := getNumber(v) 113 | if !ok { 114 | return a.Assert(false, NewFailure("BetweenEqual", msg, nil)) 115 | } 116 | 117 | return a.Assert(vv >= min && vv <= max, NewFailure("BetweenEqual", msg, nil)) 118 | } 119 | 120 | // BetweenEqualMin 断言 v 是否存在于 [min,max) 之间 121 | func (a *Assertion) BetweenEqualMin(v interface{}, min, max float64, msg ...interface{}) *Assertion { 122 | vv, ok := getNumber(v) 123 | if !ok { 124 | return a.Assert(false, NewFailure("BetweenEqualMin", msg, nil)) 125 | } 126 | 127 | return a.Assert(vv >= min && vv < max, NewFailure("BetweenEqualMin", msg, nil)) 128 | } 129 | 130 | // BetweenEqualMax 断言 v 是否存在于 (min,max] 之间 131 | func (a *Assertion) BetweenEqualMax(v interface{}, min, max float64, msg ...interface{}) *Assertion { 132 | vv, ok := getNumber(v) 133 | if !ok { 134 | return a.Assert(false, NewFailure("BetweenEqualMax", msg, nil)) 135 | } 136 | 137 | return a.Assert(vv > min && vv <= max, NewFailure("BetweenEqualMax", msg, nil)) 138 | } 139 | 140 | // bool 表示是否成功转换 141 | func getNumber(v interface{}) (float64, bool) { 142 | switch val := v.(type) { 143 | case int: 144 | return float64(val), true 145 | case int8: 146 | return float64(val), true 147 | case int16: 148 | return float64(val), true 149 | case int32: 150 | return float64(val), true 151 | case int64: 152 | return float64(val), true 153 | case uint: 154 | return float64(val), true 155 | case uint8: 156 | return float64(val), true 157 | case uint16: 158 | return float64(val), true 159 | case uint32: 160 | return float64(val), true 161 | case uint64: 162 | return float64(val), true 163 | case float32: 164 | return float64(val), true 165 | case float64: 166 | return float64(val), true 167 | } 168 | 169 | return 0, false 170 | } 171 | 172 | func getLen(v interface{}) (l int, msg string) { 173 | r := reflect.ValueOf(v) 174 | for r.Kind() == reflect.Ptr { 175 | r = r.Elem() 176 | } 177 | 178 | if v == nil { 179 | return 0, "" 180 | } 181 | 182 | switch r.Kind() { 183 | case reflect.Array, reflect.String, reflect.Slice, reflect.Map, reflect.Chan: 184 | return r.Len(), "" 185 | } 186 | return 0, fmt.Sprintf("无法获取 %s 类型的长度信息", r.Kind()) 187 | } 188 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestIsZero(t *testing.T) { 13 | zero := func(v interface{}) { 14 | t.Helper() 15 | if !isZero(v) { 16 | t.Errorf("zero: %v", v) 17 | } 18 | } 19 | 20 | zero(nil) 21 | zero(struct{}{}) 22 | zero(time.Time{}) 23 | zero(&time.Time{}) 24 | } 25 | 26 | func TestIsEqual(t *testing.T) { 27 | eq := func(v1, v2 interface{}) { 28 | t.Helper() 29 | if !isEqual(v1, v2) { 30 | t.Errorf("eq:[%v]!=[%v]", v1, v2) 31 | } 32 | } 33 | 34 | neq := func(v1, v2 interface{}) { 35 | t.Helper() 36 | if isEqual(v1, v2) { 37 | t.Errorf("eq:[%v]==[%v]", v1, v2) 38 | } 39 | } 40 | 41 | eq([]byte("abc"), "abc") 42 | eq("abc", []byte("abc")) 43 | 44 | eq([]byte("中文abc"), "中文abc") 45 | eq("中文abc", []byte("中文abc")) 46 | 47 | eq([]rune("中文abc"), "中文abc") 48 | eq("中文abc", []rune("中文abc")) 49 | 50 | eq(5, 5.0) 51 | eq(int8(5), 5) 52 | eq(5, int8(5)) 53 | eq(float64(5), int8(5)) 54 | eq([]int{1, 2, 3}, []int{1, 2, 3}) 55 | eq([]int{1, 2, 3}, []int8{1, 2, 3}) 56 | eq([]float32{1, 2.0, 3}, []int8{1, 2, 3}) 57 | eq([]float32{1, 2.0, 3}, []float64{1, 2, 3}) 58 | 59 | // 比较两个元素类型可相互转换的数组 60 | eq( 61 | [][]int{ 62 | {1, 2}, 63 | {3, 4}, 64 | }, 65 | [][]int8{ 66 | {1, 2}, 67 | {3, 4}, 68 | }, 69 | ) 70 | 71 | // 比较两个元素类型可转换的 map 72 | eq( 73 | []map[int]int{ 74 | {1: 1, 2: 2}, 75 | {3: 3, 4: 4}, 76 | }, 77 | []map[int]int8{ 78 | {1: 1, 2: 2}, 79 | {3: 3, 4: 4}, 80 | }, 81 | ) 82 | eq(map[string]int{"1": 1, "2": 2}, map[string]int8{"1": 1, "2": 2}) 83 | 84 | // 比较两个元素类型可转换的 map 85 | eq( 86 | map[int]string{ 87 | 1: "1", 88 | 2: "2", 89 | }, 90 | map[int][]byte{ 91 | 1: []byte("1"), 92 | 2: []byte("2"), 93 | }, 94 | ) 95 | 96 | // array 对比 97 | eq([2]int{1, 2}, [2]int{1, 2}) 98 | eq([2]int{9, 3}, [2]int8{9, 3}) 99 | eq([2]int8{1, 4}, [2]int{1, 4}) 100 | eq([2]int{1, 5}, []int8{1, 5}) 101 | 102 | neq(map[int]int{1: 1, 2: 2}, map[int8]int{1: 1, 2: 2}) 103 | neq([]int{1, 2, 3}, []int{3, 2, 1}) 104 | neq("5", 5) 105 | neq(true, "true") 106 | neq(true, 1) 107 | neq(true, "1") 108 | // 判断包含不同键名的两个 map 109 | neq(map[int]int{1: 1, 2: 2}, map[int]int{5: 5, 6: 6}) 110 | 111 | // time 112 | loc := time.FixedZone("utf+8", 8*3600) 113 | now := time.Now() 114 | eq(time.Time{}, time.Time{}) 115 | neq(now.In(loc), now.In(time.UTC)) // 时区不同 116 | n1 := time.Now() 117 | n2 := n1.Add(0) 118 | eq(n1, n2) 119 | 120 | // 指针 121 | v1 := 5 122 | v2 := 5 123 | p1 := &v1 124 | p2 := &v1 125 | eq(p1, p2) // 指针相等 126 | p2 = &v2 127 | eq(p1, p2) // 指向内容相等 128 | } 129 | 130 | func TestIsEmpty(t *testing.T) { 131 | if isEmpty([]string{""}) { 132 | t.Error("isEmpty([]string{\"\"})") 133 | } 134 | 135 | if !isEmpty([]string{}) { 136 | t.Error("isEmpty([]string{})") 137 | } 138 | 139 | if !isEmpty([]int{}) { 140 | t.Error("isEmpty([]int{})") 141 | } 142 | 143 | if !isEmpty(map[string]int{}) { 144 | t.Error("isEmpty(map[string]int{})") 145 | } 146 | 147 | if !isEmpty(0) { 148 | t.Error("isEmpty(0)") 149 | } 150 | 151 | if !isEmpty(int64(0)) { 152 | t.Error("isEmpty(int64(0))") 153 | } 154 | 155 | if !isEmpty(uint64(0)) { 156 | t.Error("isEmpty(uint64(0))") 157 | } 158 | 159 | if !isEmpty(0.0) { 160 | t.Error("isEmpty(0.0)") 161 | } 162 | 163 | if !isEmpty(float32(0)) { 164 | t.Error("isEmpty(0.0)") 165 | } 166 | 167 | if !isEmpty("") { 168 | t.Error("isEmpty(``)") 169 | } 170 | 171 | if !isEmpty([0]int{}) { 172 | t.Error("isEmpty([0]int{})") 173 | } 174 | 175 | if !isEmpty(time.Time{}) { 176 | t.Error("isEmpty(time.Time{})") 177 | } 178 | 179 | if !isEmpty(&time.Time{}) { 180 | t.Error("isEmpty(&time.Time{})") 181 | } 182 | 183 | if isEmpty(" ") { 184 | t.Error("isEmpty(\" \")") 185 | } 186 | } 187 | 188 | func TestIsNil(t *testing.T) { 189 | if !isNil(nil) { 190 | t.Error("isNil(nil)") 191 | } 192 | 193 | var v1 []int 194 | if !isNil(v1) { 195 | t.Error("isNil(v1)") 196 | } 197 | 198 | var v2 map[string]string 199 | if !isNil(v2) { 200 | t.Error("isNil(v2)") 201 | } 202 | } 203 | 204 | func TestIsContains(t *testing.T) { 205 | fn := func(result bool, container, item interface{}) { 206 | t.Helper() 207 | if result != isContains(container, item) { 208 | t.Errorf("%v == (isContains(%v, %v))出错\n", result, container, item) 209 | } 210 | } 211 | 212 | fn(false, nil, nil) 213 | 214 | fn(true, "abc", "a") 215 | fn(true, "abc", "c") 216 | fn(true, "abc", "bc") 217 | fn(true, "abc", byte('a')) // string vs byte 218 | fn(true, "abc", rune('a')) // string vs rune 219 | fn(true, "abc", []byte("ab")) // string vs []byte 220 | fn(true, "abc", []rune("ab")) // string vs []rune 221 | 222 | fn(true, []byte("abc"), "a") 223 | fn(true, []byte("abc"), "c") 224 | fn(true, []byte("abc"), "bc") 225 | fn(true, []byte("abc"), byte('a')) 226 | fn(true, []byte("abc"), rune('a')) 227 | fn(true, []byte("abc"), []byte("ab")) 228 | fn(true, []byte("abc"), []rune("ab")) 229 | 230 | fn(true, []rune("abc"), "a") 231 | fn(true, []rune("abc"), "c") 232 | fn(true, []rune("abc"), "bc") 233 | fn(true, []rune("abc"), byte('a')) 234 | fn(true, []rune("abc"), rune('a')) 235 | fn(true, []rune("abc"), []byte("ab")) 236 | fn(true, []rune("abc"), []rune("ab")) 237 | 238 | fn(true, "中文a", "中") 239 | fn(true, "中文a", "a") 240 | fn(true, "中文a", '中') 241 | 242 | fn(true, []int{1, 2, 3}, 1) 243 | fn(true, []int{1, 2, 3}, int8(3)) 244 | fn(true, []int{1, 2, 4}, []int{1, 2}) 245 | fn(true, []interface{}{[]int{1, 2}, 5, 6}, []int8{1, 2}) 246 | fn(true, []interface{}{[]int{1, 2}, 5, 6}, 5) 247 | 248 | fn(true, map[string]int{"1": 1, "2": 2}, map[string]int8{"1": 1}) 249 | fn(true, 250 | map[string][]int{ 251 | "1": {1, 2, 3}, 252 | "2": {4, 5, 6}, 253 | }, 254 | map[string][]int8{ 255 | "1": {1, 2, 3}, 256 | "2": {4, 5, 6}, 257 | }, 258 | ) 259 | 260 | fn(false, map[string]int{}, nil) 261 | fn(false, map[string]int{"1": 1, "2": 2}, map[string]int8{}) 262 | fn(false, map[string]int{"1": 1, "2": 2}, map[string]int8{"1": 110}) // 同键名,不同值 263 | fn(false, map[string]int{"1": 1, "2": 2}, map[string]int8{"5": 5}) 264 | fn(false, []int{1, 2, 3}, nil) 265 | fn(false, []int{1, 2, 3}, []int8{1, 3}) 266 | fn(false, []int{1, 2, 3}, []int{1, 2, 3, 4}) 267 | fn(false, []int{}, []int{1}) // 空数组 268 | } 269 | -------------------------------------------------------------------------------- /assertion.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "fmt" 9 | "reflect" 10 | "regexp" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | // Assertion 是对 [testing.TB] 的二次包装 16 | type Assertion struct { 17 | tb testing.TB 18 | print func(...interface{}) 19 | } 20 | 21 | // New 返回 [Assertion] 对象 22 | // 23 | // fatal 决定在出错时是调用 [testing.TB.Error] 还是 [testing.TB.Fatal]; 24 | func New(tb testing.TB, fatal bool) *Assertion { 25 | p := tb.Error 26 | if fatal { 27 | p = tb.Fatal 28 | } 29 | 30 | return &Assertion{ 31 | tb: tb, 32 | print: p, 33 | } 34 | } 35 | 36 | // NewWithEnv 以指定的环境变量初始化 [Assertion] 对象 37 | // 38 | // env 是以 [testing.TB.Setenv] 的形式调用。 39 | func NewWithEnv(tb testing.TB, fatal bool, env map[string]string) *Assertion { 40 | for k, v := range env { 41 | tb.Setenv(k, v) 42 | } 43 | return New(tb, fatal) 44 | } 45 | 46 | // Assert 断言 expr 条件成立 47 | // 48 | // f 表示在断言失败时输出的信息 49 | // 50 | // 普通用户直接使用 [Assertion.True] 效果是一样的,此函数主要供 [Assertion] 自身调用。 51 | func (a *Assertion) Assert(expr bool, f *Failure) *Assertion { 52 | if !expr { 53 | a.TB().Helper() 54 | a.print(GetFailureSprintFunc()(f)) 55 | } 56 | failurePool.Put(f) 57 | return a 58 | } 59 | 60 | // TB 返回 [testing.TB] 接口 61 | func (a *Assertion) TB() testing.TB { return a.tb } 62 | 63 | // True 断言表达式 expr 为真 64 | // 65 | // args 对应 [fmt.Printf] 函数中的参数,其中 args[0] 对应第一个参数 format,依次类推, 66 | // 其它断言函数的 args 参数,功能与此相同。 67 | func (a *Assertion) True(expr bool, msg ...interface{}) *Assertion { 68 | a.TB().Helper() 69 | return a.Assert(expr, NewFailure("True", msg, nil)) 70 | } 71 | 72 | func (a *Assertion) False(expr bool, msg ...interface{}) *Assertion { 73 | a.TB().Helper() 74 | return a.Assert(!expr, NewFailure("False", msg, nil)) 75 | } 76 | 77 | func (a *Assertion) Nil(expr interface{}, msg ...interface{}) *Assertion { 78 | a.TB().Helper() 79 | return a.Assert(isNil(expr), NewFailure("Nil", msg, map[string]interface{}{"v": expr})) 80 | } 81 | 82 | func (a *Assertion) NotNil(expr interface{}, msg ...interface{}) *Assertion { 83 | a.TB().Helper() 84 | return a.Assert(!isNil(expr), NewFailure("NotNil", msg, map[string]interface{}{"v": expr})) 85 | } 86 | 87 | func (a *Assertion) Equal(v1, v2 interface{}, msg ...interface{}) *Assertion { 88 | a.TB().Helper() 89 | return a.Assert(isEqual(v1, v2), NewFailure("Equal", msg, map[string]interface{}{"v1": v1, "v2": v2})) 90 | } 91 | 92 | func (a *Assertion) NotEqual(v1, v2 interface{}, msg ...interface{}) *Assertion { 93 | a.TB().Helper() 94 | return a.Assert(!isEqual(v1, v2), NewFailure("NotEqual", msg, map[string]interface{}{"v1": v1, "v2": v2})) 95 | } 96 | 97 | func (a *Assertion) Empty(expr interface{}, msg ...interface{}) *Assertion { 98 | a.TB().Helper() 99 | return a.Assert(isEmpty(expr), NewFailure("Empty", msg, map[string]interface{}{"v": expr})) 100 | } 101 | 102 | func (a *Assertion) NotEmpty(expr interface{}, msg ...interface{}) *Assertion { 103 | a.TB().Helper() 104 | return a.Assert(!isEmpty(expr), NewFailure("NotEmpty", msg, map[string]interface{}{"v": expr})) 105 | } 106 | 107 | // Contains 断言 container 包含 item 或是包含 item 中的所有项 108 | // 109 | // 若 container string、[]byte 和 []rune 类型, 110 | // 都将会以字符串的形式判断其是否包含 item。 111 | // 若 container 是个列表(array、slice、map)则判断其元素中是否包含 item 中的 112 | // 的所有项,或是 item 本身就是 container 中的一个元素。 113 | func (a *Assertion) Contains(container, item interface{}, msg ...interface{}) *Assertion { 114 | a.TB().Helper() 115 | return a.Assert(isContains(container, item), NewFailure("Contains", msg, map[string]interface{}{"container": container, "item": item})) 116 | } 117 | 118 | // NotContains 断言 container 不包含 item 或是不包含 item 中的所有项 119 | func (a *Assertion) NotContains(container, item interface{}, msg ...interface{}) *Assertion { 120 | a.TB().Helper() 121 | return a.Assert(!isContains(container, item), NewFailure("NotContains", msg, map[string]interface{}{"container": container, "item": item})) 122 | } 123 | 124 | // Zero 断言是否为零值 125 | // 126 | // 最终调用的是 [reflect.Value.IsZero] 进行判断,如果是指针,则会判断指向的对象。 127 | func (a *Assertion) Zero(v interface{}, msg ...interface{}) *Assertion { 128 | a.TB().Helper() 129 | return a.Assert(isZero(v), NewFailure("Zero", msg, map[string]interface{}{"v": v})) 130 | } 131 | 132 | // NotZero 断言是否为非零值 133 | // 134 | // 最终调用的是 [reflect.Value.IsZero] 进行判断,如果是指针,则会判断指向的对象。 135 | func (a *Assertion) NotZero(v interface{}, msg ...interface{}) *Assertion { 136 | a.TB().Helper() 137 | return a.Assert(!isZero(v), NewFailure("NotZero", msg, map[string]interface{}{"v": v})) 138 | } 139 | 140 | // TypeEqual 断言两个值的类型是否相同 141 | // 142 | // ptr 如果为 true,则会在对象为指针时,查找其指向的对象。 143 | func (a *Assertion) TypeEqual(ptr bool, v1, v2 interface{}, msg ...interface{}) *Assertion { 144 | if v1 == v2 { 145 | return a 146 | } 147 | 148 | a.TB().Helper() 149 | 150 | t1, t2 := getType(ptr, v1, v2) 151 | return a.Assert(t1 == t2, NewFailure("TypeEqual", msg, map[string]interface{}{"v1": t1, "v2": t2})) 152 | } 153 | 154 | // Same 断言为同一个对象 155 | func (a *Assertion) Same(v1, v2 interface{}, msg ...interface{}) *Assertion { 156 | a.TB().Helper() 157 | return a.Assert(isSame(v1, v2), NewFailure("Same", msg, nil)) 158 | } 159 | 160 | // NotSame 断言为不是同一个对象 161 | func (a *Assertion) NotSame(v1, v2 interface{}, msg ...interface{}) *Assertion { 162 | a.TB().Helper() 163 | return a.Assert(!isSame(v1, v2), NewFailure("NotSame", msg, nil)) 164 | } 165 | 166 | func isSame(v1, v2 interface{}) bool { 167 | rv1 := reflect.ValueOf(v1) 168 | if !canPointer(rv1.Kind()) { 169 | return false 170 | } 171 | rv2 := reflect.ValueOf(v2) 172 | if !canPointer(rv2.Kind()) { 173 | return false 174 | } 175 | 176 | return rv1.Pointer() == rv2.Pointer() 177 | } 178 | 179 | func canPointer(k reflect.Kind) bool { 180 | switch k { 181 | case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.UnsafePointer, reflect.Func: 182 | return true 183 | default: 184 | return false 185 | } 186 | } 187 | 188 | // Match 断言 v 是否匹配正则表达式 reg 189 | func (a *Assertion) Match(reg *regexp.Regexp, v interface{}, msg ...interface{}) *Assertion { 190 | a.TB().Helper() 191 | switch val := v.(type) { 192 | case string: 193 | return a.Assert(reg.MatchString(val), NewFailure("Match", msg, map[string]interface{}{"v": val})) 194 | case []byte: 195 | return a.Assert(reg.Match(val), NewFailure("Match", msg, map[string]interface{}{"v": val})) 196 | default: 197 | return a.Assert(reg.MatchString(fmt.Sprint(val)), NewFailure("Match", msg, map[string]interface{}{"v": val})) 198 | } 199 | } 200 | 201 | // NotMatch 断言 v 是否不匹配正则表达式 reg 202 | func (a *Assertion) NotMatch(reg *regexp.Regexp, v interface{}, msg ...interface{}) *Assertion { 203 | a.TB().Helper() 204 | switch val := v.(type) { 205 | case string: 206 | return a.Assert(!reg.MatchString(val), NewFailure("NotMatch", msg, map[string]interface{}{"v": val})) 207 | case []byte: 208 | return a.Assert(!reg.Match(val), NewFailure("NotMatch", msg, map[string]interface{}{"v": val})) 209 | default: 210 | return a.Assert(!reg.MatchString(fmt.Sprint(val)), NewFailure("NotMatch", msg, map[string]interface{}{"v": val})) 211 | } 212 | } 213 | 214 | // When 断言 expr 为 true 且在条件成立时调用 f 215 | // 216 | // 当有一组依赖 expr 的断言时,可以调用此方法。f 的参数 a 即为当前实例。 217 | func (a *Assertion) When(expr bool, f func(a *Assertion), msg ...interface{}) *Assertion { 218 | if expr { 219 | f(a) 220 | } 221 | return a 222 | } 223 | 224 | // Wait 等待一定时间再执行后续操作 225 | func (a *Assertion) Wait(d time.Duration) *Assertion { 226 | time.Sleep(d) 227 | return a 228 | } 229 | 230 | // WaitSeconds 等待 s 秒再执行后续操作 231 | func (a *Assertion) WaitSeconds(s int) *Assertion { return a.Wait(time.Duration(s) * time.Second) } 232 | 233 | // Go 以 goroutine 方式执行 f 234 | func (a *Assertion) Go(f func(*Assertion)) *Assertion { 235 | go f(a) 236 | return a 237 | } 238 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2014-2024 caixw 2 | // 3 | // SPDX-License-Identifier: MIT 4 | 5 | package assert 6 | 7 | import ( 8 | "bytes" 9 | "reflect" 10 | "strings" 11 | ) 12 | 13 | // 判断一个值是否为空(0, "", false, 空数组等)。 14 | // []string{""}空数组里套一个空字符串,不会被判断为空。 15 | func isEmpty(expr interface{}) bool { 16 | if isZero(expr) { 17 | return true 18 | } 19 | 20 | rv := reflect.ValueOf(expr) 21 | for rv.Kind() == reflect.Ptr { 22 | rv = rv.Elem() 23 | } 24 | switch rv.Kind() { 25 | case reflect.Slice, reflect.Map, reflect.Array, reflect.Chan: // 长度为 0 的数组也是 empty 26 | return rv.Len() == 0 27 | default: 28 | return false 29 | } 30 | } 31 | 32 | func isZero(v interface{}) bool { 33 | if isNil(v) || reflect.ValueOf(v).IsZero() { 34 | return true 35 | } 36 | 37 | rv := reflect.ValueOf(v) 38 | for rv.Kind() == reflect.Ptr { 39 | rv = rv.Elem() 40 | } 41 | return rv.IsZero() 42 | } 43 | 44 | // isNil 判断一个值是否为 nil。 45 | // 当特定类型的变量,已经声明,但还未赋值时,也将返回 true 46 | func isNil(expr interface{}) bool { 47 | if nil == expr { 48 | return true 49 | } 50 | 51 | v := reflect.ValueOf(expr) 52 | k := v.Kind() 53 | 54 | return k >= reflect.Chan && k <= reflect.Slice && v.IsNil() 55 | } 56 | 57 | // 判断两个值是否相等 58 | // 59 | // 除了通过 [reflect.DeepEqual] 判断值是否相等之外,一些类似 60 | // 可转换的数值也能正确判断,比如以下值也将会被判断为相等: 61 | // 62 | // int8(5) == int(5) 63 | // []int{1,2} == []int8{1,2} 64 | // []int{1,2} == [2]int8{1,2} 65 | // []int{1,2} == []float32{1,2} 66 | // map[string]int{"1":1,"2":2} == map[string]int8{"1":1,"2":2} 67 | // 68 | // // map 的键值不同,即使可相互转换也判断不相等。 69 | // map[int]int{1:1,2:2} != map[int8]int{1:1,2:2} 70 | func isEqual(v1, v2 interface{}) bool { 71 | if reflect.DeepEqual(v1, v2) { 72 | return true 73 | } 74 | 75 | vv1 := reflect.ValueOf(v1) 76 | vv2 := reflect.ValueOf(v2) 77 | 78 | if !vv1.IsValid() || !vv2.IsValid() { 79 | return vv1.IsValid() == vv2.IsValid() 80 | } 81 | 82 | if vv1 == vv2 { 83 | return true 84 | } 85 | 86 | vv1Type := vv1.Type() 87 | vv2Type := vv2.Type() 88 | 89 | if vv1Type.Comparable() && vv2Type.Comparable() && v1 == v2 { 90 | return true 91 | } 92 | 93 | // 过滤掉已经在 reflect.DeepEqual() 进行处理的类型 94 | switch vv1Type.Kind() { 95 | case reflect.Struct, reflect.Ptr, reflect.Func, reflect.Interface: 96 | return false 97 | case reflect.Slice, reflect.Array: 98 | // vv2.Kind() 与 vv1 的不相同 99 | if vv2.Kind() != reflect.Slice && vv2.Kind() != reflect.Array { 100 | // 虽然类型不同,但可以相互转换成 vv1 的,如:vv2 是 string,vv2 是 []byte, 101 | if vv2Type.ConvertibleTo(vv1Type) { 102 | return isEqual(vv1.Interface(), vv2.Convert(vv1Type).Interface()) 103 | } 104 | return false 105 | } 106 | 107 | // reflect.DeepEqual() 未考虑类型不同但是类型可转换的情况,比如: 108 | // []int{8,9} == []int8{8,9},此处重新对 slice 和 array 做比较处理。 109 | if vv1.Len() != vv2.Len() { 110 | return false 111 | } 112 | 113 | for i := 0; i < vv1.Len(); i++ { 114 | if !isEqual(vv1.Index(i).Interface(), vv2.Index(i).Interface()) { 115 | return false 116 | } 117 | } 118 | return true // for 中所有的值比较都相等,返回 true 119 | case reflect.Map: 120 | if vv2.Kind() != reflect.Map { 121 | return false 122 | } 123 | 124 | if vv1.IsNil() != vv2.IsNil() { 125 | return false 126 | } 127 | if vv1.Len() != vv2.Len() { 128 | return false 129 | } 130 | if vv1.Pointer() == vv2.Pointer() { 131 | return true 132 | } 133 | 134 | // 两个 map 的键名类型不同 135 | if vv2Type.Key().Kind() != vv1Type.Key().Kind() { 136 | return false 137 | } 138 | 139 | for _, index := range vv1.MapKeys() { 140 | vv2Index := vv2.MapIndex(index) 141 | if !vv2Index.IsValid() { 142 | return false 143 | } 144 | 145 | if !isEqual(vv1.MapIndex(index).Interface(), vv2Index.Interface()) { 146 | return false 147 | } 148 | } 149 | return true // for 中所有的值比较都相等,返回 true 150 | case reflect.String: 151 | if vv2.Kind() == reflect.String { 152 | return vv1.String() == vv2.String() 153 | } 154 | if vv2Type.ConvertibleTo(vv1Type) { // 考虑 v1 是 string,v2 是 []byte 的情况 155 | return isEqual(vv1.Interface(), vv2.Convert(vv1Type).Interface()) 156 | } 157 | 158 | return false 159 | } 160 | 161 | if vv1Type.ConvertibleTo(vv2Type) { 162 | return vv2.Interface() == vv1.Convert(vv2Type).Interface() 163 | } else if vv2Type.ConvertibleTo(vv1Type) { 164 | return vv1.Interface() == vv2.Convert(vv1Type).Interface() 165 | } 166 | 167 | return false 168 | } 169 | 170 | // isContains 判断 container 是否包含了 item 的内容。若是指针,会判断指针指向的内容 171 | func isContains(container, item interface{}) bool { 172 | if container == nil { // nil不包含任何东西 173 | return false 174 | } 175 | 176 | cv := reflect.ValueOf(container) 177 | iv := reflect.ValueOf(item) 178 | 179 | for cv.Kind() == reflect.Ptr { 180 | cv = cv.Elem() 181 | } 182 | 183 | for iv.Kind() == reflect.Ptr { 184 | iv = iv.Elem() 185 | } 186 | 187 | if isEqual(container, item) { 188 | return true 189 | } 190 | 191 | // 判断是字符串的情况 192 | switch c := cv.Interface().(type) { 193 | case string: 194 | switch i := iv.Interface().(type) { 195 | case string: 196 | return strings.Contains(c, i) 197 | case []byte: 198 | return strings.Contains(c, string(i)) 199 | case []rune: 200 | return strings.Contains(c, string(i)) 201 | case byte: 202 | return bytes.IndexByte([]byte(c), i) != -1 203 | case rune: 204 | return bytes.ContainsRune([]byte(c), i) 205 | } 206 | case []byte: 207 | switch i := iv.Interface().(type) { 208 | case string: 209 | return bytes.Contains(c, []byte(i)) 210 | case []byte: 211 | return bytes.Contains(c, i) 212 | case []rune: 213 | return strings.Contains(string(c), string(i)) 214 | case byte: 215 | return bytes.IndexByte(c, i) != -1 216 | case rune: 217 | return bytes.ContainsRune(c, i) 218 | } 219 | case []rune: 220 | switch i := iv.Interface().(type) { 221 | case string: 222 | return strings.Contains(string(c), i) 223 | case []byte: 224 | return strings.Contains(string(c), string(i)) 225 | case []rune: 226 | return strings.Contains(string(c), string(i)) 227 | case byte: 228 | return strings.IndexByte(string(c), i) != -1 229 | case rune: 230 | return strings.ContainsRune(string(c), i) 231 | } 232 | } 233 | 234 | if (cv.Kind() == reflect.Slice) || (cv.Kind() == reflect.Array) { 235 | if !cv.IsValid() || cv.Len() == 0 { // 空的,就不算包含另一个,即使另一个也是空值。 236 | return false 237 | } 238 | 239 | if !iv.IsValid() { 240 | return false 241 | } 242 | 243 | // item 是 container 的一个元素 244 | for i := 0; i < cv.Len(); i++ { 245 | if isEqual(cv.Index(i).Interface(), iv.Interface()) { 246 | return true 247 | } 248 | } 249 | 250 | // 开始判断 item 的元素是否与 container 中的元素相等。 251 | 252 | // 若 item 的长度为 0,表示不包含 253 | if (iv.Kind() != reflect.Slice) || (iv.Len() == 0) { 254 | return false 255 | } 256 | 257 | // item 的元素比 container 的元素多 258 | if iv.Len() > cv.Len() { 259 | return false 260 | } 261 | 262 | // 依次比较 item 的各个子元素是否都存在于 container,且下标都相同 263 | ivIndex := 0 264 | for i := 0; i < cv.Len(); i++ { 265 | if isEqual(cv.Index(i).Interface(), iv.Index(ivIndex).Interface()) { 266 | if (ivIndex == 0) && (i+iv.Len() > cv.Len()) { 267 | return false 268 | } 269 | ivIndex++ 270 | if ivIndex == iv.Len() { // 已经遍历完 iv 271 | return true 272 | } 273 | } else if ivIndex > 0 { 274 | return false 275 | } 276 | } 277 | return false 278 | } // end cv.Kind == reflect.Slice and reflect.Array 279 | 280 | if cv.Kind() == reflect.Map { 281 | if cv.Len() == 0 { 282 | return false 283 | } 284 | 285 | if (iv.Kind() != reflect.Map) || (iv.Len() == 0) { 286 | return false 287 | } 288 | 289 | if iv.Len() > cv.Len() { 290 | return false 291 | } 292 | 293 | // 判断所有 item 的项都存在于 container 中 294 | for _, key := range iv.MapKeys() { 295 | cvItem := cv.MapIndex(key) 296 | if !cvItem.IsValid() { // container 中不包含该值。 297 | return false 298 | } 299 | if !isEqual(cvItem.Interface(), iv.MapIndex(key).Interface()) { 300 | return false 301 | } 302 | } 303 | // for 中的所有判断都成立,返回 true 304 | return true 305 | } 306 | 307 | return false 308 | } 309 | 310 | func getType(ptr bool, v1, v2 interface{}) (t1, t2 reflect.Type) { 311 | t1 = reflect.TypeOf(v1) 312 | t2 = reflect.TypeOf(v2) 313 | 314 | if ptr { 315 | for t1.Kind() == reflect.Ptr { 316 | t1 = t1.Elem() 317 | } 318 | for t2.Kind() == reflect.Ptr { 319 | t2 = t2.Elem() 320 | } 321 | } 322 | 323 | return 324 | } 325 | --------------------------------------------------------------------------------