├── .gitignore ├── README.md ├── conf ├── conf.go └── conf_test.go ├── crypto └── padding │ ├── padding.go │ ├── pkcs7.go │ ├── pkcs7_test.go │ ├── zero.go │ └── zero_test.go ├── go.mod ├── http ├── README.md ├── httpclient │ ├── client.go │ └── client_test.go └── httplog │ ├── httplog.go │ └── patch │ ├── patch.go │ └── patch_test.go ├── log ├── log.go └── log_test.go ├── sql ├── README.md ├── builder │ ├── builder.go │ └── builder_test.go ├── scanner │ ├── scanner.go │ └── scanner_test.go ├── sql.go └── sql_test.go ├── tool └── json.go └── web ├── README.md ├── context.go ├── context_test.go ├── middleware.go ├── router.go ├── server.go ├── server_test.go ├── web.go ├── web_test.go └── webutil └── response.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | go.sum 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zgo 2 | a general go library, inspired by [zocle](https://github.com/zhaoweikid/zocle) [zbase](https://github.com/zhaoweikid/zbase) 3 | 4 | simple is better 5 | -------------------------------------------------------------------------------- /conf/conf.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "flag" 8 | "io/ioutil" 9 | 10 | "github.com/BurntSushi/toml" 11 | "github.com/go-yaml/yaml" 12 | ) 13 | 14 | type formater struct { 15 | Marshal func(v interface{}) ([]byte, error) 16 | Unmarshal func(data []byte, v interface{}) error 17 | } 18 | 19 | var ( 20 | formaters = map[string]formater{ 21 | "json": formater{ 22 | Marshal: json.Marshal, 23 | Unmarshal: json.Unmarshal, 24 | }, 25 | "yaml": formater{ 26 | Marshal: yaml.Marshal, 27 | Unmarshal: yaml.Unmarshal, 28 | }, 29 | "yml": formater{ 30 | Marshal: yaml.Marshal, 31 | Unmarshal: yaml.Unmarshal, 32 | }, 33 | "toml": formater{ 34 | Marshal: func(v interface{}) ([]byte, error) { 35 | b := bytes.Buffer{} 36 | err := toml.NewEncoder(&b).Encode(v) 37 | return b.Bytes(), err 38 | }, 39 | Unmarshal: toml.Unmarshal, 40 | }, 41 | } 42 | ) 43 | 44 | func Install(path string, v interface{}) error { 45 | 46 | ext := "" 47 | // get extension 48 | for i := len(path) - 1; i >= 0; i-- { 49 | if path[i] == '.' { 50 | ext = path[i+1:] 51 | break 52 | } 53 | } 54 | if ext == "" { 55 | return errors.New("invalid file extension") 56 | } 57 | 58 | data, err := ioutil.ReadFile(path) 59 | if err != nil { 60 | return err 61 | } 62 | 63 | if formater, ok := formaters[ext]; ok { 64 | err := formater.Unmarshal(data, v) 65 | if err != nil { 66 | return err 67 | } 68 | 69 | } else { 70 | return errors.New("no support extension") 71 | } 72 | return nil 73 | } 74 | 75 | var ConfigPath string 76 | 77 | func InstallFlag(defaultv string, v interface{}) error { 78 | flag.StringVar(&ConfigPath, "c", defaultv, "config path") 79 | flag.Parse() 80 | return Install(ConfigPath, v) 81 | } 82 | -------------------------------------------------------------------------------- /conf/conf_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import "io/ioutil" 4 | import "testing" 5 | import "github.com/JoveYu/zgo/log" 6 | 7 | type Config struct { 8 | Num int `json:"num" yaml:"num" toml:"num"` 9 | Text string `json:"text" yaml:"text" toml:"text"` 10 | NumList []int `json:"num_list" yaml:"num_list" toml:"num_list"` 11 | TextDict map[string]string `json:"text_dict" yaml:"text_dict" toml:"text_dict"` 12 | } 13 | 14 | func TestJson(t *testing.T) { 15 | log.Install("stdout") 16 | data := `{"num":123, "text":"hello", "num_list": [1, 2], "text_dict":{"key1": "value1", "key2": "value2"}}` 17 | ioutil.WriteFile("/tmp/zgo_conf.json", []byte(data), 0755) 18 | 19 | c := Config{} 20 | 21 | err := Install("/tmp/zgo_conf.json", &c) 22 | if err != nil { 23 | log.Error("%v", err) 24 | } 25 | log.Debug("json %+v", c) 26 | } 27 | 28 | func TestYaml(t *testing.T) { 29 | log.Install("stdout") 30 | data := ` 31 | --- 32 | num: 123 33 | num_list: 34 | - 1 35 | - 2 36 | text: hello 37 | text_dict: 38 | key1: value1 39 | key2: value2 40 | ` 41 | ioutil.WriteFile("/tmp/zgo_conf.yaml", []byte(data), 0755) 42 | 43 | c := Config{} 44 | 45 | err := Install("/tmp/zgo_conf.yaml", &c) 46 | if err != nil { 47 | log.Error("%v", err) 48 | } 49 | log.Debug("yaml %+v", c) 50 | } 51 | 52 | func TestToml(t *testing.T) { 53 | log.Install("stdout") 54 | data := ` 55 | num = 123 56 | text = "hello" 57 | num_list = [ 1, 2,] 58 | [text_dict] 59 | key1 = "value1" 60 | key2 = "value2" 61 | ` 62 | ioutil.WriteFile("/tmp/zgo_conf.toml", []byte(data), 0755) 63 | 64 | c := Config{} 65 | 66 | err := Install("/tmp/zgo_conf.toml", &c) 67 | if err != nil { 68 | log.Error("%v", err) 69 | } 70 | log.Debug("toml %+v", c) 71 | } 72 | -------------------------------------------------------------------------------- /crypto/padding/padding.go: -------------------------------------------------------------------------------- 1 | package padding 2 | 3 | type Padding interface { 4 | Pad(data []byte) ([]byte, error) 5 | UnPad(data []byte) ([]byte, error) 6 | } 7 | -------------------------------------------------------------------------------- /crypto/padding/pkcs7.go: -------------------------------------------------------------------------------- 1 | package padding 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | ) 7 | 8 | type Pkcs7Padding struct { 9 | BlockSize int 10 | } 11 | 12 | func (p Pkcs7Padding) Pad(data []byte) ([]byte, error) { 13 | if p.BlockSize < 1 || p.BlockSize > 255 { 14 | return nil, errors.New("block size error") 15 | } 16 | 17 | length := p.BlockSize - len(data)%p.BlockSize 18 | 19 | return append(data, bytes.Repeat([]byte{byte(length)}, length)...), nil 20 | } 21 | 22 | func (p Pkcs7Padding) UnPad(data []byte) ([]byte, error) { 23 | datalen := len(data) 24 | length := int(data[datalen-1]) 25 | 26 | if datalen%p.BlockSize != 0 { 27 | return nil, errors.New("not padded correctly") 28 | } 29 | 30 | if length > p.BlockSize || length <= 0 { 31 | return nil, errors.New("not padded correctly") 32 | } 33 | 34 | padding := data[datalen-length : datalen-1] 35 | for _, i := range padding { 36 | if int(i) != length { 37 | return nil, errors.New("not padded correctly") 38 | } 39 | } 40 | 41 | return data[:datalen-length], nil 42 | } 43 | -------------------------------------------------------------------------------- /crypto/padding/pkcs7_test.go: -------------------------------------------------------------------------------- 1 | package padding 2 | 3 | import "testing" 4 | import "github.com/JoveYu/zgo/log" 5 | 6 | func TestPkcs7(t *testing.T) { 7 | log.Install("stdout") 8 | 9 | blocksize := 8 10 | alldata := [][]byte{ 11 | {0xAB, 0xCD, 0xEF}, 12 | {0xAB, 0xCD, 0xEF, 0xEF, 0xEF, 0xEF, 0xEF}, 13 | {0xAB, 0xCD, 0xEF, 0xEF, 0xEF, 0xEF, 0xEF, 0xEF}, 14 | } 15 | 16 | for _, data := range alldata { 17 | pad := Pkcs7Padding{blocksize} 18 | padded, err := pad.Pad(data) 19 | if err != nil { 20 | log.Error(err) 21 | return 22 | } 23 | unpadded, err := pad.UnPad(padded) 24 | if err != nil { 25 | log.Error(err) 26 | return 27 | } 28 | 29 | log.Debug("before: %X", data) 30 | log.Debug("pad: %X", padded) 31 | log.Debug("unpad: %X", unpadded) 32 | 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /crypto/padding/zero.go: -------------------------------------------------------------------------------- 1 | package padding 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | ) 7 | 8 | type ZeroPadding struct { 9 | BlockSize int 10 | } 11 | 12 | func (p ZeroPadding) Pad(data []byte) ([]byte, error) { 13 | if p.BlockSize < 1 || p.BlockSize > 255 { 14 | return nil, errors.New("block size error") 15 | } 16 | 17 | length := p.BlockSize - len(data)%p.BlockSize 18 | 19 | return append(data, bytes.Repeat([]byte{byte(0)}, length)...), nil 20 | } 21 | 22 | func (p ZeroPadding) UnPad(data []byte) ([]byte, error) { 23 | datalen := len(data) 24 | 25 | if datalen%p.BlockSize != 0 { 26 | return nil, errors.New("not padded correctly") 27 | } 28 | 29 | var length int 30 | for length = 0; length <= datalen; length++ { 31 | if int(data[datalen-1-length]) != 0 { 32 | break 33 | } 34 | } 35 | 36 | return data[:datalen-length], nil 37 | } 38 | -------------------------------------------------------------------------------- /crypto/padding/zero_test.go: -------------------------------------------------------------------------------- 1 | package padding 2 | 3 | import "testing" 4 | import "github.com/JoveYu/zgo/log" 5 | 6 | func TestZero(t *testing.T) { 7 | log.Install("stdout") 8 | 9 | blocksize := 8 10 | alldata := [][]byte{ 11 | {0xAB, 0xCD, 0xEF}, 12 | {0xAB, 0xCD, 0xEF, 0xEF, 0xEF, 0xEF, 0xEF}, 13 | {0xAB, 0xCD, 0xEF, 0xEF, 0xEF, 0xEF, 0xEF, 0xEF}, 14 | } 15 | 16 | for _, data := range alldata { 17 | pad := ZeroPadding{blocksize} 18 | padded, err := pad.Pad(data) 19 | if err != nil { 20 | log.Error(err) 21 | return 22 | } 23 | unpadded, err := pad.UnPad(padded) 24 | if err != nil { 25 | log.Error(err) 26 | return 27 | } 28 | 29 | log.Debug("before: %X", data) 30 | log.Debug("pad: %X", padded) 31 | log.Debug("unpad: %X", unpadded) 32 | 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/JoveYu/zgo 2 | 3 | require ( 4 | github.com/BurntSushi/toml v0.3.1 5 | github.com/go-sql-driver/mysql v1.4.1 6 | github.com/go-yaml/yaml v2.1.0+incompatible 7 | github.com/kr/pretty v0.1.0 // indirect 8 | github.com/mattn/go-sqlite3 v1.10.0 9 | google.golang.org/appengine v1.4.0 // indirect 10 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 11 | gopkg.in/yaml.v2 v2.2.2 // indirect 12 | ) 13 | -------------------------------------------------------------------------------- /http/README.md: -------------------------------------------------------------------------------- 1 | # zgo/http 2 | 3 | 4 | ## zgo/http/httplog 5 | 6 | 一行代码记录所有http日志 7 | 8 | 只需要在import中添加`_ "github.com/JoveYu/zgo/http/httplog/patch"` 即可自动打印http库的请求日志 9 | 10 | ```go 11 | package main 12 | 13 | import ( 14 | "net/http" 15 | 16 | _ "github.com/JoveYu/zgo/http/httplog/patch" 17 | "github.com/JoveYu/zgo/log" 18 | ) 19 | 20 | func main() { 21 | log.Install("stdout") 22 | http.Get("http://baidu.com") 23 | } 24 | ``` 25 | 26 | ## zgo/http/httpclient 27 | 28 | 扩展标准库的client,添加一些常用的函数,默认开启httplog 29 | 30 | -------------------------------------------------------------------------------- /http/httpclient/client.go: -------------------------------------------------------------------------------- 1 | package httpclient 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "encoding/xml" 7 | "net/http" 8 | 9 | _ "github.com/JoveYu/zgo/http/httplog/patch" 10 | ) 11 | 12 | var DefaultClient = http.DefaultClient 13 | 14 | var Do = DefaultClient.Do 15 | var Get = DefaultClient.Get 16 | var Post = DefaultClient.Post 17 | var PostForm = DefaultClient.PostForm 18 | var Head = DefaultClient.Head 19 | 20 | func PostJson(url string, v interface{}) (*http.Response, error) { 21 | data, err := json.Marshal(v) 22 | if err != nil { 23 | return nil, err 24 | } 25 | return Post(url, "application/json", bytes.NewReader(data)) 26 | } 27 | 28 | func PostXml(url string, v interface{}) (*http.Response, error) { 29 | data, err := xml.Marshal(v) 30 | if err != nil { 31 | return nil, err 32 | } 33 | return Post(url, "application/xml", bytes.NewReader(data)) 34 | } 35 | -------------------------------------------------------------------------------- /http/httpclient/client_test.go: -------------------------------------------------------------------------------- 1 | package httpclient 2 | 3 | import "io/ioutil" 4 | import "strings" 5 | import "testing" 6 | import "net/http" 7 | import "net/url" 8 | import "encoding/xml" 9 | import "github.com/JoveYu/zgo/log" 10 | 11 | func TestClient(t *testing.T) { 12 | log.Install("stdout") 13 | req, err := http.NewRequest("GET", "http://httpbin.org/get", nil) 14 | if err != nil { 15 | log.Error(err) 16 | } 17 | resp, _ := Do(req) 18 | body, _ := ioutil.ReadAll(resp.Body) 19 | log.Debug("%s", body) 20 | 21 | resp, _ = Get("http://httpbin.org/get") 22 | body, _ = ioutil.ReadAll(resp.Body) 23 | log.Debug("%s", body) 24 | 25 | resp, _ = Head("http://httpbin.org/get") 26 | body, _ = ioutil.ReadAll(resp.Body) 27 | log.Debug("%s", body) 28 | 29 | resp, _ = Post("http://httpbin.org/post", "application/json", strings.NewReader("{\"key\":\"value\"}")) 30 | body, _ = ioutil.ReadAll(resp.Body) 31 | log.Debug("%s", body) 32 | 33 | resp, _ = PostForm("http://httpbin.org/post", url.Values{ 34 | "key": []string{"1", "2"}, 35 | }) 36 | body, _ = ioutil.ReadAll(resp.Body) 37 | log.Debug("%s", body) 38 | 39 | resp, _ = PostJson("http://httpbin.org/post", map[string]int{"key": 1}) 40 | body, _ = ioutil.ReadAll(resp.Body) 41 | log.Debug("%s", body) 42 | 43 | type User struct { 44 | XMLName xml.Name `xml:"xml"` 45 | Name string `xml:"name"` 46 | Id int `xml:"id,attr"` 47 | } 48 | user := User{ 49 | Name: "test", 50 | Id: 1, 51 | } 52 | resp, _ = PostXml("http://httpbin.org/post", user) 53 | body, _ = ioutil.ReadAll(resp.Body) 54 | log.Debug("%s", body) 55 | 56 | } 57 | -------------------------------------------------------------------------------- /http/httplog/httplog.go: -------------------------------------------------------------------------------- 1 | package httplog 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/JoveYu/zgo/log" 8 | ) 9 | 10 | var DefaultLogRequest = func(start time.Time, req *http.Request, resp *http.Response, err error) { 11 | 12 | if err == nil { 13 | log.Info("ep=http|method=%s|url=%s|code=%d|req=%d|resp=%d|time=%d", 14 | req.Method, req.URL, resp.StatusCode, req.ContentLength, resp.ContentLength, 15 | time.Now().Sub(start)/time.Microsecond, 16 | ) 17 | } else { 18 | log.Warn("ep=http|method=%s|url=%s|code=%d|req=%d|resp=%d|time=%d|err=%s", 19 | req.Method, req.URL, 0, req.ContentLength, 0, 20 | time.Now().Sub(start)/time.Microsecond, err, 21 | ) 22 | } 23 | } 24 | 25 | var DefaultTransport = &Transport{ 26 | RoundTripper: http.DefaultTransport, 27 | } 28 | 29 | type Transport struct { 30 | http.RoundTripper 31 | LogRequest func(time.Time, *http.Request, *http.Response, error) 32 | } 33 | 34 | func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { 35 | start := time.Now() 36 | 37 | resp, err := t.RoundTripper.RoundTrip(req) 38 | if t.LogRequest != nil { 39 | t.LogRequest(start, req, resp, err) 40 | } else { 41 | DefaultLogRequest(start, req, resp, err) 42 | } 43 | 44 | return resp, err 45 | } 46 | -------------------------------------------------------------------------------- /http/httplog/patch/patch.go: -------------------------------------------------------------------------------- 1 | package patch 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/JoveYu/zgo/http/httplog" 7 | ) 8 | 9 | func init() { 10 | http.DefaultTransport = httplog.DefaultTransport 11 | } 12 | -------------------------------------------------------------------------------- /http/httplog/patch/patch_test.go: -------------------------------------------------------------------------------- 1 | package patch 2 | 3 | import ( 4 | "github.com/JoveYu/zgo/log" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | func TestPatch(t *testing.T) { 10 | log.Install("stdout") 11 | http.Get("http://httpbin.org/get") 12 | } 13 | -------------------------------------------------------------------------------- /log/log.go: -------------------------------------------------------------------------------- 1 | // TODO log time rotate 2 | // TODO log size rotate 3 | // TODO multi logger 4 | // TODO windows color 5 | // TODO check tty 6 | 7 | package log 8 | 9 | import ( 10 | "fmt" 11 | "log" 12 | "os" 13 | "strings" 14 | "sync" 15 | ) 16 | 17 | const ( 18 | LevelDebug = (iota + 1) * 10 19 | LevelInfo 20 | LevelWarn 21 | LevelError 22 | LevelFatal 23 | ) 24 | 25 | const ( 26 | tagDebug = "[D]" 27 | tagInfo = "[I]" 28 | tagWarn = "[W]" 29 | tagError = "[E]" 30 | tagFatal = "[F]" 31 | tagLog = "[L]" 32 | ) 33 | 34 | // ref: https://en.wikipedia.org/wiki/ANSI_escape_code 35 | const ( 36 | colorDebug = "\033[37m" 37 | colorInfo = "\033[36m" 38 | colorWarn = "\033[33m" 39 | colorError = "\033[31m" 40 | colorFatal = "\033[35m" 41 | colorReset = "\033[0m" 42 | ) 43 | 44 | // TODO 45 | const ( 46 | RotateNo = iota 47 | RotateTimeDay 48 | RotateTimeHour 49 | RotateTimeMinute 50 | RotateTimeSecond 51 | RotateSizeKB 52 | RotateSizeMB 53 | RotateSizeGB 54 | ) 55 | 56 | type LevelLogger struct { 57 | *log.Logger 58 | fp *os.File 59 | mu sync.Mutex 60 | Prefix string 61 | Filename string 62 | Level int 63 | 64 | // TODO 65 | Rotate int 66 | MaxSize int 67 | MaxBackup int 68 | } 69 | 70 | var ( 71 | DefaultLog *LevelLogger 72 | ) 73 | 74 | func GetLogger() *LevelLogger { 75 | if DefaultLog == nil { 76 | fmt.Println("can not GetLogger before Install") 77 | os.Exit(1) 78 | } 79 | return DefaultLog 80 | } 81 | 82 | func Install(dest string) *LevelLogger { 83 | 84 | var base *log.Logger 85 | var fp *os.File 86 | 87 | if dest == "stdout" { 88 | fp = os.Stdout 89 | base = log.New(fp, "", log.Ldate|log.Ltime|log.Lmicroseconds|log.Lshortfile) 90 | } else { 91 | fp, err := os.OpenFile(dest, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) 92 | if err != nil { 93 | fmt.Printf("can not open logfile: %v\n", err) 94 | } 95 | base = log.New(fp, "", log.Ldate|log.Ltime|log.Lmicroseconds|log.Lshortfile) 96 | } 97 | 98 | l := LevelLogger{ 99 | Logger: base, 100 | fp: fp, 101 | Prefix: "", 102 | Filename: dest, 103 | Level: LevelDebug, 104 | } 105 | 106 | // first logger as DefaultLog 107 | if DefaultLog == nil { 108 | DefaultLog = &l 109 | } 110 | 111 | return &l 112 | } 113 | 114 | func (l *LevelLogger) Log(level int, depth int, prefix string, v ...interface{}) { 115 | if len(v) == 0 { 116 | return 117 | } 118 | if level >= l.Level { 119 | var tag, color, message string 120 | switch level { 121 | case LevelDebug: 122 | tag = tagDebug 123 | color = colorDebug 124 | case LevelInfo: 125 | tag = tagInfo 126 | color = colorInfo 127 | case LevelWarn: 128 | tag = tagWarn 129 | color = colorWarn 130 | case LevelError: 131 | tag = tagError 132 | color = colorError 133 | case LevelFatal: 134 | tag = tagFatal 135 | color = colorFatal 136 | default: 137 | tag = tagLog 138 | color = colorReset 139 | } 140 | if format, ok := v[0].(string); ok { 141 | message = fmt.Sprintf(format, v[1:]...) 142 | } else { 143 | format := strings.Repeat("%+v ", len(v)) 144 | message = fmt.Sprintf(format, v...) 145 | } 146 | if l.Filename == "stdout" { 147 | // XXX debug only, slow with 4 lock 148 | l.mu.Lock() 149 | l.Logger.SetPrefix(color) 150 | l.Logger.Output(depth, fmt.Sprint(tag, " ", prefix, message, colorReset)) 151 | l.Logger.SetPrefix("") 152 | l.mu.Unlock() 153 | } else { 154 | l.Logger.Output(depth, fmt.Sprint(tag, " ", prefix, message)) 155 | } 156 | } 157 | } 158 | 159 | func (l *LevelLogger) SetPrefix(prefix string) { 160 | l.Prefix = prefix 161 | } 162 | 163 | func (l *LevelLogger) SetLevel(level int) { 164 | l.Level = level 165 | } 166 | 167 | func (l *LevelLogger) Debug(v ...interface{}) { 168 | l.Log(LevelDebug, 3, l.Prefix, v...) 169 | } 170 | 171 | func (l *LevelLogger) Info(v ...interface{}) { 172 | l.Log(LevelInfo, 3, l.Prefix, v...) 173 | } 174 | 175 | func (l *LevelLogger) Warn(v ...interface{}) { 176 | l.Log(LevelWarn, 3, l.Prefix, v...) 177 | } 178 | 179 | func (l *LevelLogger) Error(v ...interface{}) { 180 | l.Log(LevelError, 3, l.Prefix, v...) 181 | } 182 | 183 | func (l *LevelLogger) Fatal(v ...interface{}) { 184 | l.Log(LevelFatal, 3, l.Prefix, v...) 185 | os.Exit(1) 186 | } 187 | 188 | func Debug(v ...interface{}) { 189 | if DefaultLog != nil { 190 | DefaultLog.Log(LevelDebug, 3, DefaultLog.Prefix, v...) 191 | } 192 | } 193 | 194 | func Info(v ...interface{}) { 195 | if DefaultLog != nil { 196 | DefaultLog.Log(LevelInfo, 3, DefaultLog.Prefix, v...) 197 | } 198 | } 199 | 200 | func Warn(v ...interface{}) { 201 | if DefaultLog != nil { 202 | DefaultLog.Log(LevelWarn, 3, DefaultLog.Prefix, v...) 203 | } 204 | } 205 | 206 | func Error(v ...interface{}) { 207 | if DefaultLog != nil { 208 | DefaultLog.Log(LevelError, 3, DefaultLog.Prefix, v...) 209 | } 210 | } 211 | 212 | func Fatal(v ...interface{}) { 213 | if DefaultLog != nil { 214 | DefaultLog.Log(LevelFatal, 3, DefaultLog.Prefix, v...) 215 | } 216 | os.Exit(1) 217 | } 218 | 219 | func Debugd(depth int, v ...interface{}) { 220 | if DefaultLog != nil { 221 | DefaultLog.Log(LevelDebug, 3+depth, DefaultLog.Prefix, v...) 222 | } 223 | } 224 | 225 | func Infod(depth int, v ...interface{}) { 226 | if DefaultLog != nil { 227 | DefaultLog.Log(LevelInfo, 3+depth, DefaultLog.Prefix, v...) 228 | } 229 | } 230 | 231 | func Warnd(depth int, v ...interface{}) { 232 | if DefaultLog != nil { 233 | DefaultLog.Log(LevelWarn, 3+depth, DefaultLog.Prefix, v...) 234 | } 235 | } 236 | 237 | func Errord(depth int, v ...interface{}) { 238 | if DefaultLog != nil { 239 | DefaultLog.Log(LevelError, 3+depth, DefaultLog.Prefix, v...) 240 | } 241 | } 242 | 243 | func Fatald(depth int, v ...interface{}) { 244 | if DefaultLog != nil { 245 | DefaultLog.Log(LevelFatal, 3+depth, DefaultLog.Prefix, v...) 246 | } 247 | os.Exit(1) 248 | } 249 | -------------------------------------------------------------------------------- /log/log_test.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import "errors" 4 | import "testing" 5 | 6 | func TestInstall(t *testing.T) { 7 | Install("stdout") 8 | log := GetLogger() 9 | log.Debug("test") 10 | log.Debug("test %s", "format") 11 | 12 | Debug(errors.New("test error"), errors.New("test error")) 13 | Debug("test") 14 | Info("test %s", "format") 15 | Warn("test %s", "format") 16 | Error("test %s", "format") 17 | } 18 | func TestLevel(t *testing.T) { 19 | Install("stdout") 20 | log := GetLogger() 21 | for i := 0; i < 10; i++ { 22 | log.Debug("中文 debug %d", i) 23 | log.Info("😀 info %d", i) 24 | log.Warn("warn %d", i) 25 | log.Error("error %d", i) 26 | log.Printf("print %d", i) 27 | } 28 | 29 | log.SetLevel(LevelWarn) 30 | for i := 0; i < 10; i++ { 31 | log.Debug("debug %d", i) 32 | log.Info("info %d", i) 33 | log.Warn("warn %d", i) 34 | log.Error("error %d", i) 35 | } 36 | // log.Fatal("fatal") 37 | } 38 | func TestPrefix(t *testing.T) { 39 | log := Install("stdout") 40 | log.Debug("test") 41 | log.SetPrefix("prefix:") 42 | log.Debug("test") 43 | log.Debug("test") 44 | log.SetPrefix("") 45 | log.Debug("test") 46 | } 47 | -------------------------------------------------------------------------------- /sql/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## zgo/sql 3 | 4 | 整体思路和使用习惯与 `github.com/JoveYu/zpy/base/dbpool.py` 一致,主要根据go语言静态强类型,以及不能使用可选参数的特性,调整了下使用方式 5 | 6 | 整体使用了一段时间还是比较好用的,符合我的风格 7 | 8 | 1. 方便的初始化数据库连接池 9 | 2. 方便利用结构化数据组装SQL 10 | 3. 提供方便的Scan,可以直接查询结果到struct 11 | 4. 统一日志输出,打印连接池状态 12 | 13 | ## 文档 14 | 15 | [https://godoc.org/github.com/JoveYu/zgo/sql](https://godoc.org/github.com/JoveYu/zgo/sql) 16 | 17 | ## TODO 18 | 19 | 1. sqlbuilder 支持selectjoin 比较容易 现在没需求 20 | 21 | ## Example 22 | 23 | ```go 24 | package main 25 | 26 | import ( 27 | "fmt" 28 | _ "github.com/mattn/go-sqlite3" 29 | "time" 30 | 31 | "github.com/JoveYu/zgo/log" 32 | "github.com/JoveYu/zgo/sql" 33 | ) 34 | 35 | func main() { 36 | log.Install("stdout") 37 | 38 | sql.Install(sql.DBConf{ 39 | "testdb": []string{"sqlite3", "file::memory:?mode=memory&cache=shared"}, 40 | }) 41 | db := sql.GetDB("testdb") 42 | 43 | db.Exec("drop table if exists test") 44 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 45 | 46 | // INSERT INTO `test`(`id`,`name`,`time`) VALUES(1,'name 1','2019-02-20 16:20:07') 47 | for i := 1; i <= 10; i++ { 48 | db.Insert("test", sql.Values{ 49 | "id": i, 50 | "name": fmt.Sprintf("name %d", i), 51 | "time": time.Now(), 52 | }) 53 | } 54 | 55 | // SELECT * FROM `test` WHERE (`id` = 2) order by id desc limit 1 56 | db.Select("test", sql.Where{ 57 | "id": 2, 58 | "_other": "order by id desc limit 1", 59 | }) 60 | 61 | // SELECT count(1) FROM `test` WHERE (`id` > 2) and (`id` between 4 and 6) and (`name` in ('foo','bar')) 62 | db.Select("test", sql.Where{ 63 | "id >": 2, 64 | "id between": []int{4, 6}, 65 | "name in": []string{"foo", "bar"}, 66 | "_field": "count(1)", 67 | }) 68 | 69 | // SELECT * FROM `test` WHERE (`id` > 2) GROUP BY name HAVING (`id` > 3) 70 | db.Select("test", sql.Where{ 71 | "id >": 2, 72 | "_groupby": "name", 73 | "_having": sql.Where{ 74 | "id >": 3, 75 | }, 76 | }) 77 | 78 | // UPDATE `test` SET `id`=-1 WHERE (`id` > 9) 79 | db.Update("test", sql.Values{ 80 | "id": -1, 81 | }, sql.Where{ 82 | "id >": 9, 83 | }) 84 | 85 | // UPDATE `test` SET `name`='jove' 86 | db.Update("test", sql.Values{ 87 | "name": "jove", 88 | }, sql.Where{}) 89 | 90 | // DELETE FROM `test` WHERE (`name` != 'jove') 91 | db.Delete("test", sql.Where{ 92 | "name !=": "jove", 93 | }) 94 | 95 | // select scan to struct 96 | type User struct { 97 | Id int `zdb:"id"` 98 | Name string `zdb:"name"` 99 | Time time.Time `zdb:"time"` 100 | } 101 | user := []User{} 102 | db.SelectScan(&user, "test", sql.Where{}) 103 | log.Debug(user) 104 | // [{Id:-1 Name:jove Time:2019-02-20 17:20:25.03967 +0800 +0800}...... 105 | 106 | } 107 | 108 | ``` 109 | 110 | -------------------------------------------------------------------------------- /sql/builder/builder.go: -------------------------------------------------------------------------------- 1 | // build sql just like use zpy/base/dbpool.py 2 | // not build for all sql 3 | 4 | // TODO use $1 instead of ? for pg 5 | // TODO SelectJoin 6 | // TODO InsertMany 7 | 8 | package builder 9 | 10 | import ( 11 | "fmt" 12 | "reflect" 13 | "strings" 14 | ) 15 | 16 | type Where map[string]interface{} 17 | type Values map[string]interface{} 18 | 19 | func Select(table string, where Where) (string, []interface{}) { 20 | 21 | var args []interface{} 22 | 23 | field := "*" 24 | groupby := "" 25 | having := Where{} 26 | other := "" 27 | 28 | if value, ok := where["_field"]; ok { 29 | field = value.(string) 30 | delete(where, "_field") 31 | } 32 | if value, ok := where["_groupby"]; ok { 33 | groupby = value.(string) 34 | delete(where, "_groupby") 35 | } 36 | if value, ok := where["_having"]; ok { 37 | having = value.(Where) 38 | delete(where, "_having") 39 | } 40 | if value, ok := where["_other"]; ok { 41 | other = value.(string) 42 | delete(where, "_other") 43 | } 44 | 45 | sb := strings.Builder{} 46 | sb.WriteString(fmt.Sprintf("SELECT %s FROM `%s`", field, table)) 47 | 48 | // where 49 | if len(where) > 0 { 50 | sql, arg := where2sql(where) 51 | sb.WriteString(" WHERE ") 52 | sb.WriteString(sql) 53 | args = append(args, arg...) 54 | } 55 | 56 | // groupby 57 | if groupby != "" { 58 | sb.WriteString(" GROUP BY ") 59 | sb.WriteString(groupby) 60 | } 61 | 62 | // having 63 | if len(having) > 0 { 64 | sql, arg := where2sql(having) 65 | sb.WriteString(" HAVING ") 66 | sb.WriteString(sql) 67 | args = append(args, arg...) 68 | } 69 | 70 | // orderby limit offset 71 | if other != "" { 72 | sb.WriteString(" ") 73 | sb.WriteString(other) 74 | } 75 | 76 | return sb.String(), args 77 | } 78 | 79 | func Insert(table string, value Values) (string, []interface{}) { 80 | k, v, i := values2insert(value) 81 | sql := fmt.Sprintf("INSERT INTO `%s`(%s) VALUES(%s)", table, k, v) 82 | return sql, i 83 | } 84 | 85 | func Update(table string, value Values, where Where) (string, []interface{}) { 86 | var args []interface{} 87 | 88 | other := "" 89 | if value, ok := where["_other"]; ok { 90 | other = value.(string) 91 | delete(where, "_other") 92 | } 93 | 94 | sb := strings.Builder{} 95 | sb.WriteString(fmt.Sprintf("UPDATE `%s`", table)) 96 | 97 | // set 98 | k, v := values2set(value) 99 | sb.WriteString(" SET ") 100 | sb.WriteString(k) 101 | args = append(args, v...) 102 | 103 | // where 104 | if len(where) > 0 { 105 | sql, arg := where2sql(where) 106 | sb.WriteString(" WHERE ") 107 | sb.WriteString(sql) 108 | args = append(args, arg...) 109 | } 110 | 111 | // orderby limit offset 112 | if other != "" { 113 | sb.WriteString(" ") 114 | sb.WriteString(other) 115 | } 116 | 117 | return sb.String(), args 118 | } 119 | 120 | func Delete(table string, where Where) (string, []interface{}) { 121 | var args []interface{} 122 | 123 | other := "" 124 | if value, ok := where["_other"]; ok { 125 | other = value.(string) 126 | delete(where, "_other") 127 | } 128 | 129 | sb := strings.Builder{} 130 | sb.WriteString(fmt.Sprintf("DELETE FROM `%s`", table)) 131 | 132 | // where 133 | if len(where) > 0 { 134 | sql, arg := where2sql(where) 135 | sb.WriteString(" WHERE ") 136 | sb.WriteString(sql) 137 | args = append(args, arg...) 138 | } 139 | 140 | // orderby limit offset 141 | if other != "" { 142 | sb.WriteString(" ") 143 | sb.WriteString(other) 144 | } 145 | 146 | return sb.String(), args 147 | } 148 | 149 | func FormatSql(query string, args ...interface{}) string { 150 | if len(args) == 0 { 151 | return query 152 | } 153 | if strings.Count(query, "?") != len(args) { 154 | return query 155 | } 156 | // XXX for logging only, not real sql 157 | // TODO pg is not '?' 158 | query = strings.Replace(query, "?", "[%+v]", -1) 159 | return fmt.Sprintf(query, args...) 160 | } 161 | 162 | func values2insert(values Values) (string, string, []interface{}) { 163 | var args []interface{} 164 | var name []string 165 | var value []string 166 | for k, v := range values { 167 | name = append(name, fmt.Sprintf("`%s`", k)) 168 | value = append(value, "?") 169 | args = append(args, v) 170 | } 171 | return strings.Join(name, ","), strings.Join(value, ","), args 172 | } 173 | 174 | func values2set(values Values) (sql string, args []interface{}) { 175 | var sqls []string 176 | for k, v := range values { 177 | sqls = append(sqls, fmt.Sprintf("`%s`=?", k)) 178 | args = append(args, v) 179 | } 180 | return strings.Join(sqls, ","), args 181 | } 182 | 183 | func where2sql(where Where) (sql string, args []interface{}) { 184 | var key, op string 185 | var sqls []string 186 | for k, v := range where { 187 | k = strings.Trim(k, " ") 188 | idx := strings.IndexByte(k, ' ') 189 | if idx == -1 { 190 | key = k 191 | op = "=" 192 | } else { 193 | key = k[:idx] 194 | op = k[idx+1:] 195 | } 196 | s, i := exp2sql(key, op, v) 197 | sqls = append(sqls, s) 198 | args = append(args, i...) 199 | } 200 | return strings.Join(sqls, " and "), args 201 | } 202 | 203 | func exp2sql(key string, op string, value interface{}) (sql string, args []interface{}) { 204 | 205 | builder := strings.Builder{} 206 | builder.WriteString(fmt.Sprintf("(`%s` %s ", key, op)) 207 | 208 | if strings.Contains(op, "in") { 209 | builder.WriteString("(") 210 | for idx, v := range interface2slice(value) { 211 | if idx == 0 { 212 | builder.WriteString("?") 213 | } else { 214 | builder.WriteString(",?") 215 | } 216 | args = append(args, v) 217 | } 218 | builder.WriteString("))") 219 | } else if strings.Contains(op, "between") { 220 | builder.WriteString("? and ?)") 221 | v := interface2slice(value) 222 | args = append(args, v[0]) 223 | args = append(args, v[1]) 224 | } else { 225 | builder.WriteString("?)") 226 | args = append(args, value) 227 | } 228 | return builder.String(), args 229 | } 230 | func interface2slice(value interface{}) []interface{} { 231 | v := reflect.ValueOf(value) 232 | if v.Kind() != reflect.Slice { 233 | return nil 234 | } 235 | s := make([]interface{}, v.Len()) 236 | for i := 0; i < v.Len(); i++ { 237 | s[i] = v.Index(i).Interface() 238 | } 239 | return s 240 | } 241 | -------------------------------------------------------------------------------- /sql/builder/builder_test.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | _ "github.com/mattn/go-sqlite3" 10 | 11 | "github.com/JoveYu/zgo/log" 12 | ) 13 | 14 | func TestAll(t *testing.T) { 15 | log.Install("stdout") 16 | 17 | db, _ := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared") 18 | db.Exec("drop table if exists test") 19 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 20 | for i := 1; i <= 10; i++ { 21 | sql, args := Insert("test", Values{ 22 | "id": i, 23 | "name": fmt.Sprintf("name %d", i), 24 | "time": time.Now(), 25 | }) 26 | log.Debug("sql: %s, args: %v", sql, args) 27 | log.Debug("sql: %s", FormatSql(sql, args...)) 28 | _, err := db.Exec(sql, args...) 29 | log.Debug("insert err:%v", err) 30 | } 31 | 32 | sql, args := Select("test", Where{ 33 | "id": 1, 34 | "id >": 0, 35 | "id is": nil, 36 | "id not in": []int{3, 4}, 37 | "name between": []string{"name 1", "name 5"}, 38 | "_field": "count(*)", 39 | "_groupby": "name", 40 | "_having": Where{ 41 | "id >": 0, 42 | }, 43 | "_other": "limit 1", 44 | }) 45 | 46 | log.Debug("sql: %s, args: %v", sql, args) 47 | log.Debug("sql: %s", FormatSql(sql, args...)) 48 | 49 | _, err := db.Query(sql, args...) 50 | log.Debug("select err:%v", err) 51 | 52 | sql, args = Update("test", Values{ 53 | "name": "new name", 54 | }, Where{ 55 | "id >": 3, 56 | }) 57 | log.Debug("sql: %s, args: %v", sql, args) 58 | log.Debug("sql: %s", FormatSql(sql, args...)) 59 | _, err = db.Exec(sql, args...) 60 | log.Debug("update err:%v", err) 61 | 62 | sql, args = Delete("test", Where{ 63 | "id !=": 3, 64 | }) 65 | log.Debug("sql: %s, args: %v", sql, args) 66 | log.Debug("sql: %s", FormatSql(sql, args...)) 67 | _, err = db.Exec(sql, args...) 68 | log.Debug("delete err:%v", err) 69 | 70 | sql, args = Select("test", Where{ 71 | "id >": 0, 72 | "id > ": 1, 73 | }) 74 | log.Debug("sql: %s, args: %v", sql, args) 75 | 76 | } 77 | -------------------------------------------------------------------------------- /sql/scanner/scanner.go: -------------------------------------------------------------------------------- 1 | package scanner 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "reflect" 7 | 8 | "github.com/JoveYu/zgo/log" 9 | ) 10 | 11 | var ( 12 | StructTag string = "zdb" 13 | // for useless field to scan 14 | tmpField sql.RawBytes = []byte{} 15 | ) 16 | 17 | func ScanStruct(rows *sql.Rows, dest interface{}) error { 18 | v := reflect.ValueOf(dest) 19 | 20 | if v.Kind() != reflect.Ptr { 21 | return errors.New("must pass a pointer, not a value") 22 | } 23 | 24 | v = v.Elem() 25 | tp := v.Type() 26 | 27 | switch tp.Kind() { 28 | case reflect.Slice: 29 | for rows.Next() { 30 | obj := reflect.New(tp.Elem()) 31 | 32 | err := scanOne(rows, obj.Interface()) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | v.Set(reflect.Append(v, obj.Elem())) 38 | } 39 | 40 | case reflect.Struct: 41 | if rows.Next() { 42 | err := scanOne(rows, dest) 43 | if err != nil { 44 | return err 45 | } 46 | } 47 | default: 48 | return errors.New("unknow dest") 49 | } 50 | return nil 51 | } 52 | 53 | func scanOne(rows *sql.Rows, dest interface{}) error { 54 | v := reflect.ValueOf(dest) 55 | v = v.Elem() 56 | tp := v.Type() 57 | 58 | if v.Kind() != reflect.Struct { 59 | return errors.New("dest is not struct") 60 | } 61 | 62 | cols, err := rows.Columns() 63 | if err != nil { 64 | return err 65 | } 66 | 67 | fields := make([]interface{}, len(cols)) 68 | 69 | for idx, col := range cols { 70 | ok := false 71 | for i := 0; i < v.NumField(); i++ { 72 | f := v.Field(i) 73 | tag := tp.Field(i).Tag.Get(StructTag) 74 | if tag == col { 75 | ok = true 76 | fields[idx] = f.Addr().Interface() 77 | break 78 | } 79 | } 80 | if !ok { 81 | fields[idx] = &tmpField 82 | log.Warn("sql scanner skip field [%s] in struct", col) 83 | } 84 | } 85 | 86 | err = rows.Scan(fields...) 87 | if err != nil { 88 | return err 89 | } 90 | 91 | return nil 92 | } 93 | -------------------------------------------------------------------------------- /sql/scanner/scanner_test.go: -------------------------------------------------------------------------------- 1 | package scanner 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | _ "github.com/mattn/go-sqlite3" 9 | 10 | "github.com/JoveYu/zgo/log" 11 | "github.com/JoveYu/zgo/sql" 12 | ) 13 | 14 | type Test struct { 15 | Id int `zdb:"id"` 16 | Name string `zdb:"name"` 17 | Time time.Time `zdb:"time"` 18 | } 19 | 20 | func TestAll(t *testing.T) { 21 | log.Install("stdout") 22 | sql.Install(sql.DBConf{ 23 | "sqlite3": []string{"sqlite3", "file::memory:?mode=memory&cache=shared"}, 24 | }) 25 | db := sql.GetDB("sqlite3") 26 | db.Exec("drop table if exists test") 27 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 28 | for i := 1; i <= 3; i++ { 29 | db.Insert("test", sql.Values{ 30 | "id": i, 31 | "name": fmt.Sprintf("name %d", i), 32 | "time": time.Now(), 33 | }) 34 | } 35 | rows, _ := db.Select("test", sql.Where{"id >": 2}) 36 | 37 | test := []Test{} 38 | log.Debug(test) 39 | err := ScanStruct(rows, &test) 40 | log.Debug(err) 41 | log.Debug(test) 42 | 43 | rows, _ = db.Select("test", sql.Where{}) 44 | 45 | test2 := Test{} 46 | log.Debug(test2) 47 | err = ScanStruct(rows, &test2) 48 | log.Debug(err) 49 | log.Debug(test2) 50 | 51 | } 52 | -------------------------------------------------------------------------------- /sql/sql.go: -------------------------------------------------------------------------------- 1 | // use sql not orm 2 | // use simple sql not join 3 | 4 | // use go sql just like python dbpool.py 5 | // ref : https://github.com/JoveYu/zpy/blob/master/base/dbpool.py 6 | 7 | // XXX overwrite too many func for logging 8 | 9 | package sql 10 | 11 | import ( 12 | "context" 13 | "database/sql" 14 | "strings" 15 | "time" 16 | 17 | "github.com/JoveYu/zgo/log" 18 | "github.com/JoveYu/zgo/sql/builder" 19 | "github.com/JoveYu/zgo/sql/scanner" 20 | ) 21 | 22 | var ( 23 | dbMap = make(map[string]DB) 24 | ) 25 | 26 | type DBTool struct { 27 | db *DB 28 | tx *Tx 29 | } 30 | 31 | type DB struct { 32 | *DBTool 33 | *sql.DB 34 | name string 35 | driver string 36 | dsn string 37 | } 38 | 39 | type Tx struct { 40 | *DBTool 41 | *sql.Tx 42 | db *DB 43 | } 44 | 45 | type Where builder.Where 46 | type Values builder.Values 47 | type DBConf map[string][]string 48 | 49 | func Install(conf DBConf) map[string]DB { 50 | log.Debug("available sql driver: %s", sql.Drivers()) 51 | for k, v := range conf { 52 | if len(v) != 2 { 53 | log.Fatal("parse db config error") 54 | } 55 | db, err := sql.Open(v[0], v[1]) 56 | if err != nil { 57 | log.Fatal("%s", err) 58 | } 59 | 60 | // escape password 61 | dsn := v[1] 62 | start := strings.IndexByte(dsn, ':') 63 | end := strings.IndexByte(dsn, '@') 64 | if start > 0 && end > 0 { 65 | dsn = dsn[:start+1] + "***" + dsn[end:] 66 | } 67 | 68 | zdb := DB{ 69 | DB: db, 70 | name: k, 71 | driver: v[0], 72 | dsn: dsn, 73 | } 74 | zdb.DBTool = &DBTool{db: &zdb} 75 | 76 | dbMap[k] = zdb 77 | log.Info("ep=%s|func=install|name=%s|conf=%s", zdb.driver, zdb.name, zdb.dsn) 78 | } 79 | return dbMap 80 | } 81 | 82 | func GetDB(name string) *DB { 83 | if db, ok := dbMap[name]; ok { 84 | return &db 85 | } else { 86 | log.Error("can not get db [%s]", name) 87 | return nil 88 | } 89 | } 90 | 91 | func (t *Tx) Exec(query string, args ...interface{}) (result sql.Result, err error) { 92 | defer t.db.timeit(time.Now(), &err, true, query, args...) 93 | 94 | result, err = t.Tx.Exec(query, args...) 95 | return 96 | } 97 | 98 | func (t *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { 99 | defer t.db.timeit(time.Now(), &err, true, query, args...) 100 | 101 | result, err = t.Tx.ExecContext(ctx, query, args...) 102 | return 103 | } 104 | 105 | func (t *Tx) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { 106 | defer t.db.timeit(time.Now(), &err, true, query, args...) 107 | 108 | rows, err = t.Tx.Query(query, args...) 109 | return 110 | } 111 | 112 | func (t *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 113 | defer t.db.timeit(time.Now(), &err, true, query, args...) 114 | 115 | rows, err = t.Tx.QueryContext(ctx, query, args...) 116 | return 117 | } 118 | 119 | func (t *Tx) QueryRow(query string, args ...interface{}) *sql.Row { 120 | defer t.db.timeit(time.Now(), nil, true, query, args...) 121 | 122 | return t.Tx.QueryRow(query, args...) 123 | } 124 | 125 | func (t *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 126 | defer t.db.timeit(time.Now(), nil, true, query, args...) 127 | 128 | return t.Tx.QueryRowContext(ctx, query, args...) 129 | } 130 | 131 | func (t *Tx) Commit() error { 132 | d := t.db 133 | log.Info("ep=%s|name=%s|func=commit", d.driver, d.name) 134 | return t.Tx.Commit() 135 | } 136 | 137 | func (t *Tx) Rollback() error { 138 | d := t.db 139 | log.Info("ep=%s|name=%s|func=rollback", d.driver, d.name) 140 | return t.Tx.Rollback() 141 | } 142 | 143 | func (d *DB) timeit(start time.Time, err *error, trans bool, query string, args ...interface{}) { 144 | stat := d.DB.Stats() 145 | duration := time.Since(start) 146 | 147 | t := 0 148 | if trans { 149 | t = 1 150 | } 151 | 152 | if *err == nil { 153 | log.Info("ep=%s|name=%s|use=%d|idle=%d|max=%d|wait=%d|waittime=%d|time=%d|trans=%d|sql=%s|err=", 154 | d.driver, d.name, stat.InUse, stat.Idle, stat.MaxOpenConnections, stat.WaitCount, 155 | stat.WaitDuration/time.Microsecond, duration/time.Microsecond, t, 156 | builder.FormatSql(query, args...), 157 | ) 158 | } else { 159 | log.Warn("ep=%s|name=%s|use=%d|idle=%d|max=%d|wait=%d|waittime=%d|time=%d|trans=%d|sql=%s|err=%s", 160 | d.driver, d.name, stat.InUse, stat.Idle, stat.MaxOpenConnections, stat.WaitCount, 161 | stat.WaitDuration/time.Microsecond, duration/time.Microsecond, t, 162 | builder.FormatSql(query, args...), *err, 163 | ) 164 | } 165 | } 166 | 167 | func (d *DB) Begin() (*Tx, error) { 168 | log.Info("ep=%s|name=%s|func=begin", d.driver, d.name) 169 | tx, err := d.DB.Begin() 170 | ztx := Tx{ 171 | Tx: tx, 172 | db: d, 173 | } 174 | ztx.DBTool = &DBTool{tx: &ztx} 175 | return &ztx, err 176 | } 177 | 178 | func (d *DB) Exec(query string, args ...interface{}) (result sql.Result, err error) { 179 | defer d.timeit(time.Now(), &err, false, query, args...) 180 | 181 | result, err = d.DB.Exec(query, args...) 182 | return 183 | } 184 | 185 | func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { 186 | defer d.timeit(time.Now(), &err, false, query, args...) 187 | 188 | result, err = d.DB.ExecContext(ctx, query, args...) 189 | return 190 | } 191 | 192 | func (d *DB) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { 193 | defer d.timeit(time.Now(), &err, false, query, args...) 194 | 195 | rows, err = d.DB.Query(query, args...) 196 | return 197 | } 198 | 199 | func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 200 | defer d.timeit(time.Now(), &err, false, query, args...) 201 | 202 | rows, err = d.DB.QueryContext(ctx, query, args...) 203 | return 204 | } 205 | 206 | func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { 207 | defer d.timeit(time.Now(), nil, false, query, args...) 208 | 209 | return d.DB.QueryRow(query, args...) 210 | } 211 | 212 | func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 213 | defer d.timeit(time.Now(), nil, false, query, args...) 214 | 215 | return d.DB.QueryRowContext(ctx, query, args...) 216 | } 217 | 218 | func (d *DBTool) QueryScan(obj interface{}, query string, args ...interface{}) error { 219 | var rows *sql.Rows 220 | var err error 221 | 222 | if d.tx != nil { 223 | rows, err = d.tx.Query(query, args...) 224 | } else { 225 | rows, err = d.db.Query(query, args...) 226 | } 227 | if err != nil { 228 | return err 229 | } 230 | defer rows.Close() 231 | 232 | err = scanner.ScanStruct(rows, obj) 233 | if err != nil { 234 | return err 235 | } 236 | 237 | return nil 238 | } 239 | 240 | func (d *DBTool) QueryContextScan(ctx context.Context, obj interface{}, query string, args ...interface{}) error { 241 | var rows *sql.Rows 242 | var err error 243 | 244 | if d.tx != nil { 245 | rows, err = d.tx.QueryContext(ctx, query, args...) 246 | } else { 247 | rows, err = d.db.QueryContext(ctx, query, args...) 248 | } 249 | if err != nil { 250 | return err 251 | } 252 | defer rows.Close() 253 | 254 | err = scanner.ScanStruct(rows, obj) 255 | if err != nil { 256 | return err 257 | } 258 | 259 | return nil 260 | } 261 | 262 | func (d *DBTool) SelectScan(obj interface{}, table string, where Where) error { 263 | sql, args := builder.Select(table, d.escapeWhere(where)) 264 | return d.QueryScan(obj, sql, args...) 265 | } 266 | 267 | func (d *DBTool) SelectContextScan(ctx context.Context, obj interface{}, table string, where Where) error { 268 | sql, args := builder.Select(table, d.escapeWhere(where)) 269 | return d.QueryContextScan(ctx, obj, sql, args...) 270 | } 271 | 272 | // TODO 273 | func (d *DBTool) QueryMap(query string, args ...interface{}) (data []map[string]interface{}, err error) { 274 | var rows *sql.Rows 275 | 276 | if d.tx != nil { 277 | rows, err = d.tx.Query(query, args...) 278 | } else { 279 | rows, err = d.db.Query(query, args...) 280 | } 281 | 282 | if err != nil { 283 | return nil, err 284 | } 285 | defer rows.Close() 286 | 287 | cols, err := rows.Columns() 288 | if err != nil { 289 | return nil, err 290 | } 291 | 292 | for rows.Next() { 293 | 294 | values := make([]interface{}, len(cols)) 295 | for i := range values { 296 | values[i] = new(interface{}) 297 | } 298 | 299 | err := rows.Scan(values...) 300 | if err != nil { 301 | return nil, err 302 | } 303 | 304 | m := make(map[string]interface{}) 305 | for i, col := range cols { 306 | m[col] = *(values[i].((*interface{}))) 307 | } 308 | data = append(data, m) 309 | } 310 | 311 | return data, nil 312 | } 313 | 314 | // TODO 315 | func (d *DBTool) SelectMap(table string, where Where) ([]map[string]interface{}, error) { 316 | sql, args := builder.Select(table, d.escapeWhere(where)) 317 | return d.QueryMap(sql, args...) 318 | } 319 | 320 | func (d *DBTool) Select(table string, where Where) (*sql.Rows, error) { 321 | sql, args := builder.Select(table, d.escapeWhere(where)) 322 | if d.tx != nil { 323 | return d.tx.Query(sql, args...) 324 | } else { 325 | return d.db.Query(sql, args...) 326 | } 327 | } 328 | 329 | func (d *DBTool) SelectContext(ctx context.Context, table string, where Where) (*sql.Rows, error) { 330 | sql, args := builder.Select(table, d.escapeWhere(where)) 331 | if d.tx != nil { 332 | return d.tx.QueryContext(ctx, sql, args...) 333 | } else { 334 | return d.db.QueryContext(ctx, sql, args...) 335 | } 336 | } 337 | 338 | func (d *DBTool) Insert(table string, value Values) (sql.Result, error) { 339 | sql, args := builder.Insert(table, builder.Values(value)) 340 | 341 | if d.tx != nil { 342 | return d.tx.Exec(sql, args...) 343 | } else { 344 | return d.db.Exec(sql, args...) 345 | } 346 | } 347 | 348 | func (d *DBTool) InsertContext(ctx context.Context, table string, value Values) (sql.Result, error) { 349 | sql, args := builder.Insert(table, builder.Values(value)) 350 | 351 | if d.tx != nil { 352 | return d.tx.ExecContext(ctx, sql, args...) 353 | } else { 354 | return d.db.ExecContext(ctx, sql, args...) 355 | } 356 | } 357 | 358 | func (d *DBTool) Update(table string, value Values, where Where) (sql.Result, error) { 359 | 360 | sql, args := builder.Update(table, builder.Values(value), d.escapeWhere(where)) 361 | 362 | if d.tx != nil { 363 | return d.tx.Exec(sql, args...) 364 | } else { 365 | return d.db.Exec(sql, args...) 366 | } 367 | } 368 | func (d *DBTool) UpdateContext(ctx context.Context, table string, value Values, where Where) (sql.Result, error) { 369 | 370 | sql, args := builder.Update(table, builder.Values(value), d.escapeWhere(where)) 371 | 372 | if d.tx != nil { 373 | return d.tx.ExecContext(ctx, sql, args...) 374 | } else { 375 | return d.db.ExecContext(ctx, sql, args...) 376 | } 377 | } 378 | 379 | func (d *DBTool) Delete(table string, where Where) (sql.Result, error) { 380 | 381 | sql, args := builder.Delete(table, d.escapeWhere(where)) 382 | 383 | if d.tx != nil { 384 | return d.tx.Exec(sql, args...) 385 | } else { 386 | return d.db.Exec(sql, args...) 387 | } 388 | } 389 | func (d *DBTool) DeleteContext(ctx context.Context, table string, where Where) (sql.Result, error) { 390 | 391 | sql, args := builder.Delete(table, d.escapeWhere(where)) 392 | 393 | if d.tx != nil { 394 | return d.tx.ExecContext(ctx, sql, args...) 395 | } else { 396 | return d.db.ExecContext(ctx, sql, args...) 397 | } 398 | } 399 | 400 | func (d *DBTool) escapeWhere(where Where) builder.Where { 401 | if value, ok := where["_having"]; ok { 402 | where["_having"] = builder.Where(value.(Where)) 403 | } 404 | return builder.Where(where) 405 | } 406 | -------------------------------------------------------------------------------- /sql/sql_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import "fmt" 4 | import "context" 5 | import "sync" 6 | import "time" 7 | import "testing" 8 | import "github.com/JoveYu/zgo/log" 9 | import _ "github.com/mattn/go-sqlite3" 10 | import _ "github.com/go-sql-driver/mysql" 11 | 12 | func TestInstall(t *testing.T) { 13 | log.Install("stdout") 14 | Install(DBConf{ 15 | "sqlite3": []string{"sqlite3", "file::memory:?mode=memory&cache=shared"}, 16 | // "mysql": []string{"mysql", "test:123456@tcp(127.0.0.1:3306)/zgo?charset=utf8mb4"}, 17 | }) 18 | db := GetDB("sqlite3") 19 | 20 | db.Exec("wrong sql test") 21 | 22 | db.Exec("drop table if exists test") 23 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 24 | 25 | for i := 1; i <= 10; i++ { 26 | db.InsertContext(context.TODO(), "test", Values{ 27 | "id": i, 28 | "name": fmt.Sprintf("name %d", i), 29 | "time": time.Now(), 30 | }) 31 | } 32 | 33 | rows, err := db.SelectMap("test", Where{ 34 | "_field": "count(*)", 35 | }) 36 | log.Debug("select count: %s, err: %s", rows, err) 37 | 38 | rows, err = db.SelectMap("test", Where{ 39 | "id in": []int{2, 3}, 40 | }) 41 | log.Debug("select in: %s", rows) 42 | 43 | rows, err = db.SelectMap("test", Where{ 44 | "id between": []int{2, 5}, 45 | "_other": "order by id desc", 46 | }) 47 | log.Debug("select between: %s", rows) 48 | 49 | db.Delete("test", Where{ 50 | "id >": 5, 51 | }) 52 | 53 | rows, err = db.SelectMap("test", Where{ 54 | "_field": "count(*)", 55 | }) 56 | log.Debug("select count: %s", rows) 57 | 58 | db.Update("test", Values{ 59 | "name": "new name", 60 | }, Where{ 61 | "id <": 3, 62 | }) 63 | rows, err = db.SelectMap("test", Where{}) 64 | log.Debug("select update: %s", rows) 65 | 66 | db.Update("test", Values{ 67 | "name": "new name", 68 | }, Where{}) 69 | 70 | rows, err = db.SelectMap("test", Where{ 71 | "name": "??", 72 | }) 73 | log.Debug("select ? %s", rows) 74 | 75 | } 76 | 77 | func TestTransaction(t *testing.T) { 78 | log.Install("stdout") 79 | Install(DBConf{ 80 | "sqlite3": []string{"sqlite3", "file::memory:?mode=memory&cache=shared"}, 81 | "mysql": []string{"mysql", "test:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4"}, 82 | }) 83 | db := GetDB("sqlite3") 84 | 85 | db.Exec("drop table if exists test") 86 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 87 | 88 | tx, _ := db.Begin() 89 | tx.Insert("test", Values{ 90 | "id": 1, 91 | "name": "name", 92 | "time": time.Now(), 93 | }) 94 | 95 | rows, err := db.SelectMap("test", Where{}) 96 | log.Debug("%s %s", rows, err) 97 | 98 | tx.Commit() 99 | 100 | rows, err = db.SelectMap("test", Where{}) 101 | log.Debug("%s", rows) 102 | 103 | } 104 | 105 | func TestMulitRun(t *testing.T) { 106 | log.Install("stdout") 107 | Install(DBConf{ 108 | "sqlite3": []string{"sqlite3", "file::memory:?mode=memory&cache=shared"}, 109 | }) 110 | db := GetDB("sqlite3") 111 | db.Exec("drop table if exists test") 112 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 113 | 114 | count := 3 115 | 116 | for i := 0; i < count; i++ { 117 | db.Insert("test", Values{ 118 | "id": i, 119 | "name": fmt.Sprintf("name %d", i), 120 | "time": time.Now(), 121 | }) 122 | } 123 | var wa sync.WaitGroup 124 | wa.Add(count) 125 | for i := 0; i < count; i++ { 126 | go func() { 127 | db.SelectMap("test", Where{}) 128 | wa.Done() 129 | }() 130 | } 131 | wa.Wait() 132 | } 133 | 134 | type User struct { 135 | Id int `zdb:"id"` 136 | Name string `zdb:"name"` 137 | Time time.Time `zdb:"time"` 138 | Other string 139 | } 140 | 141 | func TestScan(t *testing.T) { 142 | log.Install("stdout") 143 | Install(DBConf{ 144 | "sqlite3": []string{"sqlite3", "file::memory:?mode=memory&cache=shared"}, 145 | }) 146 | db := GetDB("sqlite3") 147 | db.Exec("drop table if exists test") 148 | db.Exec("create table if not exists test(id integer not null primary key, name text, time datetime)") 149 | 150 | count := 3 151 | 152 | for i := 0; i < count; i++ { 153 | db.Insert("test", Values{ 154 | "id": i, 155 | "name": fmt.Sprintf("name %d", i), 156 | "time": time.Now(), 157 | }) 158 | } 159 | 160 | user := []User{} 161 | log.Debug(user) 162 | err := db.SelectScan(&user, "test", Where{}) 163 | log.Debug(err) 164 | log.Debug(user) 165 | 166 | user1 := User{} 167 | log.Debug(user1) 168 | err = db.SelectScan(&user1, "test", Where{"_other": "order by id desc"}) 169 | log.Debug(err) 170 | log.Debug(user1) 171 | 172 | } 173 | -------------------------------------------------------------------------------- /tool/json.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type JSONTimeISO time.Time 9 | 10 | func (t JSONTimeISO) MarshalJSON() ([]byte, error) { 11 | s := time.Time(t).Format(time.RFC3339) 12 | return []byte(fmt.Sprintf("\"%s\"", s)), nil 13 | } 14 | 15 | type JSONTimeTimestamp time.Time 16 | 17 | func (t JSONTimeTimestamp) MarshalJSON() ([]byte, error) { 18 | return []byte(fmt.Sprintf("%d", time.Time(t).Unix())), nil 19 | } 20 | -------------------------------------------------------------------------------- /web/README.md: -------------------------------------------------------------------------------- 1 | 2 | # zgo/web 3 | 4 | ## 介绍 5 | 6 | 对自带的http模块的适当封装,尽量贴近HTTP的本质,方便使用 7 | 8 | 1. 引入ctx 提供常用的工具函数 直接使用 参考`context.go` 9 | 2. 基于正则表达式的路由(30行实现),灵活简单,支持url配置输入参数 10 | 3. 灵活的中间件支持,目前自带CORS JSONP 中间件 11 | 4. 完全兼容http库,支持不使用路由,独立使用context,使用`ContextHandler` 12 | 5. 支持对请求报文进行调试 使用`web.DefaultServer.Debug=true` 13 | 6. 统一的日志打印 14 | 15 | ## 文档 16 | 17 | [https://godoc.org/github.com/JoveYu/zgo/web](https://godoc.org/github.com/JoveYu/zgo/web) 18 | 19 | ## TODO 20 | 21 | 1. 后续可以考虑支持多backend 比如fasthttp 目前对极致性能需求不大 22 | 2. 目前正则路由已经比较灵活,后续有需要在考虑更强大路由 23 | 3. 控制内存分配,引入pool 24 | 25 | ## Example 26 | 27 | ```go 28 | package main 29 | 30 | import ( 31 | "fmt" 32 | "strconv" 33 | 34 | "github.com/JoveYu/zgo/log" 35 | "github.com/JoveYu/zgo/web" 36 | ) 37 | 38 | func ping(ctx web.Context) { 39 | ctx.Abort(403, fmt.Sprintf("%s pong", ctx.Method())) 40 | } 41 | 42 | func hello(ctx web.Context) { 43 | name := ctx.Param("name") 44 | 45 | ctx.WriteHeader(200) 46 | ctx.WriteString("hello ") 47 | ctx.WriteString(name) 48 | } 49 | 50 | func add(ctx web.Context) { 51 | astr := ctx.Param("a") 52 | bstr := ctx.Param("b") 53 | 54 | a, _ := strconv.ParseInt(astr, 10, 64) 55 | b, _ := strconv.ParseInt(bstr, 10, 64) 56 | 57 | ctx.WriteHeader(200) 58 | 59 | ctx.WriteString(fmt.Sprintf("%d", a+b)) 60 | } 61 | 62 | func redir(ctx web.Context) { 63 | ctx.Redirect(302, "http://baidu.com") 64 | } 65 | 66 | func query(ctx web.Context) { 67 | query := ctx.GetQuery("test") 68 | ctx.WriteHeader(200) 69 | ctx.WriteJSON(map[string]string{"test": query}) 70 | } 71 | 72 | func main() { 73 | log.Install("stdout") 74 | 75 | // curl /ping -> pong 76 | web.GET("^/ping$", ping) 77 | web.POST("^/ping$", ping) 78 | 79 | // curl /params/world -> hello world 80 | web.GET("^/params/(?P\\w+)$", hello) 81 | 82 | // curl /add/1/2 -> 3 83 | web.GET("^/add/(?P\\d+)/(?P\\d+)$", add) 84 | 85 | // curl /redir -> to baidu 86 | web.GET("^/redir$", redir) 87 | 88 | // curl /query?test=123 -> {"test":"123"} 89 | web.GET("^/query$", query) 90 | 91 | web.Run("127.0.0.1:7000") 92 | } 93 | 94 | ``` 95 | 96 | 97 | -------------------------------------------------------------------------------- /web/context.go: -------------------------------------------------------------------------------- 1 | // context for web framework 2 | 3 | package web 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | "mime" 10 | "mime/multipart" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "strings" 15 | ) 16 | 17 | type ContextHandlerFunc func(Context) 18 | 19 | type ContextFlag struct { 20 | FormParsed bool 21 | BreakNext bool 22 | Status int 23 | } 24 | 25 | type Context struct { 26 | context.Context 27 | Request *http.Request 28 | ResponseWriter http.ResponseWriter 29 | Charset string 30 | 31 | // for debug 32 | Debug bool 33 | DebugBody *strings.Builder 34 | 35 | // for router params 36 | Params map[string]string 37 | 38 | Flag *ContextFlag 39 | } 40 | 41 | func NewContext(w http.ResponseWriter, r *http.Request) Context { 42 | return Context{ 43 | Request: r, 44 | ResponseWriter: w, 45 | Charset: "utf-8", 46 | 47 | Debug: false, 48 | 49 | Params: map[string]string{}, 50 | 51 | Flag: &ContextFlag{ 52 | FormParsed: false, 53 | BreakNext: false, 54 | Status: 200, 55 | }, 56 | } 57 | } 58 | 59 | func ContextHandler(f ContextHandlerFunc) http.Handler { 60 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 | ctx := NewContext(w, r) 62 | f(ctx) 63 | }) 64 | } 65 | 66 | func ContextCancelHandler(f ContextHandlerFunc) http.Handler { 67 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 68 | ctx := NewContext(w, r) 69 | c, cancel := context.WithCancel(context.Background()) 70 | defer cancel() 71 | 72 | ctx.Context = c 73 | f(ctx) 74 | }) 75 | } 76 | 77 | func (ctx *Context) BreakNext() { 78 | ctx.Flag.BreakNext = true 79 | } 80 | 81 | func (ctx *Context) Param(k string) string { 82 | v, ok := ctx.Params[k] 83 | if !ok { 84 | return "" 85 | } 86 | return v 87 | } 88 | 89 | func (ctx *Context) Method() string { 90 | return ctx.Request.Method 91 | } 92 | 93 | func (ctx *Context) URL() *url.URL { 94 | return ctx.Request.URL 95 | } 96 | 97 | func (ctx *Context) ReadJSON(v interface{}) error { 98 | return json.NewDecoder(ctx.Request.Body).Decode(v) 99 | } 100 | 101 | func (ctx *Context) Write(b []byte) (int, error) { 102 | if ctx.Debug { 103 | ctx.DebugBody.Write(b) 104 | } 105 | 106 | return ctx.ResponseWriter.Write(b) 107 | } 108 | 109 | func (ctx *Context) WriteHeader(status int) { 110 | ctx.Flag.Status = status 111 | ctx.ResponseWriter.WriteHeader(status) 112 | } 113 | 114 | func (ctx *Context) WriteString(s string) { 115 | ctx.Write([]byte(s)) 116 | } 117 | 118 | func (ctx *Context) WriteJSON(v interface{}) error { 119 | ctx.SetContentType("application/json") 120 | return json.NewEncoder(ctx).Encode(v) 121 | } 122 | 123 | func (ctx *Context) WriteJSONP(v interface{}) error { 124 | callback := ctx.GetQuery("callback") 125 | if callback != "" { 126 | data, err := json.Marshal(v) 127 | if err != nil { 128 | return err 129 | } 130 | 131 | ctx.SetContentType("application/javascript") 132 | ctx.WriteString(fmt.Sprintf("%s(%s)", callback, data)) 133 | return nil 134 | } else { 135 | return ctx.WriteJSON(v) 136 | } 137 | } 138 | 139 | func (ctx *Context) WriteFile(path string) { 140 | http.ServeFile(ctx.ResponseWriter, ctx.Request, path) 141 | } 142 | 143 | // simple cors allow ajax 144 | func (ctx *Context) CORS() { 145 | origin := ctx.GetHeader("Origin") 146 | if origin != "" { 147 | ctx.SetHeader("Access-Control-Allow-Origin", origin) 148 | ctx.SetHeader("Access-Control-Allow-Credentials", "true") 149 | } 150 | 151 | method := ctx.GetHeader("Access-Control-Request-Method") 152 | if method != "" { 153 | ctx.SetHeader("Access-Control-Allow-Methods", method) 154 | } 155 | 156 | header := ctx.GetHeader("Access-Control-Request-Headers") 157 | if header != "" { 158 | ctx.SetHeader("Access-Control-Allow-Headers", header) 159 | } 160 | } 161 | 162 | func (ctx *Context) Headers() http.Header { 163 | return ctx.Request.Header 164 | } 165 | 166 | func (ctx *Context) GetHeader(k string) string { 167 | return ctx.Request.Header.Get(k) 168 | } 169 | 170 | func (ctx *Context) SetHeader(k string, v string) { 171 | ctx.ResponseWriter.Header().Set(k, v) 172 | } 173 | 174 | func (ctx *Context) AddHeader(k string, v string) { 175 | ctx.ResponseWriter.Header().Add(k, v) 176 | } 177 | 178 | func (ctx *Context) Cookies() []*http.Cookie { 179 | return ctx.Request.Cookies() 180 | } 181 | 182 | func (ctx *Context) GetCookie(k string) *http.Cookie { 183 | c, err := ctx.Request.Cookie(k) 184 | if err != nil { 185 | return nil 186 | } 187 | return c 188 | } 189 | 190 | func (ctx *Context) GetCookieV(k string) string { 191 | c := ctx.GetCookie(k) 192 | if c != nil { 193 | return c.Value 194 | } 195 | return "" 196 | } 197 | 198 | func (ctx *Context) SetCookie(c *http.Cookie) { 199 | http.SetCookie(ctx.ResponseWriter, c) 200 | } 201 | 202 | func (ctx *Context) SetCookieKV(k string, v string) { 203 | c := http.Cookie{ 204 | Name: k, 205 | Value: v, 206 | Path: "/", 207 | HttpOnly: true, 208 | } 209 | ctx.SetCookie(&c) 210 | } 211 | 212 | func (ctx *Context) DelCookie(k string) { 213 | c := http.Cookie{ 214 | Name: k, 215 | Path: "/", 216 | MaxAge: -1, 217 | } 218 | ctx.SetCookie(&c) 219 | } 220 | 221 | func (ctx *Context) UserAgent() string { 222 | return ctx.GetHeader("User-Agent") 223 | } 224 | 225 | func (ctx *Context) Query() url.Values { 226 | return ctx.Request.URL.Query() 227 | } 228 | 229 | func (ctx *Context) GetQuery(k string) string { 230 | return ctx.Request.URL.Query().Get(k) 231 | } 232 | 233 | func (ctx *Context) FormFile(k string) (multipart.File, *multipart.FileHeader, error) { 234 | return ctx.Request.FormFile(k) 235 | } 236 | 237 | func (ctx *Context) Form() url.Values { 238 | if !ctx.Flag.FormParsed { 239 | ctx.Request.ParseForm() 240 | ctx.Flag.FormParsed = true 241 | } 242 | return ctx.Request.Form 243 | } 244 | 245 | func (ctx *Context) GetForm(k string) string { 246 | if !ctx.Flag.FormParsed { 247 | ctx.Request.ParseForm() 248 | ctx.Flag.FormParsed = true 249 | } 250 | return ctx.Request.Form.Get(k) 251 | } 252 | 253 | func (ctx *Context) SetCharset(c string) { 254 | ctx.Charset = c 255 | } 256 | 257 | // allow use ext to set content type 258 | // SetContentType("json") 259 | // SetContentType("application/json") 260 | func (ctx *Context) SetContentType(t string) string { 261 | // if is ext 262 | if !strings.ContainsRune(t, '/') { 263 | t = mime.TypeByExtension(fmt.Sprintf(".%s", t)) 264 | } 265 | if t != "" { 266 | ctx.SetHeader("Content-Type", fmt.Sprintf("%s; charset=%s", t, ctx.Charset)) 267 | } 268 | return t 269 | } 270 | 271 | func (ctx *Context) ClientIP() string { 272 | clientIP := ctx.GetHeader("X-Forwarded-For") 273 | clientIP = strings.TrimSpace(strings.Split(clientIP, ",")[0]) 274 | if clientIP == "" { 275 | clientIP = strings.TrimSpace(ctx.GetHeader("X-Real-Ip")) 276 | } 277 | if clientIP != "" { 278 | return clientIP 279 | } 280 | if ip, _, err := net.SplitHostPort(strings.TrimSpace(ctx.Request.RemoteAddr)); err == nil { 281 | return ip 282 | } 283 | return "" 284 | } 285 | 286 | func (ctx *Context) Abort(status int, body string) { 287 | ctx.SetContentType("text/plain") 288 | ctx.WriteHeader(status) 289 | ctx.WriteString(body) 290 | } 291 | 292 | func (ctx *Context) AbortJSON(status int, v interface{}) { 293 | ctx.WriteHeader(status) 294 | ctx.WriteJSON(v) 295 | } 296 | 297 | func (ctx *Context) Redirect(status int, url string) { 298 | ctx.SetHeader("Location", url) 299 | ctx.WriteHeader(status) 300 | ctx.WriteString("Redirecting to ") 301 | ctx.WriteString(url) 302 | } 303 | -------------------------------------------------------------------------------- /web/context_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/JoveYu/zgo/log" 8 | ) 9 | 10 | func thandler(ctx Context) { 11 | m := ctx.Method() 12 | url := ctx.URL() 13 | log.Debug("%s %s", m, url) 14 | a := ctx.GetQuery("a") 15 | log.Debug("query a=%s", a) 16 | b := ctx.GetForm("b") 17 | log.Debug("form b=%s", b) 18 | ua := ctx.UserAgent() 19 | log.Debug("ua %s", ua) 20 | 21 | ctx.WriteHeader(200) 22 | ctx.WriteString("hello world") 23 | } 24 | 25 | func TestCtxHandler(t *testing.T) { 26 | log.Install("stdout") 27 | 28 | http.Handle("/", ContextHandler(thandler)) 29 | http.ListenAndServe("127.0.0.1:7000", nil) 30 | } 31 | -------------------------------------------------------------------------------- /web/middleware.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | func CORS(ctx Context) { 4 | ctx.CORS() 5 | } 6 | -------------------------------------------------------------------------------- /web/router.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "regexp" 5 | ) 6 | 7 | type Router struct { 8 | r string 9 | cr *regexp.Regexp 10 | method string 11 | handlers []ContextHandlerFunc 12 | } 13 | -------------------------------------------------------------------------------- /web/server.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httputil" 7 | "regexp" 8 | "strings" 9 | "time" 10 | 11 | "github.com/JoveYu/zgo/log" 12 | ) 13 | 14 | const ( 15 | PoweredBy string = "zgo/0.0.1" 16 | ) 17 | 18 | type Server struct { 19 | Addr string 20 | Routers []Router 21 | Charset string 22 | Debug bool 23 | } 24 | 25 | func NewServer() *Server { 26 | server := Server{ 27 | Charset: "utf-8", 28 | Debug: false, 29 | } 30 | return &server 31 | } 32 | 33 | func (s *Server) StaticFile(path string, dir string) { 34 | r := fmt.Sprintf("^%s.*$", path) 35 | handler := func(ctx Context) { 36 | // disable list directory 37 | if strings.HasSuffix(ctx.URL().Path, "/") { 38 | http.NotFound(ctx.ResponseWriter, ctx.Request) 39 | return 40 | } 41 | 42 | // XXX status 200 in log is wrong 43 | handler := http.StripPrefix(path, http.FileServer(http.Dir(dir))) 44 | handler.ServeHTTP(ctx.ResponseWriter, ctx.Request) 45 | } 46 | s.Router("GET", r, handler) 47 | } 48 | 49 | func (s *Server) Router(method string, path string, handlers ...ContextHandlerFunc) { 50 | cr, err := regexp.Compile(path) 51 | if err != nil { 52 | log.Warn("can not add route [%s] %s", path, err) 53 | return 54 | } 55 | 56 | s.Routers = append(s.Routers, Router{ 57 | r: path, 58 | cr: cr, 59 | method: method, 60 | handlers: handlers, 61 | }) 62 | } 63 | 64 | func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 65 | tstart := time.Now() 66 | 67 | ctx := NewContext(w, r) 68 | defer s.LogRequest(tstart, &ctx) 69 | 70 | // debug 71 | if s.Debug { 72 | ctx.Debug = true 73 | ctx.DebugBody = &strings.Builder{} 74 | 75 | // debug req 76 | data, err := httputil.DumpRequest(r, true) 77 | if err != nil { 78 | log.Error("can not dump req: %s", err) 79 | } 80 | for _, b := range strings.Split(string(data), "\n") { 81 | log.Debug("> %s", b) 82 | } 83 | 84 | // debug resp 85 | defer func(ctx Context) { 86 | log.Debug("< %s %d %s", ctx.Request.Proto, 87 | ctx.Flag.Status, http.StatusText(ctx.Flag.Status), 88 | ) 89 | for k, v := range ctx.ResponseWriter.Header() { 90 | for _, vv := range v { 91 | log.Debug("< %s: %s", k, vv) 92 | } 93 | // XXX Content-Length and Date is missing 94 | } 95 | log.Debug("<") 96 | for _, b := range strings.Split(ctx.DebugBody.String(), "\n") { 97 | log.Debug("< %s", b) 98 | } 99 | }(ctx) 100 | } 101 | 102 | path := ctx.URL().Path 103 | 104 | // default header 105 | ctx.SetHeader("X-Powered-By", PoweredBy) 106 | ctx.SetContentType("text/plain") 107 | 108 | for _, router := range s.Routers { 109 | 110 | // HEAD request use GET Handler 111 | if ctx.Method() != router.method && !(ctx.Method() == "HEAD" && router.method == "GET") { 112 | continue 113 | } 114 | 115 | if !router.cr.MatchString(path) { 116 | continue 117 | } 118 | 119 | match := router.cr.FindStringSubmatch(path) 120 | if len(match[0]) != len(path) { 121 | continue 122 | } 123 | 124 | if len(match) > 1 { 125 | for idx, name := range router.cr.SubexpNames()[1:] { 126 | ctx.Params[name] = match[idx+1] 127 | } 128 | } 129 | 130 | for _, h := range router.handlers { 131 | h(ctx) 132 | if ctx.Flag.BreakNext { 133 | break 134 | } 135 | } 136 | return 137 | } 138 | 139 | ctx.Abort(http.StatusNotFound, http.StatusText(http.StatusNotFound)) 140 | } 141 | 142 | func (s *Server) Run(addr string) error { 143 | return http.ListenAndServe(addr, s) 144 | } 145 | 146 | func (s *Server) LogRequest(tstart time.Time, ctx *Context) { 147 | 148 | log.Info("%d|%s|%s|%s|%s|%d", 149 | ctx.Flag.Status, ctx.Method(), ctx.URL().Path, 150 | ctx.Query().Encode(), ctx.ClientIP(), 151 | time.Since(tstart)/time.Microsecond, 152 | ) 153 | } 154 | -------------------------------------------------------------------------------- /web/server_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/JoveYu/zgo/log" 7 | ) 8 | 9 | func TestRouter(t *testing.T) { 10 | log.Install("stdout") 11 | 12 | server := NewServer() 13 | server.Router("GET", "^/$", thandler) 14 | server.Router("GET", "/(?P\\w+)$", thandler) 15 | 16 | server.Run("127.0.0.1:7000") 17 | } 18 | -------------------------------------------------------------------------------- /web/web.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | var ( 8 | DefaultServer = NewServer() 9 | ) 10 | 11 | func Route(method string, path string, f ...ContextHandlerFunc) { 12 | DefaultServer.Router(method, path, f...) 13 | } 14 | 15 | func GET(path string, f ...ContextHandlerFunc) { 16 | DefaultServer.Router(http.MethodGet, path, f...) 17 | } 18 | 19 | func POST(path string, f ...ContextHandlerFunc) { 20 | DefaultServer.Router(http.MethodPost, path, f...) 21 | } 22 | 23 | func PUT(path string, f ...ContextHandlerFunc) { 24 | DefaultServer.Router(http.MethodPut, path, f...) 25 | } 26 | 27 | func DELETE(path string, f ...ContextHandlerFunc) { 28 | DefaultServer.Router(http.MethodDelete, path, f...) 29 | } 30 | 31 | func PATCH(path string, f ...ContextHandlerFunc) { 32 | DefaultServer.Router(http.MethodPatch, path, f...) 33 | } 34 | 35 | func OPTIONS(path string, f ...ContextHandlerFunc) { 36 | DefaultServer.Router(http.MethodOptions, path, f...) 37 | } 38 | 39 | func StaticFile(path string, dir string) { 40 | DefaultServer.StaticFile(path, dir) 41 | } 42 | 43 | func Run(addr string) error { 44 | return DefaultServer.Run(addr) 45 | } 46 | -------------------------------------------------------------------------------- /web/web_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/JoveYu/zgo/log" 7 | ) 8 | 9 | func handler1(ctx Context) { 10 | log.Debug("handler1") 11 | ctx.BreakNext() 12 | } 13 | func handler2(ctx Context) { 14 | log.Debug("handler2") 15 | } 16 | 17 | func TestWeb(t *testing.T) { 18 | log.Install("stdout") 19 | 20 | StaticFile("/static/", "/tmp/") 21 | GET("^/$", handler1, handler2) 22 | GET("/(?P\\w+)$", thandler) 23 | 24 | DefaultServer.Debug = true 25 | Run("127.0.0.1:7000") 26 | } 27 | -------------------------------------------------------------------------------- /web/webutil/response.go: -------------------------------------------------------------------------------- 1 | package webutil 2 | 3 | type Response struct { 4 | Code string `json:"code"` 5 | Message string `json:"msg"` 6 | Error string `json:"err"` 7 | Data interface{} `json:"data"` 8 | } 9 | 10 | type Map map[interface{}]interface{} 11 | 12 | // from qfcommon 13 | const ( 14 | OK string = "0000" 15 | 16 | ERR_DB string = "2000" 17 | ERR_RPC string = "2001" 18 | ERR_SESSION string = "2002" 19 | ERR_DATA string = "2003" 20 | ERR_IO string = "2004" 21 | 22 | ERR_LOGIN string = "2100" 23 | ERR_PARAM string = "2101" 24 | ERR_USER string = "2102" 25 | ERR_ROLE string = "2103" 26 | ERR_PWD string = "2104" 27 | 28 | ERR_REQUEST string = "2200" 29 | ERR_IP string = "2201" 30 | ERR_MAC string = "2202" 31 | 32 | ERR_NODATA string = "2300" 33 | ERR_DATAEXIST string = "2301" 34 | 35 | ERR_UNKNOW string = "2400" 36 | ) 37 | 38 | var ( 39 | ErrMsg map[string]string = map[string]string{ 40 | OK: "", 41 | ERR_DB: "数据库错误", 42 | ERR_RPC: "内部服务错误", 43 | ERR_SESSION: "用户未登陆", 44 | ERR_DATA: "数据错误", 45 | ERR_IO: "输入输出错误", 46 | ERR_LOGIN: "登陆错误", 47 | ERR_PARAM: "参数错误", 48 | ERR_USER: "用户错误", 49 | ERR_ROLE: "角色错误", 50 | ERR_PWD: "密码错误", 51 | ERR_REQUEST: "非法请求", 52 | ERR_IP: "IP受限", 53 | ERR_MAC: "校验mac错误", 54 | ERR_NODATA: "无数据", 55 | ERR_DATAEXIST: "数据已经存在", 56 | ERR_UNKNOW: "未知错误", 57 | } 58 | ) 59 | 60 | func Success(data interface{}, msg string) Response { 61 | return Response{ 62 | Code: OK, 63 | Message: msg, 64 | Error: ErrMsg[OK], 65 | Data: data, 66 | } 67 | } 68 | 69 | func Error(code string, data interface{}, msg string) Response { 70 | resperr, ok := ErrMsg[code] 71 | if !ok { 72 | resperr = "Error" 73 | } 74 | return Response{ 75 | Code: code, 76 | Message: msg, 77 | Error: resperr, 78 | Data: data, 79 | } 80 | } 81 | --------------------------------------------------------------------------------