├── OWNERS
├── docs
├── cn
│ ├── dashboard-workflow.md
│ ├── assets
│ │ ├── ast.png
│ │ ├── reload.png
│ │ ├── multi-tenant.png
│ │ ├── breaker_process.png
│ │ ├── deployment_idc.png
│ │ ├── deployment_k8s.png
│ │ ├── sliding_window.png
│ │ ├── backend-conn-pool.png
│ │ └── rateLimiterAndBreaker.png
│ ├── monitoring.md
│ ├── proxy.md
│ ├── cluster_deployment.md
│ ├── config-dynamic-reload.md
│ ├── multi-tenant.md
│ ├── connection-management.md
│ ├── RESTful_api.md
│ ├── fault-tolerant.md
│ ├── proxy-config.md
│ ├── quickstart.md
│ └── namespace-config.md
└── en
│ └── assets
│ └── weir-architecture.png
├── .gitignore
├── CONTRIBUTING.md
├── pkg
├── proxy
│ ├── constant
│ │ ├── context.go
│ │ └── charset.go
│ ├── backend
│ │ ├── instance.go
│ │ ├── client
│ │ │ ├── tls.go
│ │ │ ├── req.go
│ │ │ └── conn.go
│ │ ├── selector_test.go
│ │ ├── selector.go
│ │ └── backend.go
│ ├── driver
│ │ ├── driver.go
│ │ ├── mock_NamespaceManager.go
│ │ ├── queryctx_exec_test.go
│ │ ├── mock_Stmt.go
│ │ ├── queryctx_metrics.go
│ │ ├── mock_Namespace.go
│ │ ├── domain.go
│ │ ├── sessionvars.go
│ │ ├── resultset.go
│ │ └── connmgr.go
│ ├── namespace
│ │ ├── errcode.go
│ │ ├── ratelimiter.go
│ │ ├── ratelimiter_test.go
│ │ ├── domain.go
│ │ ├── frontend.go
│ │ ├── user.go
│ │ ├── namespace.go
│ │ ├── builder.go
│ │ └── manager.go
│ ├── metrics
│ │ ├── session.go
│ │ ├── backend.go
│ │ ├── server.go
│ │ ├── metrics.go
│ │ └── queryctx.go
│ ├── server
│ │ ├── buffered_read_conn.go
│ │ ├── tokenlimiter.go
│ │ ├── server_util.go
│ │ ├── column.go
│ │ ├── conn_stmt.go
│ │ ├── conn_util.go
│ │ ├── packetio_test.go
│ │ ├── column_test.go
│ │ ├── packetio.go
│ │ └── driver.go
│ └── proxy.go
├── util
│ ├── datastructure
│ │ ├── dsutil.go
│ │ └── dsutil_test.go
│ ├── sync2
│ │ ├── boolindex.go
│ │ ├── doc.go
│ │ ├── toggle.go
│ │ ├── toggle_test.go
│ │ ├── semaphore_flaky_test.go
│ │ ├── semaphore.go
│ │ ├── atomic_test.go
│ │ └── atomic.go
│ ├── rand2
│ │ ├── rand_test.go
│ │ └── rand.go
│ ├── passwd
│ │ └── passwd.go
│ ├── rate_limit_breaker
│ │ ├── rate_limit
│ │ │ ├── leaky_bucket_test.go
│ │ │ ├── sliding_window.go
│ │ │ ├── sliding_window_test.go
│ │ │ └── leaky_bucket.go
│ │ ├── sliding_window.go
│ │ └── circuit_breaker
│ │ │ └── circuit_breaker_test.go
│ ├── errors
│ │ ├── errors_test.go
│ │ └── errors.go
│ ├── timer
│ │ ├── randticker.go
│ │ ├── randticker_flaky_test.go
│ │ ├── timer_flaky_test.go
│ │ ├── time_wheel_test.go
│ │ ├── time_wheel.go
│ │ └── timer.go
│ └── ast
│ │ └── ast_util.go
├── config
│ ├── marshaller.go
│ ├── namespace_example.yaml
│ ├── proxy_example.yaml
│ ├── namespace.go
│ ├── proxy.go
│ └── marshaller_test.go
└── configcenter
│ ├── factory.go
│ ├── file.go
│ └── etcd.go
├── code-of-conduct.md
├── .github
├── ISSUE_TEMPLATE
│ ├── development-task.md
│ ├── bug-report.md
│ └── feature-request.md
└── pull_request_template.md
├── conf
├── weirproxy.yaml
├── weirproxy_etcd.yaml
└── namespace
│ └── test_namespace.yaml
├── README-CN.md
├── tests
└── proxy
│ └── backend
│ └── connpool_test.go
├── Makefile
├── README.md
├── go.mod
├── cmd
└── weirproxy
│ └── main.go
└── .golangci.yaml
/OWNERS:
--------------------------------------------------------------------------------
1 | committers:
2 |
--------------------------------------------------------------------------------
/docs/cn/dashboard-workflow.md:
--------------------------------------------------------------------------------
1 | # Dashboard工作流
2 |
3 | (TODO)
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin
2 | *.iml
3 | .idea
4 | *.swp
5 | .DS_Store
6 | vendor
7 | .vscode/
8 |
--------------------------------------------------------------------------------
/docs/cn/assets/ast.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/ast.png
--------------------------------------------------------------------------------
/docs/cn/assets/reload.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/reload.png
--------------------------------------------------------------------------------
/docs/cn/assets/multi-tenant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/multi-tenant.png
--------------------------------------------------------------------------------
/docs/cn/assets/breaker_process.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/breaker_process.png
--------------------------------------------------------------------------------
/docs/cn/assets/deployment_idc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/deployment_idc.png
--------------------------------------------------------------------------------
/docs/cn/assets/deployment_k8s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/deployment_k8s.png
--------------------------------------------------------------------------------
/docs/cn/assets/sliding_window.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/sliding_window.png
--------------------------------------------------------------------------------
/docs/cn/assets/backend-conn-pool.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/backend-conn-pool.png
--------------------------------------------------------------------------------
/docs/en/assets/weir-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/en/assets/weir-architecture.png
--------------------------------------------------------------------------------
/docs/cn/assets/rateLimiterAndBreaker.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/rateLimiterAndBreaker.png
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Please refer to PingCAP [CONTRIBUTING.md](https://github.com/pingcap/community/blob/master/CONTRIBUTING.md)
--------------------------------------------------------------------------------
/docs/cn/monitoring.md:
--------------------------------------------------------------------------------
1 | # 监控与告警
2 |
3 | ## Grafana监控
4 |
5 | Grafana监控分为proxy级别和namespace级别, [这里](assets/grafana_proxy.json) 给出了一个proxy级别监控大盘的配置示例.
6 |
--------------------------------------------------------------------------------
/pkg/proxy/constant/context.go:
--------------------------------------------------------------------------------
1 | package constant
2 |
3 | const ContextKeyPrefix = "__w_"
4 |
5 | const ContextKeySessionVariable = ContextKeyPrefix + "session_sysvars"
6 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/instance.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | type Instance struct {
4 | addr string
5 | }
6 |
7 | func (i *Instance) Addr() string {
8 | return i.addr
9 | }
10 |
--------------------------------------------------------------------------------
/code-of-conduct.md:
--------------------------------------------------------------------------------
1 | # PingCAP Community Code of Conduct
2 |
3 | Please refer to our [PingCAP Community Code of Conduct](https://github.com/pingcap/community/blob/master/CODE_OF_CONDUCT.md)
4 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/development-task.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "🗒️ Development Task"
3 | about: As a developer, I want to record a development task.
4 | labels: type/enhancement
5 | ---
6 |
7 | ## Development Task
8 |
--------------------------------------------------------------------------------
/pkg/proxy/constant/charset.go:
--------------------------------------------------------------------------------
1 | package constant
2 |
3 | import "github.com/pingcap/parser/charset"
4 |
5 | const (
6 | DefaultCharset = charset.CharsetUTF8MB4
7 | DefaultCollationID = charset.CollationUTF8MB4
8 | )
--------------------------------------------------------------------------------
/docs/cn/proxy.md:
--------------------------------------------------------------------------------
1 | # 应用层代理
2 |
3 | Weir作为TiDB分布式数据库治理平台, 其代理组件weir-proxy在实现时复用了TiDB 4.0的协议层和SQL解析器, 因此weir-proxy在协议和语法层面, 对MySQL的兼容性与TiDB 4.0相同.
4 |
5 | 使用原生MySQL客户端即可连接weir-proxy, 通过weir-proxy将客户端的SQL请求转发到后端的TiDB Server集群.
6 |
--------------------------------------------------------------------------------
/pkg/util/datastructure/dsutil.go:
--------------------------------------------------------------------------------
1 | package datastructure
2 |
3 | func StringSliceToSet(ss []string) map[string]struct{} {
4 | sset := make(map[string]struct{}, len(ss))
5 | for _, s := range ss {
6 | sset[s] = struct{}{}
7 | }
8 | return sset
9 | }
10 |
--------------------------------------------------------------------------------
/docs/cn/cluster_deployment.md:
--------------------------------------------------------------------------------
1 | # 部署
2 |
3 | ## 物理机部署
4 |
5 | 物理机部署应该尽可能选择多个机房部署, 做到高可用, 其中每个实例应该保持配置相同, 上游可以选择带有健康检查的 Server Load Balancer .
6 |
7 | ## kubernetes 下部署
8 |
9 | kubernetes 下部署可以利用 nodeSelector 将 Pod 尽量调度到不同的 node 节点上, 此操作需要向 Node 对象添加标签就可以将 pod 定位到特定的节点或节点组, 这可以用来确保指定的 Pod 只能运行在具有一定隔离性,安全性或监管属性的节点上.
10 | 其中上游可以直接通过 Service 或者其他转发组件进行转发
11 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "🐛 Bug Report"
3 | about: Something isn't working as expected
4 | labels: 'type/bug'
5 | ---
6 |
7 | ## Bug Report
8 |
9 |
10 |
11 | ### What did you do?
12 |
13 |
14 |
15 | ### What did you expect to see?
16 |
17 | ### What did you see instead?
18 |
19 | ### What version of Weir are you using (`weir-proxy -V`)?
20 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/driver.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "crypto/tls"
5 |
6 | "github.com/tidb-incubator/weir/pkg/proxy/server"
7 | )
8 |
9 | type DriverImpl struct {
10 | nsmgr NamespaceManager
11 | }
12 |
13 | func NewDriverImpl(nsmgr NamespaceManager) *DriverImpl {
14 | return &DriverImpl{
15 | nsmgr: nsmgr,
16 | }
17 | }
18 |
19 | func (d *DriverImpl) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (server.QueryCtx, error) {
20 | return NewQueryCtxImpl(d.nsmgr, connID), nil
21 | }
22 |
--------------------------------------------------------------------------------
/conf/weirproxy.yaml:
--------------------------------------------------------------------------------
1 | version: "v1"
2 | proxy_server:
3 | addr: "0.0.0.0:6000"
4 | max_connections: 1000
5 | session_timeout: 600
6 | admin_server:
7 | addr: "0.0.0.0:6001"
8 | enable_basic_auth: false
9 | user: ""
10 | password: ""
11 | log:
12 | level: "debug"
13 | format: "console"
14 | log_file:
15 | filename: ""
16 | max_size: 300
17 | max_days: 1
18 | max_backups: 1
19 | registry:
20 | enable: false
21 | config_center:
22 | type: "file"
23 | config_file:
24 | path: "./conf/namespace"
25 | strict_parse: false
26 | performance:
27 | tcp_keep_alive: true
28 |
--------------------------------------------------------------------------------
/pkg/util/sync2/boolindex.go:
--------------------------------------------------------------------------------
1 | package sync2
2 |
3 | import "sync/atomic"
4 |
5 | // BoolIndex rolled array switch mark
6 | type BoolIndex struct {
7 | index int32
8 | }
9 |
10 | // Set set index value
11 | func (b *BoolIndex) Set(index bool) {
12 | if index {
13 | atomic.StoreInt32(&b.index, 1)
14 | } else {
15 | atomic.StoreInt32(&b.index, 0)
16 | }
17 | }
18 |
19 | // Get return current, next, current bool value
20 | func (b *BoolIndex) Get() (int32, int32, bool) {
21 | index := atomic.LoadInt32(&b.index)
22 | if index == 1 {
23 | return 1, 0, true
24 | }
25 | return 0, 1, false
26 | }
27 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/errcode.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "github.com/pingcap/errors"
5 | )
6 |
7 | var (
8 | ErrDuplicatedUser = errors.New("duplicated user")
9 | ErrInvalidSelectorType = errors.New("invalid selector type")
10 |
11 | ErrNilBreakerName = errors.New("breaker name nil")
12 | ErrInvalidFailureRateThreshold = errors.New("invalid FailureRateThreshold")
13 | ErrInvalidopenStatusDurationMs = errors.New("invalid OpenStatusDurationMs")
14 | ErrInvalidSqlTimeout = errors.New("invalid sql timeout")
15 |
16 | ErrInvalidScope = errors.New("invalid scope")
17 | )
18 |
--------------------------------------------------------------------------------
/conf/weirproxy_etcd.yaml:
--------------------------------------------------------------------------------
1 | version: "v1"
2 | proxy_server:
3 | addr: "0.0.0.0:6000"
4 | max_connections: 10
5 | admin_server:
6 | addr: "0.0.0.0:6001"
7 | enable_basic_auth: false
8 | user: ""
9 | password: ""
10 | log:
11 | level: "debug"
12 | format: "console"
13 | log_file:
14 | filename: ""
15 | max_size: 300
16 | max_days: 1
17 | max_backups: 1
18 | registry:
19 | enable: false
20 | config_center:
21 | type: "etcd"
22 | config_etcd:
23 | addrs:
24 | - "127.0.0.1:2379"
25 | base_path: "/weir/defaultcluster"
26 | username: ""
27 | password: ""
28 | strict_parse: false
29 | performance:
30 | tcp_keep_alive: true
31 |
--------------------------------------------------------------------------------
/pkg/config/marshaller.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import "github.com/goccy/go-yaml"
4 |
5 | func UnmarshalNamespaceConfig(data []byte) (*Namespace, error) {
6 | var cfg Namespace
7 | if err := yaml.Unmarshal(data, &cfg); err != nil {
8 | return nil, err
9 | }
10 | return &cfg, nil
11 | }
12 |
13 | func MarshalNamespaceConfig(cfg *Namespace) ([]byte, error) {
14 | return yaml.Marshal(cfg)
15 | }
16 |
17 | func UnmarshalProxyConfig(data []byte) (*Proxy, error) {
18 | var cfg Proxy
19 | if err := yaml.Unmarshal(data, &cfg); err != nil {
20 | return nil, err
21 | }
22 | return &cfg, nil
23 | }
24 |
25 | func MarshalProxyConfig(cfg *Proxy) ([]byte, error) {
26 | return yaml.Marshal(cfg)
27 | }
28 |
--------------------------------------------------------------------------------
/docs/cn/config-dynamic-reload.md:
--------------------------------------------------------------------------------
1 | # 配置热加载
2 |
3 | Weir 作为多租户的 TiDB 数据库治理平台, 租户配置变更是一项比较频繁的操作. 如果每次增加, 修改, 删除租户配置都需要重新启动 Weir Proxy 才能使配置生效, 无疑会对用户体验和服务稳定性带来很大影响. 因此, 我们为 Weir 的 Namespace 配置提供了热加载支持. (注: 热加载要求配置中心使用 etcd )
4 |
5 |
6 |
7 | Weir 支持通过[管理接口](docs/cn/RESTful_api.md)触发配置热加载, 需要手动触发过程, 其中包括准备 (Prepare) 和提交 (Commit) 两个阶段. Weir 维护了一个双指针队列, 两个指针分别指向当前和准备阶段的 Namespace 队列.
8 | - 在准备阶段, Weir Proxy 会从配置中心拉取 Namespace 的最新配置, 解析配置并初始化 Namespace 存储在准备阶段队列中.
9 | - 在提交阶段, Weir Proxy会执行一次原子切换操作, 当前队列和准备阶段队列指针调换, 使用新的 Namespace 处理客户端的请求, 同时将旧的 Namespace 延迟关闭.
10 |
11 | 整个热加载过程, Weir Proxy不会主动关闭客户端连接, 客户端是无感知的, 对于一些非核心配置的调整, 甚至不需要重建后端数据库连接池, 对提升客户端体验和保持Weir本身以及后端TiDB集群稳定性都有比较大的帮助.
12 |
--------------------------------------------------------------------------------
/pkg/util/datastructure/dsutil_test.go:
--------------------------------------------------------------------------------
1 | package datastructure
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestStringSliceToSet(t *testing.T) {
10 | tests := []struct {
11 | name string
12 | args []string
13 | want map[string]struct{}
14 | }{
15 | {name: "nil", args: nil, want: map[string]struct{}{}},
16 | {name: "one", args: []string{"db0"}, want: map[string]struct{}{"db0": {}}},
17 | {name: "two", args: []string{"db0", "db1"}, want: map[string]struct{}{"db0": {}, "db1": {}}},
18 | }
19 | for _, tt := range tests {
20 | t.Run(tt.name, func(t *testing.T) {
21 | ret := StringSliceToSet(tt.args)
22 | assert.Equal(t, ret, tt.want)
23 | })
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/docs/cn/multi-tenant.md:
--------------------------------------------------------------------------------
1 | # 多租户软隔离
2 |
3 | Weir作为TiDB分布式数据库的统一接入层和治理平台, 需要具备一定的多租户管理能力, 我们将Weir设计成一个共享多租户系统.
4 |
5 | ## 基本概念
6 |
7 | 在Weir系统中使用Namespace表示租户. 每个Weir集群可以管理多个Namespace, 每个Namespace有其独立管理的用户 (User) 和资源配置参数. Namespace中的资源主要是指TiDB集群的访问能力.
8 |
9 | 在Weir集群中, User名称是全局唯一的. 每个User只属于一个Namespace, 每个Namespace可以有多个User. 这样便于在客户端连接Weir Proxy时, 根据MySQL的Username找到其Namespace, 从而确定该User的一些访问权限, 如可访问Database, 访问IP等.
10 |
11 | ## 软隔离
12 |
13 | 多个Namespace关联的TiDB集群可以是相同也可以不同, 但Weir Proxy会为每个Namespace创建该TiDB集群的后端连接池, 在后端数据库连接层面将各个Namespace软隔离, 在一定程度上使各个Namespace对同一集群的访问互相不受影响.
14 |
15 |
16 |
17 | 各个Namespace可以独立动态加载, 可以动态调整Namespace配置的各项参数, Namespace重新加载时会初始化资源, 并将原有Namespace的资源回收.
18 |
--------------------------------------------------------------------------------
/README-CN.md:
--------------------------------------------------------------------------------
1 | # Weir
2 |
3 | Weir是伴鱼公司研发的开源数据库代理平台, 主要为TiDB分布式数据库提供数据库治理功能.
4 |
5 | ## 功能特性
6 |
7 | > 1. Weir 为 MySQL 协议提供应用层代理,兼容 TiDB 4.0。[L7层负载](docs/cn/proxy.md)
8 | > 2. Weir 使用连接池进行后端连接管理,并支持负载均衡。[链接管理](docs/cn/connection-management.md)
9 | > 3. Weir 支持多租户管理,所有命名空间都可以在运行时动态重新加载。[多租户软隔离](docs/cn/multi-tenant.md)
10 | > 4. Weir 支持 qps 限流和熔断机制来保护客户端和 TiDB 服务器。[熔断限流机制](docs/cn/fault-tolerant.md)
11 |
12 | ## 使用手册
13 |
14 | - [快速上手](docs/cn/quickstart.md)
15 | - [Proxy配置详解](docs/cn/proxy-config.md)
16 | - [Namespace配置详解](docs/cn/namespace-config.md)
17 | - [集群部署](docs/cn/cluster_deployment.md)
18 | - [配置热加载](docs/cn/config-dynamic-reload.md)
19 | - [监控与告警](docs/cn/monitoring.md)
20 | - [RESTful-Api](docs/cn/RESTful_api.md)
21 |
22 | ## FAQ
23 |
--------------------------------------------------------------------------------
/pkg/util/sync2/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | // Package sync2 provides extra functionality along the same lines as sync.
18 | package sync2
19 |
--------------------------------------------------------------------------------
/pkg/configcenter/factory.go:
--------------------------------------------------------------------------------
1 | package configcenter
2 |
3 | import (
4 | "github.com/tidb-incubator/weir/pkg/config"
5 | "github.com/pingcap/errors"
6 | )
7 |
8 | const (
9 | ConfigCenterTypeFile = "file"
10 | ConfigCenterTypeEtcd = "etcd"
11 | )
12 |
13 | type ConfigCenter interface {
14 | GetNamespace(ns string) (*config.Namespace, error)
15 | ListAllNamespace() ([]*config.Namespace, error)
16 | }
17 |
18 | func CreateConfigCenter(cfg config.ConfigCenter) (ConfigCenter, error) {
19 | switch cfg.Type {
20 | case ConfigCenterTypeFile:
21 | return CreateFileConfigCenter(cfg.ConfigFile.Path)
22 | case ConfigCenterTypeEtcd:
23 | return CreateEtcdConfigCenter(cfg.ConfigEtcd)
24 | default:
25 | return nil, errors.New("invalid config center type")
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/pkg/config/namespace_example.yaml:
--------------------------------------------------------------------------------
1 | version: "v1"
2 | namespace: "test_namespace"
3 | frontend:
4 | allowed_dbs:
5 | - "test_weir_db"
6 | slow_sql_time: 50
7 | denied_sqls:
8 | denied_ips:
9 | idle_timeout: 3600
10 | users:
11 | - username: "hello"
12 | password: "world"
13 | - username: "hello1"
14 | password: "world1"
15 | backend:
16 | instances:
17 | - "127.0.0.1:4000"
18 | username: "root"
19 | password: ""
20 | selector_type: "random"
21 | pool_size: 10
22 | idle_timeout: 60
23 | breaker:
24 | scope: "sql"
25 | strategies:
26 | - min_qps: 3
27 | failure_rate_threshold: 0
28 | failure_num: 5
29 | sql_timeoutMs: 2000
30 | open_status_duration_ms: 5000
31 | size: 10
32 | cell_interval_ms: 1000
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "🚀 Feature Request"
3 | about: I have a suggestion
4 | labels: type/enhancement
5 | ---
6 |
7 | ## Feature Request
8 |
9 | ### Describe your feature request related problem
10 |
11 |
12 |
13 | ### Describe the feature you'd like
14 |
15 |
16 |
17 | ### Describe alternatives you've considered
18 |
19 |
20 |
21 | ### Teachability, Documentation, Adoption, Migration Strategy
22 |
23 |
--------------------------------------------------------------------------------
/pkg/util/rand2/rand_test.go:
--------------------------------------------------------------------------------
1 | package rand2
2 |
3 | import (
4 | "math/rand"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestNewRand(t *testing.T) {
11 | src1 := rand.NewSource(1)
12 | stdRd := rand.New(src1)
13 | src2 := rand.NewSource(1)
14 | rd := New(rand.New(src2))
15 |
16 | assert.Equal(t, stdRd.Int63(), rd.Int63())
17 | assert.Equal(t, stdRd.Uint32(), rd.Uint32())
18 | assert.Equal(t, stdRd.Uint64(), rd.Uint64())
19 | assert.Equal(t, stdRd.Int31(), rd.Int31())
20 | assert.Equal(t, stdRd.Int63n(100), rd.Int63n(100))
21 | assert.Equal(t, stdRd.Int31n(100), rd.Int31n(100))
22 | assert.Equal(t, stdRd.Intn(20), rd.Intn(20))
23 | assert.Equal(t, stdRd.Float64(), rd.Float64())
24 | assert.Equal(t, stdRd.Float32(), rd.Float32())
25 | }
26 |
--------------------------------------------------------------------------------
/pkg/util/passwd/passwd.go:
--------------------------------------------------------------------------------
1 | package passwd
2 |
3 | import "crypto/sha1"
4 |
5 | // calculatePassword calculate password hash
6 | func CalculatePassword(scramble, password []byte) []byte {
7 | if len(password) == 0 {
8 | return nil
9 | }
10 |
11 | // stage1Hash = SHA1(password)
12 | crypt := sha1.New()
13 | crypt.Write(password)
14 | stage1 := crypt.Sum(nil)
15 |
16 | // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
17 | // inner Hash
18 | crypt.Reset()
19 | crypt.Write(stage1)
20 | hash := crypt.Sum(nil)
21 |
22 | // outer Hash
23 | crypt.Reset()
24 | crypt.Write(scramble)
25 | crypt.Write(hash)
26 | scramble = crypt.Sum(nil)
27 |
28 | // token = scrambleHash XOR stage1Hash
29 | for i := range scramble {
30 | scramble[i] ^= stage1[i]
31 | }
32 | return scramble
33 | }
34 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/client/tls.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "crypto/tls"
5 | "crypto/x509"
6 | )
7 |
8 | // NewClientTLSConfig: generate TLS config for client side
9 | // if insecureSkipVerify is set to true, serverName will not be validated
10 | func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
11 | pool := x509.NewCertPool()
12 | if !pool.AppendCertsFromPEM(caPem) {
13 | panic("failed to add ca PEM")
14 | }
15 |
16 | cert, err := tls.X509KeyPair(certPem, keyPem)
17 | if err != nil {
18 | panic(err)
19 | }
20 |
21 | config := &tls.Config{
22 | Certificates: []tls.Certificate{cert},
23 | RootCAs: pool,
24 | InsecureSkipVerify: insecureSkipVerify,
25 | ServerName: serverName,
26 | }
27 | return config
28 | }
29 |
--------------------------------------------------------------------------------
/pkg/proxy/metrics/session.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | // Label constants.
4 | const (
5 | LblUnretryable = "unretryable"
6 | LblReachMax = "reach_max"
7 | LblOK = "ok"
8 | LblError = "error"
9 | LblCommit = "commit"
10 | LblAbort = "abort"
11 | LblRollback = "rollback"
12 | LblType = "type"
13 | LblDb = "db"
14 | LblTable = "table"
15 | LblResult = "result"
16 | LblSQLType = "sql_type"
17 | LblGeneral = "general"
18 | LblInternal = "internal"
19 | LbTxnMode = "txn_mode"
20 | LblPessimistic = "pessimistic"
21 | LblOptimistic = "optimistic"
22 | LblStore = "store"
23 | LblAddress = "address"
24 | LblBatchGet = "batch_get"
25 | LblGet = "get"
26 | LblNamespace = "namespace"
27 | LblCluster = "cluster"
28 |
29 | LblBackendAddr = "backend_addr"
30 | )
31 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/ratelimiter.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "context"
5 | "sync"
6 |
7 | "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker/rate_limit"
8 | )
9 |
10 | type NamespaceRateLimiter struct {
11 | limiters *sync.Map
12 | qpsThreshold int
13 | scope string
14 | }
15 |
16 | func NewNamespaceRateLimiter(scope string, qpsThreshold int) *NamespaceRateLimiter {
17 | return &NamespaceRateLimiter{
18 | limiters: &sync.Map{},
19 | scope: scope,
20 | qpsThreshold: qpsThreshold,
21 | }
22 | }
23 |
24 | func (n *NamespaceRateLimiter) Scope() string {
25 | return n.scope
26 | }
27 |
28 | func (n *NamespaceRateLimiter) Limit(ctx context.Context, key string) error {
29 | if n.qpsThreshold <= 0 {
30 | return nil
31 | }
32 | limiter, _ := n.limiters.LoadOrStore(key, rate_limit.NewSlidingWindowRateLimiter(int64(n.qpsThreshold)))
33 | return limiter.(*rate_limit.SlidingWindowRateLimiter).Limit()
34 | }
35 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/ratelimiter_test.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/stretchr/testify/require"
9 | )
10 |
11 | func TestNamespaceRateLimiter_Limit(t *testing.T) {
12 | ctx := context.Background()
13 | key1 := "hello"
14 | key2 := "world"
15 | rateLimiter := NewNamespaceRateLimiter("namespace", 2)
16 | require.NoError(t, rateLimiter.Limit(ctx, key1))
17 | require.NoError(t, rateLimiter.Limit(ctx, key1))
18 | require.Error(t, rateLimiter.Limit(ctx, key1))
19 | require.NoError(t, rateLimiter.Limit(ctx, key2))
20 | time.Sleep(time.Second)
21 | require.NoError(t, rateLimiter.Limit(ctx, key1))
22 | }
23 |
24 | func TestNamespaceRateLimiter_ZeroThreshold(t *testing.T) {
25 | ctx := context.Background()
26 | key1 := "hello"
27 | rateLimiter := NewNamespaceRateLimiter("namespace", 0)
28 | require.NoError(t, rateLimiter.Limit(ctx, key1))
29 | require.NoError(t, rateLimiter.Limit(ctx, key1))
30 | }
31 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/domain.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/tidb-incubator/weir/pkg/proxy/driver"
7 | )
8 |
9 | type Namespace interface {
10 | Name() string
11 | Auth(username string, passwdBytes []byte, salt []byte) bool
12 | IsDatabaseAllowed(db string) bool
13 | ListDatabases() []string
14 | IsDeniedSQL(sqlFeature uint32) bool
15 | IsAllowedSQL(sqlFeature uint32) bool
16 | GetPooledConn(context.Context) (driver.PooledBackendConn, error)
17 | Close()
18 | GetBreaker() (driver.Breaker, error)
19 | GetRateLimiter() driver.RateLimiter
20 | }
21 |
22 | type Frontend interface {
23 | Auth(username string, passwdBytes []byte, salt []byte) bool
24 | IsDatabaseAllowed(db string) bool
25 | ListDatabases() []string
26 | IsDeniedSQL(sqlFeature uint32) bool
27 | IsAllowedSQL(sqlFeature uint32) bool
28 | }
29 |
30 | type Backend interface {
31 | Close()
32 | GetPooledConn(context.Context) (driver.PooledBackendConn, error)
33 | }
34 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/mock_NamespaceManager.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.3.0. DO NOT EDIT.
2 |
3 | package driver
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | // MockNamespaceManager is an autogenerated mock type for the NamespaceManager type
8 | type MockNamespaceManager struct {
9 | mock.Mock
10 | }
11 |
12 | // Auth provides a mock function with given fields: username, pwd, salt
13 | func (_m *MockNamespaceManager) Auth(username string, pwd []byte, salt []byte) (Namespace, bool) {
14 | ret := _m.Called(username, pwd, salt)
15 |
16 | var r0 Namespace
17 | if rf, ok := ret.Get(0).(func(string, []byte, []byte) Namespace); ok {
18 | r0 = rf(username, pwd, salt)
19 | } else {
20 | if ret.Get(0) != nil {
21 | r0 = ret.Get(0).(Namespace)
22 | }
23 | }
24 |
25 | var r1 bool
26 | if rf, ok := ret.Get(1).(func(string, []byte, []byte) bool); ok {
27 | r1 = rf(username, pwd, salt)
28 | } else {
29 | r1 = ret.Get(1).(bool)
30 | }
31 |
32 | return r0, r1
33 | }
34 |
--------------------------------------------------------------------------------
/conf/namespace/test_namespace.yaml:
--------------------------------------------------------------------------------
1 | version: "v1"
2 | namespace: "test_namespace"
3 | frontend:
4 | allowed_dbs:
5 | - "test_weir_db"
6 | slow_sql_time: 50
7 | sql_blacklist:
8 | - sql: "select * from tbl0"
9 | - sql: "select * from tbl1"
10 | sql_whitelist:
11 | - sql: "select * from tbl2"
12 | - sql: "select * from tbl3"
13 | denied_ips:
14 | idle_timeout: 3600
15 | users:
16 | - username: "hello"
17 | password: "world"
18 | - username: "hello1"
19 | password: "world1"
20 | backend:
21 | instances:
22 | - "127.0.0.1:4000"
23 | username: "root"
24 | password: ""
25 | selector_type: "random"
26 | pool_size: 10
27 | idle_timeout: 60
28 | breaker:
29 | scope: "sql"
30 | strategies:
31 | - min_qps: 3
32 | failure_rate_threshold: 0
33 | failure_num: 5
34 | sql_timeout_ms: 2000
35 | open_status_duration_ms: 5000
36 | size: 10
37 | cell_interval_ms: 1000
38 | rate_limiter:
39 | scope: "db"
40 | qps: 1000
41 |
--------------------------------------------------------------------------------
/pkg/config/proxy_example.yaml:
--------------------------------------------------------------------------------
1 | version: "v1"
2 | proxy_server:
3 | addr: "0.0.0.0:6000"
4 | max_connections: 1
5 | admin_server:
6 | addr: "0.0.0.0:6001"
7 | enable_basic_auth: false
8 | user: "hello"
9 | password: "world"
10 | performance:
11 | tcp_keep_alive: true
12 | log:
13 | # Log level: debug, info, warn, error, fatal.
14 | level: "debug"
15 | # Log format, one of json, text, console.
16 | format: "console"
17 | # File logging.
18 | log_file:
19 | # Log file name.
20 | filename: ""
21 | # Max log file size in MB (upper limit to 4096MB).
22 | max_size: 300
23 | # Max log file keep days. No clean up by default.
24 | max_days: 1
25 | # Maximum number of old log files to retain. No clean up by default.
26 | max_backups: 1
27 | registry:
28 | enable: false
29 | type: "etcd"
30 | addrs:
31 | - "192.168.0.1:2379"
32 | - "192.168.0.2:2379"
33 | - "192.168.0.3:2379"
34 | config_center:
35 | type: "file"
36 | config_file:
37 | path: "./etc"
--------------------------------------------------------------------------------
/pkg/proxy/driver/queryctx_exec_test.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/pingcap/parser"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestFirstTableNameVisitor_TableName(t *testing.T) {
11 | type fields struct {
12 | sql string
13 | want string
14 | }
15 | tests := []fields{
16 | {sql: "SELECT 1", want: ""},
17 | {sql: "SELECT * FROM tbl1", want: "tbl1"},
18 | {sql: "SELECT * FROM tbl1,tbl2", want: "tbl1"},
19 | {sql: "SELECT * FROM tbl1 INNER JOIN tbl2 on tbl1.a = tbl2.a", want: "tbl1"},
20 | {sql: "INSERT INTO tbl1 VALUES (1,2,3)", want: "tbl1"},
21 | {sql: "DELETE FROM tbl1 WHERE id=1", want: "tbl1"},
22 | {sql: "UPDATE tbl1 SET a=1 WHERE id=1", want: "tbl1"},
23 | }
24 | for _, tt := range tests {
25 | t.Run(tt.sql, func(t *testing.T) {
26 | stmt, err := parser.New().ParseOneStmt(tt.sql, "", "")
27 | f := &FirstTableNameVisitor{}
28 | stmt.Accept(f)
29 | assert.NoError(t, err)
30 | assert.Equal(t, tt.want, f.TableName())
31 | })
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/tests/proxy/backend/connpool_test.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/tidb-incubator/weir/pkg/proxy/backend"
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestConnPool_ErrorClose_Success(t *testing.T) {
13 | cfg := backend.ConnPoolConfig{
14 | Config: backend.Config{
15 | Addr:"127.0.0.1:3306",
16 | UserName:"root",
17 | Password:"123456",
18 | },
19 | Capacity:1, // pool size is set to 1
20 | IdleTimeout:0,
21 | }
22 | pool := backend.NewConnPool("test", &cfg)
23 | err := pool.Init()
24 | require.NoError(t, err)
25 |
26 | ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
27 | defer cancelFunc()
28 |
29 | conn1, err := pool.GetConn(ctx)
30 | require.NoError(t, err)
31 |
32 | // conn is closed, and another conn is created by pool
33 | err = conn1.ErrorClose()
34 | require.NoError(t, err)
35 |
36 | conn2, err := pool.GetConn(ctx)
37 | require.NoError(t, err)
38 |
39 | err = conn2.ErrorClose()
40 | require.NoError(t, err)
41 | }
42 |
43 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PROJECTNAME = $(shell basename "$(PWD)")
2 | TOOL_BIN_PATH := $(shell pwd)/.tools/bin
3 | GOBASE = $(shell pwd)
4 | BUILD_TAGS ?=
5 | LDFLAGS ?=
6 | export GOBIN := $(TOOL_BIN_PATH)
7 | export PATH := $(TOOL_BIN_PATH):$(PATH)
8 |
9 | default: weirproxy
10 |
11 | weirproxy:
12 | ifeq ("$(WITH_RACE)", "1")
13 | go build -race -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags '${BUILD_TAGS}' -o bin/weirproxy cmd/weirproxy/main.go
14 | else
15 | go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags '${BUILD_TAGS}' -o bin/weirproxy cmd/weirproxy/main.go
16 | endif
17 |
18 | go-test:
19 | go test -coverprofile=.coverage.out ./...
20 | go tool cover -func=.coverage.out -o .coverage.func
21 | tail -1 .coverage.func
22 | go tool cover -html=.coverage.out -o .coverage.html
23 |
24 | go-lint-check: install-tools
25 | golangci-lint run
26 |
27 | go-lint-fix: install-tools
28 | golangci-lint run --fix
29 |
30 | install-tools:
31 | @mkdir -p $(TOOL_BIN_PATH)
32 | @test -e $(TOOL_BIN_PATH)/golangci-lint >/dev/null 2>&1 || curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(TOOL_BIN_PATH) v1.30.0
33 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/mock_Stmt.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.3.0. DO NOT EDIT.
2 |
3 | package driver
4 |
5 | import mock "github.com/stretchr/testify/mock"
6 |
7 | // MockStmt is an autogenerated mock type for the Stmt type
8 | type MockStmt struct {
9 | mock.Mock
10 | }
11 |
12 | // ColumnNum provides a mock function with given fields:
13 | func (_m *MockStmt) ColumnNum() int {
14 | ret := _m.Called()
15 |
16 | var r0 int
17 | if rf, ok := ret.Get(0).(func() int); ok {
18 | r0 = rf()
19 | } else {
20 | r0 = ret.Get(0).(int)
21 | }
22 |
23 | return r0
24 | }
25 |
26 | // ID provides a mock function with given fields:
27 | func (_m *MockStmt) ID() int {
28 | ret := _m.Called()
29 |
30 | var r0 int
31 | if rf, ok := ret.Get(0).(func() int); ok {
32 | r0 = rf()
33 | } else {
34 | r0 = ret.Get(0).(int)
35 | }
36 |
37 | return r0
38 | }
39 |
40 | // ParamNum provides a mock function with given fields:
41 | func (_m *MockStmt) ParamNum() int {
42 | ret := _m.Called()
43 |
44 | var r0 int
45 | if rf, ok := ret.Get(0).(func() int); ok {
46 | r0 = rf()
47 | } else {
48 | r0 = ret.Get(0).(int)
49 | }
50 |
51 | return r0
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/queryctx_metrics.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/tidb-incubator/weir/pkg/proxy/metrics"
7 | wast "github.com/tidb-incubator/weir/pkg/util/ast"
8 | "github.com/pingcap/parser/ast"
9 | )
10 |
11 | func (q *QueryCtxImpl) recordQueryMetrics(ctx context.Context, stmt ast.StmtNode, err error, durationMilliSecond float64) {
12 | ns := q.ns.Name()
13 | db := q.currentDB
14 | firstTableName, _ := wast.GetAstTableNameFromCtx(ctx)
15 | stmtType := metrics.GetStmtTypeName(stmt)
16 | retLabel := metrics.RetLabel(err)
17 |
18 | metrics.QueryCtxQueryCounter.WithLabelValues(ns, db, firstTableName, stmtType, retLabel).Inc()
19 | metrics.QueryCtxQueryDurationHistogram.WithLabelValues(ns, db, firstTableName, stmtType).Observe(durationMilliSecond)
20 | }
21 |
22 | func (q *QueryCtxImpl) recordDeniedQueryMetrics(ctx context.Context, stmt ast.StmtNode) {
23 | ns := q.ns.Name()
24 | db := q.currentDB
25 | firstTableName, _ := wast.GetAstTableNameFromCtx(ctx)
26 | stmtType := metrics.GetStmtTypeName(stmt)
27 |
28 | metrics.QueryCtxQueryDeniedCounter.WithLabelValues(ns, db, firstTableName, stmtType).Inc()
29 | }
30 |
--------------------------------------------------------------------------------
/pkg/util/sync2/toggle.go:
--------------------------------------------------------------------------------
1 | package sync2
2 |
3 | import (
4 | "errors"
5 | "sync"
6 | )
7 |
8 | var (
9 | ErrToggleNotPrepared = errors.New("not prepared")
10 | )
11 |
12 | type Toggle struct {
13 | data [2]interface{}
14 | idx int32
15 | prepared bool
16 | lock sync.RWMutex
17 | }
18 |
19 | func NewToggle(o interface{}) *Toggle {
20 | return &Toggle{
21 | data: [2]interface{}{o},
22 | }
23 | }
24 |
25 | func (t *Toggle) Current() interface{} {
26 | t.lock.RLock()
27 | ret := t.data[t.idx]
28 | t.lock.RUnlock()
29 | return ret
30 | }
31 |
32 | func (t *Toggle) SwapOther(o interface{}) interface{} {
33 | t.lock.Lock()
34 | defer t.lock.Unlock()
35 |
36 | tidx := toggleIdx(t.idx)
37 | origin := t.data[tidx]
38 | t.data[tidx] = o
39 | t.prepared = true
40 | return origin
41 | }
42 |
43 | func (t *Toggle) Toggle() error {
44 | t.lock.Lock()
45 | defer t.lock.Unlock()
46 |
47 | currIdx := t.idx
48 | if !t.prepared {
49 | return ErrToggleNotPrepared
50 | }
51 |
52 | t.idx = toggleIdx(currIdx)
53 | t.prepared = false
54 | return nil
55 | }
56 |
57 | func toggleIdx(idx int32) int32 {
58 | return (idx + 1) % 2
59 | }
60 |
--------------------------------------------------------------------------------
/pkg/proxy/server/buffered_read_conn.go:
--------------------------------------------------------------------------------
1 | // Copyright 2017 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | "bufio"
18 | "net"
19 | )
20 |
21 | const defaultReaderSize = 16 * 1024
22 |
23 | // bufferedReadConn is a net.Conn compatible structure that reads from bufio.Reader.
24 | type bufferedReadConn struct {
25 | net.Conn
26 | rb *bufio.Reader
27 | }
28 |
29 | func (conn bufferedReadConn) Read(b []byte) (n int, err error) {
30 | return conn.rb.Read(b)
31 | }
32 |
33 | func newBufferedReadConn(conn net.Conn) *bufferedReadConn {
34 | return &bufferedReadConn{
35 | Conn: conn,
36 | rb: bufio.NewReaderSize(conn, defaultReaderSize),
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/pkg/proxy/metrics/backend.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import "github.com/prometheus/client_golang/prometheus"
4 |
5 | const (
6 | BackendEventIniting = "initing"
7 | BackendEventInited = "inited"
8 | BackendEventClosing = "closing"
9 | BackendEventClosed = "closed"
10 | )
11 |
12 | var (
13 | BackendEventCounter = prometheus.NewCounterVec(
14 | prometheus.CounterOpts{
15 | Namespace: ModuleWeirProxy,
16 | Subsystem: LabelBackend,
17 | Name: "backend_event_total",
18 | Help: "Counter of backend event.",
19 | }, []string{LblCluster, LblNamespace, LblType})
20 |
21 | BackendQueryCounter = prometheus.NewCounterVec(
22 | prometheus.CounterOpts{
23 | Namespace: ModuleWeirProxy,
24 | Subsystem: LabelBackend,
25 | Name: "b_conn_cnt",
26 | Help: "Counter of backend query count.",
27 | }, []string{LblCluster, LblNamespace, LblBackendAddr})
28 |
29 | BackendConnInUseGauge = prometheus.NewGaugeVec(
30 | prometheus.GaugeOpts{
31 | Namespace: ModuleWeirProxy,
32 | Subsystem: LabelBackend,
33 | Name: "b_conn_in_use",
34 | Help: "Number of backend conn in use.",
35 | }, []string{LblCluster, LblNamespace, LblBackendAddr})
36 | )
37 |
--------------------------------------------------------------------------------
/pkg/util/rate_limit_breaker/rate_limit/leaky_bucket_test.go:
--------------------------------------------------------------------------------
1 | package rate_limit
2 |
3 | import (
4 | "fmt"
5 | "sync"
6 | "testing"
7 | "time"
8 | )
9 |
10 | func TestLeakyBucketRateLimiter_Wait(t *testing.T) {
11 | // not really a test
12 | t.Skip()
13 | start := time.Now()
14 | qpsThreshold := int64(20000)
15 | rateLimiter := NewLeakyBucketRateLimiter(qpsThreshold)
16 | defer rateLimiter.Close()
17 | go func() {
18 | for i := 10; i < 0; i++ {
19 | rateLimiter.ChangeQpsThreshold(qpsThreshold)
20 | }
21 | }()
22 |
23 | wg := &sync.WaitGroup{}
24 | for i := 0; i < 1; i++ {
25 | processorName := fmt.Sprintf("#%d", i)
26 | wg.Add(1)
27 | go func() {
28 | processorLeakyBucketQueue(wg, processorName, rateLimiter, 10000)
29 | }()
30 | }
31 | wg.Wait()
32 |
33 | dur := time.Now().Sub(start)
34 | fmt.Printf("duration: %s\n", dur)
35 | }
36 |
37 | func processorLeakyBucketQueue(wg *sync.WaitGroup, processorName string, rateLimiter *LeakyBucketRateLimiter, iterates int) {
38 | for i := 0; i < iterates; i++ {
39 | rateLimiter.Limit()
40 | // fmt.Printf("processor=%s, time: %s. task_id: %d\n", processorName, time.Now().Format("15:04:05"), i)
41 | }
42 | wg.Done()
43 | }
44 |
--------------------------------------------------------------------------------
/pkg/util/errors/errors_test.go:
--------------------------------------------------------------------------------
1 | package errors
2 |
3 | import (
4 | stderrors "errors"
5 | "testing"
6 |
7 | "github.com/pingcap/errors"
8 | "github.com/siddontang/go-mysql/mysql"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestIs(t *testing.T) {
13 | badConn := mysql.ErrBadConn
14 | err := errors.Wrapf(badConn, "same error type")
15 | assert.True(t, Is(err, badConn))
16 | }
17 |
18 | func TestStdIs(t *testing.T) {
19 | badConn := mysql.ErrBadConn
20 | err := errors.Wrapf(badConn, "another error type")
21 | assert.False(t, stderrors.Is(err, badConn))
22 | }
23 |
24 | func TestCheckAndGetMyError_True(t *testing.T) {
25 | myErr := mysql.NewError(1105, "unknown")
26 | err, is := CheckAndGetMyError(myErr)
27 | assert.True(t, is)
28 | assert.NotNil(t, err)
29 | }
30 |
31 | func TestCheckAndGetMyError_False(t *testing.T) {
32 | myErr := errors.New("not a myError")
33 | err, is := CheckAndGetMyError(myErr)
34 | assert.False(t, is)
35 | assert.Nil(t, err)
36 | }
37 |
38 | func TestCheckAndGetMyError_Cause_True(t *testing.T) {
39 | myErr := errors.Wrapf(mysql.NewError(1105, "unknown"), "wrap error")
40 | err, is := CheckAndGetMyError(myErr)
41 | assert.True(t, is)
42 | assert.NotNil(t, err)
43 | }
44 |
--------------------------------------------------------------------------------
/docs/cn/connection-management.md:
--------------------------------------------------------------------------------
1 | # 连接管理
2 |
3 | 作为 TiDB 的应用层代理中间件, Weir Proxy 的一项重要职责就是管理客户端的连接与后端 TiDB Server 集群的连接. Weir Proxy 使用连接池机制, 客户端连接可以复用数据库连接执行SQL请求.
4 |
5 | ## 后端连接池
6 |
7 | Weir Proxy 会为每个租户(namespace)的 TiDB Server 集群中的每个实例创建一个连接池, 连接池的配置参数是租户级别的, 其最大连接数是确定的. 客户端在向 Weir Proxy 发起 SQL 查询请求时, Weir Proxy 会首先根据集群负载均衡策略选出一个后端实例, 从该后端实例的连接池中取出一个连接, 根据客户端连接当前状态初始化连接后, 执行 SQL 查询语句, 查询完成后再放回连接池中.
8 |
9 |
10 |
11 | ### 连接绑定
12 |
13 | 对于大多数普通SQL查询, 都可以使用连接池完成. 但是对于一些依赖后端连接状态的查询场景, 连接池就不适用了. 例如: 事务, Prepare查询.
14 |
15 | 面对这些场景, Weir Proxy的做法是: 在状态改变时从后端连接池取出一个连接, 并"绑定"到当前客户端连接上, 期间客户端连接的所有查询请求全部使用这个绑定连接, 直到状态恢复时, 再将连接放回连接池.
16 |
17 | 以下命令可能会触发后端连接绑定 (是否真正绑定与当前状态有关):
18 |
19 | - BEGIN
20 | - SET AUTOCOMMIT = 0
21 | - Binary Prepare (COM_STMT_PREPARE命令)
22 |
23 | 以下命令可能会触发后端连接解绑 (是否真正解绑与当前状态有关):
24 |
25 | - COMMIT / ROLLBACK
26 | - SET AUTOCOMMIT = 1
27 | - Binary Close (COM_STMT_CLOSE命令)
28 |
29 | ## 连接状态传递
30 |
31 | 客户端连接在执行某些SQL语句时会改变自身状态, 这些状态会影响SQL语句的执行, 例如: 切换Database, 设置系统变量等.
32 |
33 | 当客户端连接执行这些语句时, Weir Proxy会记录这些连接状态, 而在执行真正的查询请求时, Weir Proxy会对后端连接 (连接池连接或绑定连接) 执行一次初始化操作, 将后端连接的状态与客户端连接状态同步, 然后再执行查询请求.
34 |
35 | 目前支持传递给后端连接的状态:
36 |
37 | - USE DB
38 | - 设置Session级别系统变量 (TODO)
39 |
--------------------------------------------------------------------------------
/docs/cn/RESTful_api.md:
--------------------------------------------------------------------------------
1 | # API接口
2 |
3 | ## 移除 namespace
4 |
5 | #### Request
6 | - Method: **POST**
7 | - URL: ```/admin/namespace/remove/:namespace```
8 | - Headers:
9 |
10 | #### Response
11 | - Body
12 | ```
13 | {
14 | "code":200,
15 | "msg":"success"
16 | }
17 | ```
18 |
19 | #### 错误码
20 |
21 | | 错误码 | 信息 |
22 | | --- | --- |
23 | | 400 | bad namespace parameter |
24 | | 200 | success |
25 |
26 |
27 | ## 准备重新加载 namespace
28 |
29 | #### Request
30 | - Method: **POST**
31 | - URL: ```/admin/namespace/reload/prepare/:namespace```
32 |
33 | #### Response
34 | - Body
35 | ```
36 | {
37 | "code":200,
38 | "msg":"success"
39 | }
40 | ```
41 |
42 | #### 错误码
43 |
44 | | 错误码 | 信息 |
45 | | --- | --- |
46 | | 400 | bad namespace parameter |
47 | | 500 | get namespace value from configcenter error |
48 | | 500 | prepare reload namespace error |
49 | | 200 | success |
50 |
51 |
52 | ## 提交重新加载 namespace
53 |
54 | #### Request
55 | - Method: **POST**
56 | - URL: ```/admin/namespace/reload/commit/:namespace```
57 |
58 | #### Response
59 | - Body
60 | ```
61 | {
62 | "code":200,
63 | "msg":"success"
64 | }
65 | ```
66 |
67 | #### 错误码
68 |
69 | | 错误码 | 信息 |
70 | | --- | --- |
71 | | 400 | bad namespace parameter |
72 | | 500 | commit reload namespace error |
73 | | 200 | success |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/frontend.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "bytes"
5 |
6 | "github.com/tidb-incubator/weir/pkg/util/passwd"
7 | )
8 |
9 | type SQLInfo struct {
10 | SQL string
11 | }
12 |
13 | type FrontendNamespace struct {
14 | allowedDBs []string
15 | allowedDBSet map[string]struct{}
16 | userPasswd map[string]string
17 | sqlBlacklist map[uint32]SQLInfo
18 | sqlWhitelist map[uint32]SQLInfo
19 | }
20 |
21 | func (n *FrontendNamespace) Auth(username string, passwdBytes []byte, salt []byte) bool {
22 | userPasswd, ok := n.userPasswd[username]
23 | if !ok {
24 | return false
25 | }
26 | userPasswdBytes := passwd.CalculatePassword(salt, []byte(userPasswd))
27 | return bytes.Equal(userPasswdBytes, passwdBytes)
28 | }
29 |
30 | func (n *FrontendNamespace) IsDatabaseAllowed(db string) bool {
31 | _, ok := n.allowedDBSet[db]
32 | return ok
33 | }
34 |
35 | func (n *FrontendNamespace) ListDatabases() []string {
36 | ret := make([]string, len(n.allowedDBs))
37 | copy(ret, n.allowedDBs)
38 | return ret
39 | }
40 |
41 | func (n *FrontendNamespace) IsDeniedSQL(sqlFeature uint32) bool {
42 | _, ok := n.sqlBlacklist[sqlFeature]
43 | return ok
44 | }
45 |
46 | func (n *FrontendNamespace) IsAllowedSQL(sqlFeature uint32) bool {
47 | _, ok := n.sqlWhitelist[sqlFeature]
48 | return ok
49 | }
50 |
--------------------------------------------------------------------------------
/pkg/util/errors/errors.go:
--------------------------------------------------------------------------------
1 | package errors
2 |
3 | import (
4 | "reflect"
5 |
6 | gomysql "github.com/siddontang/go-mysql/mysql"
7 | )
8 |
9 | // copied from errors.Is(), but replace Unwrap() with Cause()
10 | func Is(err, target error) bool {
11 | if target == nil {
12 | return err == target
13 | }
14 |
15 | isComparable := reflect.TypeOf(target).Comparable()
16 | for {
17 | if isComparable && err == target {
18 | return true
19 | }
20 | if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) {
21 | return true
22 | }
23 | // TODO: consider supporing target.Is(err). This would allow
24 | // user-definable predicates, but also may allow for coping with sloppy
25 | // APIs, thereby making it easier to get away with them.
26 | if err = Cause(err); err == nil {
27 | return false
28 | }
29 | }
30 | }
31 |
32 | func CheckAndGetMyError(err error) (*gomysql.MyError, bool) {
33 | if err == nil {
34 | return nil, false
35 | }
36 |
37 | for {
38 | if err1, ok := err.(*gomysql.MyError); ok {
39 | return err1, true
40 | }
41 | if err = Cause(err); err == nil {
42 | return nil, false
43 | }
44 | }
45 | }
46 |
47 | func Cause(err error) error {
48 | u, ok := err.(interface {
49 | Cause() error
50 | })
51 | if !ok {
52 | return nil
53 | }
54 | return u.Cause()
55 | }
56 |
--------------------------------------------------------------------------------
/pkg/proxy/server/tokenlimiter.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | // Token is used as a permission to keep on running.
17 | type Token struct {
18 | }
19 |
20 | // TokenLimiter is used to limit the number of concurrent tasks.
21 | type TokenLimiter struct {
22 | count uint
23 | ch chan *Token
24 | }
25 |
26 | // Put releases the token.
27 | func (tl *TokenLimiter) Put(tk *Token) {
28 | tl.ch <- tk
29 | }
30 |
31 | // Get obtains a token.
32 | func (tl *TokenLimiter) Get() *Token {
33 | return <-tl.ch
34 | }
35 |
36 | // NewTokenLimiter creates a TokenLimiter with count tokens.
37 | func NewTokenLimiter(count uint) *TokenLimiter {
38 | tl := &TokenLimiter{count: count, ch: make(chan *Token, count)}
39 | for i := uint(0); i < count; i++ {
40 | tl.ch <- &Token{}
41 | }
42 |
43 | return tl
44 | }
45 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Weir
2 |
3 | Weir is a database proxy middleware platform, mainly providing traffic management for TiDB.
4 |
5 | Weir is maintained by [伴鱼](https://www.ipalfish.com/) and [PingCAP](https://pingcap.com/).
6 |
7 | [中文文档](README-CN.md)
8 |
9 | ## Features
10 |
11 | - __L7 Proxy__
12 |
13 | Weir provides application layer proxy for MySQL Protocol, and it is compatible with TiDB 4.0.
14 |
15 | - __Connection Management__
16 |
17 | Weir uses connection pool for backend connection management, and supports load balancing.
18 |
19 | - __Multi-tenant Management__
20 |
21 | Weir supports multi-tenant management. All the namespaces can be dynamic reloaded in runtime.
22 |
23 | - __Fault Tolerance__
24 |
25 | Weir supports rate limiting and circuit breaking to protect both clients and TiDB servers.
26 |
27 | ## Architecture
28 |
29 | There are three core components in Weir platform: proxy, controller and UI dashboard.
30 |
31 |
32 |
33 | ## Roadmap
34 |
35 | - Web Application Firewall (WAF) for SQL
36 | - Database Mesh for TiDB
37 | - SQL audit
38 |
39 | ## Code of Conduct
40 |
41 | This project is for everyone. We ask that our users and contributors take a few minutes to review our [Code of Conduct](code-of-conduct.md).
42 |
43 | ## License
44 |
45 | Weir is under the Apache 2.0 license. See the [LICENSE](./LICENSE) file for details.
46 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/tidb-incubator/weir
2 |
3 | go 1.14
4 |
5 | require (
6 | github.com/gin-gonic/gin v1.7.2
7 | github.com/go-playground/validator/v10 v10.8.0 // indirect
8 | github.com/goccy/go-yaml v1.8.2
9 | github.com/golang/protobuf v1.5.2 // indirect
10 | github.com/json-iterator/go v1.1.11 // indirect
11 | github.com/mattn/go-isatty v0.0.13 // indirect
12 | github.com/opentracing/opentracing-go v1.1.0
13 | github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712
14 | github.com/pingcap/errors v0.11.5-0.20190809092503-95897b64e011
15 | github.com/pingcap/failpoint v0.0.0-20200702092429-9f69995143ce
16 | github.com/pingcap/parser v0.0.0-20200803072748-fdf66528323d
17 | github.com/pingcap/tidb v1.1.0-beta.0.20200826081922-9c1c21270001
18 | github.com/prometheus/client_golang v1.5.1
19 | github.com/shirou/gopsutil v3.21.6+incompatible // indirect
20 | github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726
21 | github.com/siddontang/go-mysql v1.1.0
22 | github.com/stretchr/testify v1.6.1
23 | github.com/tklauser/go-sysconf v0.3.7 // indirect
24 | github.com/ugorji/go v1.2.6 // indirect
25 | go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738
26 | go.uber.org/zap v1.15.0
27 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
28 | google.golang.org/protobuf v1.27.1 // indirect
29 | gopkg.in/yaml.v2 v2.4.0 // indirect
30 | )
31 |
32 | replace github.com/siddontang/go-mysql => github.com/ibanyu/go-mysql v1.1.0
33 |
--------------------------------------------------------------------------------
/pkg/util/sync2/toggle_test.go:
--------------------------------------------------------------------------------
1 | package sync2
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func Test_Toggle_Current(t *testing.T) {
11 | currVal := 1
12 | toggle := NewToggle(currVal)
13 | assert.Equal(t, currVal, toggle.Current())
14 | }
15 |
16 | func Test_Toggle_SwapOther(t *testing.T) {
17 | currVal := 1
18 | toggle := NewToggle(currVal)
19 | swap1 := 2
20 | ret := toggle.SwapOther(swap1)
21 | assert.Nil(t, ret)
22 | assert.Equal(t, currVal, toggle.Current())
23 |
24 | swap2 := 3
25 | ret = toggle.SwapOther(swap2)
26 | assert.Equal(t, swap1, ret)
27 | assert.Equal(t, currVal, toggle.Current())
28 | }
29 |
30 | func Test_Toggle_Toggle_Success(t *testing.T) {
31 | currVal := 1
32 | toggle := NewToggle(currVal)
33 |
34 | swap := 2
35 | _ = toggle.SwapOther(swap)
36 | err := toggle.Toggle()
37 | assert.NoError(t, err)
38 | assert.Equal(t, swap, toggle.Current())
39 | }
40 |
41 | func Test_Toggle_Toggle_Error_NotPrepared(t *testing.T) {
42 | currVal := 1
43 | toggle := NewToggle(currVal)
44 |
45 | err := toggle.Toggle()
46 | assert.EqualError(t, err, ErrToggleNotPrepared.Error())
47 | }
48 | func BenchmarkToggle(b *testing.B) {
49 | toggle := NewToggle(1)
50 | go func() {
51 | for {
52 | toggle.SwapOther(2)
53 | time.Sleep(1 * time.Millisecond)
54 | _ = toggle.Toggle()
55 | time.Sleep(1 * time.Millisecond)
56 | }
57 | }()
58 | for i := 0; i < b.N; i++ {
59 | toggle.Current()
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/pkg/proxy/server/server_util.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | "sync"
18 |
19 | "github.com/pingcap/tidb/sessionctx/variable"
20 | "github.com/pingcap/tidb/util/logutil"
21 | "github.com/pingcap/tidb/util/timeutil"
22 | "go.uber.org/zap"
23 | )
24 |
25 | // setSysTimeZoneOnce is used for parallel run tests. When several servers are running,
26 | // only the first will actually do setSystemTimeZoneVariable, thus we can avoid data race.
27 | var setSysTimeZoneOnce = &sync.Once{}
28 |
29 | func setSystemTimeZoneVariable() {
30 | setSysTimeZoneOnce.Do(func() {
31 | tz, err := timeutil.GetSystemTZ()
32 | if err != nil {
33 | logutil.BgLogger().Error(
34 | "Error getting SystemTZ, use default value instead",
35 | zap.Error(err),
36 | zap.String("default system_time_zone", variable.SysVars["system_time_zone"].Value))
37 | return
38 | }
39 | variable.SysVars["system_time_zone"].Value = tz
40 | })
41 | }
42 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/client/req.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "github.com/siddontang/go-mysql/utils"
5 | )
6 |
7 | func (c *Conn) writeCommand(command byte) error {
8 | c.ResetSequence()
9 |
10 | return c.WritePacket([]byte{
11 | 0x01, //1 bytes long
12 | 0x00,
13 | 0x00,
14 | 0x00, //sequence
15 | command,
16 | })
17 | }
18 |
19 | func (c *Conn) writeCommandBuf(command byte, arg []byte) error {
20 | c.ResetSequence()
21 |
22 | length := len(arg) + 1
23 | data := utils.ByteSliceGet(length + 4)
24 | data[4] = command
25 |
26 | copy(data[5:], arg)
27 |
28 | err := c.WritePacket(data)
29 |
30 | utils.ByteSlicePut(data)
31 |
32 | return err
33 | }
34 |
35 | func (c *Conn) writeCommandStr(command byte, arg string) error {
36 | return c.writeCommandBuf(command, utils.StringToByteSlice(arg))
37 | }
38 |
39 | func (c *Conn) writeCommandUint32(command byte, arg uint32) error {
40 | c.ResetSequence()
41 |
42 | return c.WritePacket([]byte{
43 | 0x05, //5 bytes long
44 | 0x00,
45 | 0x00,
46 | 0x00, //sequence
47 |
48 | command,
49 |
50 | byte(arg),
51 | byte(arg >> 8),
52 | byte(arg >> 16),
53 | byte(arg >> 24),
54 | })
55 | }
56 |
57 | func (c *Conn) writeCommandStrStr(command byte, arg1 string, arg2 string) error {
58 | c.ResetSequence()
59 |
60 | data := make([]byte, 4, 6+len(arg1)+len(arg2))
61 |
62 | data = append(data, command)
63 | data = append(data, arg1...)
64 | data = append(data, 0)
65 | data = append(data, arg2...)
66 |
67 | return c.WritePacket(data)
68 | }
69 |
--------------------------------------------------------------------------------
/pkg/util/rand2/rand.go:
--------------------------------------------------------------------------------
1 | package rand2
2 |
3 | import (
4 | "math/rand"
5 | "sync"
6 | )
7 |
8 | type Rand struct {
9 | sync.Mutex
10 | stdRand *rand.Rand
11 | }
12 |
13 | func New(src rand.Source) *Rand {
14 | return &Rand{
15 | stdRand: rand.New(src),
16 | }
17 | }
18 |
19 | func (r *Rand) Int63() int64 {
20 | r.Lock()
21 | ret := r.stdRand.Int63()
22 | r.Unlock()
23 | return ret
24 | }
25 |
26 | func (r *Rand) Uint32() uint32 {
27 | r.Lock()
28 | ret := r.stdRand.Uint32()
29 | r.Unlock()
30 | return ret
31 |
32 | }
33 |
34 | func (r *Rand) Uint64() uint64 {
35 | r.Lock()
36 | ret := r.stdRand.Uint64()
37 | r.Unlock()
38 | return ret
39 | }
40 |
41 | func (r *Rand) Int31() int32 {
42 | r.Lock()
43 | ret := r.stdRand.Int31()
44 | r.Unlock()
45 | return ret
46 | }
47 |
48 | func (r *Rand) Int() int {
49 | r.Lock()
50 | ret := r.stdRand.Int()
51 | r.Unlock()
52 | return ret
53 | }
54 |
55 | func (r *Rand) Int63n(n int64) int64 {
56 | r.Lock()
57 | ret := r.stdRand.Int63n(n)
58 | r.Unlock()
59 | return ret
60 | }
61 |
62 | func (r *Rand) Int31n(n int32) int32 {
63 | r.Lock()
64 | ret := r.stdRand.Int31n(n)
65 | r.Unlock()
66 | return ret
67 | }
68 |
69 | func (r *Rand) Intn(n int) int {
70 | r.Lock()
71 | ret := r.stdRand.Intn(n)
72 | r.Unlock()
73 | return ret
74 | }
75 |
76 | func (r *Rand) Float64() float64 {
77 | r.Lock()
78 | ret := r.stdRand.Float64()
79 | r.Unlock()
80 | return ret
81 | }
82 |
83 | func (r *Rand) Float32() float32 {
84 | r.Lock()
85 | ret := r.stdRand.Float32()
86 | r.Unlock()
87 | return ret
88 | }
89 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 |
7 |
8 | ### What problem does this PR solve?
9 |
10 |
11 |
12 | ### What is changed and how it works?
13 |
14 | ### Check List
15 |
16 |
17 |
18 | Tests
19 |
20 |
21 |
22 | - Unit test
23 | - Integration test
24 | - Manual test (add detailed scripts or steps below)
25 | - No code
26 |
27 | Code changes
28 |
29 | - Has configuration change
30 | - Has HTTP API interfaces change (Don't forget to [add the declarative for API](https://github.com/tikv/pd/blob/master/docs/development.md#updating-api-documentation))
31 | - Has persistent data change
32 |
33 | Side effects
34 |
35 | - Possible performance regression
36 | - Increased code complexity
37 | - Breaking backward compatibility
38 |
39 | Related changes
40 |
41 | - PR to update [`pingcap/docs`](https://github.com/pingcap/docs)/[`pingcap/docs-cn`](https://github.com/pingcap/docs-cn):
42 | - PR to update [`pingcap/tidb-ansible`](https://github.com/pingcap/tidb-ansible):
43 | - Need to cherry-pick to the release branch
44 |
45 | ### Release note
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/pkg/util/rate_limit_breaker/rate_limit/sliding_window.go:
--------------------------------------------------------------------------------
1 | package rate_limit
2 |
3 | import (
4 | "errors"
5 | . "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker"
6 | "sync"
7 | "sync/atomic"
8 | )
9 |
10 | var ErrRateLimited error = errors.New("rate limited")
11 |
12 | // 基于滑动窗口的,并发安全的限流器。
13 | type SlidingWindowRateLimiter struct {
14 | sw *SlidingWindow // guarded by mu
15 | mu *sync.Mutex // guard sw
16 | qpsThreshold int64 // read/write through atomic operation
17 | }
18 |
19 | func NewSlidingWindowRateLimiter(qpsThreshold int64) *SlidingWindowRateLimiter {
20 | swrl := &SlidingWindowRateLimiter{
21 | // 滑动窗口覆盖 1s 时间,划分为 10 个 cell,每个 cell 时长为 100ms。
22 | sw: NewSlidingWindow(10, 100),
23 | mu: &sync.Mutex{},
24 | qpsThreshold: qpsThreshold,
25 | }
26 | return swrl
27 | }
28 |
29 | // 如果被限流,则返回 ErrRateLimited;未被限流,则返回 nil
30 | func (swrl *SlidingWindowRateLimiter) Limit() error {
31 | nowMs := GetNowMs()
32 | qpsThreshold := atomic.LoadInt64(&swrl.qpsThreshold)
33 |
34 | swrl.mu.Lock()
35 | defer swrl.mu.Unlock()
36 |
37 | const HitMetric = "hit"
38 | hits := swrl.sw.GetHit(nowMs, HitMetric)
39 | actualDurationMs := swrl.sw.GetActualDurationMs(nowMs)
40 | // actualQPS = hits / (actualDurationMs / 1000)
41 | // actualQPS >= qpsThreshold 改写即得下述表达式。
42 | if hits*1000 >= qpsThreshold*actualDurationMs {
43 | return ErrRateLimited
44 | } else {
45 | swrl.sw.Hit(nowMs, HitMetric)
46 | return nil
47 | }
48 | }
49 |
50 | func (swrl *SlidingWindowRateLimiter) ChangeQpsThreshold(newQpsThreshold int64) {
51 | atomic.StoreInt64(&swrl.qpsThreshold, newQpsThreshold)
52 | }
53 |
--------------------------------------------------------------------------------
/docs/cn/fault-tolerant.md:
--------------------------------------------------------------------------------
1 | # 熔断限流机制
2 |
3 |
4 |
5 | ## 熔断
6 | ```
7 | scope: "sql"
8 | strategies:
9 | - min_qps: 3
10 | failure_rate_threshold: 0
11 | failure_num: 5
12 | sql_timeout_ms: 2000
13 | open_status_duration_ms: 5000
14 | size: 10
15 | cell_interval_ms: 1000
16 | ```
17 | ### 熔断过程
18 |
19 | 从配置文件中我们得知 strategies 是一个数组,那么 namespace 中可以在 scope 下配置多种熔断策略。当请求从客户端进入 weir ,weir 会根据链接账户选择要进入的租户(namesapce),同时启动对应租户下的熔断管理器,
20 | 熔断管理器根据 scope 可以选择当前是哪一类熔断器,再根据类别中的特征,比如库名表名,sql 特征等选择对应的熔断器对象,进行计数统计,如下图:
21 |
22 |
23 |
24 | 当熔断时返回错误 **circuit breaker triggered** 。
25 |
26 | ### 熔断级别
27 |
28 | 这里要说的是熔断级别的问题,目前 weir 支持4种熔断级别
29 |
30 | | 级别 | 说明 |
31 | | --- | --- |
32 | | namespace | 租户级别熔断,熔断触发时所有对此租户进行熔断 |
33 | | db | 库级别熔断, 熔断触发时所有对此库的 sql 进行熔断 |
34 | | table | 表级别熔断, 熔断触发时所有对此表的 sql 进行熔断 |
35 | | sql | sql 级别熔断, 对每一个特征 sql 做监控管理,颗粒度最细,熔断触发时对这一类 sql 进行熔断 |
36 |
37 | sql 熔断在内存中的输入 eg :
38 |
39 | ```
40 | select * from test_table where id = 2;
41 | select * from test_table where in (1,2,3);
42 | ```
43 | 在熔断器中将被转化成:
44 |
45 | ```
46 | select * from test_table where id = ?;
47 | select * from test_table where in (?);
48 | ```
49 |
50 | 这里我们在 ast 解析时通过判断 ast 的 node 类型进行了值的替换, 并重写了 sql 进行输出, 这样就可以确定一类 sql 并提取他们的摘要方便我们后续使用, 如下图:
51 |
52 |
53 | 熔断器计数采用的是滑动窗口计数器,滑动窗口有实现简单,能应对周期比较长的统计
54 |
55 |
56 | ## 限流
57 |
58 | 限流分为阻塞式限流和拒绝式限流,目前当前版本完成的是拒绝式限流
59 | 拒绝式限流统计数据同样是采用滑动窗口计数器,在周期内会统计 qps 数据,qps 一旦大于阈值将执行限流,期间返回错误 **rate limited**
60 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/selector_test.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "strconv"
5 | "testing"
6 |
7 | "github.com/tidb-incubator/weir/pkg/util/rand2"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | type testSource struct {
12 | val int64
13 | }
14 |
15 | func (t *testSource) Int63() int64 {
16 | return t.val
17 | }
18 |
19 | func (*testSource) Seed(seed int64) {
20 | }
21 |
22 | func TestRandomSelector_Select_Success(t *testing.T) {
23 | source := &testSource{}
24 | rd := rand2.New(source)
25 | selector := NewRandomSelector(rd)
26 |
27 | host := "127.0.0.1"
28 | ports := []int{4000, 4001, 4002}
29 | instances := prepareInstances(host, ports)
30 |
31 | for i := 0; i < len(ports); i++ {
32 | source.val = int64(i)
33 | instance, err := selector.Select(instances)
34 | assert.NoError(t, err)
35 | assert.Equal(t, getAddr(host, ports[i]), instance.addr)
36 | }
37 | }
38 |
39 | func TestRandomSelector_Select_ErrNoInstanceToSelect(t *testing.T) {
40 | source := &testSource{}
41 | rd := rand2.New(source)
42 | selector := NewRandomSelector(rd)
43 |
44 | host := "127.0.0.1"
45 | var ports []int
46 | instances := prepareInstances(host, ports)
47 |
48 | instance, err := selector.Select(instances)
49 | assert.Nil(t, instance)
50 | assert.EqualError(t, err, ErrNoInstanceToSelect.Error())
51 | }
52 |
53 | func prepareInstances(host string, ports []int) []*Instance {
54 | var instances []*Instance
55 | for _, p := range ports {
56 | instance := &Instance{
57 | addr: getAddr(host, p),
58 | }
59 | instances = append(instances, instance)
60 | }
61 | return instances
62 | }
63 |
64 | func getAddr(host string, port int) string {
65 | return host + ":" + strconv.Itoa(port)
66 | }
67 |
--------------------------------------------------------------------------------
/pkg/util/sync2/semaphore_flaky_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sync2
18 |
19 | import (
20 | "testing"
21 | "time"
22 | )
23 |
24 | func TestSemaNoTimeout(t *testing.T) {
25 | s := NewSemaphore(1, 0)
26 | s.Acquire()
27 | released := false
28 | go func() {
29 | time.Sleep(10 * time.Millisecond)
30 | released = true
31 | s.Release()
32 | }()
33 | s.Acquire()
34 | if !released {
35 | t.Errorf("release: false, want true")
36 | }
37 | }
38 |
39 | func TestSemaTimeout(t *testing.T) {
40 | s := NewSemaphore(1, 5*time.Millisecond)
41 | s.Acquire()
42 | go func() {
43 | time.Sleep(10 * time.Millisecond)
44 | s.Release()
45 | }()
46 | if s.Acquire() {
47 | t.Errorf("Acquire: true, want false")
48 | }
49 | time.Sleep(10 * time.Millisecond)
50 | if !s.Acquire() {
51 | t.Errorf("Acquire: false, want true")
52 | }
53 | }
54 |
55 | func TestSemaTryAcquire(t *testing.T) {
56 | s := NewSemaphore(1, 0)
57 | if !s.TryAcquire() {
58 | t.Errorf("TryAcquire: false, want true")
59 | }
60 | if s.TryAcquire() {
61 | t.Errorf("TryAcquire: true, want false")
62 | }
63 | s.Release()
64 | if !s.TryAcquire() {
65 | t.Errorf("TryAcquire: false, want true")
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/selector.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "errors"
5 | "math/rand"
6 | "time"
7 |
8 | "github.com/tidb-incubator/weir/pkg/util/rand2"
9 | )
10 |
11 | const (
12 | SelectorTypeRandom = 1 + iota
13 | )
14 |
15 | const (
16 | SelectorNameUnknown = "unknown"
17 | SelectorNameRandom = "random"
18 | )
19 |
20 | var (
21 | selectorTypeMap = map[int]string{
22 | SelectorTypeRandom: SelectorNameRandom,
23 | }
24 | selectorNameMap = map[string]int{
25 | SelectorNameRandom: SelectorTypeRandom,
26 | }
27 | )
28 |
29 | var (
30 | ErrNoInstanceToSelect = errors.New("no instance to select")
31 | ErrInvalidSelectorType = errors.New("invalid selector type")
32 | )
33 |
34 | type Selector interface {
35 | Select(instances []*Instance) (*Instance, error)
36 | }
37 |
38 | type RandomSelector struct {
39 | rd *rand2.Rand
40 | }
41 |
42 | func CreateSelector(selectorType int) (Selector, error) {
43 | switch selectorType {
44 | case SelectorTypeRandom:
45 | source := rand.NewSource(time.Now().Unix())
46 | rd := rand2.New(source)
47 | return NewRandomSelector(rd), nil
48 | default:
49 | return nil, ErrInvalidSelectorType
50 | }
51 | }
52 |
53 | func NewRandomSelector(rd *rand2.Rand) *RandomSelector {
54 | return &RandomSelector{
55 | rd: rd,
56 | }
57 | }
58 |
59 | func (s *RandomSelector) Select(instances []*Instance) (*Instance, error) {
60 | length := len(instances)
61 | if length == 0 {
62 | return nil, ErrNoInstanceToSelect
63 | }
64 | idx := s.rd.Int63n(int64(length))
65 | return instances[idx], nil
66 | }
67 |
68 | func SelectorNameToType(name string) (int, bool) {
69 | t, ok := selectorNameMap[name]
70 | return t, ok
71 | }
72 |
73 | func SelectorTypeToName(t int) (string, bool) {
74 | n, ok := selectorTypeMap[t]
75 | return n, ok
76 | }
77 |
--------------------------------------------------------------------------------
/docs/cn/proxy-config.md:
--------------------------------------------------------------------------------
1 | # Proxy配置详解
2 |
3 | 以下给出了Weir Proxy的示例配置🌰.
4 |
5 | ```
6 | version: "v1"
7 | proxy_server:
8 | addr: "0.0.0.0:6000"
9 | max_connections: 1000
10 | session_timeout: 600
11 | admin_server:
12 | addr: "0.0.0.0:6001"
13 | enable_basic_auth: false
14 | user: ""
15 | password: ""
16 | log:
17 | level: "debug"
18 | format: "console"
19 | log_file:
20 | filename: ""
21 | max_size: 300
22 | max_days: 1
23 | max_backups: 1
24 | registry:
25 | enable: false
26 | config_center:
27 | type: "file"
28 | config_file:
29 | path: "./conf/namespace"
30 | strict_parse: false
31 | performance:
32 | tcp_keep_alive: true
33 | ```
34 |
35 | | 配置名 | 说明 |
36 | | --- | --- |
37 | | version | 配置 schema 的版本号, 目前为 v1 |
38 | | proxy_server | Proxy 代理服务相关配置 |
39 | | proxy_server.addr | Proxy服务端口监听地址 |
40 | | proxy_server.max_connections | 最大客户端连接数 |
41 | | proxy_server.session_timeout | 客户端空闲链接超时时间 |
42 | | admin_server | Proxy 管理相关配置 |
43 | | admin_server.addr | Proxy admin 口监听地址 |
44 | | admin_server.enable_basic_auth | 是否开启Basic Auth |
45 | | admin_server.user | Basic Auth User |
46 | | admin_server.password | Basic Auth Password |
47 | | log | 日志配置 |
48 | | log.level | 日志级别 (支持 debug, info, warn, error) |
49 | | log.format | 日志输出方式 (支持 console, file) |
50 | | log.log_file | 日志文件相关配置 |
51 | | log.log_file.filename | 日志文件名 |
52 | | log.log_file.max_size | 单个日志文件最大尺寸 |
53 | | log.log_file.max_days | 单个日志文件保存最大天数 |
54 | | config_center | 配置中心 |
55 | | config_center.type | 配置中心类型 (支持 file, etcd) |
56 | | config_center.config_file | 配置文件信息,在 type 为file时有效 |
57 | | config_center.config_file.path | Namespace配置文件所在目录 |
58 | | strict_parse | 对命名空间名称的严格校验,如果禁用strictParse,则在列出所有命名空间时将忽略解析命名空间错误 |
59 | | performance | 性能相关配置 |
60 | | tcp_keep_alive | 对客户端连接是否开启TCP Keep Alive |
61 |
--------------------------------------------------------------------------------
/cmd/weirproxy/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "io/ioutil"
7 | "os"
8 | "os/signal"
9 | "sync"
10 | "syscall"
11 |
12 | "github.com/pingcap/tidb/util/logutil"
13 | "github.com/tidb-incubator/weir/pkg/config"
14 | "github.com/tidb-incubator/weir/pkg/proxy"
15 | "go.uber.org/zap"
16 | )
17 |
18 | var (
19 | configFilePath = flag.String("config", "conf/weirproxy.yaml", "weir proxy config file path")
20 | )
21 |
22 | func main() {
23 | flag.Parse()
24 | proxyConfigData, err := ioutil.ReadFile(*configFilePath)
25 | if err != nil {
26 | fmt.Printf("read config file error: %v\n", err)
27 | os.Exit(1)
28 | }
29 |
30 | proxyCfg, err := config.UnmarshalProxyConfig(proxyConfigData)
31 | if err != nil {
32 | fmt.Printf("parse config file error: %v\n", err)
33 | os.Exit(1)
34 | }
35 |
36 | p := proxy.NewProxy(proxyCfg)
37 |
38 | if err = p.Init(); err != nil {
39 | fmt.Printf("proxy init error: %v\n", err)
40 | p.Close()
41 | os.Exit(1)
42 | }
43 |
44 | sc := make(chan os.Signal, 1)
45 | signal.Notify(sc,
46 | syscall.SIGINT,
47 | syscall.SIGTERM,
48 | syscall.SIGQUIT,
49 | syscall.SIGPIPE,
50 | syscall.SIGUSR1,
51 | )
52 |
53 | var wg sync.WaitGroup
54 | wg.Add(1)
55 |
56 | go func() {
57 | defer wg.Done()
58 | for {
59 | sig := <-sc
60 | if sig == syscall.SIGINT || sig == syscall.SIGTERM || sig == syscall.SIGQUIT {
61 | logutil.BgLogger().Warn("get os signal, close proxy server", zap.String("signal", sig.String()))
62 | p.Close()
63 | break
64 | } else {
65 | logutil.BgLogger().Warn("ignore os signal", zap.String("signal", sig.String()))
66 | }
67 | }
68 | }()
69 |
70 | if err := p.Run(); err != nil {
71 | logutil.BgLogger().Error("proxy run error, exit", zap.Error(err))
72 | }
73 |
74 | wg.Wait()
75 | }
76 |
--------------------------------------------------------------------------------
/pkg/util/timer/randticker.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package timer
18 |
19 | import (
20 | "math/rand"
21 | "time"
22 | )
23 |
24 | // RandTicker is just like time.Ticker, except that
25 | // it adds randomness to the events.
26 | type RandTicker struct {
27 | C <-chan time.Time
28 | done chan struct{}
29 | }
30 |
31 | // NewRandTicker creates a new RandTicker. d is the duration,
32 | // and variance specifies the variance. The ticker will tick
33 | // every d +/- variance.
34 | func NewRandTicker(d, variance time.Duration) *RandTicker {
35 | c := make(chan time.Time, 1)
36 | done := make(chan struct{})
37 | go func() {
38 | rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
39 | for {
40 | vr := time.Duration(rnd.Int63n(int64(2*variance)) - int64(variance))
41 | tmr := time.NewTimer(d + vr)
42 | select {
43 | case <-tmr.C:
44 | select {
45 | case c <- time.Now():
46 | default:
47 | }
48 | case <-done:
49 | tmr.Stop()
50 | close(c)
51 | return
52 | }
53 | }
54 | }()
55 | return &RandTicker{
56 | C: c,
57 | done: done,
58 | }
59 | }
60 |
61 | // Stop stops the ticker and closes the underlying channel.
62 | func (tkr *RandTicker) Stop() {
63 | close(tkr.done)
64 | }
65 |
--------------------------------------------------------------------------------
/pkg/proxy/metrics/server.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "strconv"
5 |
6 | "github.com/pingcap/errors"
7 | "github.com/pingcap/parser/terror"
8 | "github.com/prometheus/client_golang/prometheus"
9 | )
10 |
11 | var (
12 | // PanicCounter measures the count of panics.
13 | PanicCounter = prometheus.NewCounterVec(
14 | prometheus.CounterOpts{
15 | Namespace: ModuleWeirProxy,
16 | Subsystem: LabelServer,
17 | Name: "panic_total",
18 | Help: "Counter of panic.",
19 | }, []string{LblCluster, LblType})
20 |
21 | QueryTotalCounter = prometheus.NewCounterVec(
22 | prometheus.CounterOpts{
23 | Namespace: ModuleWeirProxy,
24 | Subsystem: LabelServer,
25 | Name: "query_total",
26 | Help: "Counter of queries.",
27 | }, []string{LblCluster, LblType, LblResult})
28 |
29 | ExecuteErrorCounter = prometheus.NewCounterVec(
30 | prometheus.CounterOpts{
31 | Namespace: ModuleWeirProxy,
32 | Subsystem: LabelServer,
33 | Name: "execute_error_total",
34 | Help: "Counter of execute errors.",
35 | }, []string{LblCluster, LblType})
36 |
37 | ConnGauge = prometheus.NewGaugeVec(
38 | prometheus.GaugeOpts{
39 | Namespace: ModuleWeirProxy,
40 | Subsystem: LabelServer,
41 | Name: "connections",
42 | Help: "Number of connections.",
43 | }, []string{LblCluster})
44 |
45 | EventStart = "start"
46 | EventGracefulDown = "graceful_shutdown"
47 | // Eventkill occurs when the server.Kill() function is called.
48 | EventKill = "kill"
49 | EventClose = "close"
50 | )
51 |
52 | // ExecuteErrorToLabel converts an execute error to label.
53 | func ExecuteErrorToLabel(err error) string {
54 | err = errors.Cause(err)
55 | switch x := err.(type) {
56 | case *terror.Error:
57 | return x.Class().String() + ":" + strconv.Itoa(int(x.Code()))
58 | default:
59 | return "unknown"
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/pkg/util/timer/randticker_flaky_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package timer
18 |
19 | import (
20 | "testing"
21 | "time"
22 | )
23 |
24 | const (
25 | testDuration = 100 * time.Millisecond
26 | testVariance = 20 * time.Millisecond
27 | )
28 |
29 | func TestTick(t *testing.T) {
30 | tkr := NewRandTicker(testDuration, testVariance)
31 | for i := 0; i < 5; i++ {
32 | start := time.Now()
33 | end := <-tkr.C
34 | diff := start.Add(testDuration).Sub(end)
35 | tolerance := testVariance + 20*time.Millisecond
36 | if diff < -tolerance || diff > tolerance {
37 | t.Errorf("start: %v, end: %v, diff %v. Want <%v tolerenace", start, end, diff, tolerance)
38 | }
39 | }
40 | tkr.Stop()
41 | _, ok := <-tkr.C
42 | if ok {
43 | t.Error("Channel was not closed")
44 | }
45 | }
46 |
47 | func TestTickSkip(t *testing.T) {
48 | tkr := NewRandTicker(10*time.Millisecond, 1*time.Millisecond)
49 | time.Sleep(35 * time.Millisecond)
50 | end := <-tkr.C
51 | diff := time.Since(end)
52 | if diff < 20*time.Millisecond {
53 | t.Errorf("diff: %v, want >20ms", diff)
54 | }
55 |
56 | // This tick should be up-to-date
57 | end = <-tkr.C
58 | diff = time.Since(end)
59 | if diff > 1*time.Millisecond {
60 | t.Errorf("diff: %v, want <1ms", diff)
61 | }
62 | tkr.Stop()
63 | }
64 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/user.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/tidb-incubator/weir/pkg/config"
7 | "github.com/pingcap/errors"
8 | )
9 |
10 | type UserNamespaceMapper struct {
11 | userToNamespace map[string]string
12 | }
13 |
14 | func CreateUserNamespaceMapper(namespaces []*config.Namespace) (*UserNamespaceMapper, error) {
15 | mapper := make(map[string]string)
16 | for _, ns := range namespaces {
17 | frontendNamespace := ns.Frontend
18 | for _, user := range frontendNamespace.Users {
19 | originNamespace, ok := mapper[user.Username]
20 | if ok {
21 | return nil, errors.WithMessage(ErrDuplicatedUser,
22 | fmt.Sprintf("user: %s, namespace: %s, %s", user.Username, originNamespace, ns.Namespace))
23 | }
24 | mapper[user.Username] = ns.Namespace
25 | }
26 | }
27 |
28 | ret := &UserNamespaceMapper{userToNamespace: mapper}
29 | return ret, nil
30 | }
31 |
32 | func (u *UserNamespaceMapper) GetUserNamespace(username string) (string, bool) {
33 | ns, ok := u.userToNamespace[username]
34 | return ns, ok
35 | }
36 |
37 | func (u *UserNamespaceMapper) Clone() *UserNamespaceMapper {
38 | ret := make(map[string]string)
39 | for k, v := range u.userToNamespace {
40 | ret[k] = v
41 | }
42 | return &UserNamespaceMapper{userToNamespace: ret}
43 | }
44 |
45 | func (u *UserNamespaceMapper) RemoveNamespaceUsers(ns string) {
46 | for k, namespace := range u.userToNamespace {
47 | if ns == namespace {
48 | delete(u.userToNamespace, k)
49 | }
50 | }
51 | }
52 |
53 | func (u *UserNamespaceMapper) AddNamespaceUsers(ns string, cfg *config.FrontendNamespace) error {
54 | for _, userInfo := range cfg.Users {
55 | if originNamespace, ok := u.userToNamespace[userInfo.Username]; ok {
56 | return errors.WithMessage(ErrDuplicatedUser, fmt.Sprintf("namespace: %s", originNamespace))
57 | }
58 | u.userToNamespace[userInfo.Username] = ns
59 | }
60 | return nil
61 | }
62 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/mock_Namespace.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.3.0. DO NOT EDIT.
2 |
3 | package driver
4 |
5 | import (
6 | context "context"
7 |
8 | mock "github.com/stretchr/testify/mock"
9 | )
10 |
11 | // MockNamespace is an autogenerated mock type for the Namespace type
12 | type MockNamespace struct {
13 | mock.Mock
14 | }
15 |
16 | // GetPooledConn provides a mock function with given fields: _a0
17 | func (_m *MockNamespace) GetPooledConn(_a0 context.Context) (PooledBackendConn, error) {
18 | ret := _m.Called(_a0)
19 |
20 | var r0 PooledBackendConn
21 | if rf, ok := ret.Get(0).(func(context.Context) PooledBackendConn); ok {
22 | r0 = rf(_a0)
23 | } else {
24 | if ret.Get(0) != nil {
25 | r0 = ret.Get(0).(PooledBackendConn)
26 | }
27 | }
28 |
29 | var r1 error
30 | if rf, ok := ret.Get(1).(func(context.Context) error); ok {
31 | r1 = rf(_a0)
32 | } else {
33 | r1 = ret.Error(1)
34 | }
35 |
36 | return r0, r1
37 | }
38 |
39 | // IsDatabaseAllowed provides a mock function with given fields: db
40 | func (_m *MockNamespace) IsDatabaseAllowed(db string) bool {
41 | ret := _m.Called(db)
42 |
43 | var r0 bool
44 | if rf, ok := ret.Get(0).(func(string) bool); ok {
45 | r0 = rf(db)
46 | } else {
47 | r0 = ret.Get(0).(bool)
48 | }
49 |
50 | return r0
51 | }
52 |
53 | // ListDatabases provides a mock function with given fields:
54 | func (_m *MockNamespace) ListDatabases() []string {
55 | ret := _m.Called()
56 |
57 | var r0 []string
58 | if rf, ok := ret.Get(0).(func() []string); ok {
59 | r0 = rf()
60 | } else {
61 | if ret.Get(0) != nil {
62 | r0 = ret.Get(0).([]string)
63 | }
64 | }
65 |
66 | return r0
67 | }
68 |
69 | // Name provides a mock function with given fields:
70 | func (_m *MockNamespace) Name() string {
71 | ret := _m.Called()
72 |
73 | var r0 string
74 | if rf, ok := ret.Get(0).(func() string); ok {
75 | r0 = rf()
76 | } else {
77 | r0 = ret.Get(0).(string)
78 | }
79 |
80 | return r0
81 | }
82 |
--------------------------------------------------------------------------------
/pkg/config/namespace.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | type Namespace struct {
4 | Version string `yaml:"version"`
5 | Namespace string `yaml:"namespace"`
6 | Frontend FrontendNamespace `yaml:"frontend"`
7 | Backend BackendNamespace `yaml:"backend"`
8 | Breaker BreakerInfo `yaml:"breaker"`
9 | RateLimiter RateLimiterInfo `yaml:"rate_limiter"`
10 | }
11 |
12 | type FrontendNamespace struct {
13 | AllowedDBs []string `yaml:"allowed_dbs"`
14 | SlowSQLTime int `yaml:"slow_sql_time"`
15 | DeniedIPs []string `yaml:"denied_ips"`
16 | IdleTimeout int `yaml:"idle_timeout"`
17 | Users []FrontendUserInfo `yaml:"users"`
18 | SQLBlackList []SQLInfo `yaml:"sql_blacklist"`
19 | SQLWhiteList []SQLInfo `yaml:"sql_whitelist"`
20 | }
21 |
22 | type FrontendUserInfo struct {
23 | Username string `yaml:"username"`
24 | Password string `yaml:"password"`
25 | }
26 |
27 | type SQLInfo struct {
28 | SQL string `yaml:"sql"`
29 | }
30 |
31 | type RateLimiterInfo struct {
32 | Scope string `yaml:"scope"`
33 | QPS int `yaml:"qps"`
34 | }
35 |
36 | type BackendNamespace struct {
37 | Username string `yaml:"username"`
38 | Password string `yaml:"password"`
39 | Instances []string `yaml:"instances"`
40 | SelectorType string `yaml:"selector_type"`
41 | PoolSize int `yaml:"pool_size"`
42 | IdleTimeout int `yaml:"idle_timeout"`
43 | }
44 |
45 | type StrategyInfo struct {
46 | MinQps int64 `yaml:"min_qps"`
47 | SqlTimeoutMs int64 `yaml:"sql_timeout_ms"`
48 | FailureRatethreshold int64 `yaml:"failure_rate_threshold"`
49 | FailureNum int64 `yaml:"failure_num"`
50 | OpenStatusDurationMs int64 `yaml:"open_status_duration_ms"`
51 | Size int64 `yaml:"size"`
52 | CellIntervalMs int64 `yaml:"cell_interval_ms"`
53 | }
54 |
55 | type BreakerInfo struct {
56 | Scope string `yaml:"scope"`
57 | Strategies []StrategyInfo `yaml:"strategies"`
58 | }
59 |
--------------------------------------------------------------------------------
/pkg/util/timer/timer_flaky_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package timer
18 |
19 | import (
20 | "testing"
21 | "time"
22 |
23 | "github.com/tidb-incubator/weir/pkg/util/sync2"
24 | "github.com/stretchr/testify/assert"
25 | )
26 |
27 | const (
28 | half = 50 * time.Millisecond
29 | quarter = 25 * time.Millisecond
30 | tenth = 10 * time.Millisecond
31 | )
32 |
33 | var numcalls sync2.AtomicInt64
34 |
35 | func f() {
36 | numcalls.Add(1)
37 | }
38 |
39 | func TestWait(t *testing.T) {
40 | numcalls.Set(0)
41 | timer := NewTimer(quarter)
42 | assert.False(t, timer.Running())
43 | timer.Start(f)
44 | defer timer.Stop()
45 | assert.True(t, timer.Running())
46 | time.Sleep(tenth)
47 | assert.Equal(t, int64(0), numcalls.Get())
48 | time.Sleep(quarter)
49 | assert.Equal(t, int64(1), numcalls.Get())
50 | time.Sleep(quarter)
51 | assert.Equal(t, int64(2), numcalls.Get())
52 | }
53 |
54 | func TestReset(t *testing.T) {
55 | numcalls.Set(0)
56 | timer := NewTimer(half)
57 | timer.Start(f)
58 | defer timer.Stop()
59 | timer.SetInterval(quarter)
60 | time.Sleep(tenth)
61 | assert.Equal(t, int64(0), numcalls.Get())
62 | time.Sleep(quarter)
63 | assert.Equal(t, int64(1), numcalls.Get())
64 | }
65 |
66 | func TestIndefinite(t *testing.T) {
67 | numcalls.Set(0)
68 | timer := NewTimer(0)
69 | timer.Start(f)
70 | defer timer.Stop()
71 | timer.TriggerAfter(quarter)
72 | time.Sleep(tenth)
73 | assert.Equal(t, int64(0), numcalls.Get())
74 | time.Sleep(quarter)
75 | assert.Equal(t, int64(1), numcalls.Get())
76 | }
77 |
--------------------------------------------------------------------------------
/pkg/config/proxy.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | const (
4 | MIN_SESSION_TIMEOUT = 600
5 | )
6 |
7 | const (
8 | DefaultClusterName = "default"
9 | )
10 |
11 | type Proxy struct {
12 | Version string `yaml:"version"`
13 | Cluster string `yaml:"cluster"`
14 | ProxyServer ProxyServer `yaml:"proxy_server"`
15 | AdminServer AdminServer `yaml:"admin_server"`
16 | Log Log `yaml:"log"`
17 | Registry Registry `yaml:"registry"`
18 | ConfigCenter ConfigCenter `yaml:"config_center"`
19 | Performance Performance `yaml:"performance"`
20 | }
21 |
22 | type ProxyServer struct {
23 | Addr string `yaml:"addr"`
24 | MaxConnections uint32 `yaml:"max_connections"`
25 | SessionTimeout int `yaml:"session_timeout"`
26 | }
27 |
28 | type AdminServer struct {
29 | Addr string `yaml:"addr"`
30 | EnableBasicAuth bool `yaml:"enable_basic_auth"`
31 | User string `yaml:"user"`
32 | Password string `yaml:"password"`
33 | }
34 |
35 | type Log struct {
36 | Level string `yaml:"level"`
37 | Format string `yaml:"format"`
38 | LogFile LogFile `yaml:"log_file"`
39 | }
40 |
41 | type LogFile struct {
42 | Filename string `yaml:"filename"`
43 | MaxSize int `yaml:"max_size"`
44 | MaxDays int `yaml:"max_days"`
45 | MaxBackups int `yaml:"max_backups"`
46 | }
47 |
48 | type Registry struct {
49 | Enable bool `yaml:"enable"`
50 | Type string `yaml:"type"`
51 | Addrs []string `yaml:"addrs"`
52 | }
53 |
54 | type ConfigCenter struct {
55 | Type string `yaml:"type"`
56 | ConfigFile ConfigFile `yaml:"config_file"`
57 | ConfigEtcd ConfigEtcd `yaml:"config_etcd"`
58 | }
59 |
60 | type ConfigFile struct {
61 | Path string `yaml:"path"`
62 | }
63 |
64 | type ConfigEtcd struct {
65 | Addrs []string `yaml:"addrs"`
66 | BasePath string `yaml:"base_path"`
67 | Username string `yaml:"username"`
68 | Password string `yaml:"password"`
69 | // If strictParse is disabled, parsing namespace error will be ignored when list all namespaces.
70 | StrictParse bool `yaml:"strict_parse"`
71 | }
72 |
73 | type Performance struct {
74 | TCPKeepAlive bool `yaml:"tcp_keep_alive"`
75 | }
76 |
--------------------------------------------------------------------------------
/pkg/config/marshaller_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | var testNamespaceConfig = Namespace{
10 | Version: "v1",
11 | Namespace: "test_ns",
12 | Frontend: FrontendNamespace{
13 | AllowedDBs: []string{"db0", "db1"},
14 | SlowSQLTime: 10,
15 | DeniedIPs: []string{"127.0.0.0", "128.0.0.0"},
16 | IdleTimeout: 10,
17 | Users: []FrontendUserInfo{
18 | {Username: "user0", Password: "pwd0"},
19 | {Username: "user1", Password: "pwd1"},
20 | },
21 | },
22 | Backend: BackendNamespace{
23 | Username: "user0",
24 | Password: "pwd0",
25 | Instances: []string{"127.0.0.1:4000", "127.0.0.1:4001"},
26 | SelectorType: "random",
27 | PoolSize: 1,
28 | IdleTimeout: 20,
29 | },
30 | }
31 |
32 | var testProxyConfig = Proxy{
33 | Version: "v1",
34 | ProxyServer: ProxyServer{
35 | Addr: "0.0.0.0:4000",
36 | MaxConnections: 1,
37 | },
38 | AdminServer: AdminServer{
39 | Addr: "0.0.0.0:4001",
40 | EnableBasicAuth: false,
41 | User: "user",
42 | Password: "pwd",
43 | },
44 | Log: Log{
45 | Level: "info",
46 | Format: "console",
47 | LogFile: LogFile{
48 | Filename: ".",
49 | MaxSize: 10,
50 | MaxDays: 1,
51 | MaxBackups: 1,
52 | },
53 | },
54 | Registry: Registry{
55 | Enable: false,
56 | Type: "etcd",
57 | Addrs: []string{"127.0.0.1:4000", "127.0.0.1:4001"},
58 | },
59 | ConfigCenter: ConfigCenter{
60 | Type: "file",
61 | ConfigFile: ConfigFile{
62 | Path: ".",
63 | },
64 | },
65 | Performance: Performance{
66 | TCPKeepAlive: true,
67 | },
68 | }
69 |
70 | func TestNamespaceConfigEncodeAndDecode(t *testing.T) {
71 | data, err := MarshalNamespaceConfig(&testNamespaceConfig)
72 | assert.NoError(t, err)
73 | cfg, err := UnmarshalNamespaceConfig(data)
74 | assert.NoError(t, err)
75 | assert.Equal(t, testNamespaceConfig, *cfg)
76 | }
77 |
78 | func TestProxyConfigEncodeAndDecode(t *testing.T) {
79 | data, err := MarshalProxyConfig(&testProxyConfig)
80 | assert.NoError(t, err)
81 | cfg, err := UnmarshalProxyConfig(data)
82 | assert.NoError(t, err)
83 | assert.Equal(t, testProxyConfig, *cfg)
84 | }
85 |
--------------------------------------------------------------------------------
/pkg/configcenter/file.go:
--------------------------------------------------------------------------------
1 | package configcenter
2 |
3 | import (
4 | "github.com/pingcap/errors"
5 | "github.com/tidb-incubator/weir/pkg/config"
6 | "io/ioutil"
7 | "path"
8 | "path/filepath"
9 | )
10 |
11 | var (
12 | ErrNamespaceNotFound = errors.New("namespace not found")
13 | )
14 |
15 | // FileConfigCenter is only for test use,
16 | // please do not use it in production environment.
17 | type FileConfigCenter struct {
18 | dir string
19 | cfgs map[string]*config.Namespace // key: namespace
20 | nspath map[string]string // key: namespace, value: config file path
21 | }
22 |
23 | func CreateFileConfigCenter(nsdir string) (*FileConfigCenter, error) {
24 | yamlFiles, err := listAllYamlFiles(nsdir)
25 | if err != nil {
26 | return nil, err
27 | }
28 |
29 | c := newFileConfigCenter(nsdir)
30 |
31 | for _, yamlFile := range yamlFiles {
32 | fileData, err := ioutil.ReadFile(yamlFile)
33 | if err != nil {
34 | return nil, err
35 | }
36 | cfg, err := config.UnmarshalNamespaceConfig(fileData)
37 | if err != nil {
38 | return nil, err
39 | }
40 | c.cfgs[cfg.Namespace] = cfg
41 | c.nspath[cfg.Namespace] = yamlFile
42 | }
43 | return c, nil
44 | }
45 |
46 | func newFileConfigCenter(dir string) *FileConfigCenter {
47 | return &FileConfigCenter{
48 | dir: dir,
49 | cfgs: make(map[string]*config.Namespace),
50 | nspath: make(map[string]string),
51 | }
52 | }
53 |
54 | func listAllYamlFiles(dir string) ([]string, error) {
55 | infos, err := ioutil.ReadDir(dir)
56 | if err != nil {
57 | return nil, err
58 | }
59 |
60 | var ret []string
61 | for _, info := range infos {
62 | fileName := info.Name()
63 | if path.Ext(fileName) == ".yaml" {
64 | ret = append(ret, filepath.Join(dir, fileName))
65 | }
66 | }
67 |
68 | return ret, nil
69 | }
70 |
71 | func (f *FileConfigCenter) GetNamespace(ns string) (*config.Namespace, error) {
72 | cfg, ok := f.cfgs[ns]
73 | if !ok {
74 | return nil, ErrNamespaceNotFound
75 | }
76 | return cfg, nil
77 | }
78 |
79 | func (f *FileConfigCenter) ListAllNamespace() ([]*config.Namespace, error) {
80 | var ret []*config.Namespace
81 | for _, cfg := range f.cfgs {
82 | ret = append(ret, cfg)
83 | }
84 | return ret, nil
85 | }
86 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/domain.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/siddontang/go-mysql/mysql"
7 | )
8 |
9 | type NamespaceManager interface {
10 | Auth(username string, pwd, salt []byte) (Namespace, bool)
11 | }
12 |
13 | type Namespace interface {
14 | Name() string
15 | IsDatabaseAllowed(db string) bool
16 | ListDatabases() []string
17 | IsDeniedSQL(sqlFeature uint32) bool
18 | IsAllowedSQL(sqlFeature uint32) bool
19 | GetPooledConn(context.Context) (PooledBackendConn, error)
20 | IncrConnCount()
21 | DescConnCount()
22 | GetBreaker() (Breaker, error)
23 | GetRateLimiter() RateLimiter
24 | }
25 |
26 | type Breaker interface {
27 | IsUseBreaker() bool
28 | GetBreakerScope() string
29 | Hit(name string, idx int, isFail bool) error
30 | Status(name string) (int32, int)
31 | AddTimeWheelTask(name string, connectionID uint64, flag *int32) error
32 | RemoveTimeWheelTask(connectionID uint64) error
33 | CASHalfOpenProbeSent(name string, idx int, halfOpenProbeSent bool) bool
34 | CloseBreaker()
35 | }
36 |
37 | type RateLimiter interface {
38 | Scope() string
39 | Limit(ctx context.Context, key string) error
40 | }
41 |
42 | type PooledBackendConn interface {
43 | // PutBack put conn back to pool
44 | PutBack()
45 |
46 | // ErrorClose close conn and connpool create a new conn
47 | // call this function when conn is broken.
48 | ErrorClose() error
49 | BackendConn
50 | }
51 |
52 | type SimpleBackendConn interface {
53 | Close() error
54 | BackendConn
55 | }
56 |
57 | type BackendConn interface {
58 | Ping() error
59 | UseDB(dbName string) error
60 | GetDB() string
61 | Execute(command string, args ...interface{}) (*mysql.Result, error)
62 | Begin() error
63 | Commit() error
64 | Rollback() error
65 | StmtPrepare(sql string) (Stmt, error)
66 | StmtExecuteForward(data []byte) (*mysql.Result, error)
67 | StmtClosePrepare(stmtId int) error
68 | SetCharset(charset string) error
69 | FieldList(table string, wildcard string) ([]*mysql.Field, error)
70 | SetAutoCommit(bool) error
71 | IsAutoCommit() bool
72 | IsInTransaction() bool
73 | GetCharset() string
74 | GetConnectionID() uint32
75 | GetStatus() uint16
76 | }
77 |
78 | type Stmt interface {
79 | ID() int
80 | ParamNum() int
81 | ColumnNum() int
82 | }
83 |
--------------------------------------------------------------------------------
/pkg/proxy/proxy.go:
--------------------------------------------------------------------------------
1 | package proxy
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/tidb-incubator/weir/pkg/config"
7 | "github.com/tidb-incubator/weir/pkg/configcenter"
8 | "github.com/tidb-incubator/weir/pkg/proxy/driver"
9 | "github.com/tidb-incubator/weir/pkg/proxy/metrics"
10 | "github.com/tidb-incubator/weir/pkg/proxy/namespace"
11 | "github.com/tidb-incubator/weir/pkg/proxy/server"
12 | )
13 |
14 | type Proxy struct {
15 | cfg *config.Proxy
16 | svr *server.Server
17 | apiServer *HttpApiServer
18 | nsmgr *namespace.NamespaceManager
19 | configCenter configcenter.ConfigCenter
20 | }
21 |
22 | func supplementProxyConfig(cfg *config.Proxy) *config.Proxy {
23 | if cfg.ProxyServer.SessionTimeout <= config.MIN_SESSION_TIMEOUT {
24 | cfg.ProxyServer.SessionTimeout = config.MIN_SESSION_TIMEOUT
25 | }
26 | if cfg.Cluster == "" {
27 | cfg.Cluster = config.DefaultClusterName
28 | }
29 | return cfg
30 | }
31 |
32 | func NewProxy(cfg *config.Proxy) *Proxy {
33 | return &Proxy{
34 | cfg: supplementProxyConfig(cfg),
35 | }
36 | }
37 |
38 | func (p *Proxy) Init() error {
39 | metrics.RegisterProxyMetrics(p.cfg.Cluster)
40 | cc, err := configcenter.CreateConfigCenter(p.cfg.ConfigCenter)
41 | if err != nil {
42 | return err
43 | }
44 | p.configCenter = cc
45 |
46 | nss, err := cc.ListAllNamespace()
47 | if err != nil {
48 | return err
49 | }
50 | nsmgr, err := namespace.CreateNamespaceManager(nss, namespace.BuildNamespace, namespace.DefaultAsyncCloseNamespace)
51 | if err != nil {
52 | return err
53 | }
54 | p.nsmgr = nsmgr
55 | driverImpl := driver.NewDriverImpl(nsmgr)
56 | svr, err := server.NewServer(p.cfg, driverImpl)
57 | if err != nil {
58 | return err
59 | }
60 | p.svr = svr
61 | apiServer, err := CreateHttpApiServer(svr, nsmgr, cc, p.cfg)
62 | if err != nil {
63 | return err
64 | }
65 | p.apiServer = apiServer
66 |
67 | return nil
68 | }
69 |
70 | // TODO(eastfisher): refactor this function
71 | func (p *Proxy) Run() error {
72 | go func() {
73 | time.Sleep(200 * time.Millisecond)
74 | p.apiServer.Run()
75 | }()
76 | return p.svr.Run()
77 | }
78 |
79 | func (p *Proxy) Close() {
80 | if p.apiServer != nil {
81 | p.apiServer.Close()
82 | }
83 | if p.svr != nil {
84 | p.svr.Close()
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/pkg/util/sync2/semaphore.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sync2
18 |
19 | // What's in a name? Channels have all you need to emulate a counting
20 | // semaphore with a boatload of extra functionality. However, in some
21 | // cases, you just want a familiar API.
22 |
23 | import (
24 | "time"
25 | )
26 |
27 | // Semaphore is a counting semaphore with the option to
28 | // specify a timeout.
29 | type Semaphore struct {
30 | slots chan struct{}
31 | timeout time.Duration
32 | }
33 |
34 | // NewSemaphore creates a Semaphore. The count parameter must be a positive
35 | // number. A timeout of zero means that there is no timeout.
36 | func NewSemaphore(count int, timeout time.Duration) *Semaphore {
37 | sem := &Semaphore{
38 | slots: make(chan struct{}, count),
39 | timeout: timeout,
40 | }
41 | for i := 0; i < count; i++ {
42 | sem.slots <- struct{}{}
43 | }
44 | return sem
45 | }
46 |
47 | // Acquire returns true on successful acquisition, and
48 | // false on a timeout.
49 | func (sem *Semaphore) Acquire() bool {
50 | if sem.timeout == 0 {
51 | <-sem.slots
52 | return true
53 | }
54 | tm := time.NewTimer(sem.timeout)
55 | defer tm.Stop()
56 | select {
57 | case <-sem.slots:
58 | return true
59 | case <-tm.C:
60 | return false
61 | }
62 | }
63 |
64 | // TryAcquire acquires a semaphore if it's immediately available.
65 | // It returns false otherwise.
66 | func (sem *Semaphore) TryAcquire() bool {
67 | select {
68 | case <-sem.slots:
69 | return true
70 | default:
71 | return false
72 | }
73 | }
74 |
75 | // Release releases the acquired semaphore. You must
76 | // not release more than the number of semaphores you've
77 | // acquired.
78 | func (sem *Semaphore) Release() {
79 | sem.slots <- struct{}{}
80 | }
81 |
82 | // Size returns the current number of available slots.
83 | func (sem *Semaphore) Size() int {
84 | return len(sem.slots)
85 | }
86 |
--------------------------------------------------------------------------------
/pkg/util/rate_limit_breaker/rate_limit/sliding_window_test.go:
--------------------------------------------------------------------------------
1 | package rate_limit
2 |
3 | import (
4 | "github.com/stretchr/testify/assert"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestSlidingWindowRateLimiter_Limit(t *testing.T) {
10 | rl := NewSlidingWindowRateLimiter(10)
11 | ch := make(chan int)
12 | go func() {
13 | for i := 0; i < 100; i++ {
14 | ch <- i
15 | }
16 | }()
17 |
18 | // 使用 timer 确保只处理 1s 之内的数据
19 | OUTER:
20 | for {
21 | select {
22 | case <-time.NewTimer(time.Second).C: // time is up
23 | break OUTER
24 | case i := <-ch:
25 | err := rl.Limit()
26 | if i < 10 {
27 | assert.Nil(t, err)
28 | } else {
29 | assert.Equal(t, err, ErrRateLimited)
30 | if i == 99 {
31 | break OUTER
32 | }
33 | }
34 | }
35 | }
36 | }
37 |
38 | func TestSlidingWindowRateLimiter_ChangeQpsThreshold(t *testing.T) {
39 | // 测试 ChangeQpsThreshold。qpsThreshold 一开始为 10
40 | // 1)发送 100 个请求,前 10 个通过,其他被限流
41 | // 2)等待 1s 并将 qpsThreshold 置为 20
42 | // 3)发送 100 个请求,前 20 个通过,其他被限流
43 | rl := NewSlidingWindowRateLimiter(10)
44 | ch := make(chan int)
45 | go func() {
46 | for i := 0; i < 100; i++ {
47 | ch <- i
48 | }
49 |
50 | // 等待 1s,将 qps 阈值置为 20
51 | <-time.NewTimer(time.Second).C
52 | rl.ChangeQpsThreshold(20)
53 | for i := 100; i < 200; i++ {
54 | ch <- i
55 | }
56 | }()
57 |
58 | OUTER:
59 | for {
60 | select {
61 | case <-time.NewTimer(time.Second * 2).C: // time is up
62 | break OUTER
63 | case i := <-ch:
64 | err := rl.Limit()
65 | if i < 100 { // 前 100 个,10 个通过
66 | if i < 10 {
67 | assert.Nil(t, err)
68 | } else {
69 | assert.Equal(t, err, ErrRateLimited)
70 | }
71 | } else { // 后 100 个,20 个通过
72 | if i < 120 {
73 | assert.Nil(t, err)
74 | } else {
75 | assert.Equal(t, err, ErrRateLimited)
76 | if i == 199 {
77 | break OUTER
78 | }
79 | }
80 | }
81 | }
82 | }
83 | }
84 |
85 | func TestSlidingWindowRateLimiter_ConcurrentLimit(t *testing.T) {
86 | rl := NewSlidingWindowRateLimiter(20000)
87 | resultCh := make(chan int)
88 | const GoRoutines = 100
89 | for i := 0; i < GoRoutines; i++ {
90 | go func() {
91 | passedCount := 0
92 | for j := 0; j < 1000; j++ {
93 | err := rl.Limit()
94 | if err == nil {
95 | passedCount++
96 | }
97 | }
98 | resultCh <- passedCount
99 | }()
100 | }
101 |
102 | sum := 0
103 | i := 0
104 | for {
105 | count := <-resultCh
106 | sum += count
107 | i++
108 | if i >= GoRoutines {
109 | break
110 | }
111 | }
112 | assert.Equal(t, sum, 20000)
113 | }
114 |
--------------------------------------------------------------------------------
/docs/cn/quickstart.md:
--------------------------------------------------------------------------------
1 | # 快速上手
2 |
3 | 本文介绍如何快速上手体验Weir平台.
4 |
5 | ## 前提
6 |
7 | 使用 weir-proxy 的前提首先要部署一套TiDB集群。对 weir-proxy 来说, 后端数据库也可使用 MySQL 代替 TiDB 进行测试.
8 |
9 | ### 安装 TiDB/MySQL
10 |
11 | 安装 TiDB 可参考 [TiDB数据库快速上手指南](https://docs.pingcap.com/zh/tidb/stable/quick-start-with-tidb) 进行安装。
12 |
13 | ### 构造数据
14 |
15 | 安装完成后, 需要连接TiDB / MySQL 执行以下SQL语句进行建库和建表操作, 方便测试weir-proxy.
16 |
17 | #### 建库
18 | ```
19 | DROP DATABASE IF EXISTS `test_weir_db`;
20 | CREATE DATABASE `test_weir_db`;
21 | USE `test_weir_db`;
22 | ```
23 |
24 | #### 建表
25 | ```
26 | CREATE TABLE `test_weir_user` (
27 | `id` bigint(22) unsigned NOT NULL AUTO_INCREMENT,
28 | `name` varchar(128) NOT NULL,
29 | PRIMARY KEY (`id`),
30 | UNIQUE `uniq_name` (`name`)
31 | ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
32 |
33 | CREATE TABLE `test_weir_admin` (
34 | `id` bigint(22) unsigned NOT NULL AUTO_INCREMENT,
35 | `name` varchar(128) NOT NULL,
36 | `status` varchar(128) NOT NULL DEFAULT 'normal',
37 | PRIMARY KEY (`id`),
38 | UNIQUE `uniq_name` (`name`)
39 | ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
40 | ```
41 |
42 | #### 写入测试数据
43 | ```
44 | INSERT INTO `test_weir_user` (name) VALUES ('Bob');
45 | INSERT INTO `test_weir_user` (name) VALUES ('Alice');
46 |
47 | INSERT INTO `test_weir_admin` (name) VALUES ('Ed');
48 | INSERT INTO `test_weir_admin` (name) VALUES ('Huang');
49 | ```
50 |
51 | ## 安装启动 weir-proxy
52 |
53 | ### 从源码编译安装
54 |
55 | 首先, 从github克隆代码仓库到本地.
56 |
57 | ```
58 | $ git clone https://github.com/tidb-incubator/weir
59 | ```
60 |
61 | 构建weir-proxy.
62 |
63 | ```
64 | $ make weirproxy
65 | ```
66 |
67 | 生成的weirproxy二进制文件位于bin/目录下, 文件名为bin/weirproxy.
68 |
69 | 启动weir-proxy.
70 |
71 | ```
72 | $ ./bin/weirproxy &
73 | ```
74 |
75 | weir-proxy会默认读取示例配置文件conf/weirproxy.yml进行启动, 示例配置使用本地文件作为namespace配置中心, 配置文件位于conf/namespace/目录下.
76 |
77 | 使用MySQL客户端通过weir-proxy访问TiDB集群.
78 |
79 | ```
80 | $ mysql -h127.0.0.1 -P6000 -uhello -pworld test_weir_db
81 |
82 | mysql: [Warning] Using a password on the command line interface can be insecure.
83 | Welcome to the MySQL monitor. Commands end with ; or \g.
84 | Your MySQL connection id is 1
85 | Server version: 5.7.25-TiDB-None MySQL Community Server (GPL)
86 |
87 | Copyright (c) 2000, 2016, Oracle and/or its affiliates. All rights reserved.
88 |
89 | Oracle is a registered trademark of Oracle Corporation and/or its
90 | affiliates. Other names may be trademarks of their respective
91 | owners.
92 |
93 | Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
94 |
95 | mysql>
96 | ```
97 |
98 | 如果看到连接成功, 说明weir-proxy已经启动并可以使用了. 恭喜你!
99 |
100 | 目前 Weir 平台的代理层中间件 weir-proxy 已开源, 中控组件 weir-controller 和控制台 weir-dashboard 还未开源.
--------------------------------------------------------------------------------
/pkg/util/rate_limit_breaker/rate_limit/leaky_bucket.go:
--------------------------------------------------------------------------------
1 | package rate_limit
2 |
3 | import (
4 | "sync/atomic"
5 | "time"
6 | )
7 |
8 | /* 并发的 rate limiter。
9 | * 基于队列的 leaky bucket 算法实现。
10 | * 参见 https://en.wikipedia.org/wiki/Leaky_bucket leaky bucket 有两种实现方式
11 | * As a meter: 此与 token bucket 等价
12 | * As a queue: 此具有更严格的限速,能够避免 burst flow
13 | *
14 | *
15 | * 测试结果(MacBook Pro 16 英寸):
16 | * 1K QPS: 1 req per 1 milli sec ( 1ms)
17 | * 1w QPS: 1 req per 100 micro sec ( 0.1ms) timer 可以支持此精度。
18 | * 2w QPS: 1 req per 50 micro sec (0.05ms) timer 可以支持此精度。
19 | * 10w QPS: 1 req per 10 micro sec (0.01ms) timer 开始出现误差。
20 | * 为确保精度,建议不要超过 10w QPS。
21 | * (此问题并非无解,是有优化方案的。)
22 | */
23 | type LeakyBucketRateLimiter struct {
24 | qpsThreshold int64 // 共享,需通过原子操作进行读写
25 | ch chan chan struct{} // the leaky bucket
26 | stopCh chan struct{} // 用于关闭这个 rate limiter
27 | changeCh chan struct{} // 修改本 rate limiter 之后,通过此 channel 进行通知
28 | }
29 |
30 | func NewLeakyBucketRateLimiter(qpsThreshold int64) *LeakyBucketRateLimiter {
31 | lbrl := &LeakyBucketRateLimiter{
32 | qpsThreshold: qpsThreshold,
33 | ch: make(chan chan struct{}, 1),
34 | stopCh: make(chan struct{}),
35 | changeCh: make(chan struct{}),
36 | }
37 | go func() {
38 | lbrl.leak()
39 | }()
40 | return lbrl
41 | }
42 |
43 | func (lbrl *LeakyBucketRateLimiter) ChangeQpsThreshold(newQpsThreshold int64) {
44 | atomic.StoreInt64(&lbrl.qpsThreshold, newQpsThreshold)
45 | lbrl.changeCh <- struct{}{}
46 | }
47 |
48 | func (lbrl *LeakyBucketRateLimiter) getTick() time.Duration {
49 | qpsThreshold := atomic.LoadInt64(&lbrl.qpsThreshold)
50 | tick := time.Duration(1000.0 * int64(time.Millisecond) / qpsThreshold)
51 | return tick
52 | }
53 |
54 | func (lbrl *LeakyBucketRateLimiter) leak() {
55 | // 这里的基本逻辑是,根据 qpsThreshold 计算 tick,每个 tick 的时间间隔里,只允许一次动作。
56 | // 然而,qpsThreshold 很大时,tick 很小。这带来两个问题:
57 | // 1) timer 的精度可能无法满足要求,导致误差变大
58 | // 2) 本方法循环次数很多,占用较多 CPU
59 | // 一个优化方案,是去掉「每个 tick 的时间间隔里,只允许一次动作」限制。
60 | // 而是首先确保 tick 保持在合理的范围(如 >=0.1ms 且越小越好),并根据该 tick 和 qpsThreshold 确定每个 tick 里允许的动作次数(须为整数)。
61 | tickCh := time.Tick(lbrl.getTick())
62 | OUTER:
63 | for {
64 | select {
65 | case <-lbrl.stopCh: // stopped
66 | break OUTER
67 | case <-lbrl.changeCh: // rate limiter modified
68 | newTick := lbrl.getTick()
69 | tickCh = time.Tick(newTick)
70 | case <-tickCh:
71 | select {
72 | case waiterCh := <-lbrl.ch:
73 | waiterCh <- struct{}{}
74 | default:
75 | // pass
76 | }
77 | }
78 | }
79 | }
80 |
81 | func (lbrl *LeakyBucketRateLimiter) Limit() error {
82 | ch := make(chan struct{}, 1)
83 | lbrl.ch <- ch
84 | <-ch
85 | return nil
86 | }
87 |
88 | func (lbrl *LeakyBucketRateLimiter) Close() {
89 | lbrl.stopCh <- struct{}{}
90 | }
91 |
--------------------------------------------------------------------------------
/pkg/util/ast/ast_util.go:
--------------------------------------------------------------------------------
1 | package ast
2 |
3 | import (
4 | "context"
5 | "encoding/binary"
6 | "strings"
7 |
8 | "github.com/pingcap/parser/ast"
9 | "github.com/pingcap/parser/format"
10 | driver "github.com/pingcap/tidb/types/parser_driver"
11 | )
12 |
13 | const ctxAstTableNameKey = "ctx_ast_table_name"
14 |
15 | func CtxWithAstTableName(ctx context.Context, tableName string) context.Context {
16 | return context.WithValue(ctx, ctxAstTableNameKey, tableName)
17 | }
18 |
19 | func GetAstTableNameFromCtx(ctx context.Context) (string, bool) {
20 | tableName := ctx.Value(ctxAstTableNameKey)
21 | if tableName == nil {
22 | return "", false
23 | }
24 | tableNameStr, ok := tableName.(string)
25 | if !ok {
26 | return "", false
27 | }
28 | return tableNameStr, true
29 | }
30 |
31 | type FirstTableNameVisitor struct {
32 | table string
33 | found bool
34 | }
35 |
36 | func (f *FirstTableNameVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
37 | switch nn := n.(type) {
38 | case *ast.TableName:
39 | f.table = nn.Name.String()
40 | f.found = true
41 | return n, true
42 | }
43 | return n, false
44 | }
45 |
46 | func (f *FirstTableNameVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
47 | return n, !f.found
48 | }
49 |
50 | func (f *FirstTableNameVisitor) TableName() string {
51 | return f.table
52 | }
53 |
54 | func ExtractFirstTableNameFromStmt(stmt ast.StmtNode) string {
55 | visitor := &FirstTableNameVisitor{}
56 | stmt.Accept(visitor)
57 | return visitor.table
58 | }
59 |
60 | type AstVisitor struct {
61 | sqlFeature string
62 | }
63 |
64 | func ExtractAstVisit(stmt ast.StmtNode) (*AstVisitor, error) {
65 | visitor := &AstVisitor{}
66 |
67 | stmt.Accept(visitor)
68 |
69 | sb := strings.Builder{}
70 | if err := stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)); err != nil {
71 | return nil, err
72 | }
73 | visitor.sqlFeature = sb.String()
74 |
75 | return visitor, nil
76 | }
77 |
78 | func (f *AstVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
79 | switch nn := n.(type) {
80 | case *ast.PatternInExpr:
81 | if len(nn.List) == 0 {
82 | return nn, false
83 | }
84 | if _, ok := nn.List[0].(*driver.ValueExpr); ok {
85 | nn.List = nn.List[:1]
86 | }
87 | case *driver.ValueExpr:
88 | nn.SetValue("?")
89 | }
90 | return n, false
91 | }
92 |
93 | func (f *AstVisitor) Leave(n ast.Node) (node ast.Node, ok bool) {
94 | return n, true
95 | }
96 |
97 | func (f *AstVisitor) SqlFeature() string {
98 | return f.sqlFeature
99 | }
100 |
101 | func UInt322Bytes(n uint32) []byte {
102 | b := make([]byte, 4)
103 | binary.LittleEndian.PutUint32(b, n)
104 | return b
105 | }
106 |
107 | func Bytes2Uint32(b []byte) uint32 {
108 | return binary.LittleEndian.Uint32(b)
109 | }
110 |
--------------------------------------------------------------------------------
/pkg/util/timer/time_wheel_test.go:
--------------------------------------------------------------------------------
1 | package timer
2 |
3 | import (
4 | "strconv"
5 | "sync/atomic"
6 | "testing"
7 | "time"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | type A struct {
13 | a int
14 | b string
15 | isCallbacked int32
16 | }
17 |
18 | func (a *A) callback() {
19 | atomic.StoreInt32(&a.isCallbacked, 1)
20 | }
21 |
22 | func (a *A) getCallbackValue() int32 {
23 | return atomic.LoadInt32(&a.isCallbacked)
24 | }
25 |
26 | func newTimeWheel() *TimeWheel {
27 | tw, err := NewTimeWheel(time.Second, 3600)
28 | if err != nil {
29 | panic(err)
30 | }
31 | tw.Start()
32 | return tw
33 | }
34 |
35 | func TestNewTimeWheel(t *testing.T) {
36 | tests := []struct {
37 | name string
38 | tick time.Duration
39 | bucketNum int
40 | hasErr bool
41 | }{
42 | {tick: time.Second, bucketNum: 0, hasErr: true},
43 | {tick: time.Millisecond, bucketNum: 1, hasErr: true},
44 | {tick: time.Second, bucketNum: 1, hasErr: false},
45 | }
46 | for _, test := range tests {
47 | t.Run(test.name, func(t *testing.T) {
48 | _, err := NewTimeWheel(test.tick, test.bucketNum)
49 | assert.Equal(t, test.hasErr, err != nil)
50 | })
51 | }
52 | }
53 |
54 | func TestAdd(t *testing.T) {
55 | tw := newTimeWheel()
56 | a := &A{}
57 | err := tw.Add(time.Second*1, "test", a.callback)
58 | assert.NoError(t, err)
59 |
60 | time.Sleep(time.Millisecond * 500)
61 | assert.Equal(t, int32(0), a.getCallbackValue())
62 | time.Sleep(time.Second * 2)
63 | assert.Equal(t, int32(1), a.getCallbackValue())
64 | tw.Stop()
65 | }
66 |
67 | func TestAddMultipleTimes(t *testing.T) {
68 | a := &A{}
69 | tw := newTimeWheel()
70 | for i := 0; i < 4; i++ {
71 | err := tw.Add(time.Second, "test", a.callback)
72 | assert.NoError(t, err)
73 | time.Sleep(time.Millisecond * 500)
74 | t.Logf("current: %d", i)
75 | assert.Equal(t, int32(0), a.getCallbackValue())
76 | }
77 |
78 | time.Sleep(time.Second * 2)
79 | assert.Equal(t, int32(1), a.getCallbackValue())
80 | tw.Stop()
81 | }
82 |
83 | func TestRemove(t *testing.T) {
84 | a := &A{a: 10, b: "test"}
85 | tw := newTimeWheel()
86 | err := tw.Add(time.Second*1, a, a.callback)
87 | assert.NoError(t, err)
88 |
89 | time.Sleep(time.Millisecond * 500)
90 | assert.Equal(t, int32(0), a.getCallbackValue())
91 | err = tw.Remove(a)
92 | assert.NoError(t, err)
93 | time.Sleep(time.Second * 2)
94 | assert.Equal(t, int32(0), a.getCallbackValue())
95 | tw.Stop()
96 | }
97 |
98 | func BenchmarkAdd(b *testing.B) {
99 | a := &A{}
100 | tw := newTimeWheel()
101 | for i := 0; i < b.N; i++ {
102 | key := "test" + strconv.Itoa(i)
103 | err := tw.Add(time.Second, key, a.callback)
104 | if err != nil {
105 | b.Fatalf("benchmark Add failed, %v", err)
106 | }
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/pkg/proxy/server/column.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | "github.com/pingcap/parser/mysql"
18 | )
19 |
20 | const maxColumnNameSize = 256
21 |
22 | // ColumnInfo contains information of a column
23 | type ColumnInfo struct {
24 | Schema string
25 | Table string
26 | OrgTable string
27 | Name string
28 | OrgName string
29 | ColumnLength uint32
30 | Charset uint16
31 | Flag uint16
32 | Decimal uint8
33 | Type uint8
34 | DefaultValueLength uint64
35 | DefaultValue []byte
36 | }
37 |
38 | // Dump dumps ColumnInfo to bytes.
39 | func (column *ColumnInfo) Dump(buffer []byte) []byte {
40 | nameDump, orgnameDump := []byte(column.Name), []byte(column.OrgName)
41 | if len(nameDump) > maxColumnNameSize {
42 | nameDump = nameDump[0:maxColumnNameSize]
43 | }
44 | if len(orgnameDump) > maxColumnNameSize {
45 | orgnameDump = orgnameDump[0:maxColumnNameSize]
46 | }
47 | buffer = dumpLengthEncodedString(buffer, []byte("def"))
48 | buffer = dumpLengthEncodedString(buffer, []byte(column.Schema))
49 | buffer = dumpLengthEncodedString(buffer, []byte(column.Table))
50 | buffer = dumpLengthEncodedString(buffer, []byte(column.OrgTable))
51 | buffer = dumpLengthEncodedString(buffer, nameDump)
52 | buffer = dumpLengthEncodedString(buffer, orgnameDump)
53 |
54 | buffer = append(buffer, 0x0c)
55 |
56 | buffer = dumpUint16(buffer, column.Charset)
57 | buffer = dumpUint32(buffer, column.ColumnLength)
58 | buffer = append(buffer, dumpType(column.Type))
59 | buffer = dumpUint16(buffer, dumpFlag(column.Type, column.Flag))
60 | buffer = append(buffer, column.Decimal)
61 | buffer = append(buffer, 0, 0)
62 |
63 | if column.DefaultValue != nil {
64 | buffer = dumpUint64(buffer, uint64(len(column.DefaultValue)))
65 | buffer = append(buffer, column.DefaultValue...)
66 | }
67 |
68 | return buffer
69 | }
70 |
71 | func dumpFlag(tp byte, flag uint16) uint16 {
72 | switch tp {
73 | case mysql.TypeSet:
74 | return flag | uint16(mysql.SetFlag)
75 | case mysql.TypeEnum:
76 | return flag | uint16(mysql.EnumFlag)
77 | default:
78 | if mysql.HasBinaryFlag(uint(flag)) {
79 | return flag | uint16(mysql.NotNullFlag)
80 | }
81 | return flag
82 | }
83 | }
84 |
85 | func dumpType(tp byte) byte {
86 | switch tp {
87 | case mysql.TypeSet, mysql.TypeEnum:
88 | return mysql.TypeString
89 | default:
90 | return tp
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/pkg/util/sync2/atomic_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sync2
18 |
19 | import (
20 | "testing"
21 | "time"
22 |
23 | "github.com/stretchr/testify/assert"
24 | )
25 |
26 | func TestAtomicInt32(t *testing.T) {
27 | i := NewAtomicInt32(1)
28 | assert.Equal(t, int32(1), i.Get())
29 |
30 | i.Set(2)
31 | assert.Equal(t, int32(2), i.Get())
32 |
33 | i.Add(1)
34 | assert.Equal(t, int32(3), i.Get())
35 |
36 | i.CompareAndSwap(3, 4)
37 | assert.Equal(t, int32(4), i.Get())
38 |
39 | i.CompareAndSwap(3, 5)
40 | assert.Equal(t, int32(4), i.Get())
41 | }
42 |
43 | func TestAtomicInt64(t *testing.T) {
44 | i := NewAtomicInt64(1)
45 | assert.Equal(t, int64(1), i.Get())
46 |
47 | i.Set(2)
48 | assert.Equal(t, int64(2), i.Get())
49 |
50 | i.Add(1)
51 | assert.Equal(t, int64(3), i.Get())
52 |
53 | i.CompareAndSwap(3, 4)
54 | assert.Equal(t, int64(4), i.Get())
55 |
56 | i.CompareAndSwap(3, 5)
57 | assert.Equal(t, int64(4), i.Get())
58 | }
59 |
60 | func TestAtomicDuration(t *testing.T) {
61 | d := NewAtomicDuration(time.Second)
62 | assert.Equal(t, time.Second, d.Get())
63 |
64 | d.Set(time.Second * 2)
65 | assert.Equal(t, time.Second*2, d.Get())
66 |
67 | d.Add(time.Second)
68 | assert.Equal(t, time.Second*3, d.Get())
69 |
70 | d.CompareAndSwap(time.Second*3, time.Second*4)
71 | assert.Equal(t, time.Second*4, d.Get())
72 |
73 | d.CompareAndSwap(time.Second*3, time.Second*5)
74 | assert.Equal(t, time.Second*4, d.Get())
75 | }
76 |
77 | func TestAtomicString(t *testing.T) {
78 | var s AtomicString
79 | assert.Equal(t, "", s.Get())
80 |
81 | s.Set("a")
82 | assert.Equal(t, "a", s.Get())
83 |
84 | assert.Equal(t, false, s.CompareAndSwap("b", "c"))
85 | assert.Equal(t, "a", s.Get())
86 |
87 | assert.Equal(t, true, s.CompareAndSwap("a", "c"))
88 | assert.Equal(t, "c", s.Get())
89 | }
90 |
91 | func TestAtomicBool(t *testing.T) {
92 | b := NewAtomicBool(true)
93 | assert.Equal(t, true, b.Get())
94 |
95 | b.Set(false)
96 | assert.Equal(t, false, b.Get())
97 |
98 | b.Set(true)
99 | assert.Equal(t, true, b.Get())
100 |
101 | assert.Equal(t, false, b.CompareAndSwap(false, true))
102 |
103 | assert.Equal(t, true, b.CompareAndSwap(true, false))
104 |
105 | assert.Equal(t, true, b.CompareAndSwap(false, false))
106 |
107 | assert.Equal(t, true, b.CompareAndSwap(false, true))
108 |
109 | assert.Equal(t, true, b.CompareAndSwap(true, true))
110 | }
111 |
--------------------------------------------------------------------------------
/pkg/proxy/server/conn_stmt.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "encoding/binary"
6 |
7 | "github.com/pingcap/errors"
8 | "github.com/pingcap/parser/mysql"
9 | )
10 |
11 | // TODO(eastfisher): fix me when prepare is implemented
12 | func (cc *clientConn) preparedStmt2String(stmtID uint32) string {
13 | return ""
14 | }
15 |
16 | // TODO(eastfisher): fix me when prepare is implemented
17 | func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string {
18 | return ""
19 | }
20 |
21 | func (cc *clientConn) handleStmtPrepare(ctx context.Context, sql string) error {
22 | stmtId, columns, params, err := cc.ctx.Prepare(ctx, sql)
23 | if err != nil {
24 | return err
25 | }
26 | data := make([]byte, 4, 128)
27 |
28 | //status ok
29 | data = append(data, 0)
30 | //stmt id
31 | data = dumpUint32(data, uint32(stmtId))
32 | //number columns
33 | data = dumpUint16(data, uint16(len(columns)))
34 | //number params
35 | data = dumpUint16(data, uint16(len(params)))
36 | //filter [00]
37 | data = append(data, 0)
38 | //warning count
39 | data = append(data, 0, 0) //TODO support warning count
40 |
41 | if err := cc.writePacket(data); err != nil {
42 | return err
43 | }
44 |
45 | if len(params) > 0 {
46 | for i := 0; i < len(params); i++ {
47 | data = data[0:4]
48 | data = params[i].Dump(data)
49 |
50 | if err := cc.writePacket(data); err != nil {
51 | return err
52 | }
53 | }
54 |
55 | if err := cc.writeEOF(0); err != nil {
56 | return err
57 | }
58 | }
59 |
60 | if len(columns) > 0 {
61 | for i := 0; i < len(columns); i++ {
62 | data = data[0:4]
63 | data = columns[i].Dump(data)
64 |
65 | if err := cc.writePacket(data); err != nil {
66 | return err
67 | }
68 | }
69 |
70 | if err := cc.writeEOF(0); err != nil {
71 | return err
72 | }
73 |
74 | }
75 | return cc.flush()
76 | }
77 |
78 | func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) error {
79 | if len(data) < 9 {
80 | return mysql.ErrMalformPacket
81 | }
82 |
83 | stmtID := binary.LittleEndian.Uint32(data[0:4])
84 | ret, err := cc.ctx.StmtExecuteForward(ctx, int(stmtID), data)
85 | if err != nil {
86 | return err
87 | }
88 |
89 | if ret != nil {
90 | err = cc.writeGoMySQLResultset(ctx, ret.Resultset, true, ret.Status, 0)
91 | } else {
92 | err = cc.writeOK()
93 | }
94 | return err
95 | }
96 |
97 | // TODO(eastfisher): implement this function
98 | func (cc *clientConn) handleStmtSendLongData(data []byte) error {
99 | return errors.New("stmt not implemented")
100 | }
101 |
102 | // TODO(eastfisher): implement this function
103 | func (cc *clientConn) handleStmtReset(data []byte) error {
104 | return errors.New("stmt not implemented")
105 | }
106 |
107 | func (cc *clientConn) handleStmtClose(ctx context.Context, data []byte) error {
108 | if len(data) < 4 {
109 | return nil
110 | }
111 |
112 | stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
113 | if err := cc.ctx.StmtClose(ctx, stmtID); err != nil {
114 | return err
115 | }
116 |
117 | return cc.writeOK()
118 | }
119 |
--------------------------------------------------------------------------------
/pkg/proxy/metrics/metrics.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "github.com/prometheus/client_golang/prometheus"
5 | )
6 |
7 | const (
8 | ModuleWeirProxy = "weirproxy"
9 | )
10 |
11 | // metrics labels.
12 | const (
13 | LabelServer = "server"
14 | LabelQueryCtx = "queryctx"
15 | LabelBackend = "backend"
16 | LabelSession = "session"
17 | LabelDomain = "domain"
18 | LabelDDLOwner = "ddl-owner"
19 | LabelDDL = "ddl"
20 | LabelDDLWorker = "ddl-worker"
21 | LabelDDLSyncer = "ddl-syncer"
22 | LabelGCWorker = "gcworker"
23 | LabelAnalyze = "analyze"
24 |
25 | LabelBatchRecvLoop = "batch-recv-loop"
26 | LabelBatchSendLoop = "batch-send-loop"
27 |
28 | opSucc = "ok"
29 | opFailed = "err"
30 |
31 | LableScope = "scope"
32 | ScopeGlobal = "global"
33 | ScopeSession = "session"
34 | )
35 |
36 | // RetLabel returns "ok" when err == nil and "err" when err != nil.
37 | // This could be useful when you need to observe the operation result.
38 | func RetLabel(err error) string {
39 | if err == nil {
40 | return opSucc
41 | }
42 | return opFailed
43 | }
44 |
45 | func RegisterProxyMetrics(cluster string) {
46 | curryingLabelsWithLblCluster := map[string]string{LblCluster: cluster}
47 |
48 | PanicCounter = PanicCounter.MustCurryWith(curryingLabelsWithLblCluster)
49 | prometheus.MustRegister(PanicCounter)
50 | QueryTotalCounter = QueryTotalCounter.MustCurryWith(curryingLabelsWithLblCluster)
51 | prometheus.MustRegister(QueryTotalCounter)
52 | ExecuteErrorCounter = ExecuteErrorCounter.MustCurryWith(curryingLabelsWithLblCluster)
53 | prometheus.MustRegister(ExecuteErrorCounter)
54 | ConnGauge = ConnGauge.MustCurryWith(curryingLabelsWithLblCluster)
55 | prometheus.MustRegister(ConnGauge)
56 |
57 | // query ctx metrics
58 | QueryCtxQueryCounter = QueryCtxQueryCounter.MustCurryWith(curryingLabelsWithLblCluster)
59 | prometheus.MustRegister(QueryCtxQueryCounter)
60 | QueryCtxQueryDeniedCounter = QueryCtxQueryDeniedCounter.MustCurryWith(curryingLabelsWithLblCluster)
61 | prometheus.MustRegister(QueryCtxQueryDeniedCounter)
62 | QueryCtxQueryDurationHistogram = QueryCtxQueryDurationHistogram.MustCurryWith(curryingLabelsWithLblCluster).(*prometheus.HistogramVec)
63 | prometheus.MustRegister(QueryCtxQueryDurationHistogram)
64 | QueryCtxGauge = QueryCtxGauge.MustCurryWith(curryingLabelsWithLblCluster)
65 | prometheus.MustRegister(QueryCtxGauge)
66 | QueryCtxAttachedConnGauge = QueryCtxAttachedConnGauge.MustCurryWith(curryingLabelsWithLblCluster)
67 | prometheus.MustRegister(QueryCtxAttachedConnGauge)
68 | QueryCtxTransactionDuration = QueryCtxTransactionDuration.MustCurryWith(curryingLabelsWithLblCluster).(*prometheus.HistogramVec)
69 | prometheus.MustRegister(QueryCtxTransactionDuration)
70 |
71 | // backend metrics
72 | BackendEventCounter = BackendEventCounter.MustCurryWith(curryingLabelsWithLblCluster)
73 | prometheus.MustRegister(BackendEventCounter)
74 | BackendQueryCounter = BackendQueryCounter.MustCurryWith(curryingLabelsWithLblCluster)
75 | prometheus.MustRegister(BackendQueryCounter)
76 | BackendConnInUseGauge = BackendConnInUseGauge.MustCurryWith(curryingLabelsWithLblCluster)
77 | prometheus.MustRegister(BackendConnInUseGauge)
78 | }
79 |
--------------------------------------------------------------------------------
/pkg/util/rate_limit_breaker/sliding_window.go:
--------------------------------------------------------------------------------
1 | package rate_limit_breaker
2 |
3 | import (
4 | "time"
5 | )
6 |
7 | /*
8 | * 一个 SlidingWindow 由若干个子单元(Cell)组成,数量(Size)在创建 SlidingWindow 时指定。
9 | * 各个 Cell 的时长(CellIntervalMs)相同,在创建 SlidingWindow 时指定。
10 | * Cell 持有一个 map 用于统计计数
11 | *
12 | * 自 epoch 以来,按照一个 Cell 的时长,将时间轴切成一个个的段,对应到 Cell。
13 | * 每个 Cell 记录其开始时间。因为我们并不是按照时间定时更新 Cells,而是 hit 时更新,而 hit 是由调用者控制的。
14 | */
15 |
16 | type Cell struct {
17 | startMs int64
18 | // metricName => count
19 | stats map[string]int64
20 | }
21 |
22 | func (cell *Cell) Reset() {
23 | cell.startMs = 0
24 | cell.stats = map[string]int64{}
25 | }
26 |
27 | type SlidingWindow struct {
28 | Size int64
29 | CellIntervalMs int64
30 | Cells []*Cell // invariant: len(Cells) == Size.
31 | }
32 |
33 | func NewSlidingWindow(size int64, cellIntervalMs int64) *SlidingWindow {
34 | cells := make([]*Cell, size)
35 | for i := 0; int64(i) < size; i++ {
36 | cells[i] = &Cell{
37 | startMs: 0,
38 | stats: map[string]int64{},
39 | }
40 | }
41 |
42 | return &SlidingWindow{
43 | Size: size,
44 | CellIntervalMs: cellIntervalMs,
45 | Cells: cells,
46 | }
47 | }
48 |
49 | // 一次动作(如请求),记为一次 Hit。
50 | func (sw *SlidingWindow) Hit(nowMs int64, metricNames ...string) {
51 | cell := sw.getCell(nowMs)
52 | if nowMs-cell.startMs >= sw.CellIntervalMs { // lazily check if cell expired
53 | cell.startMs = sw.cellStartMs(nowMs)
54 | cell.stats = map[string]int64{}
55 | }
56 | for _, metric := range metricNames {
57 | cell.stats[metric]++
58 | }
59 | }
60 |
61 | func (sw *SlidingWindow) getCell(nowMs int64) *Cell {
62 | idx := nowMs / sw.CellIntervalMs % sw.Size
63 | return sw.Cells[idx]
64 | }
65 |
66 | func (sw *SlidingWindow) cellStartMs(nowMs int64) int64 {
67 | return nowMs - nowMs%sw.CellIntervalMs
68 | }
69 |
70 | func (sw *SlidingWindow) GetHits(nowMs int64, metricNames ...string) map[string]int64 {
71 | windowStart := nowMs - sw.Size*sw.CellIntervalMs
72 | stats := map[string]int64{}
73 | for _, cell := range sw.Cells {
74 | if cell.startMs < windowStart { // lazily check if cell expired
75 | continue
76 | }
77 | for _, metricName := range metricNames {
78 | stats[metricName] += cell.stats[metricName]
79 | }
80 | }
81 | return stats
82 | }
83 |
84 | func (sw *SlidingWindow) GetNowHits(nowMs int64, metricNames ...string) map[string]int64 {
85 | cell := sw.getCell(nowMs)
86 | stats := map[string]int64{}
87 | for _, metricName := range metricNames {
88 | stats[metricName] += cell.stats[metricName]
89 | }
90 | return stats
91 | }
92 |
93 | func (sw *SlidingWindow) GetHit(nowMs int64, metricName string) int64 {
94 | return sw.GetHits(nowMs, metricName)[metricName]
95 | }
96 |
97 | func (sw *SlidingWindow) GetActualDurationMs(nowMs int64) int64 {
98 | // 当前时间点,可能处在某个 cell 正中间,这里精确计算所涉及 cell 的总时间段
99 | actualDurationMs := (sw.Size-1)*sw.CellIntervalMs + nowMs%sw.CellIntervalMs
100 | return actualDurationMs
101 | }
102 |
103 | // timestamp in ms
104 | func GetNowMs() int64 {
105 | return time.Now().UnixNano() / int64(time.Millisecond)
106 | }
107 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/namespace.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/pingcap/errors"
7 | "github.com/tidb-incubator/weir/pkg/config"
8 | "github.com/tidb-incubator/weir/pkg/proxy/driver"
9 | "github.com/tidb-incubator/weir/pkg/proxy/metrics"
10 | )
11 |
12 | type NamespaceHolder struct {
13 | nss map[string]Namespace
14 | }
15 |
16 | type NamespaceWrapper struct {
17 | nsmgr *NamespaceManager
18 | name string
19 | }
20 |
21 | func CreateNamespaceHolder(cfgs []*config.Namespace, build NamespaceBuilder) (*NamespaceHolder, error) {
22 | nss := make(map[string]Namespace, len(cfgs))
23 |
24 | for _, cfg := range cfgs {
25 | ns, err := build(cfg)
26 | if err != nil {
27 | return nil, errors.WithMessage(err, fmt.Sprintf("create namespace error, namespace: %s", cfg.Namespace))
28 | }
29 | nss[cfg.Namespace] = ns
30 | }
31 |
32 | holder := &NamespaceHolder{
33 | nss: nss,
34 | }
35 | return holder, nil
36 | }
37 |
38 | func (n *NamespaceHolder) Get(name string) (Namespace, bool) {
39 | ns, ok := n.nss[name]
40 | return ns, ok
41 | }
42 |
43 | func (n *NamespaceHolder) Set(name string, ns Namespace) {
44 | n.nss[name] = ns
45 | }
46 |
47 | func (n *NamespaceHolder) Delete(name string) {
48 | delete(n.nss, name)
49 | }
50 |
51 | func (n *NamespaceHolder) Clone() *NamespaceHolder {
52 | nss := make(map[string]Namespace)
53 | for name, ns := range n.nss {
54 | nss[name] = ns
55 | }
56 | return &NamespaceHolder{
57 | nss: nss,
58 | }
59 | }
60 |
61 | func (n *NamespaceWrapper) Name() string {
62 | return n.name
63 | }
64 |
65 | func (n *NamespaceWrapper) IsDatabaseAllowed(db string) bool {
66 | return n.mustGetCurrentNamespace().IsDatabaseAllowed(db)
67 | }
68 |
69 | func (n *NamespaceWrapper) ListDatabases() []string {
70 | return n.mustGetCurrentNamespace().ListDatabases()
71 | }
72 |
73 | func (n *NamespaceWrapper) IsDeniedSQL(sqlFeature uint32) bool {
74 | return n.mustGetCurrentNamespace().IsDeniedSQL(sqlFeature)
75 | }
76 |
77 | func (n *NamespaceWrapper) IsAllowedSQL(sqlFeature uint32) bool {
78 | return n.mustGetCurrentNamespace().IsAllowedSQL(sqlFeature)
79 | }
80 |
81 | func (n *NamespaceWrapper) GetPooledConn(ctx context.Context) (driver.PooledBackendConn, error) {
82 | return n.mustGetCurrentNamespace().GetPooledConn(ctx)
83 | }
84 |
85 | func (n *NamespaceWrapper) IncrConnCount() {
86 | metrics.QueryCtxGauge.WithLabelValues(n.name).Inc()
87 | }
88 |
89 | func (n *NamespaceWrapper) DescConnCount() {
90 | metrics.QueryCtxGauge.WithLabelValues(n.name).Dec()
91 | }
92 |
93 | func (n *NamespaceWrapper) Closed() bool {
94 | _, ok := n.nsmgr.getCurrentNamespaces().Get(n.name)
95 | return !ok
96 | }
97 |
98 | func (n *NamespaceWrapper) GetBreaker() (driver.Breaker, error) {
99 | return n.mustGetCurrentNamespace().GetBreaker()
100 | }
101 |
102 | func (n *NamespaceWrapper) GetRateLimiter() driver.RateLimiter {
103 | return n.mustGetCurrentNamespace().GetRateLimiter()
104 | }
105 |
106 | func (n *NamespaceWrapper) mustGetCurrentNamespace() Namespace {
107 | ns, ok := n.nsmgr.getCurrentNamespaces().Get(n.name)
108 | if !ok {
109 | panic(errors.New("namespace not found"))
110 | }
111 | return ns
112 | }
113 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/sessionvars.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "fmt"
5 | "sync/atomic"
6 |
7 | "github.com/pingcap/parser/ast"
8 | "github.com/pingcap/tidb/sessionctx/variable"
9 | )
10 |
11 | type SessionVarsWrapper struct {
12 | sessionVarMap map[string]*ast.VariableAssignment
13 | sessionVars *variable.SessionVars
14 | affectedRows uint64
15 | }
16 |
17 | func NewSessionVarsWrapper(sessionVars *variable.SessionVars) *SessionVarsWrapper {
18 | return &SessionVarsWrapper{
19 | sessionVars: sessionVars,
20 | sessionVarMap: make(map[string]*ast.VariableAssignment),
21 | }
22 | }
23 |
24 | func (s *SessionVarsWrapper) SessionVars() *variable.SessionVars {
25 | return s.sessionVars
26 | }
27 |
28 | func (s *SessionVarsWrapper) GetAllSystemVars() map[string]*ast.VariableAssignment {
29 | ret := make(map[string]*ast.VariableAssignment, len(s.sessionVarMap))
30 | for k, v := range s.sessionVarMap {
31 | ret[k] = v
32 | }
33 | return ret
34 | }
35 |
36 | func (s *SessionVarsWrapper) SetSystemVarAST(name string, v *ast.VariableAssignment) {
37 | s.sessionVarMap[name] = v
38 | }
39 |
40 | func (s *SessionVarsWrapper) CheckSessionSysVarValid(name string) error {
41 | if name == ast.SetNames || name == ast.SetCharset {
42 | return nil
43 | }
44 | sysVar := variable.GetSysVar(name)
45 | if sysVar == nil {
46 | return fmt.Errorf("%s is not a valid sysvar", name)
47 | }
48 | if (sysVar.Scope & variable.ScopeSession) == 0 {
49 | return fmt.Errorf("%s is not a session scope sysvar", name)
50 | }
51 | return nil
52 | }
53 |
54 | func (s *SessionVarsWrapper) SetSystemVarDefault(name string) {
55 | delete(s.sessionVarMap, name)
56 | }
57 |
58 | func (s *SessionVarsWrapper) Status() uint16 {
59 | return s.sessionVars.Status
60 | }
61 |
62 | func (s *SessionVarsWrapper) GetStatusFlag(flag uint16) bool {
63 | return s.sessionVars.GetStatusFlag(flag)
64 | }
65 |
66 | func (s *SessionVarsWrapper) SetStatusFlag(flag uint16, on bool) {
67 | s.sessionVars.SetStatusFlag(flag, on)
68 | }
69 |
70 | // TODO(eastfisher): remove this function
71 | func (s *SessionVarsWrapper) GetCharsetInfo() (charset, collation string) {
72 | return s.sessionVars.GetCharsetInfo()
73 | }
74 |
75 | func (s *SessionVarsWrapper) AffectedRows() uint64 {
76 | return s.affectedRows
77 | }
78 |
79 | func (s *SessionVarsWrapper) SetAffectRows(count uint64) {
80 | s.affectedRows = count
81 | }
82 |
83 | func (s *SessionVarsWrapper) LastInsertID() uint64 {
84 | return s.sessionVars.StmtCtx.LastInsertID
85 | }
86 |
87 | func (s *SessionVarsWrapper) SetLastInsertID(id uint64) {
88 | s.sessionVars.StmtCtx.LastInsertID = id
89 | }
90 |
91 | func (s *SessionVarsWrapper) GetMessage() string {
92 | return s.sessionVars.StmtCtx.GetMessage()
93 | }
94 |
95 | func (s *SessionVarsWrapper) SetMessage(msg string) {
96 | s.sessionVars.StmtCtx.SetMessage(msg)
97 | }
98 |
99 | func (s *SessionVarsWrapper) GetClientCapability() uint32 {
100 | return s.sessionVars.ClientCapability
101 | }
102 |
103 | func (s *SessionVarsWrapper) SetClientCapability(capability uint32) {
104 | s.sessionVars.ClientCapability = capability
105 | }
106 |
107 | func (s *SessionVarsWrapper) SetCommandValue(command byte) {
108 | atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command))
109 | }
110 |
--------------------------------------------------------------------------------
/pkg/proxy/server/conn_util.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | "encoding/binary"
18 | "fmt"
19 | "strconv"
20 |
21 | "github.com/pingcap/errors"
22 | "github.com/pingcap/parser"
23 | "github.com/pingcap/parser/mysql"
24 | "github.com/pingcap/tidb/kv"
25 | "github.com/pingcap/tidb/util/hack"
26 | )
27 |
28 | var _ fmt.Stringer = getLastStmtInConn{}
29 |
30 | type getLastStmtInConn struct {
31 | *clientConn
32 | }
33 |
34 | func (cc getLastStmtInConn) String() string {
35 | if len(cc.lastPacket) == 0 {
36 | return ""
37 | }
38 | cmd, data := cc.lastPacket[0], cc.lastPacket[1:]
39 | switch cmd {
40 | case mysql.ComInitDB:
41 | return "Use " + string(data)
42 | case mysql.ComFieldList:
43 | return "ListFields " + string(data)
44 | case mysql.ComQuery, mysql.ComStmtPrepare:
45 | sql := string(hack.String(data))
46 | if cc.ctx.GetSessionVars().EnableLogDesensitization {
47 | sql, _ = parser.NormalizeDigest(sql)
48 | }
49 | return queryStrForLog(sql)
50 | case mysql.ComStmtExecute, mysql.ComStmtFetch:
51 | stmtID := binary.LittleEndian.Uint32(data[0:4])
52 | return queryStrForLog(cc.preparedStmt2String(stmtID))
53 | case mysql.ComStmtClose, mysql.ComStmtReset:
54 | stmtID := binary.LittleEndian.Uint32(data[0:4])
55 | return mysql.Command2Str[cmd] + " " + strconv.Itoa(int(stmtID))
56 | default:
57 | if cmdStr, ok := mysql.Command2Str[cmd]; ok {
58 | return cmdStr
59 | }
60 | return string(hack.String(data))
61 | }
62 | }
63 |
64 | // PProfLabel return sql label used to tag pprof.
65 | func (cc getLastStmtInConn) PProfLabel() string {
66 | if len(cc.lastPacket) == 0 {
67 | return ""
68 | }
69 | cmd, data := cc.lastPacket[0], cc.lastPacket[1:]
70 | switch cmd {
71 | case mysql.ComInitDB:
72 | return "UseDB"
73 | case mysql.ComFieldList:
74 | return "ListFields"
75 | case mysql.ComStmtClose:
76 | return "CloseStmt"
77 | case mysql.ComStmtReset:
78 | return "ResetStmt"
79 | case mysql.ComQuery, mysql.ComStmtPrepare:
80 | return parser.Normalize(queryStrForLog(string(hack.String(data))))
81 | case mysql.ComStmtExecute, mysql.ComStmtFetch:
82 | stmtID := binary.LittleEndian.Uint32(data[0:4])
83 | return queryStrForLog(cc.preparedStmt2StringNoArgs(stmtID))
84 | default:
85 | return ""
86 | }
87 | }
88 |
89 | func queryStrForLog(query string) string {
90 | const size = 4096
91 | if len(query) > size {
92 | return query[:size] + fmt.Sprintf("(len: %d)", len(query))
93 | }
94 | return query
95 | }
96 |
97 | func errStrForLog(err error) string {
98 | if kv.ErrKeyExists.Equal(err) || parser.ErrParse.Equal(err) {
99 | // Do not log stack for duplicated entry error.
100 | return err.Error()
101 | }
102 | return errors.ErrorStack(err)
103 | }
104 |
--------------------------------------------------------------------------------
/.golangci.yaml:
--------------------------------------------------------------------------------
1 | # options for analysis running
2 | run:
3 | # default concurrency is a available CPU number
4 | concurrency: 4
5 |
6 | # timeout for analysis, e.g. 30s, 5m, default is 1m
7 | timeout: 20m
8 |
9 | # exit code when at least one issue was found, default is 1
10 | issues-exit-code: 1
11 |
12 | # include test files or not, default is true
13 | tests: false
14 |
15 | # list of build tags, all linters use it. Default is empty list.
16 | build-tags:
17 |
18 | # default is true. Enables skipping of directories:
19 | # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$
20 | skip-dirs-use-default: true
21 |
22 | # which dirs to skip: they won't be analyzed;
23 | # can use regexp here: generated.*, regexp is applied on full path;
24 | # default value is empty list, but next dirs are always skipped independently
25 | # from this option's value:
26 | # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$
27 | # skip-dirs:
28 | # - ^test.*
29 |
30 | # by default isn't set. If set we pass it to "go list -mod={option}". From "go help modules":
31 | # If invoked with -mod=readonly, the go command is disallowed from the implicit
32 | # automatic updating of go.mod described above. Instead, it fails when any changes
33 | # to go.mod are needed. This setting is most useful to check that go.mod does
34 | # not need updates, such as in a continuous integration and testing system.
35 | # If invoked with -mod=vendor, the go command assumes that the vendor
36 | # directory holds the correct copies of dependencies and ignores
37 | # the dependency descriptions in go.mod.
38 | modules-download-mode: readonly
39 |
40 | # which files to skip: they will be analyzed, but issues from them
41 | # won't be reported. Default value is empty list, but there is
42 | # no need to include all autogenerated files, we confidently recognize
43 | # autogenerated files. If it's not please let us know.
44 | skip-files:
45 | # - ".*\\.my\\.go$"
46 | # - lib/bad.go
47 |
48 | # all available settings of specific linters
49 | linters-settings:
50 | misspell:
51 | # Correct spellings using locale preferences for US or UK.
52 | # Default is to use a neutral variety of English.
53 | # Setting locale to US will correct the British spelling of 'colour' to 'color'.
54 | # locale: US
55 | ignore-words:
56 | - rela # This is for elf.SHT_RELA
57 |
58 | issues:
59 | # Excluding configuration per-path, per-linter, per-text and per-source
60 | exclude-rules:
61 | - linters: [staticcheck]
62 | text: "SA1019" # this is rule for deprecated method
63 | - linters: [staticcheck]
64 | text: "SA9003: empty branch"
65 | - linters: [staticcheck]
66 | text: "SA2001: empty critical section"
67 | - linters: [goerr113]
68 | text: "do not define dynamic errors, use wrapped static errors instead" # This rule to avoid opinionated check fmt.Errorf("text")
69 | - path: _test\.go # Skip 1.13 errors check for test files
70 | linters:
71 | - goerr113
72 | linters:
73 | disable-all: true
74 | enable:
75 | - deadcode
76 | - goerr113
77 | - gofmt
78 | - goimports
79 | - gosimple
80 | - govet
81 | - ineffassign
82 | - misspell
83 | - staticcheck
84 | - stylecheck
85 | - varcheck
86 |
87 | # To enable later if makes sense
88 | # - errcheck
89 | # - gocyclo
90 | # - golint
91 | # - gosec
92 | # - gosimple
93 | # - lll
94 | # - maligned
95 | # - misspell
96 | # - prealloc
97 | # - structcheck
98 | # - typecheck
99 | # - unused
100 |
--------------------------------------------------------------------------------
/pkg/proxy/server/packetio_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2019 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | "bufio"
18 | "bytes"
19 | "net"
20 | "time"
21 |
22 | . "github.com/pingcap/check"
23 | "github.com/pingcap/parser/mysql"
24 | )
25 |
26 | type PacketIOTestSuite struct {
27 | }
28 |
29 | var _ = Suite(new(PacketIOTestSuite))
30 |
31 | func (s *PacketIOTestSuite) TestWrite(c *C) {
32 | // Test write one packet
33 | var outBuffer bytes.Buffer
34 | pkt := &packetIO{bufWriter: bufio.NewWriter(&outBuffer)}
35 | err := pkt.writePacket([]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03})
36 | c.Assert(err, IsNil)
37 | err = pkt.flush()
38 | c.Assert(err, IsNil)
39 | c.Assert(outBuffer.Bytes(), DeepEquals, []byte{0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03})
40 |
41 | // Test write more than one packet
42 | outBuffer.Reset()
43 | largeInput := make([]byte, mysql.MaxPayloadLen+4)
44 | pkt = &packetIO{bufWriter: bufio.NewWriter(&outBuffer)}
45 | err = pkt.writePacket(largeInput)
46 | c.Assert(err, IsNil)
47 | err = pkt.flush()
48 | c.Assert(err, IsNil)
49 | res := outBuffer.Bytes()
50 | c.Assert(res[0], Equals, byte(0xff))
51 | c.Assert(res[1], Equals, byte(0xff))
52 | c.Assert(res[2], Equals, byte(0xff))
53 | c.Assert(res[3], Equals, byte(0))
54 | }
55 |
56 | func (s *PacketIOTestSuite) TestRead(c *C) {
57 | var inBuffer bytes.Buffer
58 | _, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
59 | c.Assert(err, IsNil)
60 | // Test read one packet
61 | brc := newBufferedReadConn(&bytesConn{inBuffer})
62 | pkt := newPacketIO(brc)
63 | bytes, err := pkt.readPacket()
64 | c.Assert(err, IsNil)
65 | c.Assert(pkt.sequence, Equals, uint8(1))
66 | c.Assert(bytes, DeepEquals, []byte{0x01})
67 |
68 | inBuffer.Reset()
69 | buf := make([]byte, mysql.MaxPayloadLen+9)
70 | buf[0] = 0xff
71 | buf[1] = 0xff
72 | buf[2] = 0xff
73 | buf[3] = 0
74 | buf[2+mysql.MaxPayloadLen] = 0x00
75 | buf[3+mysql.MaxPayloadLen] = 0x00
76 | buf[4+mysql.MaxPayloadLen] = 0x01
77 | buf[7+mysql.MaxPayloadLen] = 0x01
78 | buf[8+mysql.MaxPayloadLen] = 0x0a
79 |
80 | _, err = inBuffer.Write(buf)
81 | c.Assert(err, IsNil)
82 | // Test read multiple packets
83 | brc = newBufferedReadConn(&bytesConn{inBuffer})
84 | pkt = newPacketIO(brc)
85 | bytes, err = pkt.readPacket()
86 | c.Assert(err, IsNil)
87 | c.Assert(pkt.sequence, Equals, uint8(2))
88 | c.Assert(len(bytes), Equals, mysql.MaxPayloadLen+1)
89 | c.Assert(bytes[mysql.MaxPayloadLen], DeepEquals, byte(0x0a))
90 | }
91 |
92 | type bytesConn struct {
93 | b bytes.Buffer
94 | }
95 |
96 | func (c *bytesConn) Read(b []byte) (n int, err error) {
97 | return c.b.Read(b)
98 | }
99 |
100 | func (c *bytesConn) Write(b []byte) (n int, err error) {
101 | return 0, nil
102 | }
103 |
104 | func (c *bytesConn) Close() error {
105 | return nil
106 | }
107 |
108 | func (c *bytesConn) LocalAddr() net.Addr {
109 | return nil
110 | }
111 |
112 | func (c *bytesConn) RemoteAddr() net.Addr {
113 | return nil
114 | }
115 |
116 | func (c *bytesConn) SetDeadline(t time.Time) error {
117 | return nil
118 | }
119 |
120 | func (c *bytesConn) SetReadDeadline(t time.Time) error {
121 | return nil
122 | }
123 |
124 | func (c *bytesConn) SetWriteDeadline(t time.Time) error {
125 | return nil
126 | }
127 |
--------------------------------------------------------------------------------
/docs/cn/namespace-config.md:
--------------------------------------------------------------------------------
1 | # Namespace配置详解
2 |
3 | ## 配置说明
4 |
5 | ### 客户端连接配置
6 |
7 | ```
8 | version: "v1"
9 | namespace: "test_namespace"
10 | frontend:
11 | allowed_dbs:
12 | - "test_weir_db"
13 | slow_sql_time: 50
14 | sql_blacklist:
15 | - sql: "select * from tbl0"
16 | - sql: "select * from tbl1"
17 | sql_whitelist:
18 | - sql: "select * from tbl2"
19 | - sql: "select * from tbl3"
20 | denied_ips:
21 | users:
22 | - username: "hello"
23 | password: "world"
24 | - username: "hello1"
25 | password: "world1"
26 | ```
27 |
28 | 字段说明
29 |
30 | | 配置 | 说明 |
31 | | --- | --- |
32 | | namespace | Namespace名称, 要求Proxy集群内唯一 |
33 | | frontend | 客户端连接相关配置 |
34 | | frontend.allowed_dbs | 客户端允许访问的Database列表 |
35 | | frontend.sql_blacklist | SQL黑名单列表 |
36 | | frontend.sql_whitelist | SQL白名单列表 |
37 | | frontend.denied_ips | 链接 ip 黑名单列表 |
38 | | frontend.users | 用户连接信息列表 |
39 | | frontend.users.username | 用户名 (要求Proxy集群内唯一) |
40 | | frontend.users.password | 密码 |
41 |
42 | ### 后端连接池配置
43 |
44 | ```
45 | backend:
46 | instances:
47 | - "127.0.0.1:3306"
48 | username: "root"
49 | password: "12344321"
50 | selector_type: "random"
51 | pool_size: 10
52 | idle_timeout: 60
53 | ```
54 |
55 | 字段说明
56 |
57 | | 配置 | 说明 |
58 | | --- | --- |
59 | | instances | TiDB Server实例地址列表 |
60 | | username | 连接TiDB Server用户名|
61 | | password | 连接TiDB Server密码 |
62 | | selector_type | 负载均衡策略, 目前只支持random |
63 | | pool_size | 连接池最大连接数 (针对每个TiDB Server) |
64 | | idle_timeout | 对 TIDB 连接池连接空闲超时关闭时间 (单位: 秒) |
65 |
66 | ### 熔断器配置
67 |
68 | 关于熔断的概念可以关注伴鱼技术团队的过往博客[点击了解熔断](https://tech.ipalfish.com/blog/2020/08/23/dolphin/)
69 |
70 | ```
71 | breaker:
72 | scope: "sql"
73 | strategies:
74 | - min_qps: 3
75 | failure_rate_threshold: 0
76 | failure_num: 5
77 | sql_timeout_ms: 2000
78 | open_status_duration_ms: 5000
79 | size: 10
80 | cell_interval_ms: 1000
81 | ```
82 |
83 | 字段说明
84 |
85 | | 配置 | 说明 |
86 | | --- | --- |
87 | | scope | 熔断器粒度, 支持参数: namespace, db, table, sql |
88 | | strategies | 熔断策略列表 |
89 | | strategies.min_qps | 熔断被能被触发的最小 QPS |
90 | | strategies.failure_rate_threshold | 需要达到可以熔断的错误率阈值百分数 (0~100) |
91 | | strategies.failure_num | 需要达到可以熔断的错误数阈值 (与failure_rate_threshold只能使用其一) |
92 | | strategies.sql_timeout_ms | SQL超时阈值, (单位: 毫秒) |
93 | | strategies.open_status_duration_ms | 熔断器开启状态持续时间 (单位: 毫秒) |
94 | | strategies.size| 滑动窗口计数器的统计单元数 |
95 | | strategies.cell_interval_ms | 每个单元的时长 (单位: 毫秒), 与size字段共同组成了熔断器时间统计的滑窗 |
96 |
97 | ### 限流器配置
98 |
99 | ```
100 | rate_limiter:
101 | scope: "db"
102 | qps: 1000
103 | ```
104 |
105 | 字段说明
106 |
107 | | 配置 | 说明 |
108 | | --- | --- |
109 | | scope | 限流器粒度, 支持参数: namespace, db, table |
110 | | qps | 限流QPS (超过阈值的请求会直接返回错误) |
111 |
112 |
113 | ## 完整配置示例
114 |
115 | ```
116 | version: "v1"
117 | namespace: "test_namespace"
118 | frontend:
119 | allowed_dbs:
120 | - "test_weir_db"
121 | slow_sql_time: 50
122 | sql_blacklist:
123 | - sql: "select * from tbl0"
124 | - sql: "select * from tbl1"
125 | sql_whitelist:
126 | - sql: "select * from tbl2"
127 | - sql: "select * from tbl3"
128 | denied_ips:
129 | users:
130 | - username: "hello"
131 | password: "world"
132 | - username: "hello1"
133 | password: "world1"
134 | backend:
135 | instances:
136 | - "127.0.0.1:4000"
137 | username: "root"
138 | password: ""
139 | selector_type: "random"
140 | pool_size: 10
141 | idle_timeout: 60
142 | breaker:
143 | scope: "sql"
144 | strategies:
145 | - min_qps: 3
146 | failure_rate_threshold: 0
147 | failure_num: 5
148 | sql_timeout_ms: 2000
149 | open_status_duration_ms: 5000
150 | size: 10
151 | cell_interval_ms: 1000
152 | rate_limiter:
153 | scope: "db"
154 | qps: 1000
155 | ```
156 |
--------------------------------------------------------------------------------
/pkg/configcenter/etcd.go:
--------------------------------------------------------------------------------
1 | package configcenter
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "path"
7 | "time"
8 |
9 | "github.com/tidb-incubator/weir/pkg/config"
10 | "github.com/pingcap/errors"
11 | "github.com/pingcap/tidb/util/logutil"
12 | "go.etcd.io/etcd/clientv3"
13 | "go.etcd.io/etcd/mvcc/mvccpb"
14 | "go.uber.org/zap"
15 | )
16 |
17 | const (
18 | DefaultEtcdDialTimeout = 3 * time.Second
19 | )
20 |
21 | type EtcdConfigCenter struct {
22 | etcdClient *clientv3.Client
23 | kv clientv3.KV
24 | basePath string
25 | strictParse bool
26 | }
27 |
28 | func CreateEtcdConfigCenter(cfg config.ConfigEtcd) (*EtcdConfigCenter, error) {
29 | etcdConfig := clientv3.Config{
30 | Endpoints: cfg.Addrs,
31 | Username: cfg.Username,
32 | Password: cfg.Password,
33 | DialTimeout: DefaultEtcdDialTimeout,
34 | }
35 | etcdClient, err := clientv3.New(etcdConfig)
36 | if err != nil {
37 | return nil, errors.WithMessage(err, "create etcd config center error")
38 | }
39 |
40 | center := NewEtcdConfigCenter(etcdClient, cfg.BasePath, cfg.StrictParse)
41 | return center, nil
42 | }
43 |
44 | func NewEtcdConfigCenter(etcdClient *clientv3.Client, basePath string, strictParse bool) *EtcdConfigCenter {
45 | return &EtcdConfigCenter{
46 | etcdClient: etcdClient,
47 | kv: clientv3.NewKV(etcdClient),
48 | basePath: basePath,
49 | strictParse: strictParse,
50 | }
51 | }
52 |
53 | func (e *EtcdConfigCenter) get(ctx context.Context, key string) (*mvccpb.KeyValue, error) {
54 | resp, err := e.kv.Get(ctx, getNamespacePath(e.basePath, key))
55 | if err != nil {
56 | return nil, err
57 | }
58 | if len(resp.Kvs) == 0 {
59 | return nil, fmt.Errorf("key not found")
60 | }
61 | return resp.Kvs[0], nil
62 | }
63 |
64 | func (e *EtcdConfigCenter) list(ctx context.Context) ([]*mvccpb.KeyValue, error) {
65 | baseDir := appendSlashToDirPath(e.basePath)
66 | resp, err := e.kv.Get(ctx, baseDir, clientv3.WithPrefix())
67 | if err != nil {
68 | return nil, err
69 | }
70 | return resp.Kvs, nil
71 | }
72 |
73 | func (e *EtcdConfigCenter) GetNamespace(ns string) (*config.Namespace, error) {
74 | ctx := context.Background()
75 | etcdKeyValue, err := e.get(ctx, ns)
76 | if err != nil {
77 | return nil, err
78 | }
79 |
80 | return config.UnmarshalNamespaceConfig(etcdKeyValue.Value)
81 | }
82 |
83 | func (e *EtcdConfigCenter) ListAllNamespace() ([]*config.Namespace, error) {
84 | ctx := context.Background()
85 | etcdKeyValues, err := e.list(ctx)
86 | if err != nil {
87 | return nil, err
88 | }
89 |
90 | var ret []*config.Namespace
91 | for _, kv := range etcdKeyValues {
92 | nsCfg, err := config.UnmarshalNamespaceConfig(kv.Value)
93 | if err != nil {
94 | if e.strictParse {
95 | return nil, err
96 | } else {
97 | logutil.BgLogger().Warn("parse namespace config error", zap.Error(err), zap.ByteString("namespace", kv.Key))
98 | continue
99 | }
100 | }
101 | ret = append(ret, nsCfg)
102 | }
103 |
104 | return ret, nil
105 | }
106 |
107 | func (e *EtcdConfigCenter) SetNamespace(ns string, value string) error {
108 | ctx := context.Background()
109 | _, err := e.kv.Put(ctx, ns, value)
110 | return err
111 | }
112 |
113 | func (e *EtcdConfigCenter) DelNamespace(ns string) error {
114 | ctx := context.Background()
115 | _, err := e.kv.Delete(ctx, ns)
116 | return err
117 | }
118 |
119 | func (e *EtcdConfigCenter) Close() {
120 | if err := e.etcdClient.Close(); err != nil {
121 | logutil.BgLogger().Error("close etcd client error", zap.Error(err))
122 | }
123 | }
124 |
125 | func getNamespacePath(basePath, ns string) string {
126 | return path.Join(basePath, ns)
127 | }
128 |
129 | // avoid base dir path prefix equal
130 | func appendSlashToDirPath(dir string) string {
131 | if len(dir) == 0 {
132 | return ""
133 | }
134 | if dir[len(dir)-1] == '/' {
135 | return dir
136 | }
137 | return dir + "/"
138 | }
139 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/resultset.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync/atomic"
7 |
8 | "github.com/tidb-incubator/weir/pkg/proxy/server"
9 | "github.com/pingcap/tidb/types"
10 | "github.com/pingcap/tidb/util/chunk"
11 | "github.com/pingcap/tidb/util/hack"
12 | "github.com/siddontang/go-mysql/mysql"
13 | )
14 |
15 | type weirResultSet struct {
16 | result *mysql.Result
17 | columnInfos []*server.ColumnInfo
18 | closed int32
19 | readed bool
20 | }
21 |
22 | func wrapMySQLResult(result *mysql.Result) *weirResultSet {
23 | resultSet := &weirResultSet{
24 | result: result,
25 | }
26 | columnInfos := convertFieldsToColumnInfos(result.Fields)
27 | resultSet.columnInfos = columnInfos
28 | return resultSet
29 | }
30 |
31 | func createBinaryPrepareColumns(cnt int) []*server.ColumnInfo {
32 | info := &server.ColumnInfo{}
33 | return copyColumnInfo(info, cnt)
34 | }
35 |
36 | func createBinaryPrepareParams(cnt int) []*server.ColumnInfo {
37 | info := &server.ColumnInfo{Name: "?"}
38 | return copyColumnInfo(info, cnt)
39 | }
40 |
41 | func copyColumnInfo(info *server.ColumnInfo, cnt int) []*server.ColumnInfo {
42 | ret := make([]*server.ColumnInfo, 0, cnt)
43 | for i := 0; i < cnt; i++ {
44 | ret = append(ret, info)
45 | }
46 | return ret
47 | }
48 |
49 | func convertFieldsToColumnInfos(fields []*mysql.Field) []*server.ColumnInfo {
50 | var columnInfos []*server.ColumnInfo
51 | for _, field := range fields {
52 | columnInfo := &server.ColumnInfo{
53 | Schema: string(hack.String(field.Schema)),
54 | Table: string(hack.String(field.Table)),
55 | OrgTable: string(hack.String(field.OrgTable)),
56 | Name: string(hack.String(field.Name)),
57 | OrgName: string(hack.String(field.OrgName)),
58 | ColumnLength: field.ColumnLength,
59 | Charset: field.Charset,
60 | Flag: field.Flag,
61 | Decimal: field.Decimal,
62 | Type: field.Type,
63 | DefaultValueLength: field.DefaultValueLength,
64 | DefaultValue: field.DefaultValue,
65 | }
66 | columnInfos = append(columnInfos, columnInfo)
67 | }
68 | return columnInfos
69 | }
70 |
71 | func convertFieldTypes(fields []*mysql.Field) []*types.FieldType {
72 | var ret []*types.FieldType
73 | for _, f := range fields {
74 | ft := types.NewFieldType(f.Type)
75 | ft.Flag = uint(f.Flag)
76 | ret = append(ret, ft)
77 | }
78 | return ret
79 | }
80 |
81 | func writeResultSetDataToTrunk(r *mysql.Resultset, c *chunk.Chunk) {
82 | for _, rowValue := range r.Values {
83 | for colIdx, colValue := range rowValue {
84 | switch colValue.Type {
85 | case mysql.FieldValueTypeNull:
86 | c.Column(colIdx).AppendNull()
87 | case mysql.FieldValueTypeUnsigned:
88 | c.Column(colIdx).AppendUint64(colValue.AsUint64())
89 | case mysql.FieldValueTypeSigned:
90 | c.Column(colIdx).AppendInt64(colValue.AsInt64())
91 | case mysql.FieldValueTypeFloat:
92 | c.Column(colIdx).AppendFloat64(colValue.AsFloat64())
93 | case mysql.FieldValueTypeString:
94 | c.Column(colIdx).AppendBytes(colValue.AsString())
95 | default:
96 | panic(fmt.Errorf("invalid col value type: %v", colValue.Type))
97 | }
98 | }
99 | }
100 | }
101 |
102 | func (w *weirResultSet) Columns() []*server.ColumnInfo {
103 | return w.columnInfos
104 | }
105 |
106 | func (w *weirResultSet) NewChunk() *chunk.Chunk {
107 | columns := convertFieldTypes(w.result.Fields)
108 | rowCount := len(w.result.RowDatas)
109 | c := chunk.NewChunkWithCapacity(columns, rowCount)
110 | writeResultSetDataToTrunk(w.result.Resultset, c)
111 | return c
112 | }
113 |
114 | func (w *weirResultSet) Next(ctx context.Context, c *chunk.Chunk) error {
115 | // all the data has been converted and set into chunk when calling NewChunk(),
116 | // so we need to do nothing here
117 | if w.readed {
118 | c.Reset()
119 | }
120 | w.readed = true
121 | return nil
122 | }
123 |
124 | func (*weirResultSet) StoreFetchedRows(rows []chunk.Row) {
125 | panic("implement me")
126 | }
127 |
128 | func (*weirResultSet) GetFetchedRows() []chunk.Row {
129 | panic("implement me")
130 | }
131 |
132 | func (w *weirResultSet) Close() error {
133 | atomic.StoreInt32(&w.closed, 1)
134 | return nil
135 | }
136 |
--------------------------------------------------------------------------------
/pkg/util/timer/time_wheel.go:
--------------------------------------------------------------------------------
1 | package timer
2 |
3 | import (
4 | "errors"
5 | "time"
6 | )
7 |
8 | // Task means handle unit in time wheel
9 | type Task struct {
10 | delay time.Duration
11 | key interface{}
12 | round int // optimize time wheel to handle delay bigger than bucketsNum * tick
13 | callback func()
14 | }
15 |
16 | // TimeWheel means time wheel
17 | type TimeWheel struct {
18 | tick time.Duration
19 | ticker *time.Ticker
20 |
21 | bucketsNum int
22 | buckets []map[interface{}]*Task // key: added item, value: *Task
23 | bucketIndexes map[interface{}]int // key: added item, value: bucket position
24 |
25 | currentIndex int
26 |
27 | addC chan *Task
28 | removeC chan interface{}
29 | stopC chan struct{}
30 | }
31 |
32 | // NewTimeWheel create new time wheel
33 | func NewTimeWheel(tick time.Duration, bucketsNum int) (*TimeWheel, error) {
34 | if bucketsNum <= 0 {
35 | return nil, errors.New("bucket number must be greater than 0")
36 | }
37 | // if int(tick.Seconds()) < 1 {
38 | // return nil, errors.New("tick cannot be less than 1s")
39 | // }
40 |
41 | tw := &TimeWheel{
42 | tick: tick,
43 | bucketsNum: bucketsNum,
44 | bucketIndexes: make(map[interface{}]int, 1024),
45 | buckets: make([]map[interface{}]*Task, bucketsNum),
46 | currentIndex: 0,
47 | addC: make(chan *Task, 1024),
48 | removeC: make(chan interface{}, 1024),
49 | stopC: make(chan struct{}),
50 | }
51 |
52 | for i := 0; i < bucketsNum; i++ {
53 | tw.buckets[i] = make(map[interface{}]*Task, 16)
54 | }
55 |
56 | return tw, nil
57 | }
58 |
59 | // Start start the time wheel
60 | func (tw *TimeWheel) Start() {
61 | tw.ticker = time.NewTicker(tw.tick)
62 | go tw.start()
63 | }
64 |
65 | func (tw *TimeWheel) start() {
66 | for {
67 | select {
68 | case <-tw.ticker.C:
69 | tw.handleTick()
70 | case task := <-tw.addC:
71 | tw.add(task)
72 | case key := <-tw.removeC:
73 | tw.remove(key)
74 | case <-tw.stopC:
75 | tw.ticker.Stop()
76 | return
77 | }
78 | }
79 | }
80 |
81 | // Stop stop the time wheel
82 | func (tw *TimeWheel) Stop() {
83 | tw.stopC <- struct{}{}
84 | }
85 |
86 | func (tw *TimeWheel) handleTick() {
87 | bucket := tw.buckets[tw.currentIndex]
88 | for k := range bucket {
89 | if bucket[k].round > 0 {
90 | bucket[k].round--
91 | continue
92 | }
93 | go bucket[k].callback()
94 | delete(bucket, k)
95 | delete(tw.bucketIndexes, k)
96 | }
97 | if tw.currentIndex == tw.bucketsNum-1 {
98 | tw.currentIndex = 0
99 | return
100 | }
101 | tw.currentIndex++
102 | }
103 |
104 | // Add add an item into time wheel
105 | func (tw *TimeWheel) Add(delay time.Duration, key interface{}, callback func()) error {
106 | if delay <= 0 || key == nil {
107 | return errors.New("invalid params")
108 | }
109 | tw.addC <- &Task{delay: delay, key: key, callback: callback}
110 | return nil
111 | }
112 |
113 | func (tw *TimeWheel) add(task *Task) {
114 | round := tw.calculateRound(task.delay)
115 | index := tw.calculateIndex(task.delay)
116 | task.round = round
117 | if originIndex, ok := tw.bucketIndexes[task.key]; ok {
118 | delete(tw.buckets[originIndex], task.key)
119 | }
120 | tw.bucketIndexes[task.key] = index
121 | tw.buckets[index][task.key] = task
122 | }
123 |
124 | func (tw *TimeWheel) calculateRound(delay time.Duration) (round int) {
125 | delaySeconds := int(delay.Milliseconds())
126 | tickSeconds := int(tw.tick.Milliseconds())
127 | round = delaySeconds / tickSeconds / tw.bucketsNum
128 | return
129 | }
130 |
131 | func (tw *TimeWheel) calculateIndex(delay time.Duration) (index int) {
132 | delaySeconds := int(delay.Milliseconds())
133 | tickSeconds := int(tw.tick.Milliseconds())
134 | index = (tw.currentIndex + delaySeconds/tickSeconds) % tw.bucketsNum
135 | return
136 | }
137 |
138 | // Remove remove an item from time wheel
139 | func (tw *TimeWheel) Remove(key interface{}) error {
140 | if key == nil {
141 | return errors.New("invalid params")
142 | }
143 | tw.removeC <- key
144 | return nil
145 | }
146 |
147 | // don't need to call callback
148 | func (tw *TimeWheel) remove(key interface{}) {
149 | if index, ok := tw.bucketIndexes[key]; ok {
150 | delete(tw.bucketIndexes, key)
151 | delete(tw.buckets[index], key)
152 | }
153 | return
154 | }
155 |
--------------------------------------------------------------------------------
/pkg/util/timer/timer.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | // Package timer provides various enhanced timer functions.
18 | package timer
19 |
20 | import (
21 | "sync"
22 | "time"
23 |
24 | "github.com/tidb-incubator/weir/pkg/util/sync2"
25 | )
26 |
27 | // Out-of-band messages
28 | type typeAction int
29 |
30 | const (
31 | timerStop typeAction = iota
32 | timerReset
33 | timerTrigger
34 | )
35 |
36 | /*
37 | Timer provides timer functionality that can be controlled
38 | by the user. You start the timer by providing it a callback function,
39 | which it will call at the specified interval.
40 |
41 | var t = timer.NewTimer(1e9)
42 | t.Start(KeepHouse)
43 |
44 | func KeepHouse() {
45 | // do house keeping work
46 | }
47 |
48 | You can stop the timer by calling t.Stop, which is guaranteed to
49 | wait if KeepHouse is being executed.
50 |
51 | You can create an untimely trigger by calling t.Trigger. You can also
52 | schedule an untimely trigger by calling t.TriggerAfter.
53 |
54 | The timer interval can be changed on the fly by calling t.SetInterval.
55 | A zero value interval will cause the timer to wait indefinitely, and it
56 | will react only to an explicit Trigger or Stop.
57 | */
58 | type Timer struct {
59 | interval sync2.AtomicDuration
60 |
61 | // state management
62 | mu sync.Mutex
63 | running bool
64 |
65 | // msg is used for out-of-band messages
66 | msg chan typeAction
67 | }
68 |
69 | // NewTimer creates a new Timer object
70 | func NewTimer(interval time.Duration) *Timer {
71 | tm := &Timer{
72 | msg: make(chan typeAction),
73 | }
74 | tm.interval.Set(interval)
75 | return tm
76 | }
77 |
78 | // Start starts the timer.
79 | func (tm *Timer) Start(keephouse func()) {
80 | tm.mu.Lock()
81 | defer tm.mu.Unlock()
82 | if tm.running {
83 | return
84 | }
85 | tm.running = true
86 | go tm.run(keephouse)
87 | }
88 |
89 | func (tm *Timer) run(keephouse func()) {
90 | var timer *time.Timer
91 | for {
92 | var ch <-chan time.Time
93 | interval := tm.interval.Get()
94 | if interval > 0 {
95 | timer = time.NewTimer(interval)
96 | ch = timer.C
97 | }
98 | select {
99 | case action := <-tm.msg:
100 | if timer != nil {
101 | timer.Stop()
102 | timer = nil
103 | }
104 | switch action {
105 | case timerStop:
106 | return
107 | case timerReset:
108 | continue
109 | }
110 | case <-ch:
111 | }
112 | keephouse()
113 | }
114 | }
115 |
116 | // SetInterval changes the wait interval.
117 | // It will cause the timer to restart the wait.
118 | func (tm *Timer) SetInterval(ns time.Duration) {
119 | tm.interval.Set(ns)
120 | tm.mu.Lock()
121 | defer tm.mu.Unlock()
122 | if tm.running {
123 | tm.msg <- timerReset
124 | }
125 | }
126 |
127 | // Trigger will cause the timer to immediately execute the keephouse function.
128 | // It will then cause the timer to restart the wait.
129 | func (tm *Timer) Trigger() {
130 | tm.mu.Lock()
131 | defer tm.mu.Unlock()
132 | if tm.running {
133 | tm.msg <- timerTrigger
134 | }
135 | }
136 |
137 | // TriggerAfter waits for the specified duration and triggers the next event.
138 | func (tm *Timer) TriggerAfter(duration time.Duration) {
139 | go func() {
140 | time.Sleep(duration)
141 | tm.Trigger()
142 | }()
143 | }
144 |
145 | // Stop will stop the timer. It guarantees that the timer will not execute
146 | // any more calls to keephouse once it has returned.
147 | func (tm *Timer) Stop() {
148 | tm.mu.Lock()
149 | defer tm.mu.Unlock()
150 | if tm.running {
151 | tm.msg <- timerStop
152 | tm.running = false
153 | }
154 | }
155 |
156 | // Interval returns the current interval.
157 | func (tm *Timer) Interval() time.Duration {
158 | return tm.interval.Get()
159 | }
160 |
161 | func (tm *Timer) Running() bool {
162 | tm.mu.Lock()
163 | defer tm.mu.Unlock()
164 | return tm.running
165 | }
166 |
--------------------------------------------------------------------------------
/pkg/util/rate_limit_breaker/circuit_breaker/circuit_breaker_test.go:
--------------------------------------------------------------------------------
1 | package circuit_breaker
2 |
3 | import (
4 | "context"
5 | "errors"
6 | rateLimitBreaker "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker"
7 | "github.com/stretchr/testify/assert"
8 | "testing"
9 | )
10 |
11 | func TestCircuitBreaker_Do_NoError(t *testing.T) {
12 | ctx := context.Background()
13 | cb := NewCircuitBreaker(&CircuitBreakerConfig{
14 | minQPS: 10,
15 | failureRateThreshold: 10,
16 | failureNum: 5,
17 | OpenStatusDurationMs: 10000, // 10s
18 | forceOpen: false,
19 | size: 10,
20 | cellIntervalMs: 1000,
21 | })
22 |
23 | successCount := 0
24 | errCount := 0
25 | for i := 0; i < 1000; i++ {
26 | cb.Do(ctx, func(ctx context.Context) error {
27 | successCount++
28 | return nil
29 | }, func(ctx context.Context, err error) error {
30 | errCount++
31 | return err
32 | })
33 | }
34 | assert.Equal(t, successCount, 1000)
35 | assert.Equal(t, errCount, 0)
36 | assert.Equal(t, cb.Status(), CircuitBreakerStatusOpen)
37 | }
38 |
39 | func TestCircuitBreaker_Do_AlwaysError(t *testing.T) {
40 | ctx := context.Background()
41 | cb := NewCircuitBreaker(&CircuitBreakerConfig{
42 | minQPS: 10,
43 | failureRateThreshold: 10,
44 | OpenStatusDurationMs: 10000, // 10s
45 | forceOpen: false,
46 | })
47 |
48 | successCount := 0
49 | errCount := 0
50 | for i := 0; i < 1000; i++ {
51 | cb.Do(ctx, func(ctx context.Context) error {
52 | successCount++
53 | return errors.New("just_error")
54 | }, func(ctx context.Context, err error) error {
55 | errCount++
56 | return err
57 | })
58 | }
59 |
60 | assert.True(t, successCount < 1000)
61 | assert.Equal(t, errCount, 1000)
62 | assert.Equal(t, cb.Status(), CircuitBreakerStatusOpen)
63 | }
64 |
65 | func TestCircuitBreaker_ForceOpen(t *testing.T) {
66 | ctx := context.Background()
67 | cb := NewCircuitBreaker(&CircuitBreakerConfig{
68 | minQPS: 10,
69 | failureRateThreshold: 10,
70 | OpenStatusDurationMs: 10000, // 10s
71 | forceOpen: true,
72 | })
73 |
74 | errCount := 0
75 | for i := 0; i < 1000; i++ {
76 | cb.Do(ctx, func(ctx context.Context) error {
77 | return nil
78 | }, func(ctx context.Context, err error) error {
79 | errCount++
80 | return nil
81 | })
82 | }
83 | assert.Equal(t, errCount, 1000)
84 | assert.Equal(t, cb.Status(), CircuitBreakerStatusForceOpen)
85 | }
86 |
87 | func TestCircuitBreaker_ChangeConfig_WithForceOpen(t *testing.T) {
88 | cb := NewCircuitBreaker(&CircuitBreakerConfig{
89 | minQPS: 10,
90 | failureRateThreshold: 10,
91 | OpenStatusDurationMs: 10000, // 10s
92 | forceOpen: true,
93 | })
94 | assert.Equal(t, cb.status, CircuitBreakerStatusForceOpen)
95 | assert.Equal(t, cb.Status(), CircuitBreakerStatusForceOpen)
96 |
97 | // cancel forceOpen, status goes back to closed
98 | cb.ChangeConfig(&CircuitBreakerConfig{
99 | minQPS: 10,
100 | failureRateThreshold: 10,
101 | OpenStatusDurationMs: 10000,
102 | forceOpen: false,
103 | })
104 | assert.Equal(t, cb.status, CircuitBreakerStatusClosed)
105 | assert.Equal(t, cb.Status(), CircuitBreakerStatusClosed)
106 | }
107 |
108 | func TestCircuitBreaker_ChangeConfig_WithoutForceOpen(t *testing.T) {
109 | cb := NewCircuitBreaker(&CircuitBreakerConfig{
110 | minQPS: 10,
111 | failureRateThreshold: 10,
112 | OpenStatusDurationMs: 10000, // 10s
113 | forceOpen: false,
114 | })
115 |
116 | // 当前为 open 状态,然后修改配置(没有强制开启),检查仍为 open 状态
117 | cb.status = CircuitBreakerStatusOpen
118 | cb.openStartMs = rateLimitBreaker.GetNowMs()
119 | cb.ChangeConfig(&CircuitBreakerConfig{
120 | minQPS: 10,
121 | failureRateThreshold: 10,
122 | OpenStatusDurationMs: 100000, // 100s
123 | forceOpen: false,
124 | })
125 | assert.Equal(t, cb.status, CircuitBreakerStatusOpen)
126 | assert.Equal(t, cb.Status(), CircuitBreakerStatusOpen)
127 |
128 | // 当前为 closed 状态,然后修改配置(没有强制开启),检查仍为 closed 状态
129 | cb.status = CircuitBreakerStatusClosed
130 | cb.ChangeConfig(&CircuitBreakerConfig{
131 | minQPS: 10,
132 | failureRateThreshold: 10,
133 | OpenStatusDurationMs: 10000, // 10s
134 | forceOpen: false,
135 | })
136 | assert.Equal(t, cb.status, CircuitBreakerStatusClosed)
137 | assert.Equal(t, cb.Status(), CircuitBreakerStatusClosed)
138 | }
139 |
--------------------------------------------------------------------------------
/pkg/proxy/metrics/queryctx.go:
--------------------------------------------------------------------------------
1 | package metrics
2 |
3 | import (
4 | "github.com/pingcap/parser/ast"
5 | "github.com/prometheus/client_golang/prometheus"
6 | )
7 |
8 | type AstStmtType int
9 |
10 | const (
11 | StmtTypeUnknown AstStmtType = iota
12 | StmtTypeSelect
13 | StmtTypeInsert
14 | StmtTypeUpdate
15 | StmtTypeDelete
16 | StmtTypeDDL
17 | StmtTypeBegin
18 | StmtTypeCommit
19 | StmtTypeRollback
20 | StmtTypeSet
21 | StmtTypeShow
22 | StmtTypeUse
23 | StmtTypeComment
24 | )
25 |
26 | const (
27 | StmtNameUnknown = "unknown"
28 | StmtNameSelect = "select"
29 | StmtNameInsert = "insert"
30 | StmtNameUpdate = "update"
31 | StmtNameDelete = "delete"
32 | StmtNameDDL = "ddl"
33 | StmtNameBegin = "begin"
34 | StmtNameCommit = "commit"
35 | StmtNameRollback = "rollback"
36 | StmtNameSet = "set"
37 | StmtNameShow = "show"
38 | StmtNameUse = "use"
39 | StmtNameComment = "comment"
40 | )
41 |
42 | var (
43 | QueryCtxQueryCounter = prometheus.NewCounterVec(
44 | prometheus.CounterOpts{
45 | Namespace: ModuleWeirProxy,
46 | Subsystem: LabelQueryCtx,
47 | Name: "query_total",
48 | Help: "Counter of queries.",
49 | }, []string{LblCluster, LblNamespace, LblDb, LblTable, LblSQLType, LblResult})
50 |
51 | QueryCtxQueryDeniedCounter = prometheus.NewCounterVec(
52 | prometheus.CounterOpts{
53 | Namespace: ModuleWeirProxy,
54 | Subsystem: LabelQueryCtx,
55 | Name: "query_denied",
56 | Help: "Counter of denied queries.",
57 | }, []string{LblCluster, LblNamespace, LblDb, LblTable, LblSQLType})
58 |
59 | QueryCtxQueryDurationHistogram = prometheus.NewHistogramVec(
60 | prometheus.HistogramOpts{
61 | Namespace: ModuleWeirProxy,
62 | Subsystem: LabelQueryCtx,
63 | Name: "handle_query_duration_seconds",
64 | Help: "Bucketed histogram of processing time (s) of handled queries.",
65 | Buckets: prometheus.ExponentialBuckets(0.0005, 2, 29), // 0.5ms ~ 1.5days
66 | }, []string{LblCluster, LblNamespace, LblDb, LblTable, LblSQLType})
67 |
68 | QueryCtxGauge = prometheus.NewGaugeVec(
69 | prometheus.GaugeOpts{
70 | Namespace: ModuleWeirProxy,
71 | Subsystem: LabelQueryCtx,
72 | Name: "queryctx",
73 | Help: "Number of queryctx (equals to client connection).",
74 | }, []string{LblCluster, LblNamespace})
75 |
76 | QueryCtxAttachedConnGauge = prometheus.NewGaugeVec(
77 | prometheus.GaugeOpts{
78 | Namespace: ModuleWeirProxy,
79 | Subsystem: LabelQueryCtx,
80 | Name: "attached_connections",
81 | Help: "Number of attached backend connections.",
82 | }, []string{LblCluster, LblNamespace})
83 |
84 | QueryCtxTransactionDuration = prometheus.NewHistogramVec(
85 | prometheus.HistogramOpts{
86 | Namespace: "tidb",
87 | Subsystem: "session",
88 | Name: "transaction_duration_seconds",
89 | Help: "Bucketed histogram of a transaction execution duration, including retry.",
90 | Buckets: prometheus.ExponentialBuckets(0.001, 2, 28), // 1ms ~ 1.5days
91 | }, []string{LblCluster, LblNamespace, LblDb, LblSQLType})
92 | )
93 |
94 | func GetStmtType(stmt ast.StmtNode) AstStmtType {
95 | switch stmt.(type) {
96 | case *ast.SelectStmt:
97 | return StmtTypeSelect
98 | case *ast.InsertStmt:
99 | return StmtTypeInsert
100 | case *ast.UpdateStmt:
101 | return StmtTypeUpdate
102 | case *ast.DeleteStmt:
103 | return StmtTypeDelete
104 | case *ast.BeginStmt:
105 | return StmtTypeBegin
106 | case *ast.CommitStmt:
107 | return StmtTypeCommit
108 | case *ast.RollbackStmt:
109 | return StmtTypeRollback
110 | case *ast.SetStmt:
111 | return StmtTypeSet
112 | case *ast.ShowStmt:
113 | return StmtTypeShow
114 | case *ast.UseStmt:
115 | return StmtTypeUse
116 | default:
117 | return StmtTypeUnknown
118 | }
119 | }
120 |
121 | func GetStmtTypeName(stmt ast.StmtNode) string {
122 | switch stmt.(type) {
123 | case *ast.SelectStmt:
124 | return StmtNameSelect
125 | case *ast.InsertStmt:
126 | return StmtNameInsert
127 | case *ast.UpdateStmt:
128 | return StmtNameUpdate
129 | case *ast.DeleteStmt:
130 | return StmtNameDelete
131 | case *ast.BeginStmt:
132 | return StmtNameBegin
133 | case *ast.CommitStmt:
134 | return StmtNameCommit
135 | case *ast.RollbackStmt:
136 | return StmtNameRollback
137 | case *ast.SetStmt:
138 | return StmtNameSet
139 | case *ast.ShowStmt:
140 | return StmtNameShow
141 | case *ast.UseStmt:
142 | return StmtNameUse
143 | default:
144 | return StmtNameUnknown
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/builder.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "hash/crc32"
5 | "time"
6 |
7 | "github.com/tidb-incubator/weir/pkg/config"
8 | "github.com/tidb-incubator/weir/pkg/proxy/backend"
9 | "github.com/tidb-incubator/weir/pkg/proxy/driver"
10 | wast "github.com/tidb-incubator/weir/pkg/util/ast"
11 | "github.com/tidb-incubator/weir/pkg/util/datastructure"
12 | "github.com/pingcap/errors"
13 | "github.com/pingcap/parser"
14 | )
15 |
16 | type NamespaceImpl struct {
17 | name string
18 | Br driver.Breaker
19 | Backend
20 | Frontend
21 | rateLimiter *NamespaceRateLimiter
22 | }
23 |
24 | func BuildNamespace(cfg *config.Namespace) (Namespace, error) {
25 | be, err := BuildBackend(cfg.Namespace, &cfg.Backend)
26 | if err != nil {
27 | return nil, errors.WithMessage(err, "build backend error")
28 | }
29 | fe, err := BuildFrontend(&cfg.Frontend)
30 | if err != nil {
31 | return nil, errors.WithMessage(err, "build frontend error")
32 | }
33 | wrapper := &NamespaceImpl{
34 | name: cfg.Namespace,
35 | Backend: be,
36 | Frontend: fe,
37 | }
38 | brm, err := NewBreaker(&cfg.Breaker)
39 | if err != nil {
40 | return nil, err
41 | }
42 | br, err := brm.GetBreaker()
43 | if err != nil {
44 | return nil, err
45 | }
46 | wrapper.Br = br
47 |
48 | rateLimiter := NewNamespaceRateLimiter(cfg.RateLimiter.Scope, cfg.RateLimiter.QPS)
49 | wrapper.rateLimiter = rateLimiter
50 |
51 | return wrapper, nil
52 | }
53 |
54 | func (n *NamespaceImpl) Name() string {
55 | return n.name
56 | }
57 |
58 | func (n *NamespaceImpl) GetBreaker() (driver.Breaker, error) {
59 | return n.Br, nil
60 | }
61 |
62 | func (n *NamespaceImpl) GetRateLimiter() driver.RateLimiter {
63 | return n.rateLimiter
64 | }
65 |
66 | func BuildBackend(ns string, cfg *config.BackendNamespace) (Backend, error) {
67 | bcfg, err := parseBackendConfig(cfg)
68 | if err != nil {
69 | return nil, err
70 | }
71 |
72 | b := backend.NewBackendImpl(ns, bcfg)
73 | if err := b.Init(); err != nil {
74 | return nil, err
75 | }
76 |
77 | return b, nil
78 | }
79 |
80 | func BuildFrontend(cfg *config.FrontendNamespace) (Frontend, error) {
81 | fns := &FrontendNamespace{
82 | allowedDBs: cfg.AllowedDBs,
83 | }
84 | fns.allowedDBSet = datastructure.StringSliceToSet(cfg.AllowedDBs)
85 |
86 | userPasswds := make(map[string]string)
87 | for _, u := range cfg.Users {
88 | userPasswds[u.Username] = u.Password
89 | }
90 | fns.userPasswd = userPasswds
91 |
92 | sqlBlacklist := make(map[uint32]SQLInfo)
93 | fns.sqlBlacklist = sqlBlacklist
94 |
95 | p := parser.New()
96 | for _, deniedSQL := range cfg.SQLBlackList {
97 | stmtNodes, _, err := p.Parse(deniedSQL.SQL, "", "")
98 | if err != nil {
99 | return nil, err
100 | }
101 | if len(stmtNodes) != 1 {
102 | return nil, nil
103 | }
104 | v, err := wast.ExtractAstVisit(stmtNodes[0])
105 | if err != nil {
106 | return nil, err
107 | }
108 | fns.sqlBlacklist[crc32.ChecksumIEEE([]byte(v.SqlFeature()))] = SQLInfo{SQL: deniedSQL.SQL}
109 | }
110 |
111 | sqlWhitelist := make(map[uint32]SQLInfo)
112 | fns.sqlWhitelist = sqlWhitelist
113 | for _, allowedSQL := range cfg.SQLWhiteList {
114 | stmtNodes, _, err := p.Parse(allowedSQL.SQL, "", "")
115 | if err != nil {
116 | return nil, err
117 | }
118 | if len(stmtNodes) != 1 {
119 | return nil, nil
120 | }
121 | v, err := wast.ExtractAstVisit(stmtNodes[0])
122 | if err != nil {
123 | return nil, err
124 | }
125 | fns.sqlWhitelist[crc32.ChecksumIEEE([]byte(v.SqlFeature()))] = SQLInfo{SQL: allowedSQL.SQL}
126 | }
127 |
128 | return fns, nil
129 | }
130 |
131 | func parseBackendConfig(cfg *config.BackendNamespace) (*backend.BackendConfig, error) {
132 | selectorType, valid := backend.SelectorNameToType(cfg.SelectorType)
133 | if !valid {
134 | return nil, ErrInvalidSelectorType
135 | }
136 |
137 | addrs := make(map[string]struct{})
138 | for _, ins := range cfg.Instances {
139 | addrs[ins] = struct{}{}
140 | }
141 |
142 | bcfg := &backend.BackendConfig{
143 | Addrs: addrs,
144 | UserName: cfg.Username,
145 | Password: cfg.Password,
146 | Capacity: cfg.PoolSize,
147 | IdleTimeout: time.Duration(cfg.IdleTimeout) * time.Second,
148 | SelectorType: selectorType,
149 | }
150 | return bcfg, nil
151 | }
152 |
153 | func DefaultAsyncCloseNamespace(ns Namespace) error {
154 | nsWrapper, ok := ns.(*NamespaceImpl)
155 | if !ok {
156 | return errors.Errorf("invalid namespace type: %T", ns)
157 | }
158 | go func() {
159 | time.Sleep(30 * time.Second)
160 | //nsWrapper.BreakerHolder.CloseBreaker()
161 | nsWrapper.Backend.Close()
162 | }()
163 | return nil
164 | }
165 |
--------------------------------------------------------------------------------
/pkg/proxy/namespace/manager.go:
--------------------------------------------------------------------------------
1 | package namespace
2 |
3 | import (
4 | "sync"
5 |
6 | "github.com/pingcap/errors"
7 | "github.com/pingcap/tidb/util/logutil"
8 | "github.com/tidb-incubator/weir/pkg/config"
9 | "github.com/tidb-incubator/weir/pkg/proxy/driver"
10 | "github.com/tidb-incubator/weir/pkg/util/sync2"
11 | "go.uber.org/zap"
12 | )
13 |
14 | type NamespaceManager struct {
15 | switchIndex sync2.BoolIndex
16 | users [2]*UserNamespaceMapper
17 | nss [2]*NamespaceHolder
18 | build NamespaceBuilder
19 | close NamespaceCloser
20 |
21 | reloadLock sync.Mutex
22 | reloadPrepared map[string]bool
23 | }
24 |
25 | type NamespaceBuilder func(cfg *config.Namespace) (Namespace, error)
26 | type NamespaceCloser func(ns Namespace) error
27 |
28 | func CreateNamespaceManager(cfgs []*config.Namespace, builder NamespaceBuilder, closer NamespaceCloser) (*NamespaceManager, error) {
29 | users, err := CreateUserNamespaceMapper(cfgs)
30 | if err != nil {
31 | return nil, errors.WithMessage(err, "create UserNamespaceMapper error")
32 | }
33 |
34 | nss, err := CreateNamespaceHolder(cfgs, builder)
35 | if err != nil {
36 | return nil, errors.WithMessage(err, "create NamespaceHolder error")
37 | }
38 |
39 | mgr := NewNamespaceManager(users, nss, builder, closer)
40 | return mgr, nil
41 | }
42 |
43 | func NewNamespaceManager(users *UserNamespaceMapper, nss *NamespaceHolder, builder NamespaceBuilder, closer NamespaceCloser) *NamespaceManager {
44 | mgr := &NamespaceManager{
45 | build: builder,
46 | close: closer,
47 | reloadPrepared: make(map[string]bool),
48 | }
49 | mgr.users[0] = users
50 | mgr.nss[0] = nss
51 | return mgr
52 | }
53 |
54 | func (n *NamespaceManager) Auth(username string, pwd, salt []byte) (driver.Namespace, bool) {
55 | nsName, ok := n.getNamespaceByUsername(username)
56 | if !ok {
57 | return nil, false
58 | }
59 |
60 | wrapper := &NamespaceWrapper{
61 | nsmgr: n,
62 | name: nsName,
63 | }
64 |
65 | return wrapper, true
66 | }
67 |
68 | func (n *NamespaceManager) PrepareReloadNamespace(namespace string, cfg *config.Namespace) error {
69 | n.reloadLock.Lock()
70 | defer n.reloadLock.Unlock()
71 |
72 | newUsers := n.getCurrentUsers().Clone()
73 | newUsers.RemoveNamespaceUsers(namespace)
74 | if err := newUsers.AddNamespaceUsers(namespace, &cfg.Frontend); err != nil {
75 | return errors.WithMessage(err, "add namespace users error")
76 | }
77 |
78 | newNs, err := n.build(cfg)
79 | if err != nil {
80 | return errors.WithMessage(err, "build namespace error")
81 | }
82 |
83 | newNss := n.getCurrentNamespaces().Clone()
84 | newNss.Set(namespace, newNs)
85 |
86 | n.setOther(newUsers, newNss)
87 | n.reloadPrepared[namespace] = true
88 |
89 | return nil
90 | }
91 |
92 | func (n *NamespaceManager) CommitReloadNamespaces(namespaces []string) error {
93 | n.reloadLock.Lock()
94 | defer n.reloadLock.Unlock()
95 |
96 | for _, namespace := range namespaces {
97 | if !n.reloadPrepared[namespace] {
98 | return errors.Errorf("namespace is not prepared: %s", namespace)
99 | }
100 | }
101 |
102 | n.toggle()
103 | return nil
104 | }
105 |
106 | func (n *NamespaceManager) RemoveNamespace(name string) {
107 | n.reloadLock.Lock()
108 | defer n.reloadLock.Unlock()
109 |
110 | n.getCurrentUsers().RemoveNamespaceUsers(name)
111 | nss := n.getCurrentNamespaces()
112 | ns, ok := nss.Get(name)
113 | if !ok {
114 | return
115 | }
116 |
117 | if err := n.close(ns); err != nil {
118 | logutil.BgLogger().Error("remove namespace error", zap.Error(err), zap.String("namespace", name))
119 | return
120 | }
121 |
122 | nss.Delete(name)
123 | }
124 |
125 | func (n *NamespaceManager) getNamespaceByUsername(username string) (string, bool) {
126 | return n.getCurrentUsers().GetUserNamespace(username)
127 | }
128 |
129 | func (n *NamespaceManager) getCurrent() (*UserNamespaceMapper, *NamespaceHolder) {
130 | current, _, _ := n.switchIndex.Get()
131 | return n.users[current], n.nss[current]
132 | }
133 |
134 | func (n *NamespaceManager) getCurrentUsers() *UserNamespaceMapper {
135 | current, _, _ := n.switchIndex.Get()
136 | return n.users[current]
137 | }
138 |
139 | func (n *NamespaceManager) getCurrentNamespaces() *NamespaceHolder {
140 | current, _, _ := n.switchIndex.Get()
141 | return n.nss[current]
142 | }
143 |
144 | func (n *NamespaceManager) setOther(users *UserNamespaceMapper, nss *NamespaceHolder) {
145 | _, other, _ := n.switchIndex.Get()
146 | n.users[other] = users
147 | n.nss[other] = nss
148 | }
149 |
150 | func (n *NamespaceManager) toggle() {
151 | _, _, currentFlag := n.switchIndex.Get()
152 | n.switchIndex.Set(!currentFlag)
153 | }
154 |
--------------------------------------------------------------------------------
/pkg/proxy/server/column_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2019 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | . "github.com/pingcap/check"
18 | "github.com/pingcap/parser/mysql"
19 | )
20 |
21 | type ColumnTestSuite struct {
22 | }
23 |
24 | var _ = Suite(new(ColumnTestSuite))
25 |
26 | func (s ColumnTestSuite) TestDumpColumn(c *C) {
27 | info := ColumnInfo{
28 | Schema: "testSchema",
29 | Table: "testTable",
30 | OrgTable: "testOrgTable",
31 | Name: "testName",
32 | OrgName: "testOrgName",
33 | ColumnLength: 1,
34 | Charset: 106,
35 | Flag: 0,
36 | Decimal: 1,
37 | Type: 14,
38 | DefaultValueLength: 2,
39 | DefaultValue: []byte{5, 2},
40 | }
41 | r := info.Dump(nil)
42 | exp := []byte{0x3, 0x64, 0x65, 0x66, 0xa, 0x74, 0x65, 0x73, 0x74, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x9, 0x74, 0x65, 0x73, 0x74, 0x54, 0x61, 0x62, 0x6c, 0x65, 0xc, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x54, 0x61, 0x62, 0x6c, 0x65, 0x8, 0x74, 0x65, 0x73, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0xb, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x4e, 0x61, 0x6d, 0x65, 0xc, 0x6a, 0x0, 0x1, 0x0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5, 0x2}
43 | c.Assert(r, DeepEquals, exp)
44 |
45 | c.Assert(dumpFlag(mysql.TypeSet, 0), Equals, uint16(mysql.SetFlag))
46 | c.Assert(dumpFlag(mysql.TypeEnum, 0), Equals, uint16(mysql.EnumFlag))
47 | c.Assert(dumpFlag(mysql.TypeString, 0), Equals, uint16(0))
48 |
49 | c.Assert(dumpType(mysql.TypeSet), Equals, mysql.TypeString)
50 | c.Assert(dumpType(mysql.TypeEnum), Equals, mysql.TypeString)
51 | c.Assert(dumpType(mysql.TypeBit), Equals, mysql.TypeBit)
52 | }
53 |
54 | func (s ColumnTestSuite) TestColumnNameLimit(c *C) {
55 | aLongName := make([]byte, 0, 300)
56 | for i := 0; i < 300; i++ {
57 | aLongName = append(aLongName, 'a')
58 | }
59 | info := ColumnInfo{
60 | Schema: "testSchema",
61 | Table: "testTable",
62 | OrgTable: "testOrgTable",
63 | Name: string(aLongName),
64 | OrgName: "testOrgName",
65 | ColumnLength: 1,
66 | Charset: 106,
67 | Flag: 0,
68 | Decimal: 1,
69 | Type: 14,
70 | DefaultValueLength: 2,
71 | DefaultValue: []byte{5, 2},
72 | }
73 | r := info.Dump(nil)
74 | exp := []byte{0x3, 0x64, 0x65, 0x66, 0xa, 0x74, 0x65, 0x73, 0x74, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x9, 0x74, 0x65, 0x73, 0x74, 0x54, 0x61, 0x62, 0x6c, 0x65, 0xc, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x54, 0x61, 0x62, 0x6c, 0x65, 0xfc, 0x0, 0x1, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0xb, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x4e, 0x61, 0x6d, 0x65, 0xc, 0x6a, 0x0, 0x1, 0x0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5, 0x2}
75 | c.Assert(r, DeepEquals, exp)
76 | }
77 |
--------------------------------------------------------------------------------
/pkg/proxy/server/packetio.go:
--------------------------------------------------------------------------------
1 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
2 | //
3 | // This Source Code Form is subject to the terms of the Mozilla Public
4 | // License, v. 2.0. If a copy of the MPL was not distributed with this file,
5 | // You can obtain one at http://mozilla.org/MPL/2.0/.
6 |
7 | // The MIT License (MIT)
8 | //
9 | // Copyright (c) 2014 wandoulabs
10 | // Copyright (c) 2014 siddontang
11 | //
12 | // Permission is hereby granted, free of charge, to any person obtaining a copy of
13 | // this software and associated documentation files (the "Software"), to deal in
14 | // the Software without restriction, including without limitation the rights to
15 | // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
16 | // the Software, and to permit persons to whom the Software is furnished to do so,
17 | // subject to the following conditions:
18 | //
19 | // The above copyright notice and this permission notice shall be included in all
20 | // copies or substantial portions of the Software.
21 |
22 | // Copyright 2015 PingCAP, Inc.
23 | //
24 | // Licensed under the Apache License, Version 2.0 (the "License");
25 | // you may not use this file except in compliance with the License.
26 | // You may obtain a copy of the License at
27 | //
28 | // http://www.apache.org/licenses/LICENSE-2.0
29 | //
30 | // Unless required by applicable law or agreed to in writing, software
31 | // distributed under the License is distributed on an "AS IS" BASIS,
32 | // See the License for the specific language governing permissions and
33 | // limitations under the License.
34 |
35 | package server
36 |
37 | import (
38 | "bufio"
39 | "io"
40 | "time"
41 |
42 | "github.com/pingcap/errors"
43 | "github.com/pingcap/parser/mysql"
44 | "github.com/pingcap/parser/terror"
45 | )
46 |
47 | const defaultWriterSize = 16 * 1024
48 |
49 | // packetIO is a helper to read and write data in packet format.
50 | type packetIO struct {
51 | bufReadConn *bufferedReadConn
52 | bufWriter *bufio.Writer
53 | sequence uint8
54 | readTimeout time.Duration
55 | }
56 |
57 | func newPacketIO(bufReadConn *bufferedReadConn) *packetIO {
58 | p := &packetIO{sequence: 0}
59 | p.setBufferedReadConn(bufReadConn)
60 | return p
61 | }
62 |
63 | func (p *packetIO) setBufferedReadConn(bufReadConn *bufferedReadConn) {
64 | p.bufReadConn = bufReadConn
65 | p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize)
66 | }
67 |
68 | func (p *packetIO) setReadTimeout(timeout time.Duration) {
69 | p.readTimeout = timeout
70 | }
71 |
72 | func (p *packetIO) readOnePacket() ([]byte, error) {
73 | var header [4]byte
74 |
75 | if _, err := io.ReadFull(p.bufReadConn, header[:]); err != nil {
76 | return nil, errors.Trace(err)
77 | }
78 |
79 | sequence := header[3]
80 | if sequence != p.sequence {
81 | return nil, errInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence)
82 | }
83 |
84 | p.sequence++
85 |
86 | length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
87 |
88 | data := make([]byte, length)
89 | if _, err := io.ReadFull(p.bufReadConn, data); err != nil {
90 | return nil, errors.Trace(err)
91 | }
92 | return data, nil
93 | }
94 |
95 | func (p *packetIO) readPacket() ([]byte, error) {
96 | data, err := p.readOnePacket()
97 | if err != nil {
98 | return nil, errors.Trace(err)
99 | }
100 |
101 | if len(data) < mysql.MaxPayloadLen {
102 | return data, nil
103 | }
104 |
105 | // handle multi-packet
106 | for {
107 | buf, err := p.readOnePacket()
108 | if err != nil {
109 | return nil, errors.Trace(err)
110 | }
111 |
112 | data = append(data, buf...)
113 |
114 | if len(buf) < mysql.MaxPayloadLen {
115 | break
116 | }
117 | }
118 |
119 | return data, nil
120 | }
121 |
122 | // writePacket writes data that already have header
123 | func (p *packetIO) writePacket(data []byte) error {
124 | length := len(data) - 4
125 |
126 | for length >= mysql.MaxPayloadLen {
127 | data[0] = 0xff
128 | data[1] = 0xff
129 | data[2] = 0xff
130 |
131 | data[3] = p.sequence
132 |
133 | if n, err := p.bufWriter.Write(data[:4+mysql.MaxPayloadLen]); err != nil {
134 | return errors.Trace(mysql.ErrBadConn)
135 | } else if n != (4 + mysql.MaxPayloadLen) {
136 | return errors.Trace(mysql.ErrBadConn)
137 | } else {
138 | p.sequence++
139 | length -= mysql.MaxPayloadLen
140 | data = data[mysql.MaxPayloadLen:]
141 | }
142 | }
143 |
144 | data[0] = byte(length)
145 | data[1] = byte(length >> 8)
146 | data[2] = byte(length >> 16)
147 | data[3] = p.sequence
148 |
149 | if n, err := p.bufWriter.Write(data); err != nil {
150 | terror.Log(errors.Trace(err))
151 | return errors.Trace(mysql.ErrBadConn)
152 | } else if n != len(data) {
153 | return errors.Trace(mysql.ErrBadConn)
154 | } else {
155 | p.sequence++
156 | return nil
157 | }
158 | }
159 |
160 | func (p *packetIO) flush() error {
161 | err := p.bufWriter.Flush()
162 | if err != nil {
163 | return errors.Trace(err)
164 | }
165 | return err
166 | }
167 |
--------------------------------------------------------------------------------
/pkg/proxy/server/driver.go:
--------------------------------------------------------------------------------
1 | // Copyright 2015 PingCAP, Inc.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // See the License for the specific language governing permissions and
12 | // limitations under the License.
13 |
14 | package server
15 |
16 | import (
17 | "context"
18 | "crypto/tls"
19 | "fmt"
20 | "time"
21 |
22 | "github.com/pingcap/parser/auth"
23 | "github.com/pingcap/tidb/sessionctx/variable"
24 | "github.com/pingcap/tidb/types"
25 | "github.com/pingcap/tidb/util"
26 | "github.com/pingcap/tidb/util/chunk"
27 | "github.com/siddontang/go-mysql/mysql"
28 | )
29 |
30 | // IDriver opens IContext.
31 | type IDriver interface {
32 | // OpenCtx opens an IContext with connection id, client capability, collation, dbname and optionally the tls state.
33 | OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (QueryCtx, error)
34 | }
35 |
36 | // QueryCtx is the interface to execute command.
37 | type QueryCtx interface {
38 | // Status returns server status code.
39 | Status() uint16
40 |
41 | // LastInsertID returns last inserted ID.
42 | LastInsertID() uint64
43 |
44 | // LastMessage returns last info message generated by some commands
45 | LastMessage() string
46 |
47 | // AffectedRows returns affected rows of last executed command.
48 | AffectedRows() uint64
49 |
50 | // Value returns the value associated with this context for key.
51 | Value(key fmt.Stringer) interface{}
52 |
53 | // SetValue saves a value associated with this context for key.
54 | SetValue(key fmt.Stringer, value interface{})
55 |
56 | SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64)
57 |
58 | // CommitTxn commits the transaction operations.
59 | CommitTxn(ctx context.Context) error
60 |
61 | // RollbackTxn undoes the transaction operations.
62 | RollbackTxn()
63 |
64 | // WarningCount returns warning count of last executed command.
65 | WarningCount() uint16
66 |
67 | // CurrentDB returns current DB.
68 | CurrentDB() string
69 |
70 | // Execute executes a SQL statement.
71 | Execute(ctx context.Context, sql string) (*mysql.Result, error)
72 |
73 | // ExecuteInternal executes a internal SQL statement.
74 | ExecuteInternal(ctx context.Context, sql string) ([]ResultSet, error)
75 |
76 | // SetClientCapability sets client capability flags
77 | SetClientCapability(uint32)
78 |
79 | // Prepare prepares a statement.
80 | Prepare(ctx context.Context, sql string) (stmtId int, columns, params []*ColumnInfo, err error)
81 |
82 | StmtExecuteForward(ctx context.Context, stmtId int, data []byte) (*mysql.Result, error)
83 |
84 | StmtClose(ctx context.Context, stmtId int) error
85 |
86 | // FieldList returns columns of a table.
87 | FieldList(tableName string) (columns []*ColumnInfo, err error)
88 |
89 | // Close closes the QueryCtx.
90 | Close() error
91 |
92 | // Auth verifies user's authentication.
93 | Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool
94 |
95 | // ShowProcess shows the information about the session.
96 | ShowProcess() *util.ProcessInfo
97 |
98 | // GetSessionVars return SessionVars.
99 | GetSessionVars() *variable.SessionVars
100 |
101 | SetCommandValue(command byte)
102 |
103 | SetSessionManager(util.SessionManager)
104 | }
105 |
106 | // PreparedStatement is the interface to use a prepared statement.
107 | type PreparedStatement interface {
108 | // ID returns statement ID
109 | ID() int
110 |
111 | // Execute executes the statement.
112 | Execute(context.Context, []types.Datum) (ResultSet, error)
113 |
114 | // AppendParam appends parameter to the statement.
115 | AppendParam(paramID int, data []byte) error
116 |
117 | // NumParams returns number of parameters.
118 | NumParams() int
119 |
120 | // BoundParams returns bound parameters.
121 | BoundParams() [][]byte
122 |
123 | // SetParamsType sets type for parameters.
124 | SetParamsType([]byte)
125 |
126 | // GetParamsType returns the type for parameters.
127 | GetParamsType() []byte
128 |
129 | // StoreResultSet stores ResultSet for subsequent stmt fetching
130 | StoreResultSet(rs ResultSet)
131 |
132 | // GetResultSet gets ResultSet associated this statement
133 | GetResultSet() ResultSet
134 |
135 | // Reset removes all bound parameters.
136 | Reset()
137 |
138 | // Close closes the statement.
139 | Close() error
140 | }
141 |
142 | // ResultSet is the result set of an query.
143 | type ResultSet interface {
144 | Columns() []*ColumnInfo
145 | NewChunk() *chunk.Chunk
146 | Next(context.Context, *chunk.Chunk) error
147 | StoreFetchedRows(rows []chunk.Row)
148 | GetFetchedRows() []chunk.Row
149 | Close() error
150 | }
151 |
152 | // fetchNotifier represents notifier will be called in COM_FETCH.
153 | type fetchNotifier interface {
154 | // OnFetchReturned be called when COM_FETCH returns.
155 | // it will be used in server-side cursor.
156 | OnFetchReturned()
157 | }
158 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/backend.go:
--------------------------------------------------------------------------------
1 | package backend
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "sync"
7 | "time"
8 |
9 | "github.com/tidb-incubator/weir/pkg/proxy/backend/client"
10 | "github.com/tidb-incubator/weir/pkg/proxy/driver"
11 | "github.com/tidb-incubator/weir/pkg/proxy/metrics"
12 | "github.com/tidb-incubator/weir/pkg/util/sync2"
13 | "github.com/pingcap/tidb/util/logutil"
14 | "go.uber.org/zap"
15 | )
16 |
17 | var (
18 | ErrNoBackendAddr = errors.New("no backend addr")
19 | ErrBackendClosed = errors.New("backend is closed")
20 | ErrBackendNotFound = errors.New("backend not found")
21 | )
22 |
23 | type BackendConfig struct {
24 | Addrs map[string]struct{}
25 | UserName string
26 | Password string
27 | Capacity int
28 | IdleTimeout time.Duration
29 | SelectorType int
30 | }
31 |
32 | type BackendImpl struct {
33 | ns string
34 | cfg *BackendConfig
35 | connPools map[string]*ConnPool // key: addr
36 | instances []*Instance
37 | selector Selector
38 |
39 | lock sync.RWMutex
40 | closed sync2.AtomicBool
41 | }
42 |
43 | func NewBackendImpl(ns string, cfg *BackendConfig) *BackendImpl {
44 | return &BackendImpl{
45 | cfg: cfg,
46 | closed: sync2.NewAtomicBool(false),
47 | ns: ns,
48 | }
49 | }
50 |
51 | func (b *BackendImpl) Init() error {
52 | b.lock.Lock()
53 | defer b.lock.Unlock()
54 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventIniting).Inc()
55 |
56 | if err := b.initSelector(); err != nil {
57 | return err
58 | }
59 | if err := b.initInstances(); err != nil {
60 | return err
61 | }
62 | if err := b.initConnPools(); err != nil {
63 | return err
64 | }
65 |
66 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventInited).Inc()
67 | return nil
68 | }
69 |
70 | func (b *BackendImpl) initSelector() error {
71 | selector, err := CreateSelector(b.cfg.SelectorType)
72 | if err != nil {
73 | return err
74 | }
75 | b.selector = selector
76 | return nil
77 | }
78 |
79 | func (b *BackendImpl) initInstances() error {
80 | instances, err := createInstances(b.cfg)
81 | if err != nil {
82 | return err
83 | }
84 | b.instances = instances
85 | return nil
86 | }
87 |
88 | func (b *BackendImpl) initConnPools() error {
89 | connPools := make(map[string]*ConnPool)
90 | for addr := range b.cfg.Addrs {
91 | poolCfg := &ConnPoolConfig{
92 | Config: Config{Addr: addr, UserName: b.cfg.UserName, Password: b.cfg.Password},
93 | Capacity: b.cfg.Capacity,
94 | IdleTimeout: b.cfg.IdleTimeout,
95 | }
96 | connPool := NewConnPool(b.ns, poolCfg)
97 | connPools[addr] = connPool
98 | }
99 |
100 | successfulInitConnPoolAddrs := make(map[string]struct{})
101 | var initConnPoolErr error
102 | for addr, connPool := range connPools {
103 | if err := connPool.Init(); err != nil {
104 | initConnPoolErr = err
105 | break
106 | }
107 | successfulInitConnPoolAddrs[addr] = struct{}{}
108 | }
109 |
110 | if initConnPoolErr != nil {
111 | for addr := range successfulInitConnPoolAddrs {
112 | if err := connPools[addr].Close(); err != nil {
113 | logutil.BgLogger().Sugar().Error("close inited conn pool error, addr: %s, err: %v", addr, err)
114 | }
115 | }
116 | return initConnPoolErr
117 | }
118 |
119 | b.connPools = connPools
120 | return nil
121 | }
122 |
123 | func (b *BackendImpl) GetConn(ctx context.Context) (driver.SimpleBackendConn, error) {
124 | if b.closed.Get() {
125 | return nil, ErrBackendClosed
126 | }
127 |
128 | instance, err := b.route(b.instances)
129 | if err != nil {
130 | return nil, err
131 | }
132 |
133 | conn, err := client.Connect(instance.Addr(), b.cfg.UserName, b.cfg.Password, "")
134 | return conn, err
135 | }
136 |
137 | func (b *BackendImpl) GetPooledConn(ctx context.Context) (driver.PooledBackendConn, error) {
138 | if b.closed.Get() {
139 | return nil, ErrBackendClosed
140 | }
141 |
142 | instance, err := b.route(b.instances)
143 | if err != nil {
144 | return nil, err
145 | }
146 |
147 | b.lock.RLock()
148 | connPool, ok := b.connPools[instance.Addr()]
149 | b.lock.RUnlock()
150 | if !ok {
151 | return nil, ErrBackendNotFound
152 | }
153 |
154 | return connPool.GetConn(ctx)
155 | }
156 |
157 | func (b *BackendImpl) Close() {
158 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventClosing).Inc()
159 | if !b.closed.CompareAndSwap(false, true) {
160 | return
161 | }
162 |
163 | b.lock.Lock()
164 | defer b.lock.Unlock()
165 |
166 | for addr, connPool := range b.connPools {
167 | if err := connPool.Close(); err != nil {
168 | logutil.BgLogger().Error("close conn pool error, addr: %s, err: %v", zap.String("addr", addr), zap.Error(err))
169 | }
170 | }
171 |
172 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventClosed).Inc()
173 | }
174 |
175 | func (b *BackendImpl) route(instances []*Instance) (*Instance, error) {
176 | instance, err := b.selector.Select(b.instances)
177 | if err != nil {
178 | return nil, err
179 | }
180 |
181 | return instance, nil
182 | }
183 |
184 | func createInstances(cfg *BackendConfig) ([]*Instance, error) {
185 | if len(cfg.Addrs) == 0 {
186 | return nil, ErrNoBackendAddr
187 | }
188 |
189 | var ret []*Instance
190 | for addr := range cfg.Addrs {
191 | ins := &Instance{addr: addr}
192 | ret = append(ret, ins)
193 | }
194 | return ret, nil
195 | }
196 |
--------------------------------------------------------------------------------
/pkg/proxy/driver/connmgr.go:
--------------------------------------------------------------------------------
1 | package driver
2 |
3 | import (
4 | "context"
5 | "database/sql/driver"
6 | "sync"
7 |
8 | "github.com/tidb-incubator/weir/pkg/proxy/metrics"
9 | utilerrors "github.com/tidb-incubator/weir/pkg/util/errors"
10 | "github.com/pingcap/parser/mysql"
11 | "github.com/pingcap/tidb/util/logutil"
12 | gomysql "github.com/siddontang/go-mysql/mysql"
13 | "go.uber.org/zap"
14 | )
15 |
16 | type BackendConnManager struct {
17 | fsm *FSM
18 | state FSMState
19 |
20 | ns Namespace
21 |
22 | mu sync.Mutex
23 | txnConn PooledBackendConn
24 |
25 | // TODO: use stmt id set
26 | isPrepared bool
27 | }
28 |
29 | func NewBackendConnManager(fsm *FSM, ns Namespace) *BackendConnManager {
30 | return &BackendConnManager{
31 | fsm: fsm,
32 | state: stateInitial,
33 | ns: ns,
34 | isPrepared: false,
35 | }
36 | }
37 |
38 | func (f *BackendConnManager) MergeStatus(svw *SessionVarsWrapper) {
39 | f.mu.Lock()
40 | defer f.mu.Unlock()
41 |
42 | svw.SetStatusFlag(mysql.ServerStatusInTrans, f.state.IsInTransaction())
43 | svw.SetStatusFlag(mysql.ServerStatusAutocommit, f.state.IsAutoCommit())
44 | }
45 |
46 | func (f *BackendConnManager) Query(ctx context.Context, db, sql string) (*gomysql.Result, error) {
47 | f.mu.Lock()
48 | defer f.mu.Unlock()
49 |
50 | ret, err := f.fsm.Call(ctx, EventQuery, f, db, sql)
51 | if err != nil {
52 | return nil, err
53 | }
54 | return ret.(*gomysql.Result), nil
55 | }
56 |
57 | func (f *BackendConnManager) SetAutoCommit(ctx context.Context, autocommit bool) error {
58 | f.mu.Lock()
59 | defer f.mu.Unlock()
60 |
61 | var err error
62 | if autocommit {
63 | _, err = f.fsm.Call(ctx, EventEnableAutoCommit, f)
64 | } else {
65 | _, err = f.fsm.Call(ctx, EventDisableAutoCommit, f)
66 | }
67 | return err
68 | }
69 |
70 | func (f *BackendConnManager) Begin(ctx context.Context) error {
71 | f.mu.Lock()
72 | defer f.mu.Unlock()
73 |
74 | _, err := f.fsm.Call(ctx, EventBegin, f)
75 | return err
76 | }
77 |
78 | func (f *BackendConnManager) CommitOrRollback(ctx context.Context, commit bool) error {
79 | f.mu.Lock()
80 | defer f.mu.Unlock()
81 |
82 | _, err := f.fsm.Call(ctx, EventCommitOrRollback, f, commit)
83 | return err
84 | }
85 |
86 | func (f *BackendConnManager) StmtPrepare(ctx context.Context, db, sql string) (Stmt, error) {
87 | f.mu.Lock()
88 | defer f.mu.Unlock()
89 |
90 | ret, err := f.fsm.Call(ctx, EventStmtPrepare, f, db, sql)
91 | if err != nil {
92 | return nil, err
93 | }
94 | return ret.(Stmt), nil
95 | }
96 |
97 | func (f *BackendConnManager) StmtExecuteForward(ctx context.Context, stmtId int, data []byte) (*gomysql.Result, error) {
98 | f.mu.Lock()
99 | defer f.mu.Unlock()
100 |
101 | ret, err := f.fsm.Call(ctx, EventStmtForwardData, f, stmtId, data)
102 | if err != nil {
103 | return nil, err
104 | }
105 | if ret == nil {
106 | return nil, nil
107 | }
108 | return ret.(*gomysql.Result), nil
109 | }
110 |
111 | func (f *BackendConnManager) StmtClose(ctx context.Context, stmtId int) error {
112 | f.mu.Lock()
113 | defer f.mu.Unlock()
114 |
115 | _, err := f.fsm.Call(ctx, EventStmtClose, f, stmtId)
116 | return err
117 | }
118 |
119 | // TODO(eastfisher): is it possible to use FSM to manage close?
120 | func (f *BackendConnManager) Close() error {
121 | f.mu.Lock()
122 | defer f.mu.Unlock()
123 |
124 | if f.txnConn != nil {
125 | errClosePooledBackendConn(f.txnConn, f.ns.Name())
126 | }
127 | f.state = stateInitial
128 | f.unsetAttachedConn()
129 | return nil
130 | }
131 |
132 | func (f *BackendConnManager) queryWithoutTxn(ctx context.Context, db, sql string) (*gomysql.Result, error) {
133 | var err error
134 | conn, err := f.ns.GetPooledConn(ctx)
135 | if err != nil {
136 | return nil, err
137 | }
138 |
139 | defer func() {
140 | if err != nil && isConnError(err) {
141 | if errClose := conn.ErrorClose(); errClose != nil {
142 | logutil.BgLogger().Error("close backend conn error", zap.Error(errClose))
143 | }
144 | } else {
145 | conn.PutBack()
146 | }
147 | }()
148 |
149 | if err = conn.UseDB(db); err != nil {
150 | return nil, err
151 | }
152 |
153 | var ret *gomysql.Result
154 | ret, err = conn.Execute(sql)
155 | return ret, err
156 | }
157 |
158 | func (f *BackendConnManager) queryInTxn(ctx context.Context, db, sql string) (*gomysql.Result, error) {
159 | if err := f.txnConn.UseDB(db); err != nil {
160 | return nil, err
161 | }
162 | return f.txnConn.Execute(sql)
163 | }
164 |
165 | func (f *BackendConnManager) releaseAttachedConn(err error) {
166 | if err != nil {
167 | errClosePooledBackendConn(f.txnConn, f.ns.Name())
168 | } else {
169 | f.txnConn.PutBack()
170 | }
171 | f.unsetAttachedConn()
172 | }
173 |
174 | func (f *BackendConnManager) setAttachedConn(conn PooledBackendConn) {
175 | f.txnConn = conn
176 | metrics.QueryCtxAttachedConnGauge.WithLabelValues(f.ns.Name()).Inc()
177 | }
178 |
179 | func (f *BackendConnManager) unsetAttachedConn() {
180 | if f.txnConn != nil {
181 | metrics.QueryCtxAttachedConnGauge.WithLabelValues(f.ns.Name()).Dec()
182 | }
183 | f.txnConn = nil
184 | }
185 |
186 | func errClosePooledBackendConn(conn PooledBackendConn, ns string) {
187 | if err := conn.ErrorClose(); err != nil {
188 | logutil.BgLogger().Error("close backend conn error", zap.Error(err), zap.String("namespace", ns))
189 | }
190 | }
191 |
192 | func isConnError(err error) bool {
193 | return utilerrors.Is(err, gomysql.ErrBadConn) || utilerrors.Is(err, driver.ErrBadConn)
194 | }
195 |
--------------------------------------------------------------------------------
/pkg/util/sync2/atomic.go:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright 2019 The Vitess Authors.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | */
16 |
17 | package sync2
18 |
19 | import (
20 | "sync"
21 | "sync/atomic"
22 | "time"
23 | )
24 |
25 | // AtomicInt32 is a wrapper with a simpler interface around atomic.(Add|Store|Load|CompareAndSwap)Int32 functions.
26 | type AtomicInt32 struct {
27 | int32
28 | }
29 |
30 | // NewAtomicInt32 initializes a new AtomicInt32 with a given value.
31 | func NewAtomicInt32(n int32) AtomicInt32 {
32 | return AtomicInt32{n}
33 | }
34 |
35 | // Add atomically adds n to the value.
36 | func (i *AtomicInt32) Add(n int32) int32 {
37 | return atomic.AddInt32(&i.int32, n)
38 | }
39 |
40 | // Set atomically sets n as new value.
41 | func (i *AtomicInt32) Set(n int32) {
42 | atomic.StoreInt32(&i.int32, n)
43 | }
44 |
45 | // Get atomically returns the current value.
46 | func (i *AtomicInt32) Get() int32 {
47 | return atomic.LoadInt32(&i.int32)
48 | }
49 |
50 | // CompareAndSwap automatically swaps the old with the new value.
51 | func (i *AtomicInt32) CompareAndSwap(oldval, newval int32) (swapped bool) {
52 | return atomic.CompareAndSwapInt32(&i.int32, oldval, newval)
53 | }
54 |
55 | // AtomicInt64 is a wrapper with a simpler interface around atomic.(Add|Store|Load|CompareAndSwap)Int64 functions.
56 | type AtomicInt64 struct {
57 | int64
58 | }
59 |
60 | // NewAtomicInt64 initializes a new AtomicInt64 with a given value.
61 | func NewAtomicInt64(n int64) AtomicInt64 {
62 | return AtomicInt64{n}
63 | }
64 |
65 | // Add atomically adds n to the value.
66 | func (i *AtomicInt64) Add(n int64) int64 {
67 | return atomic.AddInt64(&i.int64, n)
68 | }
69 |
70 | // Set atomically sets n as new value.
71 | func (i *AtomicInt64) Set(n int64) {
72 | atomic.StoreInt64(&i.int64, n)
73 | }
74 |
75 | // Get atomically returns the current value.
76 | func (i *AtomicInt64) Get() int64 {
77 | return atomic.LoadInt64(&i.int64)
78 | }
79 |
80 | // CompareAndSwap automatically swaps the old with the new value.
81 | func (i *AtomicInt64) CompareAndSwap(oldval, newval int64) (swapped bool) {
82 | return atomic.CompareAndSwapInt64(&i.int64, oldval, newval)
83 | }
84 |
85 | // AtomicDuration is a wrapper with a simpler interface around atomic.(Add|Store|Load|CompareAndSwap)Int64 functions.
86 | type AtomicDuration struct {
87 | int64
88 | }
89 |
90 | // NewAtomicDuration initializes a new AtomicDuration with a given value.
91 | func NewAtomicDuration(duration time.Duration) AtomicDuration {
92 | return AtomicDuration{int64(duration)}
93 | }
94 |
95 | // Add atomically adds duration to the value.
96 | func (d *AtomicDuration) Add(duration time.Duration) time.Duration {
97 | return time.Duration(atomic.AddInt64(&d.int64, int64(duration)))
98 | }
99 |
100 | // Set atomically sets duration as new value.
101 | func (d *AtomicDuration) Set(duration time.Duration) {
102 | atomic.StoreInt64(&d.int64, int64(duration))
103 | }
104 |
105 | // Get atomically returns the current value.
106 | func (d *AtomicDuration) Get() time.Duration {
107 | return time.Duration(atomic.LoadInt64(&d.int64))
108 | }
109 |
110 | // CompareAndSwap automatically swaps the old with the new value.
111 | func (d *AtomicDuration) CompareAndSwap(oldval, newval time.Duration) (swapped bool) {
112 | return atomic.CompareAndSwapInt64(&d.int64, int64(oldval), int64(newval))
113 | }
114 |
115 | // AtomicBool gives an atomic boolean variable.
116 | type AtomicBool struct {
117 | int32
118 | }
119 |
120 | // NewAtomicBool initializes a new AtomicBool with a given value.
121 | func NewAtomicBool(n bool) AtomicBool {
122 | if n {
123 | return AtomicBool{1}
124 | }
125 | return AtomicBool{0}
126 | }
127 |
128 | // Set atomically sets n as new value.
129 | func (i *AtomicBool) Set(n bool) {
130 | if n {
131 | atomic.StoreInt32(&i.int32, 1)
132 | } else {
133 | atomic.StoreInt32(&i.int32, 0)
134 | }
135 | }
136 |
137 | // Get atomically returns the current value.
138 | func (i *AtomicBool) Get() bool {
139 | return atomic.LoadInt32(&i.int32) != 0
140 | }
141 |
142 | // CompareAndSwap automatically swaps the old with the new value.
143 | func (i *AtomicBool) CompareAndSwap(o, n bool) bool {
144 | var old, new int32
145 | if o {
146 | old = 1
147 | }
148 | if n {
149 | new = 1
150 | }
151 | return atomic.CompareAndSwapInt32(&i.int32, old, new)
152 | }
153 |
154 | // AtomicString gives you atomic-style APIs for string, but
155 | // it's only a convenience wrapper that uses a mutex. So, it's
156 | // not as efficient as the rest of the atomic types.
157 | type AtomicString struct {
158 | mu sync.Mutex
159 | str string
160 | }
161 |
162 | // Set atomically sets str as new value.
163 | func (s *AtomicString) Set(str string) {
164 | s.mu.Lock()
165 | s.str = str
166 | s.mu.Unlock()
167 | }
168 |
169 | // Get atomically returns the current value.
170 | func (s *AtomicString) Get() string {
171 | s.mu.Lock()
172 | str := s.str
173 | s.mu.Unlock()
174 | return str
175 | }
176 |
177 | // CompareAndSwap automatically swaps the old with the new value.
178 | func (s *AtomicString) CompareAndSwap(oldval, newval string) (swqpped bool) {
179 | s.mu.Lock()
180 | defer s.mu.Unlock()
181 | if s.str == oldval {
182 | s.str = newval
183 | return true
184 | }
185 | return false
186 | }
187 |
--------------------------------------------------------------------------------
/pkg/proxy/backend/client/conn.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "crypto/tls"
5 | "fmt"
6 | "net"
7 | "strings"
8 | "time"
9 |
10 | "github.com/pingcap/errors"
11 | . "github.com/siddontang/go-mysql/mysql"
12 | "github.com/siddontang/go-mysql/packet"
13 | "github.com/tidb-incubator/weir/pkg/proxy/constant"
14 | )
15 |
16 | type Conn struct {
17 | *packet.Conn
18 |
19 | user string
20 | password string
21 | db string
22 | tlsConfig *tls.Config
23 | proto string
24 |
25 | capability uint32
26 |
27 | status uint16
28 |
29 | charset string
30 |
31 | salt []byte
32 | authPluginName string
33 |
34 | connectionID uint32
35 | }
36 |
37 | func getNetProto(addr string) string {
38 | proto := "tcp"
39 | if strings.Contains(addr, "/") {
40 | proto = "unix"
41 | }
42 | return proto
43 | }
44 |
45 | // Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock.
46 | // Accepts a series of configuration functions as a variadic argument.
47 | func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
48 | proto := getNetProto(addr)
49 |
50 | c := new(Conn)
51 |
52 | var err error
53 | conn, err := net.DialTimeout(proto, addr, 10*time.Second)
54 | if err != nil {
55 | return nil, errors.Trace(err)
56 | }
57 |
58 | if c.tlsConfig != nil {
59 | c.Conn = packet.NewTLSConn(conn)
60 | } else {
61 | c.Conn = packet.NewConn(conn)
62 | }
63 |
64 | c.user = user
65 | c.password = password
66 | c.db = dbName
67 | c.proto = proto
68 |
69 | //use default charset here
70 | c.charset = constant.DefaultCharset
71 |
72 | // Apply configuration functions.
73 | for i := range options {
74 | options[i](c)
75 | }
76 |
77 | if err = c.handshake(); err != nil {
78 | return nil, errors.Trace(err)
79 | }
80 |
81 | return c, nil
82 | }
83 |
84 | func (c *Conn) handshake() error {
85 | var err error
86 | if err = c.readInitialHandshake(); err != nil {
87 | c.Close()
88 | return errors.Trace(err)
89 | }
90 |
91 | if err := c.writeAuthHandshake(); err != nil {
92 | c.Close()
93 |
94 | return errors.Trace(err)
95 | }
96 |
97 | if err := c.handleAuthResult(); err != nil {
98 | c.Close()
99 | return errors.Trace(err)
100 | }
101 |
102 | return nil
103 | }
104 |
105 | func (c *Conn) Close() error {
106 | return c.Conn.Close()
107 | }
108 |
109 | func (c *Conn) Ping() error {
110 | if err := c.writeCommand(COM_PING); err != nil {
111 | return errors.Trace(err)
112 | }
113 |
114 | if _, err := c.readOK(); err != nil {
115 | return errors.Trace(err)
116 | }
117 |
118 | return nil
119 | }
120 |
121 | // UseSSL: use default SSL
122 | // pass to options when connect
123 | func (c *Conn) UseSSL(insecureSkipVerify bool) {
124 | c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify}
125 | }
126 |
127 | // SetTLSConfig: use user-specified TLS config
128 | // pass to options when connect
129 | func (c *Conn) SetTLSConfig(config *tls.Config) {
130 | c.tlsConfig = config
131 | }
132 |
133 | func (c *Conn) UseDB(dbName string) error {
134 | if c.db == dbName {
135 | return nil
136 | }
137 |
138 | if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil {
139 | return errors.Trace(err)
140 | }
141 |
142 | if _, err := c.readOK(); err != nil {
143 | return errors.Trace(err)
144 | }
145 |
146 | c.db = dbName
147 | return nil
148 | }
149 |
150 | func (c *Conn) GetDB() string {
151 | return c.db
152 | }
153 |
154 | func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
155 | if len(args) == 0 {
156 | return c.exec(command)
157 | } else {
158 | if s, err := c.Prepare(command); err != nil {
159 | return nil, errors.Trace(err)
160 | } else {
161 | var r *Result
162 | r, err = s.Execute(args...)
163 | s.Close()
164 | return r, err
165 | }
166 | }
167 | }
168 |
169 | func (c *Conn) Begin() error {
170 | _, err := c.exec("BEGIN")
171 | return errors.Trace(err)
172 | }
173 |
174 | func (c *Conn) Commit() error {
175 | _, err := c.exec("COMMIT")
176 | return errors.Trace(err)
177 | }
178 |
179 | func (c *Conn) Rollback() error {
180 | _, err := c.exec("ROLLBACK")
181 | return errors.Trace(err)
182 | }
183 |
184 | func (c *Conn) SetCharset(charset string) error {
185 | if c.charset == charset {
186 | return nil
187 | }
188 |
189 | if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil {
190 | return errors.Trace(err)
191 | } else {
192 | c.charset = charset
193 | return nil
194 | }
195 | }
196 |
197 | func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
198 | if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
199 | return nil, errors.Trace(err)
200 | }
201 | fs := make([]*Field, 0, 4)
202 | for {
203 | data, err := c.ReadPacket()
204 | if err != nil {
205 | return nil, errors.Trace(err)
206 | }
207 | if data[0] == ERR_HEADER {
208 | return nil, c.handleErrorPacket(data)
209 | }
210 |
211 | // EOF Packet
212 | if c.isEOFPacket(data) {
213 | break
214 | }
215 | f, err := FieldData(data).Parse()
216 | if err != nil {
217 | return nil, errors.Trace(err)
218 | }
219 | fs = append(fs, f)
220 | }
221 | return fs, nil
222 | }
223 |
224 | func (c *Conn) SetAutoCommit(autocommit bool) error {
225 | if c.IsAutoCommit() == autocommit {
226 | return nil
227 | }
228 |
229 | if autocommit {
230 | if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil {
231 | return errors.Trace(err)
232 | }
233 | return nil
234 | } else {
235 | if _, err := c.exec("SET AUTOCOMMIT = 0"); err != nil {
236 | return errors.Trace(err)
237 | }
238 | return nil
239 | }
240 | }
241 |
242 | func (c *Conn) IsAutoCommit() bool {
243 | return c.status&SERVER_STATUS_AUTOCOMMIT > 0
244 | }
245 |
246 | func (c *Conn) IsInTransaction() bool {
247 | return c.status&SERVER_STATUS_IN_TRANS > 0
248 | }
249 |
250 | func (c *Conn) GetCharset() string {
251 | return c.charset
252 | }
253 |
254 | func (c *Conn) GetConnectionID() uint32 {
255 | return c.connectionID
256 | }
257 |
258 | func (c *Conn) GetStatus() uint16 {
259 | return c.status
260 | }
261 |
262 | func (c *Conn) HandleOKPacket(data []byte) *Result {
263 | r, _ := c.handleOKPacket(data)
264 | return r
265 | }
266 |
267 | func (c *Conn) HandleErrorPacket(data []byte) error {
268 | return c.handleErrorPacket(data)
269 | }
270 |
271 | func (c *Conn) ReadOKPacket() (*Result, error) {
272 | return c.readOK()
273 | }
274 |
275 | func (c *Conn) exec(query string) (*Result, error) {
276 | if err := c.writeCommandStr(COM_QUERY, query); err != nil {
277 | return nil, errors.Trace(err)
278 | }
279 |
280 | return c.readResult(false)
281 | }
282 |
--------------------------------------------------------------------------------