├── OWNERS ├── docs ├── cn │ ├── dashboard-workflow.md │ ├── assets │ │ ├── ast.png │ │ ├── reload.png │ │ ├── multi-tenant.png │ │ ├── breaker_process.png │ │ ├── deployment_idc.png │ │ ├── deployment_k8s.png │ │ ├── sliding_window.png │ │ ├── backend-conn-pool.png │ │ └── rateLimiterAndBreaker.png │ ├── monitoring.md │ ├── proxy.md │ ├── cluster_deployment.md │ ├── config-dynamic-reload.md │ ├── multi-tenant.md │ ├── connection-management.md │ ├── RESTful_api.md │ ├── fault-tolerant.md │ ├── proxy-config.md │ ├── quickstart.md │ └── namespace-config.md └── en │ └── assets │ └── weir-architecture.png ├── .gitignore ├── CONTRIBUTING.md ├── pkg ├── proxy │ ├── constant │ │ ├── context.go │ │ └── charset.go │ ├── backend │ │ ├── instance.go │ │ ├── client │ │ │ ├── tls.go │ │ │ ├── req.go │ │ │ └── conn.go │ │ ├── selector_test.go │ │ ├── selector.go │ │ └── backend.go │ ├── driver │ │ ├── driver.go │ │ ├── mock_NamespaceManager.go │ │ ├── queryctx_exec_test.go │ │ ├── mock_Stmt.go │ │ ├── queryctx_metrics.go │ │ ├── mock_Namespace.go │ │ ├── domain.go │ │ ├── sessionvars.go │ │ ├── resultset.go │ │ └── connmgr.go │ ├── namespace │ │ ├── errcode.go │ │ ├── ratelimiter.go │ │ ├── ratelimiter_test.go │ │ ├── domain.go │ │ ├── frontend.go │ │ ├── user.go │ │ ├── namespace.go │ │ ├── builder.go │ │ └── manager.go │ ├── metrics │ │ ├── session.go │ │ ├── backend.go │ │ ├── server.go │ │ ├── metrics.go │ │ └── queryctx.go │ ├── server │ │ ├── buffered_read_conn.go │ │ ├── tokenlimiter.go │ │ ├── server_util.go │ │ ├── column.go │ │ ├── conn_stmt.go │ │ ├── conn_util.go │ │ ├── packetio_test.go │ │ ├── column_test.go │ │ ├── packetio.go │ │ └── driver.go │ └── proxy.go ├── util │ ├── datastructure │ │ ├── dsutil.go │ │ └── dsutil_test.go │ ├── sync2 │ │ ├── boolindex.go │ │ ├── doc.go │ │ ├── toggle.go │ │ ├── toggle_test.go │ │ ├── semaphore_flaky_test.go │ │ ├── semaphore.go │ │ ├── atomic_test.go │ │ └── atomic.go │ ├── rand2 │ │ ├── rand_test.go │ │ └── rand.go │ ├── passwd │ │ └── passwd.go │ ├── rate_limit_breaker │ │ ├── rate_limit │ │ │ ├── leaky_bucket_test.go │ │ │ ├── sliding_window.go │ │ │ ├── sliding_window_test.go │ │ │ └── leaky_bucket.go │ │ ├── sliding_window.go │ │ └── circuit_breaker │ │ │ └── circuit_breaker_test.go │ ├── errors │ │ ├── errors_test.go │ │ └── errors.go │ ├── timer │ │ ├── randticker.go │ │ ├── randticker_flaky_test.go │ │ ├── timer_flaky_test.go │ │ ├── time_wheel_test.go │ │ ├── time_wheel.go │ │ └── timer.go │ └── ast │ │ └── ast_util.go ├── config │ ├── marshaller.go │ ├── namespace_example.yaml │ ├── proxy_example.yaml │ ├── namespace.go │ ├── proxy.go │ └── marshaller_test.go └── configcenter │ ├── factory.go │ ├── file.go │ └── etcd.go ├── code-of-conduct.md ├── .github ├── ISSUE_TEMPLATE │ ├── development-task.md │ ├── bug-report.md │ └── feature-request.md └── pull_request_template.md ├── conf ├── weirproxy.yaml ├── weirproxy_etcd.yaml └── namespace │ └── test_namespace.yaml ├── README-CN.md ├── tests └── proxy │ └── backend │ └── connpool_test.go ├── Makefile ├── README.md ├── go.mod ├── cmd └── weirproxy │ └── main.go └── .golangci.yaml /OWNERS: -------------------------------------------------------------------------------- 1 | committers: 2 | -------------------------------------------------------------------------------- /docs/cn/dashboard-workflow.md: -------------------------------------------------------------------------------- 1 | # Dashboard工作流 2 | 3 | (TODO) 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | *.iml 3 | .idea 4 | *.swp 5 | .DS_Store 6 | vendor 7 | .vscode/ 8 | -------------------------------------------------------------------------------- /docs/cn/assets/ast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/ast.png -------------------------------------------------------------------------------- /docs/cn/assets/reload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/reload.png -------------------------------------------------------------------------------- /docs/cn/assets/multi-tenant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/multi-tenant.png -------------------------------------------------------------------------------- /docs/cn/assets/breaker_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/breaker_process.png -------------------------------------------------------------------------------- /docs/cn/assets/deployment_idc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/deployment_idc.png -------------------------------------------------------------------------------- /docs/cn/assets/deployment_k8s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/deployment_k8s.png -------------------------------------------------------------------------------- /docs/cn/assets/sliding_window.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/sliding_window.png -------------------------------------------------------------------------------- /docs/cn/assets/backend-conn-pool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/backend-conn-pool.png -------------------------------------------------------------------------------- /docs/en/assets/weir-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/en/assets/weir-architecture.png -------------------------------------------------------------------------------- /docs/cn/assets/rateLimiterAndBreaker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidb-incubator/weir/HEAD/docs/cn/assets/rateLimiterAndBreaker.png -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Please refer to PingCAP [CONTRIBUTING.md](https://github.com/pingcap/community/blob/master/CONTRIBUTING.md) -------------------------------------------------------------------------------- /docs/cn/monitoring.md: -------------------------------------------------------------------------------- 1 | # 监控与告警 2 | 3 | ## Grafana监控 4 | 5 | Grafana监控分为proxy级别和namespace级别, [这里](assets/grafana_proxy.json) 给出了一个proxy级别监控大盘的配置示例. 6 | -------------------------------------------------------------------------------- /pkg/proxy/constant/context.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | const ContextKeyPrefix = "__w_" 4 | 5 | const ContextKeySessionVariable = ContextKeyPrefix + "session_sysvars" 6 | -------------------------------------------------------------------------------- /pkg/proxy/backend/instance.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | type Instance struct { 4 | addr string 5 | } 6 | 7 | func (i *Instance) Addr() string { 8 | return i.addr 9 | } 10 | -------------------------------------------------------------------------------- /code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # PingCAP Community Code of Conduct 2 | 3 | Please refer to our [PingCAP Community Code of Conduct](https://github.com/pingcap/community/blob/master/CODE_OF_CONDUCT.md) 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/development-task.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "🗒️ Development Task" 3 | about: As a developer, I want to record a development task. 4 | labels: type/enhancement 5 | --- 6 | 7 | ## Development Task 8 | -------------------------------------------------------------------------------- /pkg/proxy/constant/charset.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import "github.com/pingcap/parser/charset" 4 | 5 | const ( 6 | DefaultCharset = charset.CharsetUTF8MB4 7 | DefaultCollationID = charset.CollationUTF8MB4 8 | ) -------------------------------------------------------------------------------- /docs/cn/proxy.md: -------------------------------------------------------------------------------- 1 | # 应用层代理 2 | 3 | Weir作为TiDB分布式数据库治理平台, 其代理组件weir-proxy在实现时复用了TiDB 4.0的协议层和SQL解析器, 因此weir-proxy在协议和语法层面, 对MySQL的兼容性与TiDB 4.0相同. 4 | 5 | 使用原生MySQL客户端即可连接weir-proxy, 通过weir-proxy将客户端的SQL请求转发到后端的TiDB Server集群. 6 | -------------------------------------------------------------------------------- /pkg/util/datastructure/dsutil.go: -------------------------------------------------------------------------------- 1 | package datastructure 2 | 3 | func StringSliceToSet(ss []string) map[string]struct{} { 4 | sset := make(map[string]struct{}, len(ss)) 5 | for _, s := range ss { 6 | sset[s] = struct{}{} 7 | } 8 | return sset 9 | } 10 | -------------------------------------------------------------------------------- /docs/cn/cluster_deployment.md: -------------------------------------------------------------------------------- 1 | # 部署 2 | 3 | ## 物理机部署 4 | 5 | 物理机部署应该尽可能选择多个机房部署, 做到高可用, 其中每个实例应该保持配置相同, 上游可以选择带有健康检查的 Server Load Balancer . 6 | 7 | ## kubernetes 下部署 8 | 9 | kubernetes 下部署可以利用 nodeSelector 将 Pod 尽量调度到不同的 node 节点上, 此操作需要向 Node 对象添加标签就可以将 pod 定位到特定的节点或节点组, 这可以用来确保指定的 Pod 只能运行在具有一定隔离性,安全性或监管属性的节点上. 10 | 其中上游可以直接通过 Service 或者其他转发组件进行转发 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "🐛 Bug Report" 3 | about: Something isn't working as expected 4 | labels: 'type/bug' 5 | --- 6 | 7 | ## Bug Report 8 | 9 | 10 | 11 | ### What did you do? 12 | 13 | 14 | 15 | ### What did you expect to see? 16 | 17 | ### What did you see instead? 18 | 19 | ### What version of Weir are you using (`weir-proxy -V`)? 20 | -------------------------------------------------------------------------------- /pkg/proxy/driver/driver.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "crypto/tls" 5 | 6 | "github.com/tidb-incubator/weir/pkg/proxy/server" 7 | ) 8 | 9 | type DriverImpl struct { 10 | nsmgr NamespaceManager 11 | } 12 | 13 | func NewDriverImpl(nsmgr NamespaceManager) *DriverImpl { 14 | return &DriverImpl{ 15 | nsmgr: nsmgr, 16 | } 17 | } 18 | 19 | func (d *DriverImpl) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (server.QueryCtx, error) { 20 | return NewQueryCtxImpl(d.nsmgr, connID), nil 21 | } 22 | -------------------------------------------------------------------------------- /conf/weirproxy.yaml: -------------------------------------------------------------------------------- 1 | version: "v1" 2 | proxy_server: 3 | addr: "0.0.0.0:6000" 4 | max_connections: 1000 5 | session_timeout: 600 6 | admin_server: 7 | addr: "0.0.0.0:6001" 8 | enable_basic_auth: false 9 | user: "" 10 | password: "" 11 | log: 12 | level: "debug" 13 | format: "console" 14 | log_file: 15 | filename: "" 16 | max_size: 300 17 | max_days: 1 18 | max_backups: 1 19 | registry: 20 | enable: false 21 | config_center: 22 | type: "file" 23 | config_file: 24 | path: "./conf/namespace" 25 | strict_parse: false 26 | performance: 27 | tcp_keep_alive: true 28 | -------------------------------------------------------------------------------- /pkg/util/sync2/boolindex.go: -------------------------------------------------------------------------------- 1 | package sync2 2 | 3 | import "sync/atomic" 4 | 5 | // BoolIndex rolled array switch mark 6 | type BoolIndex struct { 7 | index int32 8 | } 9 | 10 | // Set set index value 11 | func (b *BoolIndex) Set(index bool) { 12 | if index { 13 | atomic.StoreInt32(&b.index, 1) 14 | } else { 15 | atomic.StoreInt32(&b.index, 0) 16 | } 17 | } 18 | 19 | // Get return current, next, current bool value 20 | func (b *BoolIndex) Get() (int32, int32, bool) { 21 | index := atomic.LoadInt32(&b.index) 22 | if index == 1 { 23 | return 1, 0, true 24 | } 25 | return 0, 1, false 26 | } 27 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/errcode.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "github.com/pingcap/errors" 5 | ) 6 | 7 | var ( 8 | ErrDuplicatedUser = errors.New("duplicated user") 9 | ErrInvalidSelectorType = errors.New("invalid selector type") 10 | 11 | ErrNilBreakerName = errors.New("breaker name nil") 12 | ErrInvalidFailureRateThreshold = errors.New("invalid FailureRateThreshold") 13 | ErrInvalidopenStatusDurationMs = errors.New("invalid OpenStatusDurationMs") 14 | ErrInvalidSqlTimeout = errors.New("invalid sql timeout") 15 | 16 | ErrInvalidScope = errors.New("invalid scope") 17 | ) 18 | -------------------------------------------------------------------------------- /conf/weirproxy_etcd.yaml: -------------------------------------------------------------------------------- 1 | version: "v1" 2 | proxy_server: 3 | addr: "0.0.0.0:6000" 4 | max_connections: 10 5 | admin_server: 6 | addr: "0.0.0.0:6001" 7 | enable_basic_auth: false 8 | user: "" 9 | password: "" 10 | log: 11 | level: "debug" 12 | format: "console" 13 | log_file: 14 | filename: "" 15 | max_size: 300 16 | max_days: 1 17 | max_backups: 1 18 | registry: 19 | enable: false 20 | config_center: 21 | type: "etcd" 22 | config_etcd: 23 | addrs: 24 | - "127.0.0.1:2379" 25 | base_path: "/weir/defaultcluster" 26 | username: "" 27 | password: "" 28 | strict_parse: false 29 | performance: 30 | tcp_keep_alive: true 31 | -------------------------------------------------------------------------------- /pkg/config/marshaller.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "github.com/goccy/go-yaml" 4 | 5 | func UnmarshalNamespaceConfig(data []byte) (*Namespace, error) { 6 | var cfg Namespace 7 | if err := yaml.Unmarshal(data, &cfg); err != nil { 8 | return nil, err 9 | } 10 | return &cfg, nil 11 | } 12 | 13 | func MarshalNamespaceConfig(cfg *Namespace) ([]byte, error) { 14 | return yaml.Marshal(cfg) 15 | } 16 | 17 | func UnmarshalProxyConfig(data []byte) (*Proxy, error) { 18 | var cfg Proxy 19 | if err := yaml.Unmarshal(data, &cfg); err != nil { 20 | return nil, err 21 | } 22 | return &cfg, nil 23 | } 24 | 25 | func MarshalProxyConfig(cfg *Proxy) ([]byte, error) { 26 | return yaml.Marshal(cfg) 27 | } 28 | -------------------------------------------------------------------------------- /docs/cn/config-dynamic-reload.md: -------------------------------------------------------------------------------- 1 | # 配置热加载 2 | 3 | Weir 作为多租户的 TiDB 数据库治理平台, 租户配置变更是一项比较频繁的操作. 如果每次增加, 修改, 删除租户配置都需要重新启动 Weir Proxy 才能使配置生效, 无疑会对用户体验和服务稳定性带来很大影响. 因此, 我们为 Weir 的 Namespace 配置提供了热加载支持. (注: 热加载要求配置中心使用 etcd ) 4 | 5 | 6 | 7 | Weir 支持通过[管理接口](docs/cn/RESTful_api.md)触发配置热加载, 需要手动触发过程, 其中包括准备 (Prepare) 和提交 (Commit) 两个阶段. Weir 维护了一个双指针队列, 两个指针分别指向当前和准备阶段的 Namespace 队列. 8 | - 在准备阶段, Weir Proxy 会从配置中心拉取 Namespace 的最新配置, 解析配置并初始化 Namespace 存储在准备阶段队列中. 9 | - 在提交阶段, Weir Proxy会执行一次原子切换操作, 当前队列和准备阶段队列指针调换, 使用新的 Namespace 处理客户端的请求, 同时将旧的 Namespace 延迟关闭. 10 | 11 | 整个热加载过程, Weir Proxy不会主动关闭客户端连接, 客户端是无感知的, 对于一些非核心配置的调整, 甚至不需要重建后端数据库连接池, 对提升客户端体验和保持Weir本身以及后端TiDB集群稳定性都有比较大的帮助. 12 | -------------------------------------------------------------------------------- /pkg/util/datastructure/dsutil_test.go: -------------------------------------------------------------------------------- 1 | package datastructure 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestStringSliceToSet(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | args []string 13 | want map[string]struct{} 14 | }{ 15 | {name: "nil", args: nil, want: map[string]struct{}{}}, 16 | {name: "one", args: []string{"db0"}, want: map[string]struct{}{"db0": {}}}, 17 | {name: "two", args: []string{"db0", "db1"}, want: map[string]struct{}{"db0": {}, "db1": {}}}, 18 | } 19 | for _, tt := range tests { 20 | t.Run(tt.name, func(t *testing.T) { 21 | ret := StringSliceToSet(tt.args) 22 | assert.Equal(t, ret, tt.want) 23 | }) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /docs/cn/multi-tenant.md: -------------------------------------------------------------------------------- 1 | # 多租户软隔离 2 | 3 | Weir作为TiDB分布式数据库的统一接入层和治理平台, 需要具备一定的多租户管理能力, 我们将Weir设计成一个共享多租户系统. 4 | 5 | ## 基本概念 6 | 7 | 在Weir系统中使用Namespace表示租户. 每个Weir集群可以管理多个Namespace, 每个Namespace有其独立管理的用户 (User) 和资源配置参数. Namespace中的资源主要是指TiDB集群的访问能力. 8 | 9 | 在Weir集群中, User名称是全局唯一的. 每个User只属于一个Namespace, 每个Namespace可以有多个User. 这样便于在客户端连接Weir Proxy时, 根据MySQL的Username找到其Namespace, 从而确定该User的一些访问权限, 如可访问Database, 访问IP等. 10 | 11 | ## 软隔离 12 | 13 | 多个Namespace关联的TiDB集群可以是相同也可以不同, 但Weir Proxy会为每个Namespace创建该TiDB集群的后端连接池, 在后端数据库连接层面将各个Namespace软隔离, 在一定程度上使各个Namespace对同一集群的访问互相不受影响. 14 | 15 | 16 | 17 | 各个Namespace可以独立动态加载, 可以动态调整Namespace配置的各项参数, Namespace重新加载时会初始化资源, 并将原有Namespace的资源回收. 18 | -------------------------------------------------------------------------------- /README-CN.md: -------------------------------------------------------------------------------- 1 | # Weir 2 | 3 | Weir是伴鱼公司研发的开源数据库代理平台, 主要为TiDB分布式数据库提供数据库治理功能. 4 | 5 | ## 功能特性 6 | 7 | > 1. Weir 为 MySQL 协议提供应用层代理,兼容 TiDB 4.0。[L7层负载](docs/cn/proxy.md) 8 | > 2. Weir 使用连接池进行后端连接管理,并支持负载均衡。[链接管理](docs/cn/connection-management.md) 9 | > 3. Weir 支持多租户管理,所有命名空间都可以在运行时动态重新加载。[多租户软隔离](docs/cn/multi-tenant.md) 10 | > 4. Weir 支持 qps 限流和熔断机制来保护客户端和 TiDB 服务器。[熔断限流机制](docs/cn/fault-tolerant.md) 11 | 12 | ## 使用手册 13 | 14 | - [快速上手](docs/cn/quickstart.md) 15 | - [Proxy配置详解](docs/cn/proxy-config.md) 16 | - [Namespace配置详解](docs/cn/namespace-config.md) 17 | - [集群部署](docs/cn/cluster_deployment.md) 18 | - [配置热加载](docs/cn/config-dynamic-reload.md) 19 | - [监控与告警](docs/cn/monitoring.md) 20 | - [RESTful-Api](docs/cn/RESTful_api.md) 21 | 22 | ## FAQ 23 | -------------------------------------------------------------------------------- /pkg/util/sync2/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package sync2 provides extra functionality along the same lines as sync. 18 | package sync2 19 | -------------------------------------------------------------------------------- /pkg/configcenter/factory.go: -------------------------------------------------------------------------------- 1 | package configcenter 2 | 3 | import ( 4 | "github.com/tidb-incubator/weir/pkg/config" 5 | "github.com/pingcap/errors" 6 | ) 7 | 8 | const ( 9 | ConfigCenterTypeFile = "file" 10 | ConfigCenterTypeEtcd = "etcd" 11 | ) 12 | 13 | type ConfigCenter interface { 14 | GetNamespace(ns string) (*config.Namespace, error) 15 | ListAllNamespace() ([]*config.Namespace, error) 16 | } 17 | 18 | func CreateConfigCenter(cfg config.ConfigCenter) (ConfigCenter, error) { 19 | switch cfg.Type { 20 | case ConfigCenterTypeFile: 21 | return CreateFileConfigCenter(cfg.ConfigFile.Path) 22 | case ConfigCenterTypeEtcd: 23 | return CreateEtcdConfigCenter(cfg.ConfigEtcd) 24 | default: 25 | return nil, errors.New("invalid config center type") 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /pkg/config/namespace_example.yaml: -------------------------------------------------------------------------------- 1 | version: "v1" 2 | namespace: "test_namespace" 3 | frontend: 4 | allowed_dbs: 5 | - "test_weir_db" 6 | slow_sql_time: 50 7 | denied_sqls: 8 | denied_ips: 9 | idle_timeout: 3600 10 | users: 11 | - username: "hello" 12 | password: "world" 13 | - username: "hello1" 14 | password: "world1" 15 | backend: 16 | instances: 17 | - "127.0.0.1:4000" 18 | username: "root" 19 | password: "" 20 | selector_type: "random" 21 | pool_size: 10 22 | idle_timeout: 60 23 | breaker: 24 | scope: "sql" 25 | strategies: 26 | - min_qps: 3 27 | failure_rate_threshold: 0 28 | failure_num: 5 29 | sql_timeoutMs: 2000 30 | open_status_duration_ms: 5000 31 | size: 10 32 | cell_interval_ms: 1000 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "🚀 Feature Request" 3 | about: I have a suggestion 4 | labels: type/enhancement 5 | --- 6 | 7 | ## Feature Request 8 | 9 | ### Describe your feature request related problem 10 | 11 | 12 | 13 | ### Describe the feature you'd like 14 | 15 | 16 | 17 | ### Describe alternatives you've considered 18 | 19 | 20 | 21 | ### Teachability, Documentation, Adoption, Migration Strategy 22 | 23 | -------------------------------------------------------------------------------- /pkg/util/rand2/rand_test.go: -------------------------------------------------------------------------------- 1 | package rand2 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewRand(t *testing.T) { 11 | src1 := rand.NewSource(1) 12 | stdRd := rand.New(src1) 13 | src2 := rand.NewSource(1) 14 | rd := New(rand.New(src2)) 15 | 16 | assert.Equal(t, stdRd.Int63(), rd.Int63()) 17 | assert.Equal(t, stdRd.Uint32(), rd.Uint32()) 18 | assert.Equal(t, stdRd.Uint64(), rd.Uint64()) 19 | assert.Equal(t, stdRd.Int31(), rd.Int31()) 20 | assert.Equal(t, stdRd.Int63n(100), rd.Int63n(100)) 21 | assert.Equal(t, stdRd.Int31n(100), rd.Int31n(100)) 22 | assert.Equal(t, stdRd.Intn(20), rd.Intn(20)) 23 | assert.Equal(t, stdRd.Float64(), rd.Float64()) 24 | assert.Equal(t, stdRd.Float32(), rd.Float32()) 25 | } 26 | -------------------------------------------------------------------------------- /pkg/util/passwd/passwd.go: -------------------------------------------------------------------------------- 1 | package passwd 2 | 3 | import "crypto/sha1" 4 | 5 | // calculatePassword calculate password hash 6 | func CalculatePassword(scramble, password []byte) []byte { 7 | if len(password) == 0 { 8 | return nil 9 | } 10 | 11 | // stage1Hash = SHA1(password) 12 | crypt := sha1.New() 13 | crypt.Write(password) 14 | stage1 := crypt.Sum(nil) 15 | 16 | // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) 17 | // inner Hash 18 | crypt.Reset() 19 | crypt.Write(stage1) 20 | hash := crypt.Sum(nil) 21 | 22 | // outer Hash 23 | crypt.Reset() 24 | crypt.Write(scramble) 25 | crypt.Write(hash) 26 | scramble = crypt.Sum(nil) 27 | 28 | // token = scrambleHash XOR stage1Hash 29 | for i := range scramble { 30 | scramble[i] ^= stage1[i] 31 | } 32 | return scramble 33 | } 34 | -------------------------------------------------------------------------------- /pkg/proxy/backend/client/tls.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | ) 7 | 8 | // NewClientTLSConfig: generate TLS config for client side 9 | // if insecureSkipVerify is set to true, serverName will not be validated 10 | func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool, serverName string) *tls.Config { 11 | pool := x509.NewCertPool() 12 | if !pool.AppendCertsFromPEM(caPem) { 13 | panic("failed to add ca PEM") 14 | } 15 | 16 | cert, err := tls.X509KeyPair(certPem, keyPem) 17 | if err != nil { 18 | panic(err) 19 | } 20 | 21 | config := &tls.Config{ 22 | Certificates: []tls.Certificate{cert}, 23 | RootCAs: pool, 24 | InsecureSkipVerify: insecureSkipVerify, 25 | ServerName: serverName, 26 | } 27 | return config 28 | } 29 | -------------------------------------------------------------------------------- /pkg/proxy/metrics/session.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | // Label constants. 4 | const ( 5 | LblUnretryable = "unretryable" 6 | LblReachMax = "reach_max" 7 | LblOK = "ok" 8 | LblError = "error" 9 | LblCommit = "commit" 10 | LblAbort = "abort" 11 | LblRollback = "rollback" 12 | LblType = "type" 13 | LblDb = "db" 14 | LblTable = "table" 15 | LblResult = "result" 16 | LblSQLType = "sql_type" 17 | LblGeneral = "general" 18 | LblInternal = "internal" 19 | LbTxnMode = "txn_mode" 20 | LblPessimistic = "pessimistic" 21 | LblOptimistic = "optimistic" 22 | LblStore = "store" 23 | LblAddress = "address" 24 | LblBatchGet = "batch_get" 25 | LblGet = "get" 26 | LblNamespace = "namespace" 27 | LblCluster = "cluster" 28 | 29 | LblBackendAddr = "backend_addr" 30 | ) 31 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/ratelimiter.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker/rate_limit" 8 | ) 9 | 10 | type NamespaceRateLimiter struct { 11 | limiters *sync.Map 12 | qpsThreshold int 13 | scope string 14 | } 15 | 16 | func NewNamespaceRateLimiter(scope string, qpsThreshold int) *NamespaceRateLimiter { 17 | return &NamespaceRateLimiter{ 18 | limiters: &sync.Map{}, 19 | scope: scope, 20 | qpsThreshold: qpsThreshold, 21 | } 22 | } 23 | 24 | func (n *NamespaceRateLimiter) Scope() string { 25 | return n.scope 26 | } 27 | 28 | func (n *NamespaceRateLimiter) Limit(ctx context.Context, key string) error { 29 | if n.qpsThreshold <= 0 { 30 | return nil 31 | } 32 | limiter, _ := n.limiters.LoadOrStore(key, rate_limit.NewSlidingWindowRateLimiter(int64(n.qpsThreshold))) 33 | return limiter.(*rate_limit.SlidingWindowRateLimiter).Limit() 34 | } 35 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestNamespaceRateLimiter_Limit(t *testing.T) { 12 | ctx := context.Background() 13 | key1 := "hello" 14 | key2 := "world" 15 | rateLimiter := NewNamespaceRateLimiter("namespace", 2) 16 | require.NoError(t, rateLimiter.Limit(ctx, key1)) 17 | require.NoError(t, rateLimiter.Limit(ctx, key1)) 18 | require.Error(t, rateLimiter.Limit(ctx, key1)) 19 | require.NoError(t, rateLimiter.Limit(ctx, key2)) 20 | time.Sleep(time.Second) 21 | require.NoError(t, rateLimiter.Limit(ctx, key1)) 22 | } 23 | 24 | func TestNamespaceRateLimiter_ZeroThreshold(t *testing.T) { 25 | ctx := context.Background() 26 | key1 := "hello" 27 | rateLimiter := NewNamespaceRateLimiter("namespace", 0) 28 | require.NoError(t, rateLimiter.Limit(ctx, key1)) 29 | require.NoError(t, rateLimiter.Limit(ctx, key1)) 30 | } 31 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/domain.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/tidb-incubator/weir/pkg/proxy/driver" 7 | ) 8 | 9 | type Namespace interface { 10 | Name() string 11 | Auth(username string, passwdBytes []byte, salt []byte) bool 12 | IsDatabaseAllowed(db string) bool 13 | ListDatabases() []string 14 | IsDeniedSQL(sqlFeature uint32) bool 15 | IsAllowedSQL(sqlFeature uint32) bool 16 | GetPooledConn(context.Context) (driver.PooledBackendConn, error) 17 | Close() 18 | GetBreaker() (driver.Breaker, error) 19 | GetRateLimiter() driver.RateLimiter 20 | } 21 | 22 | type Frontend interface { 23 | Auth(username string, passwdBytes []byte, salt []byte) bool 24 | IsDatabaseAllowed(db string) bool 25 | ListDatabases() []string 26 | IsDeniedSQL(sqlFeature uint32) bool 27 | IsAllowedSQL(sqlFeature uint32) bool 28 | } 29 | 30 | type Backend interface { 31 | Close() 32 | GetPooledConn(context.Context) (driver.PooledBackendConn, error) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/proxy/driver/mock_NamespaceManager.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.3.0. DO NOT EDIT. 2 | 3 | package driver 4 | 5 | import mock "github.com/stretchr/testify/mock" 6 | 7 | // MockNamespaceManager is an autogenerated mock type for the NamespaceManager type 8 | type MockNamespaceManager struct { 9 | mock.Mock 10 | } 11 | 12 | // Auth provides a mock function with given fields: username, pwd, salt 13 | func (_m *MockNamespaceManager) Auth(username string, pwd []byte, salt []byte) (Namespace, bool) { 14 | ret := _m.Called(username, pwd, salt) 15 | 16 | var r0 Namespace 17 | if rf, ok := ret.Get(0).(func(string, []byte, []byte) Namespace); ok { 18 | r0 = rf(username, pwd, salt) 19 | } else { 20 | if ret.Get(0) != nil { 21 | r0 = ret.Get(0).(Namespace) 22 | } 23 | } 24 | 25 | var r1 bool 26 | if rf, ok := ret.Get(1).(func(string, []byte, []byte) bool); ok { 27 | r1 = rf(username, pwd, salt) 28 | } else { 29 | r1 = ret.Get(1).(bool) 30 | } 31 | 32 | return r0, r1 33 | } 34 | -------------------------------------------------------------------------------- /conf/namespace/test_namespace.yaml: -------------------------------------------------------------------------------- 1 | version: "v1" 2 | namespace: "test_namespace" 3 | frontend: 4 | allowed_dbs: 5 | - "test_weir_db" 6 | slow_sql_time: 50 7 | sql_blacklist: 8 | - sql: "select * from tbl0" 9 | - sql: "select * from tbl1" 10 | sql_whitelist: 11 | - sql: "select * from tbl2" 12 | - sql: "select * from tbl3" 13 | denied_ips: 14 | idle_timeout: 3600 15 | users: 16 | - username: "hello" 17 | password: "world" 18 | - username: "hello1" 19 | password: "world1" 20 | backend: 21 | instances: 22 | - "127.0.0.1:4000" 23 | username: "root" 24 | password: "" 25 | selector_type: "random" 26 | pool_size: 10 27 | idle_timeout: 60 28 | breaker: 29 | scope: "sql" 30 | strategies: 31 | - min_qps: 3 32 | failure_rate_threshold: 0 33 | failure_num: 5 34 | sql_timeout_ms: 2000 35 | open_status_duration_ms: 5000 36 | size: 10 37 | cell_interval_ms: 1000 38 | rate_limiter: 39 | scope: "db" 40 | qps: 1000 41 | -------------------------------------------------------------------------------- /pkg/config/proxy_example.yaml: -------------------------------------------------------------------------------- 1 | version: "v1" 2 | proxy_server: 3 | addr: "0.0.0.0:6000" 4 | max_connections: 1 5 | admin_server: 6 | addr: "0.0.0.0:6001" 7 | enable_basic_auth: false 8 | user: "hello" 9 | password: "world" 10 | performance: 11 | tcp_keep_alive: true 12 | log: 13 | # Log level: debug, info, warn, error, fatal. 14 | level: "debug" 15 | # Log format, one of json, text, console. 16 | format: "console" 17 | # File logging. 18 | log_file: 19 | # Log file name. 20 | filename: "" 21 | # Max log file size in MB (upper limit to 4096MB). 22 | max_size: 300 23 | # Max log file keep days. No clean up by default. 24 | max_days: 1 25 | # Maximum number of old log files to retain. No clean up by default. 26 | max_backups: 1 27 | registry: 28 | enable: false 29 | type: "etcd" 30 | addrs: 31 | - "192.168.0.1:2379" 32 | - "192.168.0.2:2379" 33 | - "192.168.0.3:2379" 34 | config_center: 35 | type: "file" 36 | config_file: 37 | path: "./etc" -------------------------------------------------------------------------------- /pkg/proxy/driver/queryctx_exec_test.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pingcap/parser" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestFirstTableNameVisitor_TableName(t *testing.T) { 11 | type fields struct { 12 | sql string 13 | want string 14 | } 15 | tests := []fields{ 16 | {sql: "SELECT 1", want: ""}, 17 | {sql: "SELECT * FROM tbl1", want: "tbl1"}, 18 | {sql: "SELECT * FROM tbl1,tbl2", want: "tbl1"}, 19 | {sql: "SELECT * FROM tbl1 INNER JOIN tbl2 on tbl1.a = tbl2.a", want: "tbl1"}, 20 | {sql: "INSERT INTO tbl1 VALUES (1,2,3)", want: "tbl1"}, 21 | {sql: "DELETE FROM tbl1 WHERE id=1", want: "tbl1"}, 22 | {sql: "UPDATE tbl1 SET a=1 WHERE id=1", want: "tbl1"}, 23 | } 24 | for _, tt := range tests { 25 | t.Run(tt.sql, func(t *testing.T) { 26 | stmt, err := parser.New().ParseOneStmt(tt.sql, "", "") 27 | f := &FirstTableNameVisitor{} 28 | stmt.Accept(f) 29 | assert.NoError(t, err) 30 | assert.Equal(t, tt.want, f.TableName()) 31 | }) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /tests/proxy/backend/connpool_test.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/tidb-incubator/weir/pkg/proxy/backend" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestConnPool_ErrorClose_Success(t *testing.T) { 13 | cfg := backend.ConnPoolConfig{ 14 | Config: backend.Config{ 15 | Addr:"127.0.0.1:3306", 16 | UserName:"root", 17 | Password:"123456", 18 | }, 19 | Capacity:1, // pool size is set to 1 20 | IdleTimeout:0, 21 | } 22 | pool := backend.NewConnPool("test", &cfg) 23 | err := pool.Init() 24 | require.NoError(t, err) 25 | 26 | ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second) 27 | defer cancelFunc() 28 | 29 | conn1, err := pool.GetConn(ctx) 30 | require.NoError(t, err) 31 | 32 | // conn is closed, and another conn is created by pool 33 | err = conn1.ErrorClose() 34 | require.NoError(t, err) 35 | 36 | conn2, err := pool.GetConn(ctx) 37 | require.NoError(t, err) 38 | 39 | err = conn2.ErrorClose() 40 | require.NoError(t, err) 41 | } 42 | 43 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECTNAME = $(shell basename "$(PWD)") 2 | TOOL_BIN_PATH := $(shell pwd)/.tools/bin 3 | GOBASE = $(shell pwd) 4 | BUILD_TAGS ?= 5 | LDFLAGS ?= 6 | export GOBIN := $(TOOL_BIN_PATH) 7 | export PATH := $(TOOL_BIN_PATH):$(PATH) 8 | 9 | default: weirproxy 10 | 11 | weirproxy: 12 | ifeq ("$(WITH_RACE)", "1") 13 | go build -race -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags '${BUILD_TAGS}' -o bin/weirproxy cmd/weirproxy/main.go 14 | else 15 | go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags '${BUILD_TAGS}' -o bin/weirproxy cmd/weirproxy/main.go 16 | endif 17 | 18 | go-test: 19 | go test -coverprofile=.coverage.out ./... 20 | go tool cover -func=.coverage.out -o .coverage.func 21 | tail -1 .coverage.func 22 | go tool cover -html=.coverage.out -o .coverage.html 23 | 24 | go-lint-check: install-tools 25 | golangci-lint run 26 | 27 | go-lint-fix: install-tools 28 | golangci-lint run --fix 29 | 30 | install-tools: 31 | @mkdir -p $(TOOL_BIN_PATH) 32 | @test -e $(TOOL_BIN_PATH)/golangci-lint >/dev/null 2>&1 || curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(TOOL_BIN_PATH) v1.30.0 33 | -------------------------------------------------------------------------------- /pkg/proxy/driver/mock_Stmt.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.3.0. DO NOT EDIT. 2 | 3 | package driver 4 | 5 | import mock "github.com/stretchr/testify/mock" 6 | 7 | // MockStmt is an autogenerated mock type for the Stmt type 8 | type MockStmt struct { 9 | mock.Mock 10 | } 11 | 12 | // ColumnNum provides a mock function with given fields: 13 | func (_m *MockStmt) ColumnNum() int { 14 | ret := _m.Called() 15 | 16 | var r0 int 17 | if rf, ok := ret.Get(0).(func() int); ok { 18 | r0 = rf() 19 | } else { 20 | r0 = ret.Get(0).(int) 21 | } 22 | 23 | return r0 24 | } 25 | 26 | // ID provides a mock function with given fields: 27 | func (_m *MockStmt) ID() int { 28 | ret := _m.Called() 29 | 30 | var r0 int 31 | if rf, ok := ret.Get(0).(func() int); ok { 32 | r0 = rf() 33 | } else { 34 | r0 = ret.Get(0).(int) 35 | } 36 | 37 | return r0 38 | } 39 | 40 | // ParamNum provides a mock function with given fields: 41 | func (_m *MockStmt) ParamNum() int { 42 | ret := _m.Called() 43 | 44 | var r0 int 45 | if rf, ok := ret.Get(0).(func() int); ok { 46 | r0 = rf() 47 | } else { 48 | r0 = ret.Get(0).(int) 49 | } 50 | 51 | return r0 52 | } 53 | -------------------------------------------------------------------------------- /pkg/proxy/driver/queryctx_metrics.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/tidb-incubator/weir/pkg/proxy/metrics" 7 | wast "github.com/tidb-incubator/weir/pkg/util/ast" 8 | "github.com/pingcap/parser/ast" 9 | ) 10 | 11 | func (q *QueryCtxImpl) recordQueryMetrics(ctx context.Context, stmt ast.StmtNode, err error, durationMilliSecond float64) { 12 | ns := q.ns.Name() 13 | db := q.currentDB 14 | firstTableName, _ := wast.GetAstTableNameFromCtx(ctx) 15 | stmtType := metrics.GetStmtTypeName(stmt) 16 | retLabel := metrics.RetLabel(err) 17 | 18 | metrics.QueryCtxQueryCounter.WithLabelValues(ns, db, firstTableName, stmtType, retLabel).Inc() 19 | metrics.QueryCtxQueryDurationHistogram.WithLabelValues(ns, db, firstTableName, stmtType).Observe(durationMilliSecond) 20 | } 21 | 22 | func (q *QueryCtxImpl) recordDeniedQueryMetrics(ctx context.Context, stmt ast.StmtNode) { 23 | ns := q.ns.Name() 24 | db := q.currentDB 25 | firstTableName, _ := wast.GetAstTableNameFromCtx(ctx) 26 | stmtType := metrics.GetStmtTypeName(stmt) 27 | 28 | metrics.QueryCtxQueryDeniedCounter.WithLabelValues(ns, db, firstTableName, stmtType).Inc() 29 | } 30 | -------------------------------------------------------------------------------- /pkg/util/sync2/toggle.go: -------------------------------------------------------------------------------- 1 | package sync2 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | var ( 9 | ErrToggleNotPrepared = errors.New("not prepared") 10 | ) 11 | 12 | type Toggle struct { 13 | data [2]interface{} 14 | idx int32 15 | prepared bool 16 | lock sync.RWMutex 17 | } 18 | 19 | func NewToggle(o interface{}) *Toggle { 20 | return &Toggle{ 21 | data: [2]interface{}{o}, 22 | } 23 | } 24 | 25 | func (t *Toggle) Current() interface{} { 26 | t.lock.RLock() 27 | ret := t.data[t.idx] 28 | t.lock.RUnlock() 29 | return ret 30 | } 31 | 32 | func (t *Toggle) SwapOther(o interface{}) interface{} { 33 | t.lock.Lock() 34 | defer t.lock.Unlock() 35 | 36 | tidx := toggleIdx(t.idx) 37 | origin := t.data[tidx] 38 | t.data[tidx] = o 39 | t.prepared = true 40 | return origin 41 | } 42 | 43 | func (t *Toggle) Toggle() error { 44 | t.lock.Lock() 45 | defer t.lock.Unlock() 46 | 47 | currIdx := t.idx 48 | if !t.prepared { 49 | return ErrToggleNotPrepared 50 | } 51 | 52 | t.idx = toggleIdx(currIdx) 53 | t.prepared = false 54 | return nil 55 | } 56 | 57 | func toggleIdx(idx int32) int32 { 58 | return (idx + 1) % 2 59 | } 60 | -------------------------------------------------------------------------------- /pkg/proxy/server/buffered_read_conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | "bufio" 18 | "net" 19 | ) 20 | 21 | const defaultReaderSize = 16 * 1024 22 | 23 | // bufferedReadConn is a net.Conn compatible structure that reads from bufio.Reader. 24 | type bufferedReadConn struct { 25 | net.Conn 26 | rb *bufio.Reader 27 | } 28 | 29 | func (conn bufferedReadConn) Read(b []byte) (n int, err error) { 30 | return conn.rb.Read(b) 31 | } 32 | 33 | func newBufferedReadConn(conn net.Conn) *bufferedReadConn { 34 | return &bufferedReadConn{ 35 | Conn: conn, 36 | rb: bufio.NewReaderSize(conn, defaultReaderSize), 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /pkg/proxy/metrics/backend.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import "github.com/prometheus/client_golang/prometheus" 4 | 5 | const ( 6 | BackendEventIniting = "initing" 7 | BackendEventInited = "inited" 8 | BackendEventClosing = "closing" 9 | BackendEventClosed = "closed" 10 | ) 11 | 12 | var ( 13 | BackendEventCounter = prometheus.NewCounterVec( 14 | prometheus.CounterOpts{ 15 | Namespace: ModuleWeirProxy, 16 | Subsystem: LabelBackend, 17 | Name: "backend_event_total", 18 | Help: "Counter of backend event.", 19 | }, []string{LblCluster, LblNamespace, LblType}) 20 | 21 | BackendQueryCounter = prometheus.NewCounterVec( 22 | prometheus.CounterOpts{ 23 | Namespace: ModuleWeirProxy, 24 | Subsystem: LabelBackend, 25 | Name: "b_conn_cnt", 26 | Help: "Counter of backend query count.", 27 | }, []string{LblCluster, LblNamespace, LblBackendAddr}) 28 | 29 | BackendConnInUseGauge = prometheus.NewGaugeVec( 30 | prometheus.GaugeOpts{ 31 | Namespace: ModuleWeirProxy, 32 | Subsystem: LabelBackend, 33 | Name: "b_conn_in_use", 34 | Help: "Number of backend conn in use.", 35 | }, []string{LblCluster, LblNamespace, LblBackendAddr}) 36 | ) 37 | -------------------------------------------------------------------------------- /pkg/util/rate_limit_breaker/rate_limit/leaky_bucket_test.go: -------------------------------------------------------------------------------- 1 | package rate_limit 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestLeakyBucketRateLimiter_Wait(t *testing.T) { 11 | // not really a test 12 | t.Skip() 13 | start := time.Now() 14 | qpsThreshold := int64(20000) 15 | rateLimiter := NewLeakyBucketRateLimiter(qpsThreshold) 16 | defer rateLimiter.Close() 17 | go func() { 18 | for i := 10; i < 0; i++ { 19 | rateLimiter.ChangeQpsThreshold(qpsThreshold) 20 | } 21 | }() 22 | 23 | wg := &sync.WaitGroup{} 24 | for i := 0; i < 1; i++ { 25 | processorName := fmt.Sprintf("#%d", i) 26 | wg.Add(1) 27 | go func() { 28 | processorLeakyBucketQueue(wg, processorName, rateLimiter, 10000) 29 | }() 30 | } 31 | wg.Wait() 32 | 33 | dur := time.Now().Sub(start) 34 | fmt.Printf("duration: %s\n", dur) 35 | } 36 | 37 | func processorLeakyBucketQueue(wg *sync.WaitGroup, processorName string, rateLimiter *LeakyBucketRateLimiter, iterates int) { 38 | for i := 0; i < iterates; i++ { 39 | rateLimiter.Limit() 40 | // fmt.Printf("processor=%s, time: %s. task_id: %d\n", processorName, time.Now().Format("15:04:05"), i) 41 | } 42 | wg.Done() 43 | } 44 | -------------------------------------------------------------------------------- /pkg/util/errors/errors_test.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | stderrors "errors" 5 | "testing" 6 | 7 | "github.com/pingcap/errors" 8 | "github.com/siddontang/go-mysql/mysql" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestIs(t *testing.T) { 13 | badConn := mysql.ErrBadConn 14 | err := errors.Wrapf(badConn, "same error type") 15 | assert.True(t, Is(err, badConn)) 16 | } 17 | 18 | func TestStdIs(t *testing.T) { 19 | badConn := mysql.ErrBadConn 20 | err := errors.Wrapf(badConn, "another error type") 21 | assert.False(t, stderrors.Is(err, badConn)) 22 | } 23 | 24 | func TestCheckAndGetMyError_True(t *testing.T) { 25 | myErr := mysql.NewError(1105, "unknown") 26 | err, is := CheckAndGetMyError(myErr) 27 | assert.True(t, is) 28 | assert.NotNil(t, err) 29 | } 30 | 31 | func TestCheckAndGetMyError_False(t *testing.T) { 32 | myErr := errors.New("not a myError") 33 | err, is := CheckAndGetMyError(myErr) 34 | assert.False(t, is) 35 | assert.Nil(t, err) 36 | } 37 | 38 | func TestCheckAndGetMyError_Cause_True(t *testing.T) { 39 | myErr := errors.Wrapf(mysql.NewError(1105, "unknown"), "wrap error") 40 | err, is := CheckAndGetMyError(myErr) 41 | assert.True(t, is) 42 | assert.NotNil(t, err) 43 | } 44 | -------------------------------------------------------------------------------- /docs/cn/connection-management.md: -------------------------------------------------------------------------------- 1 | # 连接管理 2 | 3 | 作为 TiDB 的应用层代理中间件, Weir Proxy 的一项重要职责就是管理客户端的连接与后端 TiDB Server 集群的连接. Weir Proxy 使用连接池机制, 客户端连接可以复用数据库连接执行SQL请求. 4 | 5 | ## 后端连接池 6 | 7 | Weir Proxy 会为每个租户(namespace)的 TiDB Server 集群中的每个实例创建一个连接池, 连接池的配置参数是租户级别的, 其最大连接数是确定的. 客户端在向 Weir Proxy 发起 SQL 查询请求时, Weir Proxy 会首先根据集群负载均衡策略选出一个后端实例, 从该后端实例的连接池中取出一个连接, 根据客户端连接当前状态初始化连接后, 执行 SQL 查询语句, 查询完成后再放回连接池中. 8 | 9 | 10 | 11 | ### 连接绑定 12 | 13 | 对于大多数普通SQL查询, 都可以使用连接池完成. 但是对于一些依赖后端连接状态的查询场景, 连接池就不适用了. 例如: 事务, Prepare查询. 14 | 15 | 面对这些场景, Weir Proxy的做法是: 在状态改变时从后端连接池取出一个连接, 并"绑定"到当前客户端连接上, 期间客户端连接的所有查询请求全部使用这个绑定连接, 直到状态恢复时, 再将连接放回连接池. 16 | 17 | 以下命令可能会触发后端连接绑定 (是否真正绑定与当前状态有关): 18 | 19 | - BEGIN 20 | - SET AUTOCOMMIT = 0 21 | - Binary Prepare (COM_STMT_PREPARE命令) 22 | 23 | 以下命令可能会触发后端连接解绑 (是否真正解绑与当前状态有关): 24 | 25 | - COMMIT / ROLLBACK 26 | - SET AUTOCOMMIT = 1 27 | - Binary Close (COM_STMT_CLOSE命令) 28 | 29 | ## 连接状态传递 30 | 31 | 客户端连接在执行某些SQL语句时会改变自身状态, 这些状态会影响SQL语句的执行, 例如: 切换Database, 设置系统变量等. 32 | 33 | 当客户端连接执行这些语句时, Weir Proxy会记录这些连接状态, 而在执行真正的查询请求时, Weir Proxy会对后端连接 (连接池连接或绑定连接) 执行一次初始化操作, 将后端连接的状态与客户端连接状态同步, 然后再执行查询请求. 34 | 35 | 目前支持传递给后端连接的状态: 36 | 37 | - USE DB 38 | - 设置Session级别系统变量 (TODO) 39 | -------------------------------------------------------------------------------- /docs/cn/RESTful_api.md: -------------------------------------------------------------------------------- 1 | # API接口 2 | 3 | ## 移除 namespace 4 | 5 | #### Request 6 | - Method: **POST** 7 | - URL: ```/admin/namespace/remove/:namespace``` 8 | - Headers: 9 | 10 | #### Response 11 | - Body 12 | ``` 13 | { 14 | "code":200, 15 | "msg":"success" 16 | } 17 | ``` 18 | 19 | #### 错误码 20 | 21 | | 错误码 | 信息 | 22 | | --- | --- | 23 | | 400 | bad namespace parameter | 24 | | 200 | success | 25 | 26 | 27 | ## 准备重新加载 namespace 28 | 29 | #### Request 30 | - Method: **POST** 31 | - URL: ```/admin/namespace/reload/prepare/:namespace``` 32 | 33 | #### Response 34 | - Body 35 | ``` 36 | { 37 | "code":200, 38 | "msg":"success" 39 | } 40 | ``` 41 | 42 | #### 错误码 43 | 44 | | 错误码 | 信息 | 45 | | --- | --- | 46 | | 400 | bad namespace parameter | 47 | | 500 | get namespace value from configcenter error | 48 | | 500 | prepare reload namespace error | 49 | | 200 | success | 50 | 51 | 52 | ## 提交重新加载 namespace 53 | 54 | #### Request 55 | - Method: **POST** 56 | - URL: ```/admin/namespace/reload/commit/:namespace``` 57 | 58 | #### Response 59 | - Body 60 | ``` 61 | { 62 | "code":200, 63 | "msg":"success" 64 | } 65 | ``` 66 | 67 | #### 错误码 68 | 69 | | 错误码 | 信息 | 70 | | --- | --- | 71 | | 400 | bad namespace parameter | 72 | | 500 | commit reload namespace error | 73 | | 200 | success | -------------------------------------------------------------------------------- /pkg/proxy/namespace/frontend.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/tidb-incubator/weir/pkg/util/passwd" 7 | ) 8 | 9 | type SQLInfo struct { 10 | SQL string 11 | } 12 | 13 | type FrontendNamespace struct { 14 | allowedDBs []string 15 | allowedDBSet map[string]struct{} 16 | userPasswd map[string]string 17 | sqlBlacklist map[uint32]SQLInfo 18 | sqlWhitelist map[uint32]SQLInfo 19 | } 20 | 21 | func (n *FrontendNamespace) Auth(username string, passwdBytes []byte, salt []byte) bool { 22 | userPasswd, ok := n.userPasswd[username] 23 | if !ok { 24 | return false 25 | } 26 | userPasswdBytes := passwd.CalculatePassword(salt, []byte(userPasswd)) 27 | return bytes.Equal(userPasswdBytes, passwdBytes) 28 | } 29 | 30 | func (n *FrontendNamespace) IsDatabaseAllowed(db string) bool { 31 | _, ok := n.allowedDBSet[db] 32 | return ok 33 | } 34 | 35 | func (n *FrontendNamespace) ListDatabases() []string { 36 | ret := make([]string, len(n.allowedDBs)) 37 | copy(ret, n.allowedDBs) 38 | return ret 39 | } 40 | 41 | func (n *FrontendNamespace) IsDeniedSQL(sqlFeature uint32) bool { 42 | _, ok := n.sqlBlacklist[sqlFeature] 43 | return ok 44 | } 45 | 46 | func (n *FrontendNamespace) IsAllowedSQL(sqlFeature uint32) bool { 47 | _, ok := n.sqlWhitelist[sqlFeature] 48 | return ok 49 | } 50 | -------------------------------------------------------------------------------- /pkg/util/errors/errors.go: -------------------------------------------------------------------------------- 1 | package errors 2 | 3 | import ( 4 | "reflect" 5 | 6 | gomysql "github.com/siddontang/go-mysql/mysql" 7 | ) 8 | 9 | // copied from errors.Is(), but replace Unwrap() with Cause() 10 | func Is(err, target error) bool { 11 | if target == nil { 12 | return err == target 13 | } 14 | 15 | isComparable := reflect.TypeOf(target).Comparable() 16 | for { 17 | if isComparable && err == target { 18 | return true 19 | } 20 | if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) { 21 | return true 22 | } 23 | // TODO: consider supporing target.Is(err). This would allow 24 | // user-definable predicates, but also may allow for coping with sloppy 25 | // APIs, thereby making it easier to get away with them. 26 | if err = Cause(err); err == nil { 27 | return false 28 | } 29 | } 30 | } 31 | 32 | func CheckAndGetMyError(err error) (*gomysql.MyError, bool) { 33 | if err == nil { 34 | return nil, false 35 | } 36 | 37 | for { 38 | if err1, ok := err.(*gomysql.MyError); ok { 39 | return err1, true 40 | } 41 | if err = Cause(err); err == nil { 42 | return nil, false 43 | } 44 | } 45 | } 46 | 47 | func Cause(err error) error { 48 | u, ok := err.(interface { 49 | Cause() error 50 | }) 51 | if !ok { 52 | return nil 53 | } 54 | return u.Cause() 55 | } 56 | -------------------------------------------------------------------------------- /pkg/proxy/server/tokenlimiter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | // Token is used as a permission to keep on running. 17 | type Token struct { 18 | } 19 | 20 | // TokenLimiter is used to limit the number of concurrent tasks. 21 | type TokenLimiter struct { 22 | count uint 23 | ch chan *Token 24 | } 25 | 26 | // Put releases the token. 27 | func (tl *TokenLimiter) Put(tk *Token) { 28 | tl.ch <- tk 29 | } 30 | 31 | // Get obtains a token. 32 | func (tl *TokenLimiter) Get() *Token { 33 | return <-tl.ch 34 | } 35 | 36 | // NewTokenLimiter creates a TokenLimiter with count tokens. 37 | func NewTokenLimiter(count uint) *TokenLimiter { 38 | tl := &TokenLimiter{count: count, ch: make(chan *Token, count)} 39 | for i := uint(0); i < count; i++ { 40 | tl.ch <- &Token{} 41 | } 42 | 43 | return tl 44 | } 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weir 2 | 3 | Weir is a database proxy middleware platform, mainly providing traffic management for TiDB. 4 | 5 | Weir is maintained by [伴鱼](https://www.ipalfish.com/) and [PingCAP](https://pingcap.com/). 6 | 7 | [中文文档](README-CN.md) 8 | 9 | ## Features 10 | 11 | - __L7 Proxy__ 12 | 13 | Weir provides application layer proxy for MySQL Protocol, and it is compatible with TiDB 4.0. 14 | 15 | - __Connection Management__ 16 | 17 | Weir uses connection pool for backend connection management, and supports load balancing. 18 | 19 | - __Multi-tenant Management__ 20 | 21 | Weir supports multi-tenant management. All the namespaces can be dynamic reloaded in runtime. 22 | 23 | - __Fault Tolerance__ 24 | 25 | Weir supports rate limiting and circuit breaking to protect both clients and TiDB servers. 26 | 27 | ## Architecture 28 | 29 | There are three core components in Weir platform: proxy, controller and UI dashboard. 30 | 31 | 32 | 33 | ## Roadmap 34 | 35 | - Web Application Firewall (WAF) for SQL 36 | - Database Mesh for TiDB 37 | - SQL audit 38 | 39 | ## Code of Conduct 40 | 41 | This project is for everyone. We ask that our users and contributors take a few minutes to review our [Code of Conduct](code-of-conduct.md). 42 | 43 | ## License 44 | 45 | Weir is under the Apache 2.0 license. See the [LICENSE](./LICENSE) file for details. 46 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tidb-incubator/weir 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/gin-gonic/gin v1.7.2 7 | github.com/go-playground/validator/v10 v10.8.0 // indirect 8 | github.com/goccy/go-yaml v1.8.2 9 | github.com/golang/protobuf v1.5.2 // indirect 10 | github.com/json-iterator/go v1.1.11 // indirect 11 | github.com/mattn/go-isatty v0.0.13 // indirect 12 | github.com/opentracing/opentracing-go v1.1.0 13 | github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712 14 | github.com/pingcap/errors v0.11.5-0.20190809092503-95897b64e011 15 | github.com/pingcap/failpoint v0.0.0-20200702092429-9f69995143ce 16 | github.com/pingcap/parser v0.0.0-20200803072748-fdf66528323d 17 | github.com/pingcap/tidb v1.1.0-beta.0.20200826081922-9c1c21270001 18 | github.com/prometheus/client_golang v1.5.1 19 | github.com/shirou/gopsutil v3.21.6+incompatible // indirect 20 | github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 21 | github.com/siddontang/go-mysql v1.1.0 22 | github.com/stretchr/testify v1.6.1 23 | github.com/tklauser/go-sysconf v0.3.7 // indirect 24 | github.com/ugorji/go v1.2.6 // indirect 25 | go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738 26 | go.uber.org/zap v1.15.0 27 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 28 | google.golang.org/protobuf v1.27.1 // indirect 29 | gopkg.in/yaml.v2 v2.4.0 // indirect 30 | ) 31 | 32 | replace github.com/siddontang/go-mysql => github.com/ibanyu/go-mysql v1.1.0 33 | -------------------------------------------------------------------------------- /pkg/util/sync2/toggle_test.go: -------------------------------------------------------------------------------- 1 | package sync2 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_Toggle_Current(t *testing.T) { 11 | currVal := 1 12 | toggle := NewToggle(currVal) 13 | assert.Equal(t, currVal, toggle.Current()) 14 | } 15 | 16 | func Test_Toggle_SwapOther(t *testing.T) { 17 | currVal := 1 18 | toggle := NewToggle(currVal) 19 | swap1 := 2 20 | ret := toggle.SwapOther(swap1) 21 | assert.Nil(t, ret) 22 | assert.Equal(t, currVal, toggle.Current()) 23 | 24 | swap2 := 3 25 | ret = toggle.SwapOther(swap2) 26 | assert.Equal(t, swap1, ret) 27 | assert.Equal(t, currVal, toggle.Current()) 28 | } 29 | 30 | func Test_Toggle_Toggle_Success(t *testing.T) { 31 | currVal := 1 32 | toggle := NewToggle(currVal) 33 | 34 | swap := 2 35 | _ = toggle.SwapOther(swap) 36 | err := toggle.Toggle() 37 | assert.NoError(t, err) 38 | assert.Equal(t, swap, toggle.Current()) 39 | } 40 | 41 | func Test_Toggle_Toggle_Error_NotPrepared(t *testing.T) { 42 | currVal := 1 43 | toggle := NewToggle(currVal) 44 | 45 | err := toggle.Toggle() 46 | assert.EqualError(t, err, ErrToggleNotPrepared.Error()) 47 | } 48 | func BenchmarkToggle(b *testing.B) { 49 | toggle := NewToggle(1) 50 | go func() { 51 | for { 52 | toggle.SwapOther(2) 53 | time.Sleep(1 * time.Millisecond) 54 | _ = toggle.Toggle() 55 | time.Sleep(1 * time.Millisecond) 56 | } 57 | }() 58 | for i := 0; i < b.N; i++ { 59 | toggle.Current() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /pkg/proxy/server/server_util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | "sync" 18 | 19 | "github.com/pingcap/tidb/sessionctx/variable" 20 | "github.com/pingcap/tidb/util/logutil" 21 | "github.com/pingcap/tidb/util/timeutil" 22 | "go.uber.org/zap" 23 | ) 24 | 25 | // setSysTimeZoneOnce is used for parallel run tests. When several servers are running, 26 | // only the first will actually do setSystemTimeZoneVariable, thus we can avoid data race. 27 | var setSysTimeZoneOnce = &sync.Once{} 28 | 29 | func setSystemTimeZoneVariable() { 30 | setSysTimeZoneOnce.Do(func() { 31 | tz, err := timeutil.GetSystemTZ() 32 | if err != nil { 33 | logutil.BgLogger().Error( 34 | "Error getting SystemTZ, use default value instead", 35 | zap.Error(err), 36 | zap.String("default system_time_zone", variable.SysVars["system_time_zone"].Value)) 37 | return 38 | } 39 | variable.SysVars["system_time_zone"].Value = tz 40 | }) 41 | } 42 | -------------------------------------------------------------------------------- /pkg/proxy/backend/client/req.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/siddontang/go-mysql/utils" 5 | ) 6 | 7 | func (c *Conn) writeCommand(command byte) error { 8 | c.ResetSequence() 9 | 10 | return c.WritePacket([]byte{ 11 | 0x01, //1 bytes long 12 | 0x00, 13 | 0x00, 14 | 0x00, //sequence 15 | command, 16 | }) 17 | } 18 | 19 | func (c *Conn) writeCommandBuf(command byte, arg []byte) error { 20 | c.ResetSequence() 21 | 22 | length := len(arg) + 1 23 | data := utils.ByteSliceGet(length + 4) 24 | data[4] = command 25 | 26 | copy(data[5:], arg) 27 | 28 | err := c.WritePacket(data) 29 | 30 | utils.ByteSlicePut(data) 31 | 32 | return err 33 | } 34 | 35 | func (c *Conn) writeCommandStr(command byte, arg string) error { 36 | return c.writeCommandBuf(command, utils.StringToByteSlice(arg)) 37 | } 38 | 39 | func (c *Conn) writeCommandUint32(command byte, arg uint32) error { 40 | c.ResetSequence() 41 | 42 | return c.WritePacket([]byte{ 43 | 0x05, //5 bytes long 44 | 0x00, 45 | 0x00, 46 | 0x00, //sequence 47 | 48 | command, 49 | 50 | byte(arg), 51 | byte(arg >> 8), 52 | byte(arg >> 16), 53 | byte(arg >> 24), 54 | }) 55 | } 56 | 57 | func (c *Conn) writeCommandStrStr(command byte, arg1 string, arg2 string) error { 58 | c.ResetSequence() 59 | 60 | data := make([]byte, 4, 6+len(arg1)+len(arg2)) 61 | 62 | data = append(data, command) 63 | data = append(data, arg1...) 64 | data = append(data, 0) 65 | data = append(data, arg2...) 66 | 67 | return c.WritePacket(data) 68 | } 69 | -------------------------------------------------------------------------------- /pkg/util/rand2/rand.go: -------------------------------------------------------------------------------- 1 | package rand2 2 | 3 | import ( 4 | "math/rand" 5 | "sync" 6 | ) 7 | 8 | type Rand struct { 9 | sync.Mutex 10 | stdRand *rand.Rand 11 | } 12 | 13 | func New(src rand.Source) *Rand { 14 | return &Rand{ 15 | stdRand: rand.New(src), 16 | } 17 | } 18 | 19 | func (r *Rand) Int63() int64 { 20 | r.Lock() 21 | ret := r.stdRand.Int63() 22 | r.Unlock() 23 | return ret 24 | } 25 | 26 | func (r *Rand) Uint32() uint32 { 27 | r.Lock() 28 | ret := r.stdRand.Uint32() 29 | r.Unlock() 30 | return ret 31 | 32 | } 33 | 34 | func (r *Rand) Uint64() uint64 { 35 | r.Lock() 36 | ret := r.stdRand.Uint64() 37 | r.Unlock() 38 | return ret 39 | } 40 | 41 | func (r *Rand) Int31() int32 { 42 | r.Lock() 43 | ret := r.stdRand.Int31() 44 | r.Unlock() 45 | return ret 46 | } 47 | 48 | func (r *Rand) Int() int { 49 | r.Lock() 50 | ret := r.stdRand.Int() 51 | r.Unlock() 52 | return ret 53 | } 54 | 55 | func (r *Rand) Int63n(n int64) int64 { 56 | r.Lock() 57 | ret := r.stdRand.Int63n(n) 58 | r.Unlock() 59 | return ret 60 | } 61 | 62 | func (r *Rand) Int31n(n int32) int32 { 63 | r.Lock() 64 | ret := r.stdRand.Int31n(n) 65 | r.Unlock() 66 | return ret 67 | } 68 | 69 | func (r *Rand) Intn(n int) int { 70 | r.Lock() 71 | ret := r.stdRand.Intn(n) 72 | r.Unlock() 73 | return ret 74 | } 75 | 76 | func (r *Rand) Float64() float64 { 77 | r.Lock() 78 | ret := r.stdRand.Float64() 79 | r.Unlock() 80 | return ret 81 | } 82 | 83 | func (r *Rand) Float32() float32 { 84 | r.Lock() 85 | ret := r.stdRand.Float32() 86 | r.Unlock() 87 | return ret 88 | } 89 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 7 | 8 | ### What problem does this PR solve? 9 | 10 | 11 | 12 | ### What is changed and how it works? 13 | 14 | ### Check List 15 | 16 | 17 | 18 | Tests 19 | 20 | 21 | 22 | - Unit test 23 | - Integration test 24 | - Manual test (add detailed scripts or steps below) 25 | - No code 26 | 27 | Code changes 28 | 29 | - Has configuration change 30 | - Has HTTP API interfaces change (Don't forget to [add the declarative for API](https://github.com/tikv/pd/blob/master/docs/development.md#updating-api-documentation)) 31 | - Has persistent data change 32 | 33 | Side effects 34 | 35 | - Possible performance regression 36 | - Increased code complexity 37 | - Breaking backward compatibility 38 | 39 | Related changes 40 | 41 | - PR to update [`pingcap/docs`](https://github.com/pingcap/docs)/[`pingcap/docs-cn`](https://github.com/pingcap/docs-cn): 42 | - PR to update [`pingcap/tidb-ansible`](https://github.com/pingcap/tidb-ansible): 43 | - Need to cherry-pick to the release branch 44 | 45 | ### Release note 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /pkg/util/rate_limit_breaker/rate_limit/sliding_window.go: -------------------------------------------------------------------------------- 1 | package rate_limit 2 | 3 | import ( 4 | "errors" 5 | . "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker" 6 | "sync" 7 | "sync/atomic" 8 | ) 9 | 10 | var ErrRateLimited error = errors.New("rate limited") 11 | 12 | // 基于滑动窗口的,并发安全的限流器。 13 | type SlidingWindowRateLimiter struct { 14 | sw *SlidingWindow // guarded by mu 15 | mu *sync.Mutex // guard sw 16 | qpsThreshold int64 // read/write through atomic operation 17 | } 18 | 19 | func NewSlidingWindowRateLimiter(qpsThreshold int64) *SlidingWindowRateLimiter { 20 | swrl := &SlidingWindowRateLimiter{ 21 | // 滑动窗口覆盖 1s 时间,划分为 10 个 cell,每个 cell 时长为 100ms。 22 | sw: NewSlidingWindow(10, 100), 23 | mu: &sync.Mutex{}, 24 | qpsThreshold: qpsThreshold, 25 | } 26 | return swrl 27 | } 28 | 29 | // 如果被限流,则返回 ErrRateLimited;未被限流,则返回 nil 30 | func (swrl *SlidingWindowRateLimiter) Limit() error { 31 | nowMs := GetNowMs() 32 | qpsThreshold := atomic.LoadInt64(&swrl.qpsThreshold) 33 | 34 | swrl.mu.Lock() 35 | defer swrl.mu.Unlock() 36 | 37 | const HitMetric = "hit" 38 | hits := swrl.sw.GetHit(nowMs, HitMetric) 39 | actualDurationMs := swrl.sw.GetActualDurationMs(nowMs) 40 | // actualQPS = hits / (actualDurationMs / 1000) 41 | // actualQPS >= qpsThreshold 改写即得下述表达式。 42 | if hits*1000 >= qpsThreshold*actualDurationMs { 43 | return ErrRateLimited 44 | } else { 45 | swrl.sw.Hit(nowMs, HitMetric) 46 | return nil 47 | } 48 | } 49 | 50 | func (swrl *SlidingWindowRateLimiter) ChangeQpsThreshold(newQpsThreshold int64) { 51 | atomic.StoreInt64(&swrl.qpsThreshold, newQpsThreshold) 52 | } 53 | -------------------------------------------------------------------------------- /docs/cn/fault-tolerant.md: -------------------------------------------------------------------------------- 1 | # 熔断限流机制 2 | 3 | 4 | 5 | ## 熔断 6 | ``` 7 | scope: "sql" 8 | strategies: 9 | - min_qps: 3 10 | failure_rate_threshold: 0 11 | failure_num: 5 12 | sql_timeout_ms: 2000 13 | open_status_duration_ms: 5000 14 | size: 10 15 | cell_interval_ms: 1000 16 | ``` 17 | ### 熔断过程 18 | 19 | 从配置文件中我们得知 strategies 是一个数组,那么 namespace 中可以在 scope 下配置多种熔断策略。当请求从客户端进入 weir ,weir 会根据链接账户选择要进入的租户(namesapce),同时启动对应租户下的熔断管理器, 20 | 熔断管理器根据 scope 可以选择当前是哪一类熔断器,再根据类别中的特征,比如库名表名,sql 特征等选择对应的熔断器对象,进行计数统计,如下图: 21 | 22 | 23 | 24 | 当熔断时返回错误 **circuit breaker triggered** 。 25 | 26 | ### 熔断级别 27 | 28 | 这里要说的是熔断级别的问题,目前 weir 支持4种熔断级别 29 | 30 | | 级别 | 说明 | 31 | | --- | --- | 32 | | namespace | 租户级别熔断,熔断触发时所有对此租户进行熔断 | 33 | | db | 库级别熔断, 熔断触发时所有对此库的 sql 进行熔断 | 34 | | table | 表级别熔断, 熔断触发时所有对此表的 sql 进行熔断 | 35 | | sql | sql 级别熔断, 对每一个特征 sql 做监控管理,颗粒度最细,熔断触发时对这一类 sql 进行熔断 | 36 | 37 | sql 熔断在内存中的输入 eg : 38 | 39 | ``` 40 | select * from test_table where id = 2; 41 | select * from test_table where in (1,2,3); 42 | ``` 43 | 在熔断器中将被转化成: 44 | 45 | ``` 46 | select * from test_table where id = ?; 47 | select * from test_table where in (?); 48 | ``` 49 | 50 | 这里我们在 ast 解析时通过判断 ast 的 node 类型进行了值的替换, 并重写了 sql 进行输出, 这样就可以确定一类 sql 并提取他们的摘要方便我们后续使用, 如下图: 51 | 52 | 53 | 熔断器计数采用的是滑动窗口计数器,滑动窗口有实现简单,能应对周期比较长的统计 54 | 55 | 56 | ## 限流 57 | 58 | 限流分为阻塞式限流和拒绝式限流,目前当前版本完成的是拒绝式限流 59 | 拒绝式限流统计数据同样是采用滑动窗口计数器,在周期内会统计 qps 数据,qps 一旦大于阈值将执行限流,期间返回错误 **rate limited** 60 | -------------------------------------------------------------------------------- /pkg/proxy/backend/selector_test.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | 7 | "github.com/tidb-incubator/weir/pkg/util/rand2" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type testSource struct { 12 | val int64 13 | } 14 | 15 | func (t *testSource) Int63() int64 { 16 | return t.val 17 | } 18 | 19 | func (*testSource) Seed(seed int64) { 20 | } 21 | 22 | func TestRandomSelector_Select_Success(t *testing.T) { 23 | source := &testSource{} 24 | rd := rand2.New(source) 25 | selector := NewRandomSelector(rd) 26 | 27 | host := "127.0.0.1" 28 | ports := []int{4000, 4001, 4002} 29 | instances := prepareInstances(host, ports) 30 | 31 | for i := 0; i < len(ports); i++ { 32 | source.val = int64(i) 33 | instance, err := selector.Select(instances) 34 | assert.NoError(t, err) 35 | assert.Equal(t, getAddr(host, ports[i]), instance.addr) 36 | } 37 | } 38 | 39 | func TestRandomSelector_Select_ErrNoInstanceToSelect(t *testing.T) { 40 | source := &testSource{} 41 | rd := rand2.New(source) 42 | selector := NewRandomSelector(rd) 43 | 44 | host := "127.0.0.1" 45 | var ports []int 46 | instances := prepareInstances(host, ports) 47 | 48 | instance, err := selector.Select(instances) 49 | assert.Nil(t, instance) 50 | assert.EqualError(t, err, ErrNoInstanceToSelect.Error()) 51 | } 52 | 53 | func prepareInstances(host string, ports []int) []*Instance { 54 | var instances []*Instance 55 | for _, p := range ports { 56 | instance := &Instance{ 57 | addr: getAddr(host, p), 58 | } 59 | instances = append(instances, instance) 60 | } 61 | return instances 62 | } 63 | 64 | func getAddr(host string, port int) string { 65 | return host + ":" + strconv.Itoa(port) 66 | } 67 | -------------------------------------------------------------------------------- /pkg/util/sync2/semaphore_flaky_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sync2 18 | 19 | import ( 20 | "testing" 21 | "time" 22 | ) 23 | 24 | func TestSemaNoTimeout(t *testing.T) { 25 | s := NewSemaphore(1, 0) 26 | s.Acquire() 27 | released := false 28 | go func() { 29 | time.Sleep(10 * time.Millisecond) 30 | released = true 31 | s.Release() 32 | }() 33 | s.Acquire() 34 | if !released { 35 | t.Errorf("release: false, want true") 36 | } 37 | } 38 | 39 | func TestSemaTimeout(t *testing.T) { 40 | s := NewSemaphore(1, 5*time.Millisecond) 41 | s.Acquire() 42 | go func() { 43 | time.Sleep(10 * time.Millisecond) 44 | s.Release() 45 | }() 46 | if s.Acquire() { 47 | t.Errorf("Acquire: true, want false") 48 | } 49 | time.Sleep(10 * time.Millisecond) 50 | if !s.Acquire() { 51 | t.Errorf("Acquire: false, want true") 52 | } 53 | } 54 | 55 | func TestSemaTryAcquire(t *testing.T) { 56 | s := NewSemaphore(1, 0) 57 | if !s.TryAcquire() { 58 | t.Errorf("TryAcquire: false, want true") 59 | } 60 | if s.TryAcquire() { 61 | t.Errorf("TryAcquire: true, want false") 62 | } 63 | s.Release() 64 | if !s.TryAcquire() { 65 | t.Errorf("TryAcquire: false, want true") 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /pkg/proxy/backend/selector.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "errors" 5 | "math/rand" 6 | "time" 7 | 8 | "github.com/tidb-incubator/weir/pkg/util/rand2" 9 | ) 10 | 11 | const ( 12 | SelectorTypeRandom = 1 + iota 13 | ) 14 | 15 | const ( 16 | SelectorNameUnknown = "unknown" 17 | SelectorNameRandom = "random" 18 | ) 19 | 20 | var ( 21 | selectorTypeMap = map[int]string{ 22 | SelectorTypeRandom: SelectorNameRandom, 23 | } 24 | selectorNameMap = map[string]int{ 25 | SelectorNameRandom: SelectorTypeRandom, 26 | } 27 | ) 28 | 29 | var ( 30 | ErrNoInstanceToSelect = errors.New("no instance to select") 31 | ErrInvalidSelectorType = errors.New("invalid selector type") 32 | ) 33 | 34 | type Selector interface { 35 | Select(instances []*Instance) (*Instance, error) 36 | } 37 | 38 | type RandomSelector struct { 39 | rd *rand2.Rand 40 | } 41 | 42 | func CreateSelector(selectorType int) (Selector, error) { 43 | switch selectorType { 44 | case SelectorTypeRandom: 45 | source := rand.NewSource(time.Now().Unix()) 46 | rd := rand2.New(source) 47 | return NewRandomSelector(rd), nil 48 | default: 49 | return nil, ErrInvalidSelectorType 50 | } 51 | } 52 | 53 | func NewRandomSelector(rd *rand2.Rand) *RandomSelector { 54 | return &RandomSelector{ 55 | rd: rd, 56 | } 57 | } 58 | 59 | func (s *RandomSelector) Select(instances []*Instance) (*Instance, error) { 60 | length := len(instances) 61 | if length == 0 { 62 | return nil, ErrNoInstanceToSelect 63 | } 64 | idx := s.rd.Int63n(int64(length)) 65 | return instances[idx], nil 66 | } 67 | 68 | func SelectorNameToType(name string) (int, bool) { 69 | t, ok := selectorNameMap[name] 70 | return t, ok 71 | } 72 | 73 | func SelectorTypeToName(t int) (string, bool) { 74 | n, ok := selectorTypeMap[t] 75 | return n, ok 76 | } 77 | -------------------------------------------------------------------------------- /docs/cn/proxy-config.md: -------------------------------------------------------------------------------- 1 | # Proxy配置详解 2 | 3 | 以下给出了Weir Proxy的示例配置🌰. 4 | 5 | ``` 6 | version: "v1" 7 | proxy_server: 8 | addr: "0.0.0.0:6000" 9 | max_connections: 1000 10 | session_timeout: 600 11 | admin_server: 12 | addr: "0.0.0.0:6001" 13 | enable_basic_auth: false 14 | user: "" 15 | password: "" 16 | log: 17 | level: "debug" 18 | format: "console" 19 | log_file: 20 | filename: "" 21 | max_size: 300 22 | max_days: 1 23 | max_backups: 1 24 | registry: 25 | enable: false 26 | config_center: 27 | type: "file" 28 | config_file: 29 | path: "./conf/namespace" 30 | strict_parse: false 31 | performance: 32 | tcp_keep_alive: true 33 | ``` 34 | 35 | | 配置名 | 说明 | 36 | | --- | --- | 37 | | version | 配置 schema 的版本号, 目前为 v1 | 38 | | proxy_server | Proxy 代理服务相关配置 | 39 | | proxy_server.addr | Proxy服务端口监听地址 | 40 | | proxy_server.max_connections | 最大客户端连接数 | 41 | | proxy_server.session_timeout | 客户端空闲链接超时时间 | 42 | | admin_server | Proxy 管理相关配置 | 43 | | admin_server.addr | Proxy admin 口监听地址 | 44 | | admin_server.enable_basic_auth | 是否开启Basic Auth | 45 | | admin_server.user | Basic Auth User | 46 | | admin_server.password | Basic Auth Password | 47 | | log | 日志配置 | 48 | | log.level | 日志级别 (支持 debug, info, warn, error) | 49 | | log.format | 日志输出方式 (支持 console, file) | 50 | | log.log_file | 日志文件相关配置 | 51 | | log.log_file.filename | 日志文件名 | 52 | | log.log_file.max_size | 单个日志文件最大尺寸 | 53 | | log.log_file.max_days | 单个日志文件保存最大天数 | 54 | | config_center | 配置中心 | 55 | | config_center.type | 配置中心类型 (支持 file, etcd) | 56 | | config_center.config_file | 配置文件信息,在 type 为file时有效 | 57 | | config_center.config_file.path | Namespace配置文件所在目录 | 58 | | strict_parse | 对命名空间名称的严格校验,如果禁用strictParse,则在列出所有命名空间时将忽略解析命名空间错误 | 59 | | performance | 性能相关配置 | 60 | | tcp_keep_alive | 对客户端连接是否开启TCP Keep Alive | 61 | -------------------------------------------------------------------------------- /cmd/weirproxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | "os/signal" 9 | "sync" 10 | "syscall" 11 | 12 | "github.com/pingcap/tidb/util/logutil" 13 | "github.com/tidb-incubator/weir/pkg/config" 14 | "github.com/tidb-incubator/weir/pkg/proxy" 15 | "go.uber.org/zap" 16 | ) 17 | 18 | var ( 19 | configFilePath = flag.String("config", "conf/weirproxy.yaml", "weir proxy config file path") 20 | ) 21 | 22 | func main() { 23 | flag.Parse() 24 | proxyConfigData, err := ioutil.ReadFile(*configFilePath) 25 | if err != nil { 26 | fmt.Printf("read config file error: %v\n", err) 27 | os.Exit(1) 28 | } 29 | 30 | proxyCfg, err := config.UnmarshalProxyConfig(proxyConfigData) 31 | if err != nil { 32 | fmt.Printf("parse config file error: %v\n", err) 33 | os.Exit(1) 34 | } 35 | 36 | p := proxy.NewProxy(proxyCfg) 37 | 38 | if err = p.Init(); err != nil { 39 | fmt.Printf("proxy init error: %v\n", err) 40 | p.Close() 41 | os.Exit(1) 42 | } 43 | 44 | sc := make(chan os.Signal, 1) 45 | signal.Notify(sc, 46 | syscall.SIGINT, 47 | syscall.SIGTERM, 48 | syscall.SIGQUIT, 49 | syscall.SIGPIPE, 50 | syscall.SIGUSR1, 51 | ) 52 | 53 | var wg sync.WaitGroup 54 | wg.Add(1) 55 | 56 | go func() { 57 | defer wg.Done() 58 | for { 59 | sig := <-sc 60 | if sig == syscall.SIGINT || sig == syscall.SIGTERM || sig == syscall.SIGQUIT { 61 | logutil.BgLogger().Warn("get os signal, close proxy server", zap.String("signal", sig.String())) 62 | p.Close() 63 | break 64 | } else { 65 | logutil.BgLogger().Warn("ignore os signal", zap.String("signal", sig.String())) 66 | } 67 | } 68 | }() 69 | 70 | if err := p.Run(); err != nil { 71 | logutil.BgLogger().Error("proxy run error, exit", zap.Error(err)) 72 | } 73 | 74 | wg.Wait() 75 | } 76 | -------------------------------------------------------------------------------- /pkg/util/timer/randticker.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package timer 18 | 19 | import ( 20 | "math/rand" 21 | "time" 22 | ) 23 | 24 | // RandTicker is just like time.Ticker, except that 25 | // it adds randomness to the events. 26 | type RandTicker struct { 27 | C <-chan time.Time 28 | done chan struct{} 29 | } 30 | 31 | // NewRandTicker creates a new RandTicker. d is the duration, 32 | // and variance specifies the variance. The ticker will tick 33 | // every d +/- variance. 34 | func NewRandTicker(d, variance time.Duration) *RandTicker { 35 | c := make(chan time.Time, 1) 36 | done := make(chan struct{}) 37 | go func() { 38 | rnd := rand.New(rand.NewSource(time.Now().UnixNano())) 39 | for { 40 | vr := time.Duration(rnd.Int63n(int64(2*variance)) - int64(variance)) 41 | tmr := time.NewTimer(d + vr) 42 | select { 43 | case <-tmr.C: 44 | select { 45 | case c <- time.Now(): 46 | default: 47 | } 48 | case <-done: 49 | tmr.Stop() 50 | close(c) 51 | return 52 | } 53 | } 54 | }() 55 | return &RandTicker{ 56 | C: c, 57 | done: done, 58 | } 59 | } 60 | 61 | // Stop stops the ticker and closes the underlying channel. 62 | func (tkr *RandTicker) Stop() { 63 | close(tkr.done) 64 | } 65 | -------------------------------------------------------------------------------- /pkg/proxy/metrics/server.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/pingcap/errors" 7 | "github.com/pingcap/parser/terror" 8 | "github.com/prometheus/client_golang/prometheus" 9 | ) 10 | 11 | var ( 12 | // PanicCounter measures the count of panics. 13 | PanicCounter = prometheus.NewCounterVec( 14 | prometheus.CounterOpts{ 15 | Namespace: ModuleWeirProxy, 16 | Subsystem: LabelServer, 17 | Name: "panic_total", 18 | Help: "Counter of panic.", 19 | }, []string{LblCluster, LblType}) 20 | 21 | QueryTotalCounter = prometheus.NewCounterVec( 22 | prometheus.CounterOpts{ 23 | Namespace: ModuleWeirProxy, 24 | Subsystem: LabelServer, 25 | Name: "query_total", 26 | Help: "Counter of queries.", 27 | }, []string{LblCluster, LblType, LblResult}) 28 | 29 | ExecuteErrorCounter = prometheus.NewCounterVec( 30 | prometheus.CounterOpts{ 31 | Namespace: ModuleWeirProxy, 32 | Subsystem: LabelServer, 33 | Name: "execute_error_total", 34 | Help: "Counter of execute errors.", 35 | }, []string{LblCluster, LblType}) 36 | 37 | ConnGauge = prometheus.NewGaugeVec( 38 | prometheus.GaugeOpts{ 39 | Namespace: ModuleWeirProxy, 40 | Subsystem: LabelServer, 41 | Name: "connections", 42 | Help: "Number of connections.", 43 | }, []string{LblCluster}) 44 | 45 | EventStart = "start" 46 | EventGracefulDown = "graceful_shutdown" 47 | // Eventkill occurs when the server.Kill() function is called. 48 | EventKill = "kill" 49 | EventClose = "close" 50 | ) 51 | 52 | // ExecuteErrorToLabel converts an execute error to label. 53 | func ExecuteErrorToLabel(err error) string { 54 | err = errors.Cause(err) 55 | switch x := err.(type) { 56 | case *terror.Error: 57 | return x.Class().String() + ":" + strconv.Itoa(int(x.Code())) 58 | default: 59 | return "unknown" 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /pkg/util/timer/randticker_flaky_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package timer 18 | 19 | import ( 20 | "testing" 21 | "time" 22 | ) 23 | 24 | const ( 25 | testDuration = 100 * time.Millisecond 26 | testVariance = 20 * time.Millisecond 27 | ) 28 | 29 | func TestTick(t *testing.T) { 30 | tkr := NewRandTicker(testDuration, testVariance) 31 | for i := 0; i < 5; i++ { 32 | start := time.Now() 33 | end := <-tkr.C 34 | diff := start.Add(testDuration).Sub(end) 35 | tolerance := testVariance + 20*time.Millisecond 36 | if diff < -tolerance || diff > tolerance { 37 | t.Errorf("start: %v, end: %v, diff %v. Want <%v tolerenace", start, end, diff, tolerance) 38 | } 39 | } 40 | tkr.Stop() 41 | _, ok := <-tkr.C 42 | if ok { 43 | t.Error("Channel was not closed") 44 | } 45 | } 46 | 47 | func TestTickSkip(t *testing.T) { 48 | tkr := NewRandTicker(10*time.Millisecond, 1*time.Millisecond) 49 | time.Sleep(35 * time.Millisecond) 50 | end := <-tkr.C 51 | diff := time.Since(end) 52 | if diff < 20*time.Millisecond { 53 | t.Errorf("diff: %v, want >20ms", diff) 54 | } 55 | 56 | // This tick should be up-to-date 57 | end = <-tkr.C 58 | diff = time.Since(end) 59 | if diff > 1*time.Millisecond { 60 | t.Errorf("diff: %v, want <1ms", diff) 61 | } 62 | tkr.Stop() 63 | } 64 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/user.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/tidb-incubator/weir/pkg/config" 7 | "github.com/pingcap/errors" 8 | ) 9 | 10 | type UserNamespaceMapper struct { 11 | userToNamespace map[string]string 12 | } 13 | 14 | func CreateUserNamespaceMapper(namespaces []*config.Namespace) (*UserNamespaceMapper, error) { 15 | mapper := make(map[string]string) 16 | for _, ns := range namespaces { 17 | frontendNamespace := ns.Frontend 18 | for _, user := range frontendNamespace.Users { 19 | originNamespace, ok := mapper[user.Username] 20 | if ok { 21 | return nil, errors.WithMessage(ErrDuplicatedUser, 22 | fmt.Sprintf("user: %s, namespace: %s, %s", user.Username, originNamespace, ns.Namespace)) 23 | } 24 | mapper[user.Username] = ns.Namespace 25 | } 26 | } 27 | 28 | ret := &UserNamespaceMapper{userToNamespace: mapper} 29 | return ret, nil 30 | } 31 | 32 | func (u *UserNamespaceMapper) GetUserNamespace(username string) (string, bool) { 33 | ns, ok := u.userToNamespace[username] 34 | return ns, ok 35 | } 36 | 37 | func (u *UserNamespaceMapper) Clone() *UserNamespaceMapper { 38 | ret := make(map[string]string) 39 | for k, v := range u.userToNamespace { 40 | ret[k] = v 41 | } 42 | return &UserNamespaceMapper{userToNamespace: ret} 43 | } 44 | 45 | func (u *UserNamespaceMapper) RemoveNamespaceUsers(ns string) { 46 | for k, namespace := range u.userToNamespace { 47 | if ns == namespace { 48 | delete(u.userToNamespace, k) 49 | } 50 | } 51 | } 52 | 53 | func (u *UserNamespaceMapper) AddNamespaceUsers(ns string, cfg *config.FrontendNamespace) error { 54 | for _, userInfo := range cfg.Users { 55 | if originNamespace, ok := u.userToNamespace[userInfo.Username]; ok { 56 | return errors.WithMessage(ErrDuplicatedUser, fmt.Sprintf("namespace: %s", originNamespace)) 57 | } 58 | u.userToNamespace[userInfo.Username] = ns 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /pkg/proxy/driver/mock_Namespace.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.3.0. DO NOT EDIT. 2 | 3 | package driver 4 | 5 | import ( 6 | context "context" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | // MockNamespace is an autogenerated mock type for the Namespace type 12 | type MockNamespace struct { 13 | mock.Mock 14 | } 15 | 16 | // GetPooledConn provides a mock function with given fields: _a0 17 | func (_m *MockNamespace) GetPooledConn(_a0 context.Context) (PooledBackendConn, error) { 18 | ret := _m.Called(_a0) 19 | 20 | var r0 PooledBackendConn 21 | if rf, ok := ret.Get(0).(func(context.Context) PooledBackendConn); ok { 22 | r0 = rf(_a0) 23 | } else { 24 | if ret.Get(0) != nil { 25 | r0 = ret.Get(0).(PooledBackendConn) 26 | } 27 | } 28 | 29 | var r1 error 30 | if rf, ok := ret.Get(1).(func(context.Context) error); ok { 31 | r1 = rf(_a0) 32 | } else { 33 | r1 = ret.Error(1) 34 | } 35 | 36 | return r0, r1 37 | } 38 | 39 | // IsDatabaseAllowed provides a mock function with given fields: db 40 | func (_m *MockNamespace) IsDatabaseAllowed(db string) bool { 41 | ret := _m.Called(db) 42 | 43 | var r0 bool 44 | if rf, ok := ret.Get(0).(func(string) bool); ok { 45 | r0 = rf(db) 46 | } else { 47 | r0 = ret.Get(0).(bool) 48 | } 49 | 50 | return r0 51 | } 52 | 53 | // ListDatabases provides a mock function with given fields: 54 | func (_m *MockNamespace) ListDatabases() []string { 55 | ret := _m.Called() 56 | 57 | var r0 []string 58 | if rf, ok := ret.Get(0).(func() []string); ok { 59 | r0 = rf() 60 | } else { 61 | if ret.Get(0) != nil { 62 | r0 = ret.Get(0).([]string) 63 | } 64 | } 65 | 66 | return r0 67 | } 68 | 69 | // Name provides a mock function with given fields: 70 | func (_m *MockNamespace) Name() string { 71 | ret := _m.Called() 72 | 73 | var r0 string 74 | if rf, ok := ret.Get(0).(func() string); ok { 75 | r0 = rf() 76 | } else { 77 | r0 = ret.Get(0).(string) 78 | } 79 | 80 | return r0 81 | } 82 | -------------------------------------------------------------------------------- /pkg/config/namespace.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | type Namespace struct { 4 | Version string `yaml:"version"` 5 | Namespace string `yaml:"namespace"` 6 | Frontend FrontendNamespace `yaml:"frontend"` 7 | Backend BackendNamespace `yaml:"backend"` 8 | Breaker BreakerInfo `yaml:"breaker"` 9 | RateLimiter RateLimiterInfo `yaml:"rate_limiter"` 10 | } 11 | 12 | type FrontendNamespace struct { 13 | AllowedDBs []string `yaml:"allowed_dbs"` 14 | SlowSQLTime int `yaml:"slow_sql_time"` 15 | DeniedIPs []string `yaml:"denied_ips"` 16 | IdleTimeout int `yaml:"idle_timeout"` 17 | Users []FrontendUserInfo `yaml:"users"` 18 | SQLBlackList []SQLInfo `yaml:"sql_blacklist"` 19 | SQLWhiteList []SQLInfo `yaml:"sql_whitelist"` 20 | } 21 | 22 | type FrontendUserInfo struct { 23 | Username string `yaml:"username"` 24 | Password string `yaml:"password"` 25 | } 26 | 27 | type SQLInfo struct { 28 | SQL string `yaml:"sql"` 29 | } 30 | 31 | type RateLimiterInfo struct { 32 | Scope string `yaml:"scope"` 33 | QPS int `yaml:"qps"` 34 | } 35 | 36 | type BackendNamespace struct { 37 | Username string `yaml:"username"` 38 | Password string `yaml:"password"` 39 | Instances []string `yaml:"instances"` 40 | SelectorType string `yaml:"selector_type"` 41 | PoolSize int `yaml:"pool_size"` 42 | IdleTimeout int `yaml:"idle_timeout"` 43 | } 44 | 45 | type StrategyInfo struct { 46 | MinQps int64 `yaml:"min_qps"` 47 | SqlTimeoutMs int64 `yaml:"sql_timeout_ms"` 48 | FailureRatethreshold int64 `yaml:"failure_rate_threshold"` 49 | FailureNum int64 `yaml:"failure_num"` 50 | OpenStatusDurationMs int64 `yaml:"open_status_duration_ms"` 51 | Size int64 `yaml:"size"` 52 | CellIntervalMs int64 `yaml:"cell_interval_ms"` 53 | } 54 | 55 | type BreakerInfo struct { 56 | Scope string `yaml:"scope"` 57 | Strategies []StrategyInfo `yaml:"strategies"` 58 | } 59 | -------------------------------------------------------------------------------- /pkg/util/timer/timer_flaky_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package timer 18 | 19 | import ( 20 | "testing" 21 | "time" 22 | 23 | "github.com/tidb-incubator/weir/pkg/util/sync2" 24 | "github.com/stretchr/testify/assert" 25 | ) 26 | 27 | const ( 28 | half = 50 * time.Millisecond 29 | quarter = 25 * time.Millisecond 30 | tenth = 10 * time.Millisecond 31 | ) 32 | 33 | var numcalls sync2.AtomicInt64 34 | 35 | func f() { 36 | numcalls.Add(1) 37 | } 38 | 39 | func TestWait(t *testing.T) { 40 | numcalls.Set(0) 41 | timer := NewTimer(quarter) 42 | assert.False(t, timer.Running()) 43 | timer.Start(f) 44 | defer timer.Stop() 45 | assert.True(t, timer.Running()) 46 | time.Sleep(tenth) 47 | assert.Equal(t, int64(0), numcalls.Get()) 48 | time.Sleep(quarter) 49 | assert.Equal(t, int64(1), numcalls.Get()) 50 | time.Sleep(quarter) 51 | assert.Equal(t, int64(2), numcalls.Get()) 52 | } 53 | 54 | func TestReset(t *testing.T) { 55 | numcalls.Set(0) 56 | timer := NewTimer(half) 57 | timer.Start(f) 58 | defer timer.Stop() 59 | timer.SetInterval(quarter) 60 | time.Sleep(tenth) 61 | assert.Equal(t, int64(0), numcalls.Get()) 62 | time.Sleep(quarter) 63 | assert.Equal(t, int64(1), numcalls.Get()) 64 | } 65 | 66 | func TestIndefinite(t *testing.T) { 67 | numcalls.Set(0) 68 | timer := NewTimer(0) 69 | timer.Start(f) 70 | defer timer.Stop() 71 | timer.TriggerAfter(quarter) 72 | time.Sleep(tenth) 73 | assert.Equal(t, int64(0), numcalls.Get()) 74 | time.Sleep(quarter) 75 | assert.Equal(t, int64(1), numcalls.Get()) 76 | } 77 | -------------------------------------------------------------------------------- /pkg/config/proxy.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | const ( 4 | MIN_SESSION_TIMEOUT = 600 5 | ) 6 | 7 | const ( 8 | DefaultClusterName = "default" 9 | ) 10 | 11 | type Proxy struct { 12 | Version string `yaml:"version"` 13 | Cluster string `yaml:"cluster"` 14 | ProxyServer ProxyServer `yaml:"proxy_server"` 15 | AdminServer AdminServer `yaml:"admin_server"` 16 | Log Log `yaml:"log"` 17 | Registry Registry `yaml:"registry"` 18 | ConfigCenter ConfigCenter `yaml:"config_center"` 19 | Performance Performance `yaml:"performance"` 20 | } 21 | 22 | type ProxyServer struct { 23 | Addr string `yaml:"addr"` 24 | MaxConnections uint32 `yaml:"max_connections"` 25 | SessionTimeout int `yaml:"session_timeout"` 26 | } 27 | 28 | type AdminServer struct { 29 | Addr string `yaml:"addr"` 30 | EnableBasicAuth bool `yaml:"enable_basic_auth"` 31 | User string `yaml:"user"` 32 | Password string `yaml:"password"` 33 | } 34 | 35 | type Log struct { 36 | Level string `yaml:"level"` 37 | Format string `yaml:"format"` 38 | LogFile LogFile `yaml:"log_file"` 39 | } 40 | 41 | type LogFile struct { 42 | Filename string `yaml:"filename"` 43 | MaxSize int `yaml:"max_size"` 44 | MaxDays int `yaml:"max_days"` 45 | MaxBackups int `yaml:"max_backups"` 46 | } 47 | 48 | type Registry struct { 49 | Enable bool `yaml:"enable"` 50 | Type string `yaml:"type"` 51 | Addrs []string `yaml:"addrs"` 52 | } 53 | 54 | type ConfigCenter struct { 55 | Type string `yaml:"type"` 56 | ConfigFile ConfigFile `yaml:"config_file"` 57 | ConfigEtcd ConfigEtcd `yaml:"config_etcd"` 58 | } 59 | 60 | type ConfigFile struct { 61 | Path string `yaml:"path"` 62 | } 63 | 64 | type ConfigEtcd struct { 65 | Addrs []string `yaml:"addrs"` 66 | BasePath string `yaml:"base_path"` 67 | Username string `yaml:"username"` 68 | Password string `yaml:"password"` 69 | // If strictParse is disabled, parsing namespace error will be ignored when list all namespaces. 70 | StrictParse bool `yaml:"strict_parse"` 71 | } 72 | 73 | type Performance struct { 74 | TCPKeepAlive bool `yaml:"tcp_keep_alive"` 75 | } 76 | -------------------------------------------------------------------------------- /pkg/config/marshaller_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | var testNamespaceConfig = Namespace{ 10 | Version: "v1", 11 | Namespace: "test_ns", 12 | Frontend: FrontendNamespace{ 13 | AllowedDBs: []string{"db0", "db1"}, 14 | SlowSQLTime: 10, 15 | DeniedIPs: []string{"127.0.0.0", "128.0.0.0"}, 16 | IdleTimeout: 10, 17 | Users: []FrontendUserInfo{ 18 | {Username: "user0", Password: "pwd0"}, 19 | {Username: "user1", Password: "pwd1"}, 20 | }, 21 | }, 22 | Backend: BackendNamespace{ 23 | Username: "user0", 24 | Password: "pwd0", 25 | Instances: []string{"127.0.0.1:4000", "127.0.0.1:4001"}, 26 | SelectorType: "random", 27 | PoolSize: 1, 28 | IdleTimeout: 20, 29 | }, 30 | } 31 | 32 | var testProxyConfig = Proxy{ 33 | Version: "v1", 34 | ProxyServer: ProxyServer{ 35 | Addr: "0.0.0.0:4000", 36 | MaxConnections: 1, 37 | }, 38 | AdminServer: AdminServer{ 39 | Addr: "0.0.0.0:4001", 40 | EnableBasicAuth: false, 41 | User: "user", 42 | Password: "pwd", 43 | }, 44 | Log: Log{ 45 | Level: "info", 46 | Format: "console", 47 | LogFile: LogFile{ 48 | Filename: ".", 49 | MaxSize: 10, 50 | MaxDays: 1, 51 | MaxBackups: 1, 52 | }, 53 | }, 54 | Registry: Registry{ 55 | Enable: false, 56 | Type: "etcd", 57 | Addrs: []string{"127.0.0.1:4000", "127.0.0.1:4001"}, 58 | }, 59 | ConfigCenter: ConfigCenter{ 60 | Type: "file", 61 | ConfigFile: ConfigFile{ 62 | Path: ".", 63 | }, 64 | }, 65 | Performance: Performance{ 66 | TCPKeepAlive: true, 67 | }, 68 | } 69 | 70 | func TestNamespaceConfigEncodeAndDecode(t *testing.T) { 71 | data, err := MarshalNamespaceConfig(&testNamespaceConfig) 72 | assert.NoError(t, err) 73 | cfg, err := UnmarshalNamespaceConfig(data) 74 | assert.NoError(t, err) 75 | assert.Equal(t, testNamespaceConfig, *cfg) 76 | } 77 | 78 | func TestProxyConfigEncodeAndDecode(t *testing.T) { 79 | data, err := MarshalProxyConfig(&testProxyConfig) 80 | assert.NoError(t, err) 81 | cfg, err := UnmarshalProxyConfig(data) 82 | assert.NoError(t, err) 83 | assert.Equal(t, testProxyConfig, *cfg) 84 | } 85 | -------------------------------------------------------------------------------- /pkg/configcenter/file.go: -------------------------------------------------------------------------------- 1 | package configcenter 2 | 3 | import ( 4 | "github.com/pingcap/errors" 5 | "github.com/tidb-incubator/weir/pkg/config" 6 | "io/ioutil" 7 | "path" 8 | "path/filepath" 9 | ) 10 | 11 | var ( 12 | ErrNamespaceNotFound = errors.New("namespace not found") 13 | ) 14 | 15 | // FileConfigCenter is only for test use, 16 | // please do not use it in production environment. 17 | type FileConfigCenter struct { 18 | dir string 19 | cfgs map[string]*config.Namespace // key: namespace 20 | nspath map[string]string // key: namespace, value: config file path 21 | } 22 | 23 | func CreateFileConfigCenter(nsdir string) (*FileConfigCenter, error) { 24 | yamlFiles, err := listAllYamlFiles(nsdir) 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | c := newFileConfigCenter(nsdir) 30 | 31 | for _, yamlFile := range yamlFiles { 32 | fileData, err := ioutil.ReadFile(yamlFile) 33 | if err != nil { 34 | return nil, err 35 | } 36 | cfg, err := config.UnmarshalNamespaceConfig(fileData) 37 | if err != nil { 38 | return nil, err 39 | } 40 | c.cfgs[cfg.Namespace] = cfg 41 | c.nspath[cfg.Namespace] = yamlFile 42 | } 43 | return c, nil 44 | } 45 | 46 | func newFileConfigCenter(dir string) *FileConfigCenter { 47 | return &FileConfigCenter{ 48 | dir: dir, 49 | cfgs: make(map[string]*config.Namespace), 50 | nspath: make(map[string]string), 51 | } 52 | } 53 | 54 | func listAllYamlFiles(dir string) ([]string, error) { 55 | infos, err := ioutil.ReadDir(dir) 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | var ret []string 61 | for _, info := range infos { 62 | fileName := info.Name() 63 | if path.Ext(fileName) == ".yaml" { 64 | ret = append(ret, filepath.Join(dir, fileName)) 65 | } 66 | } 67 | 68 | return ret, nil 69 | } 70 | 71 | func (f *FileConfigCenter) GetNamespace(ns string) (*config.Namespace, error) { 72 | cfg, ok := f.cfgs[ns] 73 | if !ok { 74 | return nil, ErrNamespaceNotFound 75 | } 76 | return cfg, nil 77 | } 78 | 79 | func (f *FileConfigCenter) ListAllNamespace() ([]*config.Namespace, error) { 80 | var ret []*config.Namespace 81 | for _, cfg := range f.cfgs { 82 | ret = append(ret, cfg) 83 | } 84 | return ret, nil 85 | } 86 | -------------------------------------------------------------------------------- /pkg/proxy/driver/domain.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/siddontang/go-mysql/mysql" 7 | ) 8 | 9 | type NamespaceManager interface { 10 | Auth(username string, pwd, salt []byte) (Namespace, bool) 11 | } 12 | 13 | type Namespace interface { 14 | Name() string 15 | IsDatabaseAllowed(db string) bool 16 | ListDatabases() []string 17 | IsDeniedSQL(sqlFeature uint32) bool 18 | IsAllowedSQL(sqlFeature uint32) bool 19 | GetPooledConn(context.Context) (PooledBackendConn, error) 20 | IncrConnCount() 21 | DescConnCount() 22 | GetBreaker() (Breaker, error) 23 | GetRateLimiter() RateLimiter 24 | } 25 | 26 | type Breaker interface { 27 | IsUseBreaker() bool 28 | GetBreakerScope() string 29 | Hit(name string, idx int, isFail bool) error 30 | Status(name string) (int32, int) 31 | AddTimeWheelTask(name string, connectionID uint64, flag *int32) error 32 | RemoveTimeWheelTask(connectionID uint64) error 33 | CASHalfOpenProbeSent(name string, idx int, halfOpenProbeSent bool) bool 34 | CloseBreaker() 35 | } 36 | 37 | type RateLimiter interface { 38 | Scope() string 39 | Limit(ctx context.Context, key string) error 40 | } 41 | 42 | type PooledBackendConn interface { 43 | // PutBack put conn back to pool 44 | PutBack() 45 | 46 | // ErrorClose close conn and connpool create a new conn 47 | // call this function when conn is broken. 48 | ErrorClose() error 49 | BackendConn 50 | } 51 | 52 | type SimpleBackendConn interface { 53 | Close() error 54 | BackendConn 55 | } 56 | 57 | type BackendConn interface { 58 | Ping() error 59 | UseDB(dbName string) error 60 | GetDB() string 61 | Execute(command string, args ...interface{}) (*mysql.Result, error) 62 | Begin() error 63 | Commit() error 64 | Rollback() error 65 | StmtPrepare(sql string) (Stmt, error) 66 | StmtExecuteForward(data []byte) (*mysql.Result, error) 67 | StmtClosePrepare(stmtId int) error 68 | SetCharset(charset string) error 69 | FieldList(table string, wildcard string) ([]*mysql.Field, error) 70 | SetAutoCommit(bool) error 71 | IsAutoCommit() bool 72 | IsInTransaction() bool 73 | GetCharset() string 74 | GetConnectionID() uint32 75 | GetStatus() uint16 76 | } 77 | 78 | type Stmt interface { 79 | ID() int 80 | ParamNum() int 81 | ColumnNum() int 82 | } 83 | -------------------------------------------------------------------------------- /pkg/proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/tidb-incubator/weir/pkg/config" 7 | "github.com/tidb-incubator/weir/pkg/configcenter" 8 | "github.com/tidb-incubator/weir/pkg/proxy/driver" 9 | "github.com/tidb-incubator/weir/pkg/proxy/metrics" 10 | "github.com/tidb-incubator/weir/pkg/proxy/namespace" 11 | "github.com/tidb-incubator/weir/pkg/proxy/server" 12 | ) 13 | 14 | type Proxy struct { 15 | cfg *config.Proxy 16 | svr *server.Server 17 | apiServer *HttpApiServer 18 | nsmgr *namespace.NamespaceManager 19 | configCenter configcenter.ConfigCenter 20 | } 21 | 22 | func supplementProxyConfig(cfg *config.Proxy) *config.Proxy { 23 | if cfg.ProxyServer.SessionTimeout <= config.MIN_SESSION_TIMEOUT { 24 | cfg.ProxyServer.SessionTimeout = config.MIN_SESSION_TIMEOUT 25 | } 26 | if cfg.Cluster == "" { 27 | cfg.Cluster = config.DefaultClusterName 28 | } 29 | return cfg 30 | } 31 | 32 | func NewProxy(cfg *config.Proxy) *Proxy { 33 | return &Proxy{ 34 | cfg: supplementProxyConfig(cfg), 35 | } 36 | } 37 | 38 | func (p *Proxy) Init() error { 39 | metrics.RegisterProxyMetrics(p.cfg.Cluster) 40 | cc, err := configcenter.CreateConfigCenter(p.cfg.ConfigCenter) 41 | if err != nil { 42 | return err 43 | } 44 | p.configCenter = cc 45 | 46 | nss, err := cc.ListAllNamespace() 47 | if err != nil { 48 | return err 49 | } 50 | nsmgr, err := namespace.CreateNamespaceManager(nss, namespace.BuildNamespace, namespace.DefaultAsyncCloseNamespace) 51 | if err != nil { 52 | return err 53 | } 54 | p.nsmgr = nsmgr 55 | driverImpl := driver.NewDriverImpl(nsmgr) 56 | svr, err := server.NewServer(p.cfg, driverImpl) 57 | if err != nil { 58 | return err 59 | } 60 | p.svr = svr 61 | apiServer, err := CreateHttpApiServer(svr, nsmgr, cc, p.cfg) 62 | if err != nil { 63 | return err 64 | } 65 | p.apiServer = apiServer 66 | 67 | return nil 68 | } 69 | 70 | // TODO(eastfisher): refactor this function 71 | func (p *Proxy) Run() error { 72 | go func() { 73 | time.Sleep(200 * time.Millisecond) 74 | p.apiServer.Run() 75 | }() 76 | return p.svr.Run() 77 | } 78 | 79 | func (p *Proxy) Close() { 80 | if p.apiServer != nil { 81 | p.apiServer.Close() 82 | } 83 | if p.svr != nil { 84 | p.svr.Close() 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /pkg/util/sync2/semaphore.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sync2 18 | 19 | // What's in a name? Channels have all you need to emulate a counting 20 | // semaphore with a boatload of extra functionality. However, in some 21 | // cases, you just want a familiar API. 22 | 23 | import ( 24 | "time" 25 | ) 26 | 27 | // Semaphore is a counting semaphore with the option to 28 | // specify a timeout. 29 | type Semaphore struct { 30 | slots chan struct{} 31 | timeout time.Duration 32 | } 33 | 34 | // NewSemaphore creates a Semaphore. The count parameter must be a positive 35 | // number. A timeout of zero means that there is no timeout. 36 | func NewSemaphore(count int, timeout time.Duration) *Semaphore { 37 | sem := &Semaphore{ 38 | slots: make(chan struct{}, count), 39 | timeout: timeout, 40 | } 41 | for i := 0; i < count; i++ { 42 | sem.slots <- struct{}{} 43 | } 44 | return sem 45 | } 46 | 47 | // Acquire returns true on successful acquisition, and 48 | // false on a timeout. 49 | func (sem *Semaphore) Acquire() bool { 50 | if sem.timeout == 0 { 51 | <-sem.slots 52 | return true 53 | } 54 | tm := time.NewTimer(sem.timeout) 55 | defer tm.Stop() 56 | select { 57 | case <-sem.slots: 58 | return true 59 | case <-tm.C: 60 | return false 61 | } 62 | } 63 | 64 | // TryAcquire acquires a semaphore if it's immediately available. 65 | // It returns false otherwise. 66 | func (sem *Semaphore) TryAcquire() bool { 67 | select { 68 | case <-sem.slots: 69 | return true 70 | default: 71 | return false 72 | } 73 | } 74 | 75 | // Release releases the acquired semaphore. You must 76 | // not release more than the number of semaphores you've 77 | // acquired. 78 | func (sem *Semaphore) Release() { 79 | sem.slots <- struct{}{} 80 | } 81 | 82 | // Size returns the current number of available slots. 83 | func (sem *Semaphore) Size() int { 84 | return len(sem.slots) 85 | } 86 | -------------------------------------------------------------------------------- /pkg/util/rate_limit_breaker/rate_limit/sliding_window_test.go: -------------------------------------------------------------------------------- 1 | package rate_limit 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestSlidingWindowRateLimiter_Limit(t *testing.T) { 10 | rl := NewSlidingWindowRateLimiter(10) 11 | ch := make(chan int) 12 | go func() { 13 | for i := 0; i < 100; i++ { 14 | ch <- i 15 | } 16 | }() 17 | 18 | // 使用 timer 确保只处理 1s 之内的数据 19 | OUTER: 20 | for { 21 | select { 22 | case <-time.NewTimer(time.Second).C: // time is up 23 | break OUTER 24 | case i := <-ch: 25 | err := rl.Limit() 26 | if i < 10 { 27 | assert.Nil(t, err) 28 | } else { 29 | assert.Equal(t, err, ErrRateLimited) 30 | if i == 99 { 31 | break OUTER 32 | } 33 | } 34 | } 35 | } 36 | } 37 | 38 | func TestSlidingWindowRateLimiter_ChangeQpsThreshold(t *testing.T) { 39 | // 测试 ChangeQpsThreshold。qpsThreshold 一开始为 10 40 | // 1)发送 100 个请求,前 10 个通过,其他被限流 41 | // 2)等待 1s 并将 qpsThreshold 置为 20 42 | // 3)发送 100 个请求,前 20 个通过,其他被限流 43 | rl := NewSlidingWindowRateLimiter(10) 44 | ch := make(chan int) 45 | go func() { 46 | for i := 0; i < 100; i++ { 47 | ch <- i 48 | } 49 | 50 | // 等待 1s,将 qps 阈值置为 20 51 | <-time.NewTimer(time.Second).C 52 | rl.ChangeQpsThreshold(20) 53 | for i := 100; i < 200; i++ { 54 | ch <- i 55 | } 56 | }() 57 | 58 | OUTER: 59 | for { 60 | select { 61 | case <-time.NewTimer(time.Second * 2).C: // time is up 62 | break OUTER 63 | case i := <-ch: 64 | err := rl.Limit() 65 | if i < 100 { // 前 100 个,10 个通过 66 | if i < 10 { 67 | assert.Nil(t, err) 68 | } else { 69 | assert.Equal(t, err, ErrRateLimited) 70 | } 71 | } else { // 后 100 个,20 个通过 72 | if i < 120 { 73 | assert.Nil(t, err) 74 | } else { 75 | assert.Equal(t, err, ErrRateLimited) 76 | if i == 199 { 77 | break OUTER 78 | } 79 | } 80 | } 81 | } 82 | } 83 | } 84 | 85 | func TestSlidingWindowRateLimiter_ConcurrentLimit(t *testing.T) { 86 | rl := NewSlidingWindowRateLimiter(20000) 87 | resultCh := make(chan int) 88 | const GoRoutines = 100 89 | for i := 0; i < GoRoutines; i++ { 90 | go func() { 91 | passedCount := 0 92 | for j := 0; j < 1000; j++ { 93 | err := rl.Limit() 94 | if err == nil { 95 | passedCount++ 96 | } 97 | } 98 | resultCh <- passedCount 99 | }() 100 | } 101 | 102 | sum := 0 103 | i := 0 104 | for { 105 | count := <-resultCh 106 | sum += count 107 | i++ 108 | if i >= GoRoutines { 109 | break 110 | } 111 | } 112 | assert.Equal(t, sum, 20000) 113 | } 114 | -------------------------------------------------------------------------------- /docs/cn/quickstart.md: -------------------------------------------------------------------------------- 1 | # 快速上手 2 | 3 | 本文介绍如何快速上手体验Weir平台. 4 | 5 | ## 前提 6 | 7 | 使用 weir-proxy 的前提首先要部署一套TiDB集群。对 weir-proxy 来说, 后端数据库也可使用 MySQL 代替 TiDB 进行测试. 8 | 9 | ### 安装 TiDB/MySQL 10 | 11 | 安装 TiDB 可参考 [TiDB数据库快速上手指南](https://docs.pingcap.com/zh/tidb/stable/quick-start-with-tidb) 进行安装。 12 | 13 | ### 构造数据 14 | 15 | 安装完成后, 需要连接TiDB / MySQL 执行以下SQL语句进行建库和建表操作, 方便测试weir-proxy. 16 | 17 | #### 建库 18 | ``` 19 | DROP DATABASE IF EXISTS `test_weir_db`; 20 | CREATE DATABASE `test_weir_db`; 21 | USE `test_weir_db`; 22 | ``` 23 | 24 | #### 建表 25 | ``` 26 | CREATE TABLE `test_weir_user` ( 27 | `id` bigint(22) unsigned NOT NULL AUTO_INCREMENT, 28 | `name` varchar(128) NOT NULL, 29 | PRIMARY KEY (`id`), 30 | UNIQUE `uniq_name` (`name`) 31 | ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; 32 | 33 | CREATE TABLE `test_weir_admin` ( 34 | `id` bigint(22) unsigned NOT NULL AUTO_INCREMENT, 35 | `name` varchar(128) NOT NULL, 36 | `status` varchar(128) NOT NULL DEFAULT 'normal', 37 | PRIMARY KEY (`id`), 38 | UNIQUE `uniq_name` (`name`) 39 | ) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; 40 | ``` 41 | 42 | #### 写入测试数据 43 | ``` 44 | INSERT INTO `test_weir_user` (name) VALUES ('Bob'); 45 | INSERT INTO `test_weir_user` (name) VALUES ('Alice'); 46 | 47 | INSERT INTO `test_weir_admin` (name) VALUES ('Ed'); 48 | INSERT INTO `test_weir_admin` (name) VALUES ('Huang'); 49 | ``` 50 | 51 | ## 安装启动 weir-proxy 52 | 53 | ### 从源码编译安装 54 | 55 | 首先, 从github克隆代码仓库到本地. 56 | 57 | ``` 58 | $ git clone https://github.com/tidb-incubator/weir 59 | ``` 60 | 61 | 构建weir-proxy. 62 | 63 | ``` 64 | $ make weirproxy 65 | ``` 66 | 67 | 生成的weirproxy二进制文件位于bin/目录下, 文件名为bin/weirproxy. 68 | 69 | 启动weir-proxy. 70 | 71 | ``` 72 | $ ./bin/weirproxy & 73 | ``` 74 | 75 | weir-proxy会默认读取示例配置文件conf/weirproxy.yml进行启动, 示例配置使用本地文件作为namespace配置中心, 配置文件位于conf/namespace/目录下. 76 | 77 | 使用MySQL客户端通过weir-proxy访问TiDB集群. 78 | 79 | ``` 80 | $ mysql -h127.0.0.1 -P6000 -uhello -pworld test_weir_db 81 | 82 | mysql: [Warning] Using a password on the command line interface can be insecure. 83 | Welcome to the MySQL monitor. Commands end with ; or \g. 84 | Your MySQL connection id is 1 85 | Server version: 5.7.25-TiDB-None MySQL Community Server (GPL) 86 | 87 | Copyright (c) 2000, 2016, Oracle and/or its affiliates. All rights reserved. 88 | 89 | Oracle is a registered trademark of Oracle Corporation and/or its 90 | affiliates. Other names may be trademarks of their respective 91 | owners. 92 | 93 | Type 'help;' or '\h' for help. Type '\c' to clear the current input statement. 94 | 95 | mysql> 96 | ``` 97 | 98 | 如果看到连接成功, 说明weir-proxy已经启动并可以使用了. 恭喜你! 99 | 100 | 目前 Weir 平台的代理层中间件 weir-proxy 已开源, 中控组件 weir-controller 和控制台 weir-dashboard 还未开源. -------------------------------------------------------------------------------- /pkg/util/rate_limit_breaker/rate_limit/leaky_bucket.go: -------------------------------------------------------------------------------- 1 | package rate_limit 2 | 3 | import ( 4 | "sync/atomic" 5 | "time" 6 | ) 7 | 8 | /* 并发的 rate limiter。 9 | * 基于队列的 leaky bucket 算法实现。 10 | * 参见 https://en.wikipedia.org/wiki/Leaky_bucket leaky bucket 有两种实现方式 11 | * As a meter: 此与 token bucket 等价 12 | * As a queue: 此具有更严格的限速,能够避免 burst flow 13 | * 14 | * 15 | * 测试结果(MacBook Pro 16 英寸): 16 | * 1K QPS: 1 req per 1 milli sec ( 1ms) 17 | * 1w QPS: 1 req per 100 micro sec ( 0.1ms) timer 可以支持此精度。 18 | * 2w QPS: 1 req per 50 micro sec (0.05ms) timer 可以支持此精度。 19 | * 10w QPS: 1 req per 10 micro sec (0.01ms) timer 开始出现误差。 20 | * 为确保精度,建议不要超过 10w QPS。 21 | * (此问题并非无解,是有优化方案的。) 22 | */ 23 | type LeakyBucketRateLimiter struct { 24 | qpsThreshold int64 // 共享,需通过原子操作进行读写 25 | ch chan chan struct{} // the leaky bucket 26 | stopCh chan struct{} // 用于关闭这个 rate limiter 27 | changeCh chan struct{} // 修改本 rate limiter 之后,通过此 channel 进行通知 28 | } 29 | 30 | func NewLeakyBucketRateLimiter(qpsThreshold int64) *LeakyBucketRateLimiter { 31 | lbrl := &LeakyBucketRateLimiter{ 32 | qpsThreshold: qpsThreshold, 33 | ch: make(chan chan struct{}, 1), 34 | stopCh: make(chan struct{}), 35 | changeCh: make(chan struct{}), 36 | } 37 | go func() { 38 | lbrl.leak() 39 | }() 40 | return lbrl 41 | } 42 | 43 | func (lbrl *LeakyBucketRateLimiter) ChangeQpsThreshold(newQpsThreshold int64) { 44 | atomic.StoreInt64(&lbrl.qpsThreshold, newQpsThreshold) 45 | lbrl.changeCh <- struct{}{} 46 | } 47 | 48 | func (lbrl *LeakyBucketRateLimiter) getTick() time.Duration { 49 | qpsThreshold := atomic.LoadInt64(&lbrl.qpsThreshold) 50 | tick := time.Duration(1000.0 * int64(time.Millisecond) / qpsThreshold) 51 | return tick 52 | } 53 | 54 | func (lbrl *LeakyBucketRateLimiter) leak() { 55 | // 这里的基本逻辑是,根据 qpsThreshold 计算 tick,每个 tick 的时间间隔里,只允许一次动作。 56 | // 然而,qpsThreshold 很大时,tick 很小。这带来两个问题: 57 | // 1) timer 的精度可能无法满足要求,导致误差变大 58 | // 2) 本方法循环次数很多,占用较多 CPU 59 | // 一个优化方案,是去掉「每个 tick 的时间间隔里,只允许一次动作」限制。 60 | // 而是首先确保 tick 保持在合理的范围(如 >=0.1ms 且越小越好),并根据该 tick 和 qpsThreshold 确定每个 tick 里允许的动作次数(须为整数)。 61 | tickCh := time.Tick(lbrl.getTick()) 62 | OUTER: 63 | for { 64 | select { 65 | case <-lbrl.stopCh: // stopped 66 | break OUTER 67 | case <-lbrl.changeCh: // rate limiter modified 68 | newTick := lbrl.getTick() 69 | tickCh = time.Tick(newTick) 70 | case <-tickCh: 71 | select { 72 | case waiterCh := <-lbrl.ch: 73 | waiterCh <- struct{}{} 74 | default: 75 | // pass 76 | } 77 | } 78 | } 79 | } 80 | 81 | func (lbrl *LeakyBucketRateLimiter) Limit() error { 82 | ch := make(chan struct{}, 1) 83 | lbrl.ch <- ch 84 | <-ch 85 | return nil 86 | } 87 | 88 | func (lbrl *LeakyBucketRateLimiter) Close() { 89 | lbrl.stopCh <- struct{}{} 90 | } 91 | -------------------------------------------------------------------------------- /pkg/util/ast/ast_util.go: -------------------------------------------------------------------------------- 1 | package ast 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "strings" 7 | 8 | "github.com/pingcap/parser/ast" 9 | "github.com/pingcap/parser/format" 10 | driver "github.com/pingcap/tidb/types/parser_driver" 11 | ) 12 | 13 | const ctxAstTableNameKey = "ctx_ast_table_name" 14 | 15 | func CtxWithAstTableName(ctx context.Context, tableName string) context.Context { 16 | return context.WithValue(ctx, ctxAstTableNameKey, tableName) 17 | } 18 | 19 | func GetAstTableNameFromCtx(ctx context.Context) (string, bool) { 20 | tableName := ctx.Value(ctxAstTableNameKey) 21 | if tableName == nil { 22 | return "", false 23 | } 24 | tableNameStr, ok := tableName.(string) 25 | if !ok { 26 | return "", false 27 | } 28 | return tableNameStr, true 29 | } 30 | 31 | type FirstTableNameVisitor struct { 32 | table string 33 | found bool 34 | } 35 | 36 | func (f *FirstTableNameVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) { 37 | switch nn := n.(type) { 38 | case *ast.TableName: 39 | f.table = nn.Name.String() 40 | f.found = true 41 | return n, true 42 | } 43 | return n, false 44 | } 45 | 46 | func (f *FirstTableNameVisitor) Leave(n ast.Node) (node ast.Node, ok bool) { 47 | return n, !f.found 48 | } 49 | 50 | func (f *FirstTableNameVisitor) TableName() string { 51 | return f.table 52 | } 53 | 54 | func ExtractFirstTableNameFromStmt(stmt ast.StmtNode) string { 55 | visitor := &FirstTableNameVisitor{} 56 | stmt.Accept(visitor) 57 | return visitor.table 58 | } 59 | 60 | type AstVisitor struct { 61 | sqlFeature string 62 | } 63 | 64 | func ExtractAstVisit(stmt ast.StmtNode) (*AstVisitor, error) { 65 | visitor := &AstVisitor{} 66 | 67 | stmt.Accept(visitor) 68 | 69 | sb := strings.Builder{} 70 | if err := stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)); err != nil { 71 | return nil, err 72 | } 73 | visitor.sqlFeature = sb.String() 74 | 75 | return visitor, nil 76 | } 77 | 78 | func (f *AstVisitor) Enter(n ast.Node) (node ast.Node, skipChildren bool) { 79 | switch nn := n.(type) { 80 | case *ast.PatternInExpr: 81 | if len(nn.List) == 0 { 82 | return nn, false 83 | } 84 | if _, ok := nn.List[0].(*driver.ValueExpr); ok { 85 | nn.List = nn.List[:1] 86 | } 87 | case *driver.ValueExpr: 88 | nn.SetValue("?") 89 | } 90 | return n, false 91 | } 92 | 93 | func (f *AstVisitor) Leave(n ast.Node) (node ast.Node, ok bool) { 94 | return n, true 95 | } 96 | 97 | func (f *AstVisitor) SqlFeature() string { 98 | return f.sqlFeature 99 | } 100 | 101 | func UInt322Bytes(n uint32) []byte { 102 | b := make([]byte, 4) 103 | binary.LittleEndian.PutUint32(b, n) 104 | return b 105 | } 106 | 107 | func Bytes2Uint32(b []byte) uint32 { 108 | return binary.LittleEndian.Uint32(b) 109 | } 110 | -------------------------------------------------------------------------------- /pkg/util/timer/time_wheel_test.go: -------------------------------------------------------------------------------- 1 | package timer 2 | 3 | import ( 4 | "strconv" 5 | "sync/atomic" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type A struct { 13 | a int 14 | b string 15 | isCallbacked int32 16 | } 17 | 18 | func (a *A) callback() { 19 | atomic.StoreInt32(&a.isCallbacked, 1) 20 | } 21 | 22 | func (a *A) getCallbackValue() int32 { 23 | return atomic.LoadInt32(&a.isCallbacked) 24 | } 25 | 26 | func newTimeWheel() *TimeWheel { 27 | tw, err := NewTimeWheel(time.Second, 3600) 28 | if err != nil { 29 | panic(err) 30 | } 31 | tw.Start() 32 | return tw 33 | } 34 | 35 | func TestNewTimeWheel(t *testing.T) { 36 | tests := []struct { 37 | name string 38 | tick time.Duration 39 | bucketNum int 40 | hasErr bool 41 | }{ 42 | {tick: time.Second, bucketNum: 0, hasErr: true}, 43 | {tick: time.Millisecond, bucketNum: 1, hasErr: true}, 44 | {tick: time.Second, bucketNum: 1, hasErr: false}, 45 | } 46 | for _, test := range tests { 47 | t.Run(test.name, func(t *testing.T) { 48 | _, err := NewTimeWheel(test.tick, test.bucketNum) 49 | assert.Equal(t, test.hasErr, err != nil) 50 | }) 51 | } 52 | } 53 | 54 | func TestAdd(t *testing.T) { 55 | tw := newTimeWheel() 56 | a := &A{} 57 | err := tw.Add(time.Second*1, "test", a.callback) 58 | assert.NoError(t, err) 59 | 60 | time.Sleep(time.Millisecond * 500) 61 | assert.Equal(t, int32(0), a.getCallbackValue()) 62 | time.Sleep(time.Second * 2) 63 | assert.Equal(t, int32(1), a.getCallbackValue()) 64 | tw.Stop() 65 | } 66 | 67 | func TestAddMultipleTimes(t *testing.T) { 68 | a := &A{} 69 | tw := newTimeWheel() 70 | for i := 0; i < 4; i++ { 71 | err := tw.Add(time.Second, "test", a.callback) 72 | assert.NoError(t, err) 73 | time.Sleep(time.Millisecond * 500) 74 | t.Logf("current: %d", i) 75 | assert.Equal(t, int32(0), a.getCallbackValue()) 76 | } 77 | 78 | time.Sleep(time.Second * 2) 79 | assert.Equal(t, int32(1), a.getCallbackValue()) 80 | tw.Stop() 81 | } 82 | 83 | func TestRemove(t *testing.T) { 84 | a := &A{a: 10, b: "test"} 85 | tw := newTimeWheel() 86 | err := tw.Add(time.Second*1, a, a.callback) 87 | assert.NoError(t, err) 88 | 89 | time.Sleep(time.Millisecond * 500) 90 | assert.Equal(t, int32(0), a.getCallbackValue()) 91 | err = tw.Remove(a) 92 | assert.NoError(t, err) 93 | time.Sleep(time.Second * 2) 94 | assert.Equal(t, int32(0), a.getCallbackValue()) 95 | tw.Stop() 96 | } 97 | 98 | func BenchmarkAdd(b *testing.B) { 99 | a := &A{} 100 | tw := newTimeWheel() 101 | for i := 0; i < b.N; i++ { 102 | key := "test" + strconv.Itoa(i) 103 | err := tw.Add(time.Second, key, a.callback) 104 | if err != nil { 105 | b.Fatalf("benchmark Add failed, %v", err) 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /pkg/proxy/server/column.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | "github.com/pingcap/parser/mysql" 18 | ) 19 | 20 | const maxColumnNameSize = 256 21 | 22 | // ColumnInfo contains information of a column 23 | type ColumnInfo struct { 24 | Schema string 25 | Table string 26 | OrgTable string 27 | Name string 28 | OrgName string 29 | ColumnLength uint32 30 | Charset uint16 31 | Flag uint16 32 | Decimal uint8 33 | Type uint8 34 | DefaultValueLength uint64 35 | DefaultValue []byte 36 | } 37 | 38 | // Dump dumps ColumnInfo to bytes. 39 | func (column *ColumnInfo) Dump(buffer []byte) []byte { 40 | nameDump, orgnameDump := []byte(column.Name), []byte(column.OrgName) 41 | if len(nameDump) > maxColumnNameSize { 42 | nameDump = nameDump[0:maxColumnNameSize] 43 | } 44 | if len(orgnameDump) > maxColumnNameSize { 45 | orgnameDump = orgnameDump[0:maxColumnNameSize] 46 | } 47 | buffer = dumpLengthEncodedString(buffer, []byte("def")) 48 | buffer = dumpLengthEncodedString(buffer, []byte(column.Schema)) 49 | buffer = dumpLengthEncodedString(buffer, []byte(column.Table)) 50 | buffer = dumpLengthEncodedString(buffer, []byte(column.OrgTable)) 51 | buffer = dumpLengthEncodedString(buffer, nameDump) 52 | buffer = dumpLengthEncodedString(buffer, orgnameDump) 53 | 54 | buffer = append(buffer, 0x0c) 55 | 56 | buffer = dumpUint16(buffer, column.Charset) 57 | buffer = dumpUint32(buffer, column.ColumnLength) 58 | buffer = append(buffer, dumpType(column.Type)) 59 | buffer = dumpUint16(buffer, dumpFlag(column.Type, column.Flag)) 60 | buffer = append(buffer, column.Decimal) 61 | buffer = append(buffer, 0, 0) 62 | 63 | if column.DefaultValue != nil { 64 | buffer = dumpUint64(buffer, uint64(len(column.DefaultValue))) 65 | buffer = append(buffer, column.DefaultValue...) 66 | } 67 | 68 | return buffer 69 | } 70 | 71 | func dumpFlag(tp byte, flag uint16) uint16 { 72 | switch tp { 73 | case mysql.TypeSet: 74 | return flag | uint16(mysql.SetFlag) 75 | case mysql.TypeEnum: 76 | return flag | uint16(mysql.EnumFlag) 77 | default: 78 | if mysql.HasBinaryFlag(uint(flag)) { 79 | return flag | uint16(mysql.NotNullFlag) 80 | } 81 | return flag 82 | } 83 | } 84 | 85 | func dumpType(tp byte) byte { 86 | switch tp { 87 | case mysql.TypeSet, mysql.TypeEnum: 88 | return mysql.TypeString 89 | default: 90 | return tp 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /pkg/util/sync2/atomic_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sync2 18 | 19 | import ( 20 | "testing" 21 | "time" 22 | 23 | "github.com/stretchr/testify/assert" 24 | ) 25 | 26 | func TestAtomicInt32(t *testing.T) { 27 | i := NewAtomicInt32(1) 28 | assert.Equal(t, int32(1), i.Get()) 29 | 30 | i.Set(2) 31 | assert.Equal(t, int32(2), i.Get()) 32 | 33 | i.Add(1) 34 | assert.Equal(t, int32(3), i.Get()) 35 | 36 | i.CompareAndSwap(3, 4) 37 | assert.Equal(t, int32(4), i.Get()) 38 | 39 | i.CompareAndSwap(3, 5) 40 | assert.Equal(t, int32(4), i.Get()) 41 | } 42 | 43 | func TestAtomicInt64(t *testing.T) { 44 | i := NewAtomicInt64(1) 45 | assert.Equal(t, int64(1), i.Get()) 46 | 47 | i.Set(2) 48 | assert.Equal(t, int64(2), i.Get()) 49 | 50 | i.Add(1) 51 | assert.Equal(t, int64(3), i.Get()) 52 | 53 | i.CompareAndSwap(3, 4) 54 | assert.Equal(t, int64(4), i.Get()) 55 | 56 | i.CompareAndSwap(3, 5) 57 | assert.Equal(t, int64(4), i.Get()) 58 | } 59 | 60 | func TestAtomicDuration(t *testing.T) { 61 | d := NewAtomicDuration(time.Second) 62 | assert.Equal(t, time.Second, d.Get()) 63 | 64 | d.Set(time.Second * 2) 65 | assert.Equal(t, time.Second*2, d.Get()) 66 | 67 | d.Add(time.Second) 68 | assert.Equal(t, time.Second*3, d.Get()) 69 | 70 | d.CompareAndSwap(time.Second*3, time.Second*4) 71 | assert.Equal(t, time.Second*4, d.Get()) 72 | 73 | d.CompareAndSwap(time.Second*3, time.Second*5) 74 | assert.Equal(t, time.Second*4, d.Get()) 75 | } 76 | 77 | func TestAtomicString(t *testing.T) { 78 | var s AtomicString 79 | assert.Equal(t, "", s.Get()) 80 | 81 | s.Set("a") 82 | assert.Equal(t, "a", s.Get()) 83 | 84 | assert.Equal(t, false, s.CompareAndSwap("b", "c")) 85 | assert.Equal(t, "a", s.Get()) 86 | 87 | assert.Equal(t, true, s.CompareAndSwap("a", "c")) 88 | assert.Equal(t, "c", s.Get()) 89 | } 90 | 91 | func TestAtomicBool(t *testing.T) { 92 | b := NewAtomicBool(true) 93 | assert.Equal(t, true, b.Get()) 94 | 95 | b.Set(false) 96 | assert.Equal(t, false, b.Get()) 97 | 98 | b.Set(true) 99 | assert.Equal(t, true, b.Get()) 100 | 101 | assert.Equal(t, false, b.CompareAndSwap(false, true)) 102 | 103 | assert.Equal(t, true, b.CompareAndSwap(true, false)) 104 | 105 | assert.Equal(t, true, b.CompareAndSwap(false, false)) 106 | 107 | assert.Equal(t, true, b.CompareAndSwap(false, true)) 108 | 109 | assert.Equal(t, true, b.CompareAndSwap(true, true)) 110 | } 111 | -------------------------------------------------------------------------------- /pkg/proxy/server/conn_stmt.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | 7 | "github.com/pingcap/errors" 8 | "github.com/pingcap/parser/mysql" 9 | ) 10 | 11 | // TODO(eastfisher): fix me when prepare is implemented 12 | func (cc *clientConn) preparedStmt2String(stmtID uint32) string { 13 | return "" 14 | } 15 | 16 | // TODO(eastfisher): fix me when prepare is implemented 17 | func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string { 18 | return "" 19 | } 20 | 21 | func (cc *clientConn) handleStmtPrepare(ctx context.Context, sql string) error { 22 | stmtId, columns, params, err := cc.ctx.Prepare(ctx, sql) 23 | if err != nil { 24 | return err 25 | } 26 | data := make([]byte, 4, 128) 27 | 28 | //status ok 29 | data = append(data, 0) 30 | //stmt id 31 | data = dumpUint32(data, uint32(stmtId)) 32 | //number columns 33 | data = dumpUint16(data, uint16(len(columns))) 34 | //number params 35 | data = dumpUint16(data, uint16(len(params))) 36 | //filter [00] 37 | data = append(data, 0) 38 | //warning count 39 | data = append(data, 0, 0) //TODO support warning count 40 | 41 | if err := cc.writePacket(data); err != nil { 42 | return err 43 | } 44 | 45 | if len(params) > 0 { 46 | for i := 0; i < len(params); i++ { 47 | data = data[0:4] 48 | data = params[i].Dump(data) 49 | 50 | if err := cc.writePacket(data); err != nil { 51 | return err 52 | } 53 | } 54 | 55 | if err := cc.writeEOF(0); err != nil { 56 | return err 57 | } 58 | } 59 | 60 | if len(columns) > 0 { 61 | for i := 0; i < len(columns); i++ { 62 | data = data[0:4] 63 | data = columns[i].Dump(data) 64 | 65 | if err := cc.writePacket(data); err != nil { 66 | return err 67 | } 68 | } 69 | 70 | if err := cc.writeEOF(0); err != nil { 71 | return err 72 | } 73 | 74 | } 75 | return cc.flush() 76 | } 77 | 78 | func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) error { 79 | if len(data) < 9 { 80 | return mysql.ErrMalformPacket 81 | } 82 | 83 | stmtID := binary.LittleEndian.Uint32(data[0:4]) 84 | ret, err := cc.ctx.StmtExecuteForward(ctx, int(stmtID), data) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | if ret != nil { 90 | err = cc.writeGoMySQLResultset(ctx, ret.Resultset, true, ret.Status, 0) 91 | } else { 92 | err = cc.writeOK() 93 | } 94 | return err 95 | } 96 | 97 | // TODO(eastfisher): implement this function 98 | func (cc *clientConn) handleStmtSendLongData(data []byte) error { 99 | return errors.New("stmt not implemented") 100 | } 101 | 102 | // TODO(eastfisher): implement this function 103 | func (cc *clientConn) handleStmtReset(data []byte) error { 104 | return errors.New("stmt not implemented") 105 | } 106 | 107 | func (cc *clientConn) handleStmtClose(ctx context.Context, data []byte) error { 108 | if len(data) < 4 { 109 | return nil 110 | } 111 | 112 | stmtID := int(binary.LittleEndian.Uint32(data[0:4])) 113 | if err := cc.ctx.StmtClose(ctx, stmtID); err != nil { 114 | return err 115 | } 116 | 117 | return cc.writeOK() 118 | } 119 | -------------------------------------------------------------------------------- /pkg/proxy/metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | ) 6 | 7 | const ( 8 | ModuleWeirProxy = "weirproxy" 9 | ) 10 | 11 | // metrics labels. 12 | const ( 13 | LabelServer = "server" 14 | LabelQueryCtx = "queryctx" 15 | LabelBackend = "backend" 16 | LabelSession = "session" 17 | LabelDomain = "domain" 18 | LabelDDLOwner = "ddl-owner" 19 | LabelDDL = "ddl" 20 | LabelDDLWorker = "ddl-worker" 21 | LabelDDLSyncer = "ddl-syncer" 22 | LabelGCWorker = "gcworker" 23 | LabelAnalyze = "analyze" 24 | 25 | LabelBatchRecvLoop = "batch-recv-loop" 26 | LabelBatchSendLoop = "batch-send-loop" 27 | 28 | opSucc = "ok" 29 | opFailed = "err" 30 | 31 | LableScope = "scope" 32 | ScopeGlobal = "global" 33 | ScopeSession = "session" 34 | ) 35 | 36 | // RetLabel returns "ok" when err == nil and "err" when err != nil. 37 | // This could be useful when you need to observe the operation result. 38 | func RetLabel(err error) string { 39 | if err == nil { 40 | return opSucc 41 | } 42 | return opFailed 43 | } 44 | 45 | func RegisterProxyMetrics(cluster string) { 46 | curryingLabelsWithLblCluster := map[string]string{LblCluster: cluster} 47 | 48 | PanicCounter = PanicCounter.MustCurryWith(curryingLabelsWithLblCluster) 49 | prometheus.MustRegister(PanicCounter) 50 | QueryTotalCounter = QueryTotalCounter.MustCurryWith(curryingLabelsWithLblCluster) 51 | prometheus.MustRegister(QueryTotalCounter) 52 | ExecuteErrorCounter = ExecuteErrorCounter.MustCurryWith(curryingLabelsWithLblCluster) 53 | prometheus.MustRegister(ExecuteErrorCounter) 54 | ConnGauge = ConnGauge.MustCurryWith(curryingLabelsWithLblCluster) 55 | prometheus.MustRegister(ConnGauge) 56 | 57 | // query ctx metrics 58 | QueryCtxQueryCounter = QueryCtxQueryCounter.MustCurryWith(curryingLabelsWithLblCluster) 59 | prometheus.MustRegister(QueryCtxQueryCounter) 60 | QueryCtxQueryDeniedCounter = QueryCtxQueryDeniedCounter.MustCurryWith(curryingLabelsWithLblCluster) 61 | prometheus.MustRegister(QueryCtxQueryDeniedCounter) 62 | QueryCtxQueryDurationHistogram = QueryCtxQueryDurationHistogram.MustCurryWith(curryingLabelsWithLblCluster).(*prometheus.HistogramVec) 63 | prometheus.MustRegister(QueryCtxQueryDurationHistogram) 64 | QueryCtxGauge = QueryCtxGauge.MustCurryWith(curryingLabelsWithLblCluster) 65 | prometheus.MustRegister(QueryCtxGauge) 66 | QueryCtxAttachedConnGauge = QueryCtxAttachedConnGauge.MustCurryWith(curryingLabelsWithLblCluster) 67 | prometheus.MustRegister(QueryCtxAttachedConnGauge) 68 | QueryCtxTransactionDuration = QueryCtxTransactionDuration.MustCurryWith(curryingLabelsWithLblCluster).(*prometheus.HistogramVec) 69 | prometheus.MustRegister(QueryCtxTransactionDuration) 70 | 71 | // backend metrics 72 | BackendEventCounter = BackendEventCounter.MustCurryWith(curryingLabelsWithLblCluster) 73 | prometheus.MustRegister(BackendEventCounter) 74 | BackendQueryCounter = BackendQueryCounter.MustCurryWith(curryingLabelsWithLblCluster) 75 | prometheus.MustRegister(BackendQueryCounter) 76 | BackendConnInUseGauge = BackendConnInUseGauge.MustCurryWith(curryingLabelsWithLblCluster) 77 | prometheus.MustRegister(BackendConnInUseGauge) 78 | } 79 | -------------------------------------------------------------------------------- /pkg/util/rate_limit_breaker/sliding_window.go: -------------------------------------------------------------------------------- 1 | package rate_limit_breaker 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | /* 8 | * 一个 SlidingWindow 由若干个子单元(Cell)组成,数量(Size)在创建 SlidingWindow 时指定。 9 | * 各个 Cell 的时长(CellIntervalMs)相同,在创建 SlidingWindow 时指定。 10 | * Cell 持有一个 map 用于统计计数 11 | * 12 | * 自 epoch 以来,按照一个 Cell 的时长,将时间轴切成一个个的段,对应到 Cell。 13 | * 每个 Cell 记录其开始时间。因为我们并不是按照时间定时更新 Cells,而是 hit 时更新,而 hit 是由调用者控制的。 14 | */ 15 | 16 | type Cell struct { 17 | startMs int64 18 | // metricName => count 19 | stats map[string]int64 20 | } 21 | 22 | func (cell *Cell) Reset() { 23 | cell.startMs = 0 24 | cell.stats = map[string]int64{} 25 | } 26 | 27 | type SlidingWindow struct { 28 | Size int64 29 | CellIntervalMs int64 30 | Cells []*Cell // invariant: len(Cells) == Size. 31 | } 32 | 33 | func NewSlidingWindow(size int64, cellIntervalMs int64) *SlidingWindow { 34 | cells := make([]*Cell, size) 35 | for i := 0; int64(i) < size; i++ { 36 | cells[i] = &Cell{ 37 | startMs: 0, 38 | stats: map[string]int64{}, 39 | } 40 | } 41 | 42 | return &SlidingWindow{ 43 | Size: size, 44 | CellIntervalMs: cellIntervalMs, 45 | Cells: cells, 46 | } 47 | } 48 | 49 | // 一次动作(如请求),记为一次 Hit。 50 | func (sw *SlidingWindow) Hit(nowMs int64, metricNames ...string) { 51 | cell := sw.getCell(nowMs) 52 | if nowMs-cell.startMs >= sw.CellIntervalMs { // lazily check if cell expired 53 | cell.startMs = sw.cellStartMs(nowMs) 54 | cell.stats = map[string]int64{} 55 | } 56 | for _, metric := range metricNames { 57 | cell.stats[metric]++ 58 | } 59 | } 60 | 61 | func (sw *SlidingWindow) getCell(nowMs int64) *Cell { 62 | idx := nowMs / sw.CellIntervalMs % sw.Size 63 | return sw.Cells[idx] 64 | } 65 | 66 | func (sw *SlidingWindow) cellStartMs(nowMs int64) int64 { 67 | return nowMs - nowMs%sw.CellIntervalMs 68 | } 69 | 70 | func (sw *SlidingWindow) GetHits(nowMs int64, metricNames ...string) map[string]int64 { 71 | windowStart := nowMs - sw.Size*sw.CellIntervalMs 72 | stats := map[string]int64{} 73 | for _, cell := range sw.Cells { 74 | if cell.startMs < windowStart { // lazily check if cell expired 75 | continue 76 | } 77 | for _, metricName := range metricNames { 78 | stats[metricName] += cell.stats[metricName] 79 | } 80 | } 81 | return stats 82 | } 83 | 84 | func (sw *SlidingWindow) GetNowHits(nowMs int64, metricNames ...string) map[string]int64 { 85 | cell := sw.getCell(nowMs) 86 | stats := map[string]int64{} 87 | for _, metricName := range metricNames { 88 | stats[metricName] += cell.stats[metricName] 89 | } 90 | return stats 91 | } 92 | 93 | func (sw *SlidingWindow) GetHit(nowMs int64, metricName string) int64 { 94 | return sw.GetHits(nowMs, metricName)[metricName] 95 | } 96 | 97 | func (sw *SlidingWindow) GetActualDurationMs(nowMs int64) int64 { 98 | // 当前时间点,可能处在某个 cell 正中间,这里精确计算所涉及 cell 的总时间段 99 | actualDurationMs := (sw.Size-1)*sw.CellIntervalMs + nowMs%sw.CellIntervalMs 100 | return actualDurationMs 101 | } 102 | 103 | // timestamp in ms 104 | func GetNowMs() int64 { 105 | return time.Now().UnixNano() / int64(time.Millisecond) 106 | } 107 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/namespace.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/pingcap/errors" 7 | "github.com/tidb-incubator/weir/pkg/config" 8 | "github.com/tidb-incubator/weir/pkg/proxy/driver" 9 | "github.com/tidb-incubator/weir/pkg/proxy/metrics" 10 | ) 11 | 12 | type NamespaceHolder struct { 13 | nss map[string]Namespace 14 | } 15 | 16 | type NamespaceWrapper struct { 17 | nsmgr *NamespaceManager 18 | name string 19 | } 20 | 21 | func CreateNamespaceHolder(cfgs []*config.Namespace, build NamespaceBuilder) (*NamespaceHolder, error) { 22 | nss := make(map[string]Namespace, len(cfgs)) 23 | 24 | for _, cfg := range cfgs { 25 | ns, err := build(cfg) 26 | if err != nil { 27 | return nil, errors.WithMessage(err, fmt.Sprintf("create namespace error, namespace: %s", cfg.Namespace)) 28 | } 29 | nss[cfg.Namespace] = ns 30 | } 31 | 32 | holder := &NamespaceHolder{ 33 | nss: nss, 34 | } 35 | return holder, nil 36 | } 37 | 38 | func (n *NamespaceHolder) Get(name string) (Namespace, bool) { 39 | ns, ok := n.nss[name] 40 | return ns, ok 41 | } 42 | 43 | func (n *NamespaceHolder) Set(name string, ns Namespace) { 44 | n.nss[name] = ns 45 | } 46 | 47 | func (n *NamespaceHolder) Delete(name string) { 48 | delete(n.nss, name) 49 | } 50 | 51 | func (n *NamespaceHolder) Clone() *NamespaceHolder { 52 | nss := make(map[string]Namespace) 53 | for name, ns := range n.nss { 54 | nss[name] = ns 55 | } 56 | return &NamespaceHolder{ 57 | nss: nss, 58 | } 59 | } 60 | 61 | func (n *NamespaceWrapper) Name() string { 62 | return n.name 63 | } 64 | 65 | func (n *NamespaceWrapper) IsDatabaseAllowed(db string) bool { 66 | return n.mustGetCurrentNamespace().IsDatabaseAllowed(db) 67 | } 68 | 69 | func (n *NamespaceWrapper) ListDatabases() []string { 70 | return n.mustGetCurrentNamespace().ListDatabases() 71 | } 72 | 73 | func (n *NamespaceWrapper) IsDeniedSQL(sqlFeature uint32) bool { 74 | return n.mustGetCurrentNamespace().IsDeniedSQL(sqlFeature) 75 | } 76 | 77 | func (n *NamespaceWrapper) IsAllowedSQL(sqlFeature uint32) bool { 78 | return n.mustGetCurrentNamespace().IsAllowedSQL(sqlFeature) 79 | } 80 | 81 | func (n *NamespaceWrapper) GetPooledConn(ctx context.Context) (driver.PooledBackendConn, error) { 82 | return n.mustGetCurrentNamespace().GetPooledConn(ctx) 83 | } 84 | 85 | func (n *NamespaceWrapper) IncrConnCount() { 86 | metrics.QueryCtxGauge.WithLabelValues(n.name).Inc() 87 | } 88 | 89 | func (n *NamespaceWrapper) DescConnCount() { 90 | metrics.QueryCtxGauge.WithLabelValues(n.name).Dec() 91 | } 92 | 93 | func (n *NamespaceWrapper) Closed() bool { 94 | _, ok := n.nsmgr.getCurrentNamespaces().Get(n.name) 95 | return !ok 96 | } 97 | 98 | func (n *NamespaceWrapper) GetBreaker() (driver.Breaker, error) { 99 | return n.mustGetCurrentNamespace().GetBreaker() 100 | } 101 | 102 | func (n *NamespaceWrapper) GetRateLimiter() driver.RateLimiter { 103 | return n.mustGetCurrentNamespace().GetRateLimiter() 104 | } 105 | 106 | func (n *NamespaceWrapper) mustGetCurrentNamespace() Namespace { 107 | ns, ok := n.nsmgr.getCurrentNamespaces().Get(n.name) 108 | if !ok { 109 | panic(errors.New("namespace not found")) 110 | } 111 | return ns 112 | } 113 | -------------------------------------------------------------------------------- /pkg/proxy/driver/sessionvars.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "fmt" 5 | "sync/atomic" 6 | 7 | "github.com/pingcap/parser/ast" 8 | "github.com/pingcap/tidb/sessionctx/variable" 9 | ) 10 | 11 | type SessionVarsWrapper struct { 12 | sessionVarMap map[string]*ast.VariableAssignment 13 | sessionVars *variable.SessionVars 14 | affectedRows uint64 15 | } 16 | 17 | func NewSessionVarsWrapper(sessionVars *variable.SessionVars) *SessionVarsWrapper { 18 | return &SessionVarsWrapper{ 19 | sessionVars: sessionVars, 20 | sessionVarMap: make(map[string]*ast.VariableAssignment), 21 | } 22 | } 23 | 24 | func (s *SessionVarsWrapper) SessionVars() *variable.SessionVars { 25 | return s.sessionVars 26 | } 27 | 28 | func (s *SessionVarsWrapper) GetAllSystemVars() map[string]*ast.VariableAssignment { 29 | ret := make(map[string]*ast.VariableAssignment, len(s.sessionVarMap)) 30 | for k, v := range s.sessionVarMap { 31 | ret[k] = v 32 | } 33 | return ret 34 | } 35 | 36 | func (s *SessionVarsWrapper) SetSystemVarAST(name string, v *ast.VariableAssignment) { 37 | s.sessionVarMap[name] = v 38 | } 39 | 40 | func (s *SessionVarsWrapper) CheckSessionSysVarValid(name string) error { 41 | if name == ast.SetNames || name == ast.SetCharset { 42 | return nil 43 | } 44 | sysVar := variable.GetSysVar(name) 45 | if sysVar == nil { 46 | return fmt.Errorf("%s is not a valid sysvar", name) 47 | } 48 | if (sysVar.Scope & variable.ScopeSession) == 0 { 49 | return fmt.Errorf("%s is not a session scope sysvar", name) 50 | } 51 | return nil 52 | } 53 | 54 | func (s *SessionVarsWrapper) SetSystemVarDefault(name string) { 55 | delete(s.sessionVarMap, name) 56 | } 57 | 58 | func (s *SessionVarsWrapper) Status() uint16 { 59 | return s.sessionVars.Status 60 | } 61 | 62 | func (s *SessionVarsWrapper) GetStatusFlag(flag uint16) bool { 63 | return s.sessionVars.GetStatusFlag(flag) 64 | } 65 | 66 | func (s *SessionVarsWrapper) SetStatusFlag(flag uint16, on bool) { 67 | s.sessionVars.SetStatusFlag(flag, on) 68 | } 69 | 70 | // TODO(eastfisher): remove this function 71 | func (s *SessionVarsWrapper) GetCharsetInfo() (charset, collation string) { 72 | return s.sessionVars.GetCharsetInfo() 73 | } 74 | 75 | func (s *SessionVarsWrapper) AffectedRows() uint64 { 76 | return s.affectedRows 77 | } 78 | 79 | func (s *SessionVarsWrapper) SetAffectRows(count uint64) { 80 | s.affectedRows = count 81 | } 82 | 83 | func (s *SessionVarsWrapper) LastInsertID() uint64 { 84 | return s.sessionVars.StmtCtx.LastInsertID 85 | } 86 | 87 | func (s *SessionVarsWrapper) SetLastInsertID(id uint64) { 88 | s.sessionVars.StmtCtx.LastInsertID = id 89 | } 90 | 91 | func (s *SessionVarsWrapper) GetMessage() string { 92 | return s.sessionVars.StmtCtx.GetMessage() 93 | } 94 | 95 | func (s *SessionVarsWrapper) SetMessage(msg string) { 96 | s.sessionVars.StmtCtx.SetMessage(msg) 97 | } 98 | 99 | func (s *SessionVarsWrapper) GetClientCapability() uint32 { 100 | return s.sessionVars.ClientCapability 101 | } 102 | 103 | func (s *SessionVarsWrapper) SetClientCapability(capability uint32) { 104 | s.sessionVars.ClientCapability = capability 105 | } 106 | 107 | func (s *SessionVarsWrapper) SetCommandValue(command byte) { 108 | atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command)) 109 | } 110 | -------------------------------------------------------------------------------- /pkg/proxy/server/conn_util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | "encoding/binary" 18 | "fmt" 19 | "strconv" 20 | 21 | "github.com/pingcap/errors" 22 | "github.com/pingcap/parser" 23 | "github.com/pingcap/parser/mysql" 24 | "github.com/pingcap/tidb/kv" 25 | "github.com/pingcap/tidb/util/hack" 26 | ) 27 | 28 | var _ fmt.Stringer = getLastStmtInConn{} 29 | 30 | type getLastStmtInConn struct { 31 | *clientConn 32 | } 33 | 34 | func (cc getLastStmtInConn) String() string { 35 | if len(cc.lastPacket) == 0 { 36 | return "" 37 | } 38 | cmd, data := cc.lastPacket[0], cc.lastPacket[1:] 39 | switch cmd { 40 | case mysql.ComInitDB: 41 | return "Use " + string(data) 42 | case mysql.ComFieldList: 43 | return "ListFields " + string(data) 44 | case mysql.ComQuery, mysql.ComStmtPrepare: 45 | sql := string(hack.String(data)) 46 | if cc.ctx.GetSessionVars().EnableLogDesensitization { 47 | sql, _ = parser.NormalizeDigest(sql) 48 | } 49 | return queryStrForLog(sql) 50 | case mysql.ComStmtExecute, mysql.ComStmtFetch: 51 | stmtID := binary.LittleEndian.Uint32(data[0:4]) 52 | return queryStrForLog(cc.preparedStmt2String(stmtID)) 53 | case mysql.ComStmtClose, mysql.ComStmtReset: 54 | stmtID := binary.LittleEndian.Uint32(data[0:4]) 55 | return mysql.Command2Str[cmd] + " " + strconv.Itoa(int(stmtID)) 56 | default: 57 | if cmdStr, ok := mysql.Command2Str[cmd]; ok { 58 | return cmdStr 59 | } 60 | return string(hack.String(data)) 61 | } 62 | } 63 | 64 | // PProfLabel return sql label used to tag pprof. 65 | func (cc getLastStmtInConn) PProfLabel() string { 66 | if len(cc.lastPacket) == 0 { 67 | return "" 68 | } 69 | cmd, data := cc.lastPacket[0], cc.lastPacket[1:] 70 | switch cmd { 71 | case mysql.ComInitDB: 72 | return "UseDB" 73 | case mysql.ComFieldList: 74 | return "ListFields" 75 | case mysql.ComStmtClose: 76 | return "CloseStmt" 77 | case mysql.ComStmtReset: 78 | return "ResetStmt" 79 | case mysql.ComQuery, mysql.ComStmtPrepare: 80 | return parser.Normalize(queryStrForLog(string(hack.String(data)))) 81 | case mysql.ComStmtExecute, mysql.ComStmtFetch: 82 | stmtID := binary.LittleEndian.Uint32(data[0:4]) 83 | return queryStrForLog(cc.preparedStmt2StringNoArgs(stmtID)) 84 | default: 85 | return "" 86 | } 87 | } 88 | 89 | func queryStrForLog(query string) string { 90 | const size = 4096 91 | if len(query) > size { 92 | return query[:size] + fmt.Sprintf("(len: %d)", len(query)) 93 | } 94 | return query 95 | } 96 | 97 | func errStrForLog(err error) string { 98 | if kv.ErrKeyExists.Equal(err) || parser.ErrParse.Equal(err) { 99 | // Do not log stack for duplicated entry error. 100 | return err.Error() 101 | } 102 | return errors.ErrorStack(err) 103 | } 104 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | # options for analysis running 2 | run: 3 | # default concurrency is a available CPU number 4 | concurrency: 4 5 | 6 | # timeout for analysis, e.g. 30s, 5m, default is 1m 7 | timeout: 20m 8 | 9 | # exit code when at least one issue was found, default is 1 10 | issues-exit-code: 1 11 | 12 | # include test files or not, default is true 13 | tests: false 14 | 15 | # list of build tags, all linters use it. Default is empty list. 16 | build-tags: 17 | 18 | # default is true. Enables skipping of directories: 19 | # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ 20 | skip-dirs-use-default: true 21 | 22 | # which dirs to skip: they won't be analyzed; 23 | # can use regexp here: generated.*, regexp is applied on full path; 24 | # default value is empty list, but next dirs are always skipped independently 25 | # from this option's value: 26 | # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ 27 | # skip-dirs: 28 | # - ^test.* 29 | 30 | # by default isn't set. If set we pass it to "go list -mod={option}". From "go help modules": 31 | # If invoked with -mod=readonly, the go command is disallowed from the implicit 32 | # automatic updating of go.mod described above. Instead, it fails when any changes 33 | # to go.mod are needed. This setting is most useful to check that go.mod does 34 | # not need updates, such as in a continuous integration and testing system. 35 | # If invoked with -mod=vendor, the go command assumes that the vendor 36 | # directory holds the correct copies of dependencies and ignores 37 | # the dependency descriptions in go.mod. 38 | modules-download-mode: readonly 39 | 40 | # which files to skip: they will be analyzed, but issues from them 41 | # won't be reported. Default value is empty list, but there is 42 | # no need to include all autogenerated files, we confidently recognize 43 | # autogenerated files. If it's not please let us know. 44 | skip-files: 45 | # - ".*\\.my\\.go$" 46 | # - lib/bad.go 47 | 48 | # all available settings of specific linters 49 | linters-settings: 50 | misspell: 51 | # Correct spellings using locale preferences for US or UK. 52 | # Default is to use a neutral variety of English. 53 | # Setting locale to US will correct the British spelling of 'colour' to 'color'. 54 | # locale: US 55 | ignore-words: 56 | - rela # This is for elf.SHT_RELA 57 | 58 | issues: 59 | # Excluding configuration per-path, per-linter, per-text and per-source 60 | exclude-rules: 61 | - linters: [staticcheck] 62 | text: "SA1019" # this is rule for deprecated method 63 | - linters: [staticcheck] 64 | text: "SA9003: empty branch" 65 | - linters: [staticcheck] 66 | text: "SA2001: empty critical section" 67 | - linters: [goerr113] 68 | text: "do not define dynamic errors, use wrapped static errors instead" # This rule to avoid opinionated check fmt.Errorf("text") 69 | - path: _test\.go # Skip 1.13 errors check for test files 70 | linters: 71 | - goerr113 72 | linters: 73 | disable-all: true 74 | enable: 75 | - deadcode 76 | - goerr113 77 | - gofmt 78 | - goimports 79 | - gosimple 80 | - govet 81 | - ineffassign 82 | - misspell 83 | - staticcheck 84 | - stylecheck 85 | - varcheck 86 | 87 | # To enable later if makes sense 88 | # - errcheck 89 | # - gocyclo 90 | # - golint 91 | # - gosec 92 | # - gosimple 93 | # - lll 94 | # - maligned 95 | # - misspell 96 | # - prealloc 97 | # - structcheck 98 | # - typecheck 99 | # - unused 100 | -------------------------------------------------------------------------------- /pkg/proxy/server/packetio_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | "bufio" 18 | "bytes" 19 | "net" 20 | "time" 21 | 22 | . "github.com/pingcap/check" 23 | "github.com/pingcap/parser/mysql" 24 | ) 25 | 26 | type PacketIOTestSuite struct { 27 | } 28 | 29 | var _ = Suite(new(PacketIOTestSuite)) 30 | 31 | func (s *PacketIOTestSuite) TestWrite(c *C) { 32 | // Test write one packet 33 | var outBuffer bytes.Buffer 34 | pkt := &packetIO{bufWriter: bufio.NewWriter(&outBuffer)} 35 | err := pkt.writePacket([]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}) 36 | c.Assert(err, IsNil) 37 | err = pkt.flush() 38 | c.Assert(err, IsNil) 39 | c.Assert(outBuffer.Bytes(), DeepEquals, []byte{0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}) 40 | 41 | // Test write more than one packet 42 | outBuffer.Reset() 43 | largeInput := make([]byte, mysql.MaxPayloadLen+4) 44 | pkt = &packetIO{bufWriter: bufio.NewWriter(&outBuffer)} 45 | err = pkt.writePacket(largeInput) 46 | c.Assert(err, IsNil) 47 | err = pkt.flush() 48 | c.Assert(err, IsNil) 49 | res := outBuffer.Bytes() 50 | c.Assert(res[0], Equals, byte(0xff)) 51 | c.Assert(res[1], Equals, byte(0xff)) 52 | c.Assert(res[2], Equals, byte(0xff)) 53 | c.Assert(res[3], Equals, byte(0)) 54 | } 55 | 56 | func (s *PacketIOTestSuite) TestRead(c *C) { 57 | var inBuffer bytes.Buffer 58 | _, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01}) 59 | c.Assert(err, IsNil) 60 | // Test read one packet 61 | brc := newBufferedReadConn(&bytesConn{inBuffer}) 62 | pkt := newPacketIO(brc) 63 | bytes, err := pkt.readPacket() 64 | c.Assert(err, IsNil) 65 | c.Assert(pkt.sequence, Equals, uint8(1)) 66 | c.Assert(bytes, DeepEquals, []byte{0x01}) 67 | 68 | inBuffer.Reset() 69 | buf := make([]byte, mysql.MaxPayloadLen+9) 70 | buf[0] = 0xff 71 | buf[1] = 0xff 72 | buf[2] = 0xff 73 | buf[3] = 0 74 | buf[2+mysql.MaxPayloadLen] = 0x00 75 | buf[3+mysql.MaxPayloadLen] = 0x00 76 | buf[4+mysql.MaxPayloadLen] = 0x01 77 | buf[7+mysql.MaxPayloadLen] = 0x01 78 | buf[8+mysql.MaxPayloadLen] = 0x0a 79 | 80 | _, err = inBuffer.Write(buf) 81 | c.Assert(err, IsNil) 82 | // Test read multiple packets 83 | brc = newBufferedReadConn(&bytesConn{inBuffer}) 84 | pkt = newPacketIO(brc) 85 | bytes, err = pkt.readPacket() 86 | c.Assert(err, IsNil) 87 | c.Assert(pkt.sequence, Equals, uint8(2)) 88 | c.Assert(len(bytes), Equals, mysql.MaxPayloadLen+1) 89 | c.Assert(bytes[mysql.MaxPayloadLen], DeepEquals, byte(0x0a)) 90 | } 91 | 92 | type bytesConn struct { 93 | b bytes.Buffer 94 | } 95 | 96 | func (c *bytesConn) Read(b []byte) (n int, err error) { 97 | return c.b.Read(b) 98 | } 99 | 100 | func (c *bytesConn) Write(b []byte) (n int, err error) { 101 | return 0, nil 102 | } 103 | 104 | func (c *bytesConn) Close() error { 105 | return nil 106 | } 107 | 108 | func (c *bytesConn) LocalAddr() net.Addr { 109 | return nil 110 | } 111 | 112 | func (c *bytesConn) RemoteAddr() net.Addr { 113 | return nil 114 | } 115 | 116 | func (c *bytesConn) SetDeadline(t time.Time) error { 117 | return nil 118 | } 119 | 120 | func (c *bytesConn) SetReadDeadline(t time.Time) error { 121 | return nil 122 | } 123 | 124 | func (c *bytesConn) SetWriteDeadline(t time.Time) error { 125 | return nil 126 | } 127 | -------------------------------------------------------------------------------- /docs/cn/namespace-config.md: -------------------------------------------------------------------------------- 1 | # Namespace配置详解 2 | 3 | ## 配置说明 4 | 5 | ### 客户端连接配置 6 | 7 | ``` 8 | version: "v1" 9 | namespace: "test_namespace" 10 | frontend: 11 | allowed_dbs: 12 | - "test_weir_db" 13 | slow_sql_time: 50 14 | sql_blacklist: 15 | - sql: "select * from tbl0" 16 | - sql: "select * from tbl1" 17 | sql_whitelist: 18 | - sql: "select * from tbl2" 19 | - sql: "select * from tbl3" 20 | denied_ips: 21 | users: 22 | - username: "hello" 23 | password: "world" 24 | - username: "hello1" 25 | password: "world1" 26 | ``` 27 | 28 | 字段说明 29 | 30 | | 配置 | 说明 | 31 | | --- | --- | 32 | | namespace | Namespace名称, 要求Proxy集群内唯一 | 33 | | frontend | 客户端连接相关配置 | 34 | | frontend.allowed_dbs | 客户端允许访问的Database列表 | 35 | | frontend.sql_blacklist | SQL黑名单列表 | 36 | | frontend.sql_whitelist | SQL白名单列表 | 37 | | frontend.denied_ips | 链接 ip 黑名单列表 | 38 | | frontend.users | 用户连接信息列表 | 39 | | frontend.users.username | 用户名 (要求Proxy集群内唯一) | 40 | | frontend.users.password | 密码 | 41 | 42 | ### 后端连接池配置 43 | 44 | ``` 45 | backend: 46 | instances: 47 | - "127.0.0.1:3306" 48 | username: "root" 49 | password: "12344321" 50 | selector_type: "random" 51 | pool_size: 10 52 | idle_timeout: 60 53 | ``` 54 | 55 | 字段说明 56 | 57 | | 配置 | 说明 | 58 | | --- | --- | 59 | | instances | TiDB Server实例地址列表 | 60 | | username | 连接TiDB Server用户名| 61 | | password | 连接TiDB Server密码 | 62 | | selector_type | 负载均衡策略, 目前只支持random | 63 | | pool_size | 连接池最大连接数 (针对每个TiDB Server) | 64 | | idle_timeout | 对 TIDB 连接池连接空闲超时关闭时间 (单位: 秒) | 65 | 66 | ### 熔断器配置 67 | 68 | 关于熔断的概念可以关注伴鱼技术团队的过往博客[点击了解熔断](https://tech.ipalfish.com/blog/2020/08/23/dolphin/) 69 | 70 | ``` 71 | breaker: 72 | scope: "sql" 73 | strategies: 74 | - min_qps: 3 75 | failure_rate_threshold: 0 76 | failure_num: 5 77 | sql_timeout_ms: 2000 78 | open_status_duration_ms: 5000 79 | size: 10 80 | cell_interval_ms: 1000 81 | ``` 82 | 83 | 字段说明 84 | 85 | | 配置 | 说明 | 86 | | --- | --- | 87 | | scope | 熔断器粒度, 支持参数: namespace, db, table, sql | 88 | | strategies | 熔断策略列表 | 89 | | strategies.min_qps | 熔断被能被触发的最小 QPS | 90 | | strategies.failure_rate_threshold | 需要达到可以熔断的错误率阈值百分数 (0~100) | 91 | | strategies.failure_num | 需要达到可以熔断的错误数阈值 (与failure_rate_threshold只能使用其一) | 92 | | strategies.sql_timeout_ms | SQL超时阈值, (单位: 毫秒) | 93 | | strategies.open_status_duration_ms | 熔断器开启状态持续时间 (单位: 毫秒) | 94 | | strategies.size| 滑动窗口计数器的统计单元数 | 95 | | strategies.cell_interval_ms | 每个单元的时长 (单位: 毫秒), 与size字段共同组成了熔断器时间统计的滑窗 | 96 | 97 | ### 限流器配置 98 | 99 | ``` 100 | rate_limiter: 101 | scope: "db" 102 | qps: 1000 103 | ``` 104 | 105 | 字段说明 106 | 107 | | 配置 | 说明 | 108 | | --- | --- | 109 | | scope | 限流器粒度, 支持参数: namespace, db, table | 110 | | qps | 限流QPS (超过阈值的请求会直接返回错误) | 111 | 112 | 113 | ## 完整配置示例 114 | 115 | ``` 116 | version: "v1" 117 | namespace: "test_namespace" 118 | frontend: 119 | allowed_dbs: 120 | - "test_weir_db" 121 | slow_sql_time: 50 122 | sql_blacklist: 123 | - sql: "select * from tbl0" 124 | - sql: "select * from tbl1" 125 | sql_whitelist: 126 | - sql: "select * from tbl2" 127 | - sql: "select * from tbl3" 128 | denied_ips: 129 | users: 130 | - username: "hello" 131 | password: "world" 132 | - username: "hello1" 133 | password: "world1" 134 | backend: 135 | instances: 136 | - "127.0.0.1:4000" 137 | username: "root" 138 | password: "" 139 | selector_type: "random" 140 | pool_size: 10 141 | idle_timeout: 60 142 | breaker: 143 | scope: "sql" 144 | strategies: 145 | - min_qps: 3 146 | failure_rate_threshold: 0 147 | failure_num: 5 148 | sql_timeout_ms: 2000 149 | open_status_duration_ms: 5000 150 | size: 10 151 | cell_interval_ms: 1000 152 | rate_limiter: 153 | scope: "db" 154 | qps: 1000 155 | ``` 156 | -------------------------------------------------------------------------------- /pkg/configcenter/etcd.go: -------------------------------------------------------------------------------- 1 | package configcenter 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "path" 7 | "time" 8 | 9 | "github.com/tidb-incubator/weir/pkg/config" 10 | "github.com/pingcap/errors" 11 | "github.com/pingcap/tidb/util/logutil" 12 | "go.etcd.io/etcd/clientv3" 13 | "go.etcd.io/etcd/mvcc/mvccpb" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | const ( 18 | DefaultEtcdDialTimeout = 3 * time.Second 19 | ) 20 | 21 | type EtcdConfigCenter struct { 22 | etcdClient *clientv3.Client 23 | kv clientv3.KV 24 | basePath string 25 | strictParse bool 26 | } 27 | 28 | func CreateEtcdConfigCenter(cfg config.ConfigEtcd) (*EtcdConfigCenter, error) { 29 | etcdConfig := clientv3.Config{ 30 | Endpoints: cfg.Addrs, 31 | Username: cfg.Username, 32 | Password: cfg.Password, 33 | DialTimeout: DefaultEtcdDialTimeout, 34 | } 35 | etcdClient, err := clientv3.New(etcdConfig) 36 | if err != nil { 37 | return nil, errors.WithMessage(err, "create etcd config center error") 38 | } 39 | 40 | center := NewEtcdConfigCenter(etcdClient, cfg.BasePath, cfg.StrictParse) 41 | return center, nil 42 | } 43 | 44 | func NewEtcdConfigCenter(etcdClient *clientv3.Client, basePath string, strictParse bool) *EtcdConfigCenter { 45 | return &EtcdConfigCenter{ 46 | etcdClient: etcdClient, 47 | kv: clientv3.NewKV(etcdClient), 48 | basePath: basePath, 49 | strictParse: strictParse, 50 | } 51 | } 52 | 53 | func (e *EtcdConfigCenter) get(ctx context.Context, key string) (*mvccpb.KeyValue, error) { 54 | resp, err := e.kv.Get(ctx, getNamespacePath(e.basePath, key)) 55 | if err != nil { 56 | return nil, err 57 | } 58 | if len(resp.Kvs) == 0 { 59 | return nil, fmt.Errorf("key not found") 60 | } 61 | return resp.Kvs[0], nil 62 | } 63 | 64 | func (e *EtcdConfigCenter) list(ctx context.Context) ([]*mvccpb.KeyValue, error) { 65 | baseDir := appendSlashToDirPath(e.basePath) 66 | resp, err := e.kv.Get(ctx, baseDir, clientv3.WithPrefix()) 67 | if err != nil { 68 | return nil, err 69 | } 70 | return resp.Kvs, nil 71 | } 72 | 73 | func (e *EtcdConfigCenter) GetNamespace(ns string) (*config.Namespace, error) { 74 | ctx := context.Background() 75 | etcdKeyValue, err := e.get(ctx, ns) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | return config.UnmarshalNamespaceConfig(etcdKeyValue.Value) 81 | } 82 | 83 | func (e *EtcdConfigCenter) ListAllNamespace() ([]*config.Namespace, error) { 84 | ctx := context.Background() 85 | etcdKeyValues, err := e.list(ctx) 86 | if err != nil { 87 | return nil, err 88 | } 89 | 90 | var ret []*config.Namespace 91 | for _, kv := range etcdKeyValues { 92 | nsCfg, err := config.UnmarshalNamespaceConfig(kv.Value) 93 | if err != nil { 94 | if e.strictParse { 95 | return nil, err 96 | } else { 97 | logutil.BgLogger().Warn("parse namespace config error", zap.Error(err), zap.ByteString("namespace", kv.Key)) 98 | continue 99 | } 100 | } 101 | ret = append(ret, nsCfg) 102 | } 103 | 104 | return ret, nil 105 | } 106 | 107 | func (e *EtcdConfigCenter) SetNamespace(ns string, value string) error { 108 | ctx := context.Background() 109 | _, err := e.kv.Put(ctx, ns, value) 110 | return err 111 | } 112 | 113 | func (e *EtcdConfigCenter) DelNamespace(ns string) error { 114 | ctx := context.Background() 115 | _, err := e.kv.Delete(ctx, ns) 116 | return err 117 | } 118 | 119 | func (e *EtcdConfigCenter) Close() { 120 | if err := e.etcdClient.Close(); err != nil { 121 | logutil.BgLogger().Error("close etcd client error", zap.Error(err)) 122 | } 123 | } 124 | 125 | func getNamespacePath(basePath, ns string) string { 126 | return path.Join(basePath, ns) 127 | } 128 | 129 | // avoid base dir path prefix equal 130 | func appendSlashToDirPath(dir string) string { 131 | if len(dir) == 0 { 132 | return "" 133 | } 134 | if dir[len(dir)-1] == '/' { 135 | return dir 136 | } 137 | return dir + "/" 138 | } 139 | -------------------------------------------------------------------------------- /pkg/proxy/driver/resultset.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync/atomic" 7 | 8 | "github.com/tidb-incubator/weir/pkg/proxy/server" 9 | "github.com/pingcap/tidb/types" 10 | "github.com/pingcap/tidb/util/chunk" 11 | "github.com/pingcap/tidb/util/hack" 12 | "github.com/siddontang/go-mysql/mysql" 13 | ) 14 | 15 | type weirResultSet struct { 16 | result *mysql.Result 17 | columnInfos []*server.ColumnInfo 18 | closed int32 19 | readed bool 20 | } 21 | 22 | func wrapMySQLResult(result *mysql.Result) *weirResultSet { 23 | resultSet := &weirResultSet{ 24 | result: result, 25 | } 26 | columnInfos := convertFieldsToColumnInfos(result.Fields) 27 | resultSet.columnInfos = columnInfos 28 | return resultSet 29 | } 30 | 31 | func createBinaryPrepareColumns(cnt int) []*server.ColumnInfo { 32 | info := &server.ColumnInfo{} 33 | return copyColumnInfo(info, cnt) 34 | } 35 | 36 | func createBinaryPrepareParams(cnt int) []*server.ColumnInfo { 37 | info := &server.ColumnInfo{Name: "?"} 38 | return copyColumnInfo(info, cnt) 39 | } 40 | 41 | func copyColumnInfo(info *server.ColumnInfo, cnt int) []*server.ColumnInfo { 42 | ret := make([]*server.ColumnInfo, 0, cnt) 43 | for i := 0; i < cnt; i++ { 44 | ret = append(ret, info) 45 | } 46 | return ret 47 | } 48 | 49 | func convertFieldsToColumnInfos(fields []*mysql.Field) []*server.ColumnInfo { 50 | var columnInfos []*server.ColumnInfo 51 | for _, field := range fields { 52 | columnInfo := &server.ColumnInfo{ 53 | Schema: string(hack.String(field.Schema)), 54 | Table: string(hack.String(field.Table)), 55 | OrgTable: string(hack.String(field.OrgTable)), 56 | Name: string(hack.String(field.Name)), 57 | OrgName: string(hack.String(field.OrgName)), 58 | ColumnLength: field.ColumnLength, 59 | Charset: field.Charset, 60 | Flag: field.Flag, 61 | Decimal: field.Decimal, 62 | Type: field.Type, 63 | DefaultValueLength: field.DefaultValueLength, 64 | DefaultValue: field.DefaultValue, 65 | } 66 | columnInfos = append(columnInfos, columnInfo) 67 | } 68 | return columnInfos 69 | } 70 | 71 | func convertFieldTypes(fields []*mysql.Field) []*types.FieldType { 72 | var ret []*types.FieldType 73 | for _, f := range fields { 74 | ft := types.NewFieldType(f.Type) 75 | ft.Flag = uint(f.Flag) 76 | ret = append(ret, ft) 77 | } 78 | return ret 79 | } 80 | 81 | func writeResultSetDataToTrunk(r *mysql.Resultset, c *chunk.Chunk) { 82 | for _, rowValue := range r.Values { 83 | for colIdx, colValue := range rowValue { 84 | switch colValue.Type { 85 | case mysql.FieldValueTypeNull: 86 | c.Column(colIdx).AppendNull() 87 | case mysql.FieldValueTypeUnsigned: 88 | c.Column(colIdx).AppendUint64(colValue.AsUint64()) 89 | case mysql.FieldValueTypeSigned: 90 | c.Column(colIdx).AppendInt64(colValue.AsInt64()) 91 | case mysql.FieldValueTypeFloat: 92 | c.Column(colIdx).AppendFloat64(colValue.AsFloat64()) 93 | case mysql.FieldValueTypeString: 94 | c.Column(colIdx).AppendBytes(colValue.AsString()) 95 | default: 96 | panic(fmt.Errorf("invalid col value type: %v", colValue.Type)) 97 | } 98 | } 99 | } 100 | } 101 | 102 | func (w *weirResultSet) Columns() []*server.ColumnInfo { 103 | return w.columnInfos 104 | } 105 | 106 | func (w *weirResultSet) NewChunk() *chunk.Chunk { 107 | columns := convertFieldTypes(w.result.Fields) 108 | rowCount := len(w.result.RowDatas) 109 | c := chunk.NewChunkWithCapacity(columns, rowCount) 110 | writeResultSetDataToTrunk(w.result.Resultset, c) 111 | return c 112 | } 113 | 114 | func (w *weirResultSet) Next(ctx context.Context, c *chunk.Chunk) error { 115 | // all the data has been converted and set into chunk when calling NewChunk(), 116 | // so we need to do nothing here 117 | if w.readed { 118 | c.Reset() 119 | } 120 | w.readed = true 121 | return nil 122 | } 123 | 124 | func (*weirResultSet) StoreFetchedRows(rows []chunk.Row) { 125 | panic("implement me") 126 | } 127 | 128 | func (*weirResultSet) GetFetchedRows() []chunk.Row { 129 | panic("implement me") 130 | } 131 | 132 | func (w *weirResultSet) Close() error { 133 | atomic.StoreInt32(&w.closed, 1) 134 | return nil 135 | } 136 | -------------------------------------------------------------------------------- /pkg/util/timer/time_wheel.go: -------------------------------------------------------------------------------- 1 | package timer 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | ) 7 | 8 | // Task means handle unit in time wheel 9 | type Task struct { 10 | delay time.Duration 11 | key interface{} 12 | round int // optimize time wheel to handle delay bigger than bucketsNum * tick 13 | callback func() 14 | } 15 | 16 | // TimeWheel means time wheel 17 | type TimeWheel struct { 18 | tick time.Duration 19 | ticker *time.Ticker 20 | 21 | bucketsNum int 22 | buckets []map[interface{}]*Task // key: added item, value: *Task 23 | bucketIndexes map[interface{}]int // key: added item, value: bucket position 24 | 25 | currentIndex int 26 | 27 | addC chan *Task 28 | removeC chan interface{} 29 | stopC chan struct{} 30 | } 31 | 32 | // NewTimeWheel create new time wheel 33 | func NewTimeWheel(tick time.Duration, bucketsNum int) (*TimeWheel, error) { 34 | if bucketsNum <= 0 { 35 | return nil, errors.New("bucket number must be greater than 0") 36 | } 37 | // if int(tick.Seconds()) < 1 { 38 | // return nil, errors.New("tick cannot be less than 1s") 39 | // } 40 | 41 | tw := &TimeWheel{ 42 | tick: tick, 43 | bucketsNum: bucketsNum, 44 | bucketIndexes: make(map[interface{}]int, 1024), 45 | buckets: make([]map[interface{}]*Task, bucketsNum), 46 | currentIndex: 0, 47 | addC: make(chan *Task, 1024), 48 | removeC: make(chan interface{}, 1024), 49 | stopC: make(chan struct{}), 50 | } 51 | 52 | for i := 0; i < bucketsNum; i++ { 53 | tw.buckets[i] = make(map[interface{}]*Task, 16) 54 | } 55 | 56 | return tw, nil 57 | } 58 | 59 | // Start start the time wheel 60 | func (tw *TimeWheel) Start() { 61 | tw.ticker = time.NewTicker(tw.tick) 62 | go tw.start() 63 | } 64 | 65 | func (tw *TimeWheel) start() { 66 | for { 67 | select { 68 | case <-tw.ticker.C: 69 | tw.handleTick() 70 | case task := <-tw.addC: 71 | tw.add(task) 72 | case key := <-tw.removeC: 73 | tw.remove(key) 74 | case <-tw.stopC: 75 | tw.ticker.Stop() 76 | return 77 | } 78 | } 79 | } 80 | 81 | // Stop stop the time wheel 82 | func (tw *TimeWheel) Stop() { 83 | tw.stopC <- struct{}{} 84 | } 85 | 86 | func (tw *TimeWheel) handleTick() { 87 | bucket := tw.buckets[tw.currentIndex] 88 | for k := range bucket { 89 | if bucket[k].round > 0 { 90 | bucket[k].round-- 91 | continue 92 | } 93 | go bucket[k].callback() 94 | delete(bucket, k) 95 | delete(tw.bucketIndexes, k) 96 | } 97 | if tw.currentIndex == tw.bucketsNum-1 { 98 | tw.currentIndex = 0 99 | return 100 | } 101 | tw.currentIndex++ 102 | } 103 | 104 | // Add add an item into time wheel 105 | func (tw *TimeWheel) Add(delay time.Duration, key interface{}, callback func()) error { 106 | if delay <= 0 || key == nil { 107 | return errors.New("invalid params") 108 | } 109 | tw.addC <- &Task{delay: delay, key: key, callback: callback} 110 | return nil 111 | } 112 | 113 | func (tw *TimeWheel) add(task *Task) { 114 | round := tw.calculateRound(task.delay) 115 | index := tw.calculateIndex(task.delay) 116 | task.round = round 117 | if originIndex, ok := tw.bucketIndexes[task.key]; ok { 118 | delete(tw.buckets[originIndex], task.key) 119 | } 120 | tw.bucketIndexes[task.key] = index 121 | tw.buckets[index][task.key] = task 122 | } 123 | 124 | func (tw *TimeWheel) calculateRound(delay time.Duration) (round int) { 125 | delaySeconds := int(delay.Milliseconds()) 126 | tickSeconds := int(tw.tick.Milliseconds()) 127 | round = delaySeconds / tickSeconds / tw.bucketsNum 128 | return 129 | } 130 | 131 | func (tw *TimeWheel) calculateIndex(delay time.Duration) (index int) { 132 | delaySeconds := int(delay.Milliseconds()) 133 | tickSeconds := int(tw.tick.Milliseconds()) 134 | index = (tw.currentIndex + delaySeconds/tickSeconds) % tw.bucketsNum 135 | return 136 | } 137 | 138 | // Remove remove an item from time wheel 139 | func (tw *TimeWheel) Remove(key interface{}) error { 140 | if key == nil { 141 | return errors.New("invalid params") 142 | } 143 | tw.removeC <- key 144 | return nil 145 | } 146 | 147 | // don't need to call callback 148 | func (tw *TimeWheel) remove(key interface{}) { 149 | if index, ok := tw.bucketIndexes[key]; ok { 150 | delete(tw.bucketIndexes, key) 151 | delete(tw.buckets[index], key) 152 | } 153 | return 154 | } 155 | -------------------------------------------------------------------------------- /pkg/util/timer/timer.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | // Package timer provides various enhanced timer functions. 18 | package timer 19 | 20 | import ( 21 | "sync" 22 | "time" 23 | 24 | "github.com/tidb-incubator/weir/pkg/util/sync2" 25 | ) 26 | 27 | // Out-of-band messages 28 | type typeAction int 29 | 30 | const ( 31 | timerStop typeAction = iota 32 | timerReset 33 | timerTrigger 34 | ) 35 | 36 | /* 37 | Timer provides timer functionality that can be controlled 38 | by the user. You start the timer by providing it a callback function, 39 | which it will call at the specified interval. 40 | 41 | var t = timer.NewTimer(1e9) 42 | t.Start(KeepHouse) 43 | 44 | func KeepHouse() { 45 | // do house keeping work 46 | } 47 | 48 | You can stop the timer by calling t.Stop, which is guaranteed to 49 | wait if KeepHouse is being executed. 50 | 51 | You can create an untimely trigger by calling t.Trigger. You can also 52 | schedule an untimely trigger by calling t.TriggerAfter. 53 | 54 | The timer interval can be changed on the fly by calling t.SetInterval. 55 | A zero value interval will cause the timer to wait indefinitely, and it 56 | will react only to an explicit Trigger or Stop. 57 | */ 58 | type Timer struct { 59 | interval sync2.AtomicDuration 60 | 61 | // state management 62 | mu sync.Mutex 63 | running bool 64 | 65 | // msg is used for out-of-band messages 66 | msg chan typeAction 67 | } 68 | 69 | // NewTimer creates a new Timer object 70 | func NewTimer(interval time.Duration) *Timer { 71 | tm := &Timer{ 72 | msg: make(chan typeAction), 73 | } 74 | tm.interval.Set(interval) 75 | return tm 76 | } 77 | 78 | // Start starts the timer. 79 | func (tm *Timer) Start(keephouse func()) { 80 | tm.mu.Lock() 81 | defer tm.mu.Unlock() 82 | if tm.running { 83 | return 84 | } 85 | tm.running = true 86 | go tm.run(keephouse) 87 | } 88 | 89 | func (tm *Timer) run(keephouse func()) { 90 | var timer *time.Timer 91 | for { 92 | var ch <-chan time.Time 93 | interval := tm.interval.Get() 94 | if interval > 0 { 95 | timer = time.NewTimer(interval) 96 | ch = timer.C 97 | } 98 | select { 99 | case action := <-tm.msg: 100 | if timer != nil { 101 | timer.Stop() 102 | timer = nil 103 | } 104 | switch action { 105 | case timerStop: 106 | return 107 | case timerReset: 108 | continue 109 | } 110 | case <-ch: 111 | } 112 | keephouse() 113 | } 114 | } 115 | 116 | // SetInterval changes the wait interval. 117 | // It will cause the timer to restart the wait. 118 | func (tm *Timer) SetInterval(ns time.Duration) { 119 | tm.interval.Set(ns) 120 | tm.mu.Lock() 121 | defer tm.mu.Unlock() 122 | if tm.running { 123 | tm.msg <- timerReset 124 | } 125 | } 126 | 127 | // Trigger will cause the timer to immediately execute the keephouse function. 128 | // It will then cause the timer to restart the wait. 129 | func (tm *Timer) Trigger() { 130 | tm.mu.Lock() 131 | defer tm.mu.Unlock() 132 | if tm.running { 133 | tm.msg <- timerTrigger 134 | } 135 | } 136 | 137 | // TriggerAfter waits for the specified duration and triggers the next event. 138 | func (tm *Timer) TriggerAfter(duration time.Duration) { 139 | go func() { 140 | time.Sleep(duration) 141 | tm.Trigger() 142 | }() 143 | } 144 | 145 | // Stop will stop the timer. It guarantees that the timer will not execute 146 | // any more calls to keephouse once it has returned. 147 | func (tm *Timer) Stop() { 148 | tm.mu.Lock() 149 | defer tm.mu.Unlock() 150 | if tm.running { 151 | tm.msg <- timerStop 152 | tm.running = false 153 | } 154 | } 155 | 156 | // Interval returns the current interval. 157 | func (tm *Timer) Interval() time.Duration { 158 | return tm.interval.Get() 159 | } 160 | 161 | func (tm *Timer) Running() bool { 162 | tm.mu.Lock() 163 | defer tm.mu.Unlock() 164 | return tm.running 165 | } 166 | -------------------------------------------------------------------------------- /pkg/util/rate_limit_breaker/circuit_breaker/circuit_breaker_test.go: -------------------------------------------------------------------------------- 1 | package circuit_breaker 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | rateLimitBreaker "github.com/tidb-incubator/weir/pkg/util/rate_limit_breaker" 7 | "github.com/stretchr/testify/assert" 8 | "testing" 9 | ) 10 | 11 | func TestCircuitBreaker_Do_NoError(t *testing.T) { 12 | ctx := context.Background() 13 | cb := NewCircuitBreaker(&CircuitBreakerConfig{ 14 | minQPS: 10, 15 | failureRateThreshold: 10, 16 | failureNum: 5, 17 | OpenStatusDurationMs: 10000, // 10s 18 | forceOpen: false, 19 | size: 10, 20 | cellIntervalMs: 1000, 21 | }) 22 | 23 | successCount := 0 24 | errCount := 0 25 | for i := 0; i < 1000; i++ { 26 | cb.Do(ctx, func(ctx context.Context) error { 27 | successCount++ 28 | return nil 29 | }, func(ctx context.Context, err error) error { 30 | errCount++ 31 | return err 32 | }) 33 | } 34 | assert.Equal(t, successCount, 1000) 35 | assert.Equal(t, errCount, 0) 36 | assert.Equal(t, cb.Status(), CircuitBreakerStatusOpen) 37 | } 38 | 39 | func TestCircuitBreaker_Do_AlwaysError(t *testing.T) { 40 | ctx := context.Background() 41 | cb := NewCircuitBreaker(&CircuitBreakerConfig{ 42 | minQPS: 10, 43 | failureRateThreshold: 10, 44 | OpenStatusDurationMs: 10000, // 10s 45 | forceOpen: false, 46 | }) 47 | 48 | successCount := 0 49 | errCount := 0 50 | for i := 0; i < 1000; i++ { 51 | cb.Do(ctx, func(ctx context.Context) error { 52 | successCount++ 53 | return errors.New("just_error") 54 | }, func(ctx context.Context, err error) error { 55 | errCount++ 56 | return err 57 | }) 58 | } 59 | 60 | assert.True(t, successCount < 1000) 61 | assert.Equal(t, errCount, 1000) 62 | assert.Equal(t, cb.Status(), CircuitBreakerStatusOpen) 63 | } 64 | 65 | func TestCircuitBreaker_ForceOpen(t *testing.T) { 66 | ctx := context.Background() 67 | cb := NewCircuitBreaker(&CircuitBreakerConfig{ 68 | minQPS: 10, 69 | failureRateThreshold: 10, 70 | OpenStatusDurationMs: 10000, // 10s 71 | forceOpen: true, 72 | }) 73 | 74 | errCount := 0 75 | for i := 0; i < 1000; i++ { 76 | cb.Do(ctx, func(ctx context.Context) error { 77 | return nil 78 | }, func(ctx context.Context, err error) error { 79 | errCount++ 80 | return nil 81 | }) 82 | } 83 | assert.Equal(t, errCount, 1000) 84 | assert.Equal(t, cb.Status(), CircuitBreakerStatusForceOpen) 85 | } 86 | 87 | func TestCircuitBreaker_ChangeConfig_WithForceOpen(t *testing.T) { 88 | cb := NewCircuitBreaker(&CircuitBreakerConfig{ 89 | minQPS: 10, 90 | failureRateThreshold: 10, 91 | OpenStatusDurationMs: 10000, // 10s 92 | forceOpen: true, 93 | }) 94 | assert.Equal(t, cb.status, CircuitBreakerStatusForceOpen) 95 | assert.Equal(t, cb.Status(), CircuitBreakerStatusForceOpen) 96 | 97 | // cancel forceOpen, status goes back to closed 98 | cb.ChangeConfig(&CircuitBreakerConfig{ 99 | minQPS: 10, 100 | failureRateThreshold: 10, 101 | OpenStatusDurationMs: 10000, 102 | forceOpen: false, 103 | }) 104 | assert.Equal(t, cb.status, CircuitBreakerStatusClosed) 105 | assert.Equal(t, cb.Status(), CircuitBreakerStatusClosed) 106 | } 107 | 108 | func TestCircuitBreaker_ChangeConfig_WithoutForceOpen(t *testing.T) { 109 | cb := NewCircuitBreaker(&CircuitBreakerConfig{ 110 | minQPS: 10, 111 | failureRateThreshold: 10, 112 | OpenStatusDurationMs: 10000, // 10s 113 | forceOpen: false, 114 | }) 115 | 116 | // 当前为 open 状态,然后修改配置(没有强制开启),检查仍为 open 状态 117 | cb.status = CircuitBreakerStatusOpen 118 | cb.openStartMs = rateLimitBreaker.GetNowMs() 119 | cb.ChangeConfig(&CircuitBreakerConfig{ 120 | minQPS: 10, 121 | failureRateThreshold: 10, 122 | OpenStatusDurationMs: 100000, // 100s 123 | forceOpen: false, 124 | }) 125 | assert.Equal(t, cb.status, CircuitBreakerStatusOpen) 126 | assert.Equal(t, cb.Status(), CircuitBreakerStatusOpen) 127 | 128 | // 当前为 closed 状态,然后修改配置(没有强制开启),检查仍为 closed 状态 129 | cb.status = CircuitBreakerStatusClosed 130 | cb.ChangeConfig(&CircuitBreakerConfig{ 131 | minQPS: 10, 132 | failureRateThreshold: 10, 133 | OpenStatusDurationMs: 10000, // 10s 134 | forceOpen: false, 135 | }) 136 | assert.Equal(t, cb.status, CircuitBreakerStatusClosed) 137 | assert.Equal(t, cb.Status(), CircuitBreakerStatusClosed) 138 | } 139 | -------------------------------------------------------------------------------- /pkg/proxy/metrics/queryctx.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/pingcap/parser/ast" 5 | "github.com/prometheus/client_golang/prometheus" 6 | ) 7 | 8 | type AstStmtType int 9 | 10 | const ( 11 | StmtTypeUnknown AstStmtType = iota 12 | StmtTypeSelect 13 | StmtTypeInsert 14 | StmtTypeUpdate 15 | StmtTypeDelete 16 | StmtTypeDDL 17 | StmtTypeBegin 18 | StmtTypeCommit 19 | StmtTypeRollback 20 | StmtTypeSet 21 | StmtTypeShow 22 | StmtTypeUse 23 | StmtTypeComment 24 | ) 25 | 26 | const ( 27 | StmtNameUnknown = "unknown" 28 | StmtNameSelect = "select" 29 | StmtNameInsert = "insert" 30 | StmtNameUpdate = "update" 31 | StmtNameDelete = "delete" 32 | StmtNameDDL = "ddl" 33 | StmtNameBegin = "begin" 34 | StmtNameCommit = "commit" 35 | StmtNameRollback = "rollback" 36 | StmtNameSet = "set" 37 | StmtNameShow = "show" 38 | StmtNameUse = "use" 39 | StmtNameComment = "comment" 40 | ) 41 | 42 | var ( 43 | QueryCtxQueryCounter = prometheus.NewCounterVec( 44 | prometheus.CounterOpts{ 45 | Namespace: ModuleWeirProxy, 46 | Subsystem: LabelQueryCtx, 47 | Name: "query_total", 48 | Help: "Counter of queries.", 49 | }, []string{LblCluster, LblNamespace, LblDb, LblTable, LblSQLType, LblResult}) 50 | 51 | QueryCtxQueryDeniedCounter = prometheus.NewCounterVec( 52 | prometheus.CounterOpts{ 53 | Namespace: ModuleWeirProxy, 54 | Subsystem: LabelQueryCtx, 55 | Name: "query_denied", 56 | Help: "Counter of denied queries.", 57 | }, []string{LblCluster, LblNamespace, LblDb, LblTable, LblSQLType}) 58 | 59 | QueryCtxQueryDurationHistogram = prometheus.NewHistogramVec( 60 | prometheus.HistogramOpts{ 61 | Namespace: ModuleWeirProxy, 62 | Subsystem: LabelQueryCtx, 63 | Name: "handle_query_duration_seconds", 64 | Help: "Bucketed histogram of processing time (s) of handled queries.", 65 | Buckets: prometheus.ExponentialBuckets(0.0005, 2, 29), // 0.5ms ~ 1.5days 66 | }, []string{LblCluster, LblNamespace, LblDb, LblTable, LblSQLType}) 67 | 68 | QueryCtxGauge = prometheus.NewGaugeVec( 69 | prometheus.GaugeOpts{ 70 | Namespace: ModuleWeirProxy, 71 | Subsystem: LabelQueryCtx, 72 | Name: "queryctx", 73 | Help: "Number of queryctx (equals to client connection).", 74 | }, []string{LblCluster, LblNamespace}) 75 | 76 | QueryCtxAttachedConnGauge = prometheus.NewGaugeVec( 77 | prometheus.GaugeOpts{ 78 | Namespace: ModuleWeirProxy, 79 | Subsystem: LabelQueryCtx, 80 | Name: "attached_connections", 81 | Help: "Number of attached backend connections.", 82 | }, []string{LblCluster, LblNamespace}) 83 | 84 | QueryCtxTransactionDuration = prometheus.NewHistogramVec( 85 | prometheus.HistogramOpts{ 86 | Namespace: "tidb", 87 | Subsystem: "session", 88 | Name: "transaction_duration_seconds", 89 | Help: "Bucketed histogram of a transaction execution duration, including retry.", 90 | Buckets: prometheus.ExponentialBuckets(0.001, 2, 28), // 1ms ~ 1.5days 91 | }, []string{LblCluster, LblNamespace, LblDb, LblSQLType}) 92 | ) 93 | 94 | func GetStmtType(stmt ast.StmtNode) AstStmtType { 95 | switch stmt.(type) { 96 | case *ast.SelectStmt: 97 | return StmtTypeSelect 98 | case *ast.InsertStmt: 99 | return StmtTypeInsert 100 | case *ast.UpdateStmt: 101 | return StmtTypeUpdate 102 | case *ast.DeleteStmt: 103 | return StmtTypeDelete 104 | case *ast.BeginStmt: 105 | return StmtTypeBegin 106 | case *ast.CommitStmt: 107 | return StmtTypeCommit 108 | case *ast.RollbackStmt: 109 | return StmtTypeRollback 110 | case *ast.SetStmt: 111 | return StmtTypeSet 112 | case *ast.ShowStmt: 113 | return StmtTypeShow 114 | case *ast.UseStmt: 115 | return StmtTypeUse 116 | default: 117 | return StmtTypeUnknown 118 | } 119 | } 120 | 121 | func GetStmtTypeName(stmt ast.StmtNode) string { 122 | switch stmt.(type) { 123 | case *ast.SelectStmt: 124 | return StmtNameSelect 125 | case *ast.InsertStmt: 126 | return StmtNameInsert 127 | case *ast.UpdateStmt: 128 | return StmtNameUpdate 129 | case *ast.DeleteStmt: 130 | return StmtNameDelete 131 | case *ast.BeginStmt: 132 | return StmtNameBegin 133 | case *ast.CommitStmt: 134 | return StmtNameCommit 135 | case *ast.RollbackStmt: 136 | return StmtNameRollback 137 | case *ast.SetStmt: 138 | return StmtNameSet 139 | case *ast.ShowStmt: 140 | return StmtNameShow 141 | case *ast.UseStmt: 142 | return StmtNameUse 143 | default: 144 | return StmtNameUnknown 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/builder.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "hash/crc32" 5 | "time" 6 | 7 | "github.com/tidb-incubator/weir/pkg/config" 8 | "github.com/tidb-incubator/weir/pkg/proxy/backend" 9 | "github.com/tidb-incubator/weir/pkg/proxy/driver" 10 | wast "github.com/tidb-incubator/weir/pkg/util/ast" 11 | "github.com/tidb-incubator/weir/pkg/util/datastructure" 12 | "github.com/pingcap/errors" 13 | "github.com/pingcap/parser" 14 | ) 15 | 16 | type NamespaceImpl struct { 17 | name string 18 | Br driver.Breaker 19 | Backend 20 | Frontend 21 | rateLimiter *NamespaceRateLimiter 22 | } 23 | 24 | func BuildNamespace(cfg *config.Namespace) (Namespace, error) { 25 | be, err := BuildBackend(cfg.Namespace, &cfg.Backend) 26 | if err != nil { 27 | return nil, errors.WithMessage(err, "build backend error") 28 | } 29 | fe, err := BuildFrontend(&cfg.Frontend) 30 | if err != nil { 31 | return nil, errors.WithMessage(err, "build frontend error") 32 | } 33 | wrapper := &NamespaceImpl{ 34 | name: cfg.Namespace, 35 | Backend: be, 36 | Frontend: fe, 37 | } 38 | brm, err := NewBreaker(&cfg.Breaker) 39 | if err != nil { 40 | return nil, err 41 | } 42 | br, err := brm.GetBreaker() 43 | if err != nil { 44 | return nil, err 45 | } 46 | wrapper.Br = br 47 | 48 | rateLimiter := NewNamespaceRateLimiter(cfg.RateLimiter.Scope, cfg.RateLimiter.QPS) 49 | wrapper.rateLimiter = rateLimiter 50 | 51 | return wrapper, nil 52 | } 53 | 54 | func (n *NamespaceImpl) Name() string { 55 | return n.name 56 | } 57 | 58 | func (n *NamespaceImpl) GetBreaker() (driver.Breaker, error) { 59 | return n.Br, nil 60 | } 61 | 62 | func (n *NamespaceImpl) GetRateLimiter() driver.RateLimiter { 63 | return n.rateLimiter 64 | } 65 | 66 | func BuildBackend(ns string, cfg *config.BackendNamespace) (Backend, error) { 67 | bcfg, err := parseBackendConfig(cfg) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | b := backend.NewBackendImpl(ns, bcfg) 73 | if err := b.Init(); err != nil { 74 | return nil, err 75 | } 76 | 77 | return b, nil 78 | } 79 | 80 | func BuildFrontend(cfg *config.FrontendNamespace) (Frontend, error) { 81 | fns := &FrontendNamespace{ 82 | allowedDBs: cfg.AllowedDBs, 83 | } 84 | fns.allowedDBSet = datastructure.StringSliceToSet(cfg.AllowedDBs) 85 | 86 | userPasswds := make(map[string]string) 87 | for _, u := range cfg.Users { 88 | userPasswds[u.Username] = u.Password 89 | } 90 | fns.userPasswd = userPasswds 91 | 92 | sqlBlacklist := make(map[uint32]SQLInfo) 93 | fns.sqlBlacklist = sqlBlacklist 94 | 95 | p := parser.New() 96 | for _, deniedSQL := range cfg.SQLBlackList { 97 | stmtNodes, _, err := p.Parse(deniedSQL.SQL, "", "") 98 | if err != nil { 99 | return nil, err 100 | } 101 | if len(stmtNodes) != 1 { 102 | return nil, nil 103 | } 104 | v, err := wast.ExtractAstVisit(stmtNodes[0]) 105 | if err != nil { 106 | return nil, err 107 | } 108 | fns.sqlBlacklist[crc32.ChecksumIEEE([]byte(v.SqlFeature()))] = SQLInfo{SQL: deniedSQL.SQL} 109 | } 110 | 111 | sqlWhitelist := make(map[uint32]SQLInfo) 112 | fns.sqlWhitelist = sqlWhitelist 113 | for _, allowedSQL := range cfg.SQLWhiteList { 114 | stmtNodes, _, err := p.Parse(allowedSQL.SQL, "", "") 115 | if err != nil { 116 | return nil, err 117 | } 118 | if len(stmtNodes) != 1 { 119 | return nil, nil 120 | } 121 | v, err := wast.ExtractAstVisit(stmtNodes[0]) 122 | if err != nil { 123 | return nil, err 124 | } 125 | fns.sqlWhitelist[crc32.ChecksumIEEE([]byte(v.SqlFeature()))] = SQLInfo{SQL: allowedSQL.SQL} 126 | } 127 | 128 | return fns, nil 129 | } 130 | 131 | func parseBackendConfig(cfg *config.BackendNamespace) (*backend.BackendConfig, error) { 132 | selectorType, valid := backend.SelectorNameToType(cfg.SelectorType) 133 | if !valid { 134 | return nil, ErrInvalidSelectorType 135 | } 136 | 137 | addrs := make(map[string]struct{}) 138 | for _, ins := range cfg.Instances { 139 | addrs[ins] = struct{}{} 140 | } 141 | 142 | bcfg := &backend.BackendConfig{ 143 | Addrs: addrs, 144 | UserName: cfg.Username, 145 | Password: cfg.Password, 146 | Capacity: cfg.PoolSize, 147 | IdleTimeout: time.Duration(cfg.IdleTimeout) * time.Second, 148 | SelectorType: selectorType, 149 | } 150 | return bcfg, nil 151 | } 152 | 153 | func DefaultAsyncCloseNamespace(ns Namespace) error { 154 | nsWrapper, ok := ns.(*NamespaceImpl) 155 | if !ok { 156 | return errors.Errorf("invalid namespace type: %T", ns) 157 | } 158 | go func() { 159 | time.Sleep(30 * time.Second) 160 | //nsWrapper.BreakerHolder.CloseBreaker() 161 | nsWrapper.Backend.Close() 162 | }() 163 | return nil 164 | } 165 | -------------------------------------------------------------------------------- /pkg/proxy/namespace/manager.go: -------------------------------------------------------------------------------- 1 | package namespace 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/pingcap/errors" 7 | "github.com/pingcap/tidb/util/logutil" 8 | "github.com/tidb-incubator/weir/pkg/config" 9 | "github.com/tidb-incubator/weir/pkg/proxy/driver" 10 | "github.com/tidb-incubator/weir/pkg/util/sync2" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | type NamespaceManager struct { 15 | switchIndex sync2.BoolIndex 16 | users [2]*UserNamespaceMapper 17 | nss [2]*NamespaceHolder 18 | build NamespaceBuilder 19 | close NamespaceCloser 20 | 21 | reloadLock sync.Mutex 22 | reloadPrepared map[string]bool 23 | } 24 | 25 | type NamespaceBuilder func(cfg *config.Namespace) (Namespace, error) 26 | type NamespaceCloser func(ns Namespace) error 27 | 28 | func CreateNamespaceManager(cfgs []*config.Namespace, builder NamespaceBuilder, closer NamespaceCloser) (*NamespaceManager, error) { 29 | users, err := CreateUserNamespaceMapper(cfgs) 30 | if err != nil { 31 | return nil, errors.WithMessage(err, "create UserNamespaceMapper error") 32 | } 33 | 34 | nss, err := CreateNamespaceHolder(cfgs, builder) 35 | if err != nil { 36 | return nil, errors.WithMessage(err, "create NamespaceHolder error") 37 | } 38 | 39 | mgr := NewNamespaceManager(users, nss, builder, closer) 40 | return mgr, nil 41 | } 42 | 43 | func NewNamespaceManager(users *UserNamespaceMapper, nss *NamespaceHolder, builder NamespaceBuilder, closer NamespaceCloser) *NamespaceManager { 44 | mgr := &NamespaceManager{ 45 | build: builder, 46 | close: closer, 47 | reloadPrepared: make(map[string]bool), 48 | } 49 | mgr.users[0] = users 50 | mgr.nss[0] = nss 51 | return mgr 52 | } 53 | 54 | func (n *NamespaceManager) Auth(username string, pwd, salt []byte) (driver.Namespace, bool) { 55 | nsName, ok := n.getNamespaceByUsername(username) 56 | if !ok { 57 | return nil, false 58 | } 59 | 60 | wrapper := &NamespaceWrapper{ 61 | nsmgr: n, 62 | name: nsName, 63 | } 64 | 65 | return wrapper, true 66 | } 67 | 68 | func (n *NamespaceManager) PrepareReloadNamespace(namespace string, cfg *config.Namespace) error { 69 | n.reloadLock.Lock() 70 | defer n.reloadLock.Unlock() 71 | 72 | newUsers := n.getCurrentUsers().Clone() 73 | newUsers.RemoveNamespaceUsers(namespace) 74 | if err := newUsers.AddNamespaceUsers(namespace, &cfg.Frontend); err != nil { 75 | return errors.WithMessage(err, "add namespace users error") 76 | } 77 | 78 | newNs, err := n.build(cfg) 79 | if err != nil { 80 | return errors.WithMessage(err, "build namespace error") 81 | } 82 | 83 | newNss := n.getCurrentNamespaces().Clone() 84 | newNss.Set(namespace, newNs) 85 | 86 | n.setOther(newUsers, newNss) 87 | n.reloadPrepared[namespace] = true 88 | 89 | return nil 90 | } 91 | 92 | func (n *NamespaceManager) CommitReloadNamespaces(namespaces []string) error { 93 | n.reloadLock.Lock() 94 | defer n.reloadLock.Unlock() 95 | 96 | for _, namespace := range namespaces { 97 | if !n.reloadPrepared[namespace] { 98 | return errors.Errorf("namespace is not prepared: %s", namespace) 99 | } 100 | } 101 | 102 | n.toggle() 103 | return nil 104 | } 105 | 106 | func (n *NamespaceManager) RemoveNamespace(name string) { 107 | n.reloadLock.Lock() 108 | defer n.reloadLock.Unlock() 109 | 110 | n.getCurrentUsers().RemoveNamespaceUsers(name) 111 | nss := n.getCurrentNamespaces() 112 | ns, ok := nss.Get(name) 113 | if !ok { 114 | return 115 | } 116 | 117 | if err := n.close(ns); err != nil { 118 | logutil.BgLogger().Error("remove namespace error", zap.Error(err), zap.String("namespace", name)) 119 | return 120 | } 121 | 122 | nss.Delete(name) 123 | } 124 | 125 | func (n *NamespaceManager) getNamespaceByUsername(username string) (string, bool) { 126 | return n.getCurrentUsers().GetUserNamespace(username) 127 | } 128 | 129 | func (n *NamespaceManager) getCurrent() (*UserNamespaceMapper, *NamespaceHolder) { 130 | current, _, _ := n.switchIndex.Get() 131 | return n.users[current], n.nss[current] 132 | } 133 | 134 | func (n *NamespaceManager) getCurrentUsers() *UserNamespaceMapper { 135 | current, _, _ := n.switchIndex.Get() 136 | return n.users[current] 137 | } 138 | 139 | func (n *NamespaceManager) getCurrentNamespaces() *NamespaceHolder { 140 | current, _, _ := n.switchIndex.Get() 141 | return n.nss[current] 142 | } 143 | 144 | func (n *NamespaceManager) setOther(users *UserNamespaceMapper, nss *NamespaceHolder) { 145 | _, other, _ := n.switchIndex.Get() 146 | n.users[other] = users 147 | n.nss[other] = nss 148 | } 149 | 150 | func (n *NamespaceManager) toggle() { 151 | _, _, currentFlag := n.switchIndex.Get() 152 | n.switchIndex.Set(!currentFlag) 153 | } 154 | -------------------------------------------------------------------------------- /pkg/proxy/server/column_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | . "github.com/pingcap/check" 18 | "github.com/pingcap/parser/mysql" 19 | ) 20 | 21 | type ColumnTestSuite struct { 22 | } 23 | 24 | var _ = Suite(new(ColumnTestSuite)) 25 | 26 | func (s ColumnTestSuite) TestDumpColumn(c *C) { 27 | info := ColumnInfo{ 28 | Schema: "testSchema", 29 | Table: "testTable", 30 | OrgTable: "testOrgTable", 31 | Name: "testName", 32 | OrgName: "testOrgName", 33 | ColumnLength: 1, 34 | Charset: 106, 35 | Flag: 0, 36 | Decimal: 1, 37 | Type: 14, 38 | DefaultValueLength: 2, 39 | DefaultValue: []byte{5, 2}, 40 | } 41 | r := info.Dump(nil) 42 | exp := []byte{0x3, 0x64, 0x65, 0x66, 0xa, 0x74, 0x65, 0x73, 0x74, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x9, 0x74, 0x65, 0x73, 0x74, 0x54, 0x61, 0x62, 0x6c, 0x65, 0xc, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x54, 0x61, 0x62, 0x6c, 0x65, 0x8, 0x74, 0x65, 0x73, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0xb, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x4e, 0x61, 0x6d, 0x65, 0xc, 0x6a, 0x0, 0x1, 0x0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5, 0x2} 43 | c.Assert(r, DeepEquals, exp) 44 | 45 | c.Assert(dumpFlag(mysql.TypeSet, 0), Equals, uint16(mysql.SetFlag)) 46 | c.Assert(dumpFlag(mysql.TypeEnum, 0), Equals, uint16(mysql.EnumFlag)) 47 | c.Assert(dumpFlag(mysql.TypeString, 0), Equals, uint16(0)) 48 | 49 | c.Assert(dumpType(mysql.TypeSet), Equals, mysql.TypeString) 50 | c.Assert(dumpType(mysql.TypeEnum), Equals, mysql.TypeString) 51 | c.Assert(dumpType(mysql.TypeBit), Equals, mysql.TypeBit) 52 | } 53 | 54 | func (s ColumnTestSuite) TestColumnNameLimit(c *C) { 55 | aLongName := make([]byte, 0, 300) 56 | for i := 0; i < 300; i++ { 57 | aLongName = append(aLongName, 'a') 58 | } 59 | info := ColumnInfo{ 60 | Schema: "testSchema", 61 | Table: "testTable", 62 | OrgTable: "testOrgTable", 63 | Name: string(aLongName), 64 | OrgName: "testOrgName", 65 | ColumnLength: 1, 66 | Charset: 106, 67 | Flag: 0, 68 | Decimal: 1, 69 | Type: 14, 70 | DefaultValueLength: 2, 71 | DefaultValue: []byte{5, 2}, 72 | } 73 | r := info.Dump(nil) 74 | exp := []byte{0x3, 0x64, 0x65, 0x66, 0xa, 0x74, 0x65, 0x73, 0x74, 0x53, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x9, 0x74, 0x65, 0x73, 0x74, 0x54, 0x61, 0x62, 0x6c, 0x65, 0xc, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x54, 0x61, 0x62, 0x6c, 0x65, 0xfc, 0x0, 0x1, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0xb, 0x74, 0x65, 0x73, 0x74, 0x4f, 0x72, 0x67, 0x4e, 0x61, 0x6d, 0x65, 0xc, 0x6a, 0x0, 0x1, 0x0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5, 0x2} 75 | c.Assert(r, DeepEquals, exp) 76 | } 77 | -------------------------------------------------------------------------------- /pkg/proxy/server/packetio.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 2 | // 3 | // This Source Code Form is subject to the terms of the Mozilla Public 4 | // License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 | // You can obtain one at http://mozilla.org/MPL/2.0/. 6 | 7 | // The MIT License (MIT) 8 | // 9 | // Copyright (c) 2014 wandoulabs 10 | // Copyright (c) 2014 siddontang 11 | // 12 | // Permission is hereby granted, free of charge, to any person obtaining a copy of 13 | // this software and associated documentation files (the "Software"), to deal in 14 | // the Software without restriction, including without limitation the rights to 15 | // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 16 | // the Software, and to permit persons to whom the Software is furnished to do so, 17 | // subject to the following conditions: 18 | // 19 | // The above copyright notice and this permission notice shall be included in all 20 | // copies or substantial portions of the Software. 21 | 22 | // Copyright 2015 PingCAP, Inc. 23 | // 24 | // Licensed under the Apache License, Version 2.0 (the "License"); 25 | // you may not use this file except in compliance with the License. 26 | // You may obtain a copy of the License at 27 | // 28 | // http://www.apache.org/licenses/LICENSE-2.0 29 | // 30 | // Unless required by applicable law or agreed to in writing, software 31 | // distributed under the License is distributed on an "AS IS" BASIS, 32 | // See the License for the specific language governing permissions and 33 | // limitations under the License. 34 | 35 | package server 36 | 37 | import ( 38 | "bufio" 39 | "io" 40 | "time" 41 | 42 | "github.com/pingcap/errors" 43 | "github.com/pingcap/parser/mysql" 44 | "github.com/pingcap/parser/terror" 45 | ) 46 | 47 | const defaultWriterSize = 16 * 1024 48 | 49 | // packetIO is a helper to read and write data in packet format. 50 | type packetIO struct { 51 | bufReadConn *bufferedReadConn 52 | bufWriter *bufio.Writer 53 | sequence uint8 54 | readTimeout time.Duration 55 | } 56 | 57 | func newPacketIO(bufReadConn *bufferedReadConn) *packetIO { 58 | p := &packetIO{sequence: 0} 59 | p.setBufferedReadConn(bufReadConn) 60 | return p 61 | } 62 | 63 | func (p *packetIO) setBufferedReadConn(bufReadConn *bufferedReadConn) { 64 | p.bufReadConn = bufReadConn 65 | p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize) 66 | } 67 | 68 | func (p *packetIO) setReadTimeout(timeout time.Duration) { 69 | p.readTimeout = timeout 70 | } 71 | 72 | func (p *packetIO) readOnePacket() ([]byte, error) { 73 | var header [4]byte 74 | 75 | if _, err := io.ReadFull(p.bufReadConn, header[:]); err != nil { 76 | return nil, errors.Trace(err) 77 | } 78 | 79 | sequence := header[3] 80 | if sequence != p.sequence { 81 | return nil, errInvalidSequence.GenWithStack("invalid sequence %d != %d", sequence, p.sequence) 82 | } 83 | 84 | p.sequence++ 85 | 86 | length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) 87 | 88 | data := make([]byte, length) 89 | if _, err := io.ReadFull(p.bufReadConn, data); err != nil { 90 | return nil, errors.Trace(err) 91 | } 92 | return data, nil 93 | } 94 | 95 | func (p *packetIO) readPacket() ([]byte, error) { 96 | data, err := p.readOnePacket() 97 | if err != nil { 98 | return nil, errors.Trace(err) 99 | } 100 | 101 | if len(data) < mysql.MaxPayloadLen { 102 | return data, nil 103 | } 104 | 105 | // handle multi-packet 106 | for { 107 | buf, err := p.readOnePacket() 108 | if err != nil { 109 | return nil, errors.Trace(err) 110 | } 111 | 112 | data = append(data, buf...) 113 | 114 | if len(buf) < mysql.MaxPayloadLen { 115 | break 116 | } 117 | } 118 | 119 | return data, nil 120 | } 121 | 122 | // writePacket writes data that already have header 123 | func (p *packetIO) writePacket(data []byte) error { 124 | length := len(data) - 4 125 | 126 | for length >= mysql.MaxPayloadLen { 127 | data[0] = 0xff 128 | data[1] = 0xff 129 | data[2] = 0xff 130 | 131 | data[3] = p.sequence 132 | 133 | if n, err := p.bufWriter.Write(data[:4+mysql.MaxPayloadLen]); err != nil { 134 | return errors.Trace(mysql.ErrBadConn) 135 | } else if n != (4 + mysql.MaxPayloadLen) { 136 | return errors.Trace(mysql.ErrBadConn) 137 | } else { 138 | p.sequence++ 139 | length -= mysql.MaxPayloadLen 140 | data = data[mysql.MaxPayloadLen:] 141 | } 142 | } 143 | 144 | data[0] = byte(length) 145 | data[1] = byte(length >> 8) 146 | data[2] = byte(length >> 16) 147 | data[3] = p.sequence 148 | 149 | if n, err := p.bufWriter.Write(data); err != nil { 150 | terror.Log(errors.Trace(err)) 151 | return errors.Trace(mysql.ErrBadConn) 152 | } else if n != len(data) { 153 | return errors.Trace(mysql.ErrBadConn) 154 | } else { 155 | p.sequence++ 156 | return nil 157 | } 158 | } 159 | 160 | func (p *packetIO) flush() error { 161 | err := p.bufWriter.Flush() 162 | if err != nil { 163 | return errors.Trace(err) 164 | } 165 | return err 166 | } 167 | -------------------------------------------------------------------------------- /pkg/proxy/server/driver.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 PingCAP, Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | package server 15 | 16 | import ( 17 | "context" 18 | "crypto/tls" 19 | "fmt" 20 | "time" 21 | 22 | "github.com/pingcap/parser/auth" 23 | "github.com/pingcap/tidb/sessionctx/variable" 24 | "github.com/pingcap/tidb/types" 25 | "github.com/pingcap/tidb/util" 26 | "github.com/pingcap/tidb/util/chunk" 27 | "github.com/siddontang/go-mysql/mysql" 28 | ) 29 | 30 | // IDriver opens IContext. 31 | type IDriver interface { 32 | // OpenCtx opens an IContext with connection id, client capability, collation, dbname and optionally the tls state. 33 | OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (QueryCtx, error) 34 | } 35 | 36 | // QueryCtx is the interface to execute command. 37 | type QueryCtx interface { 38 | // Status returns server status code. 39 | Status() uint16 40 | 41 | // LastInsertID returns last inserted ID. 42 | LastInsertID() uint64 43 | 44 | // LastMessage returns last info message generated by some commands 45 | LastMessage() string 46 | 47 | // AffectedRows returns affected rows of last executed command. 48 | AffectedRows() uint64 49 | 50 | // Value returns the value associated with this context for key. 51 | Value(key fmt.Stringer) interface{} 52 | 53 | // SetValue saves a value associated with this context for key. 54 | SetValue(key fmt.Stringer, value interface{}) 55 | 56 | SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) 57 | 58 | // CommitTxn commits the transaction operations. 59 | CommitTxn(ctx context.Context) error 60 | 61 | // RollbackTxn undoes the transaction operations. 62 | RollbackTxn() 63 | 64 | // WarningCount returns warning count of last executed command. 65 | WarningCount() uint16 66 | 67 | // CurrentDB returns current DB. 68 | CurrentDB() string 69 | 70 | // Execute executes a SQL statement. 71 | Execute(ctx context.Context, sql string) (*mysql.Result, error) 72 | 73 | // ExecuteInternal executes a internal SQL statement. 74 | ExecuteInternal(ctx context.Context, sql string) ([]ResultSet, error) 75 | 76 | // SetClientCapability sets client capability flags 77 | SetClientCapability(uint32) 78 | 79 | // Prepare prepares a statement. 80 | Prepare(ctx context.Context, sql string) (stmtId int, columns, params []*ColumnInfo, err error) 81 | 82 | StmtExecuteForward(ctx context.Context, stmtId int, data []byte) (*mysql.Result, error) 83 | 84 | StmtClose(ctx context.Context, stmtId int) error 85 | 86 | // FieldList returns columns of a table. 87 | FieldList(tableName string) (columns []*ColumnInfo, err error) 88 | 89 | // Close closes the QueryCtx. 90 | Close() error 91 | 92 | // Auth verifies user's authentication. 93 | Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool 94 | 95 | // ShowProcess shows the information about the session. 96 | ShowProcess() *util.ProcessInfo 97 | 98 | // GetSessionVars return SessionVars. 99 | GetSessionVars() *variable.SessionVars 100 | 101 | SetCommandValue(command byte) 102 | 103 | SetSessionManager(util.SessionManager) 104 | } 105 | 106 | // PreparedStatement is the interface to use a prepared statement. 107 | type PreparedStatement interface { 108 | // ID returns statement ID 109 | ID() int 110 | 111 | // Execute executes the statement. 112 | Execute(context.Context, []types.Datum) (ResultSet, error) 113 | 114 | // AppendParam appends parameter to the statement. 115 | AppendParam(paramID int, data []byte) error 116 | 117 | // NumParams returns number of parameters. 118 | NumParams() int 119 | 120 | // BoundParams returns bound parameters. 121 | BoundParams() [][]byte 122 | 123 | // SetParamsType sets type for parameters. 124 | SetParamsType([]byte) 125 | 126 | // GetParamsType returns the type for parameters. 127 | GetParamsType() []byte 128 | 129 | // StoreResultSet stores ResultSet for subsequent stmt fetching 130 | StoreResultSet(rs ResultSet) 131 | 132 | // GetResultSet gets ResultSet associated this statement 133 | GetResultSet() ResultSet 134 | 135 | // Reset removes all bound parameters. 136 | Reset() 137 | 138 | // Close closes the statement. 139 | Close() error 140 | } 141 | 142 | // ResultSet is the result set of an query. 143 | type ResultSet interface { 144 | Columns() []*ColumnInfo 145 | NewChunk() *chunk.Chunk 146 | Next(context.Context, *chunk.Chunk) error 147 | StoreFetchedRows(rows []chunk.Row) 148 | GetFetchedRows() []chunk.Row 149 | Close() error 150 | } 151 | 152 | // fetchNotifier represents notifier will be called in COM_FETCH. 153 | type fetchNotifier interface { 154 | // OnFetchReturned be called when COM_FETCH returns. 155 | // it will be used in server-side cursor. 156 | OnFetchReturned() 157 | } 158 | -------------------------------------------------------------------------------- /pkg/proxy/backend/backend.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | "time" 8 | 9 | "github.com/tidb-incubator/weir/pkg/proxy/backend/client" 10 | "github.com/tidb-incubator/weir/pkg/proxy/driver" 11 | "github.com/tidb-incubator/weir/pkg/proxy/metrics" 12 | "github.com/tidb-incubator/weir/pkg/util/sync2" 13 | "github.com/pingcap/tidb/util/logutil" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | var ( 18 | ErrNoBackendAddr = errors.New("no backend addr") 19 | ErrBackendClosed = errors.New("backend is closed") 20 | ErrBackendNotFound = errors.New("backend not found") 21 | ) 22 | 23 | type BackendConfig struct { 24 | Addrs map[string]struct{} 25 | UserName string 26 | Password string 27 | Capacity int 28 | IdleTimeout time.Duration 29 | SelectorType int 30 | } 31 | 32 | type BackendImpl struct { 33 | ns string 34 | cfg *BackendConfig 35 | connPools map[string]*ConnPool // key: addr 36 | instances []*Instance 37 | selector Selector 38 | 39 | lock sync.RWMutex 40 | closed sync2.AtomicBool 41 | } 42 | 43 | func NewBackendImpl(ns string, cfg *BackendConfig) *BackendImpl { 44 | return &BackendImpl{ 45 | cfg: cfg, 46 | closed: sync2.NewAtomicBool(false), 47 | ns: ns, 48 | } 49 | } 50 | 51 | func (b *BackendImpl) Init() error { 52 | b.lock.Lock() 53 | defer b.lock.Unlock() 54 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventIniting).Inc() 55 | 56 | if err := b.initSelector(); err != nil { 57 | return err 58 | } 59 | if err := b.initInstances(); err != nil { 60 | return err 61 | } 62 | if err := b.initConnPools(); err != nil { 63 | return err 64 | } 65 | 66 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventInited).Inc() 67 | return nil 68 | } 69 | 70 | func (b *BackendImpl) initSelector() error { 71 | selector, err := CreateSelector(b.cfg.SelectorType) 72 | if err != nil { 73 | return err 74 | } 75 | b.selector = selector 76 | return nil 77 | } 78 | 79 | func (b *BackendImpl) initInstances() error { 80 | instances, err := createInstances(b.cfg) 81 | if err != nil { 82 | return err 83 | } 84 | b.instances = instances 85 | return nil 86 | } 87 | 88 | func (b *BackendImpl) initConnPools() error { 89 | connPools := make(map[string]*ConnPool) 90 | for addr := range b.cfg.Addrs { 91 | poolCfg := &ConnPoolConfig{ 92 | Config: Config{Addr: addr, UserName: b.cfg.UserName, Password: b.cfg.Password}, 93 | Capacity: b.cfg.Capacity, 94 | IdleTimeout: b.cfg.IdleTimeout, 95 | } 96 | connPool := NewConnPool(b.ns, poolCfg) 97 | connPools[addr] = connPool 98 | } 99 | 100 | successfulInitConnPoolAddrs := make(map[string]struct{}) 101 | var initConnPoolErr error 102 | for addr, connPool := range connPools { 103 | if err := connPool.Init(); err != nil { 104 | initConnPoolErr = err 105 | break 106 | } 107 | successfulInitConnPoolAddrs[addr] = struct{}{} 108 | } 109 | 110 | if initConnPoolErr != nil { 111 | for addr := range successfulInitConnPoolAddrs { 112 | if err := connPools[addr].Close(); err != nil { 113 | logutil.BgLogger().Sugar().Error("close inited conn pool error, addr: %s, err: %v", addr, err) 114 | } 115 | } 116 | return initConnPoolErr 117 | } 118 | 119 | b.connPools = connPools 120 | return nil 121 | } 122 | 123 | func (b *BackendImpl) GetConn(ctx context.Context) (driver.SimpleBackendConn, error) { 124 | if b.closed.Get() { 125 | return nil, ErrBackendClosed 126 | } 127 | 128 | instance, err := b.route(b.instances) 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | conn, err := client.Connect(instance.Addr(), b.cfg.UserName, b.cfg.Password, "") 134 | return conn, err 135 | } 136 | 137 | func (b *BackendImpl) GetPooledConn(ctx context.Context) (driver.PooledBackendConn, error) { 138 | if b.closed.Get() { 139 | return nil, ErrBackendClosed 140 | } 141 | 142 | instance, err := b.route(b.instances) 143 | if err != nil { 144 | return nil, err 145 | } 146 | 147 | b.lock.RLock() 148 | connPool, ok := b.connPools[instance.Addr()] 149 | b.lock.RUnlock() 150 | if !ok { 151 | return nil, ErrBackendNotFound 152 | } 153 | 154 | return connPool.GetConn(ctx) 155 | } 156 | 157 | func (b *BackendImpl) Close() { 158 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventClosing).Inc() 159 | if !b.closed.CompareAndSwap(false, true) { 160 | return 161 | } 162 | 163 | b.lock.Lock() 164 | defer b.lock.Unlock() 165 | 166 | for addr, connPool := range b.connPools { 167 | if err := connPool.Close(); err != nil { 168 | logutil.BgLogger().Error("close conn pool error, addr: %s, err: %v", zap.String("addr", addr), zap.Error(err)) 169 | } 170 | } 171 | 172 | metrics.BackendEventCounter.WithLabelValues(b.ns, metrics.BackendEventClosed).Inc() 173 | } 174 | 175 | func (b *BackendImpl) route(instances []*Instance) (*Instance, error) { 176 | instance, err := b.selector.Select(b.instances) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | return instance, nil 182 | } 183 | 184 | func createInstances(cfg *BackendConfig) ([]*Instance, error) { 185 | if len(cfg.Addrs) == 0 { 186 | return nil, ErrNoBackendAddr 187 | } 188 | 189 | var ret []*Instance 190 | for addr := range cfg.Addrs { 191 | ins := &Instance{addr: addr} 192 | ret = append(ret, ins) 193 | } 194 | return ret, nil 195 | } 196 | -------------------------------------------------------------------------------- /pkg/proxy/driver/connmgr.go: -------------------------------------------------------------------------------- 1 | package driver 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "sync" 7 | 8 | "github.com/tidb-incubator/weir/pkg/proxy/metrics" 9 | utilerrors "github.com/tidb-incubator/weir/pkg/util/errors" 10 | "github.com/pingcap/parser/mysql" 11 | "github.com/pingcap/tidb/util/logutil" 12 | gomysql "github.com/siddontang/go-mysql/mysql" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | type BackendConnManager struct { 17 | fsm *FSM 18 | state FSMState 19 | 20 | ns Namespace 21 | 22 | mu sync.Mutex 23 | txnConn PooledBackendConn 24 | 25 | // TODO: use stmt id set 26 | isPrepared bool 27 | } 28 | 29 | func NewBackendConnManager(fsm *FSM, ns Namespace) *BackendConnManager { 30 | return &BackendConnManager{ 31 | fsm: fsm, 32 | state: stateInitial, 33 | ns: ns, 34 | isPrepared: false, 35 | } 36 | } 37 | 38 | func (f *BackendConnManager) MergeStatus(svw *SessionVarsWrapper) { 39 | f.mu.Lock() 40 | defer f.mu.Unlock() 41 | 42 | svw.SetStatusFlag(mysql.ServerStatusInTrans, f.state.IsInTransaction()) 43 | svw.SetStatusFlag(mysql.ServerStatusAutocommit, f.state.IsAutoCommit()) 44 | } 45 | 46 | func (f *BackendConnManager) Query(ctx context.Context, db, sql string) (*gomysql.Result, error) { 47 | f.mu.Lock() 48 | defer f.mu.Unlock() 49 | 50 | ret, err := f.fsm.Call(ctx, EventQuery, f, db, sql) 51 | if err != nil { 52 | return nil, err 53 | } 54 | return ret.(*gomysql.Result), nil 55 | } 56 | 57 | func (f *BackendConnManager) SetAutoCommit(ctx context.Context, autocommit bool) error { 58 | f.mu.Lock() 59 | defer f.mu.Unlock() 60 | 61 | var err error 62 | if autocommit { 63 | _, err = f.fsm.Call(ctx, EventEnableAutoCommit, f) 64 | } else { 65 | _, err = f.fsm.Call(ctx, EventDisableAutoCommit, f) 66 | } 67 | return err 68 | } 69 | 70 | func (f *BackendConnManager) Begin(ctx context.Context) error { 71 | f.mu.Lock() 72 | defer f.mu.Unlock() 73 | 74 | _, err := f.fsm.Call(ctx, EventBegin, f) 75 | return err 76 | } 77 | 78 | func (f *BackendConnManager) CommitOrRollback(ctx context.Context, commit bool) error { 79 | f.mu.Lock() 80 | defer f.mu.Unlock() 81 | 82 | _, err := f.fsm.Call(ctx, EventCommitOrRollback, f, commit) 83 | return err 84 | } 85 | 86 | func (f *BackendConnManager) StmtPrepare(ctx context.Context, db, sql string) (Stmt, error) { 87 | f.mu.Lock() 88 | defer f.mu.Unlock() 89 | 90 | ret, err := f.fsm.Call(ctx, EventStmtPrepare, f, db, sql) 91 | if err != nil { 92 | return nil, err 93 | } 94 | return ret.(Stmt), nil 95 | } 96 | 97 | func (f *BackendConnManager) StmtExecuteForward(ctx context.Context, stmtId int, data []byte) (*gomysql.Result, error) { 98 | f.mu.Lock() 99 | defer f.mu.Unlock() 100 | 101 | ret, err := f.fsm.Call(ctx, EventStmtForwardData, f, stmtId, data) 102 | if err != nil { 103 | return nil, err 104 | } 105 | if ret == nil { 106 | return nil, nil 107 | } 108 | return ret.(*gomysql.Result), nil 109 | } 110 | 111 | func (f *BackendConnManager) StmtClose(ctx context.Context, stmtId int) error { 112 | f.mu.Lock() 113 | defer f.mu.Unlock() 114 | 115 | _, err := f.fsm.Call(ctx, EventStmtClose, f, stmtId) 116 | return err 117 | } 118 | 119 | // TODO(eastfisher): is it possible to use FSM to manage close? 120 | func (f *BackendConnManager) Close() error { 121 | f.mu.Lock() 122 | defer f.mu.Unlock() 123 | 124 | if f.txnConn != nil { 125 | errClosePooledBackendConn(f.txnConn, f.ns.Name()) 126 | } 127 | f.state = stateInitial 128 | f.unsetAttachedConn() 129 | return nil 130 | } 131 | 132 | func (f *BackendConnManager) queryWithoutTxn(ctx context.Context, db, sql string) (*gomysql.Result, error) { 133 | var err error 134 | conn, err := f.ns.GetPooledConn(ctx) 135 | if err != nil { 136 | return nil, err 137 | } 138 | 139 | defer func() { 140 | if err != nil && isConnError(err) { 141 | if errClose := conn.ErrorClose(); errClose != nil { 142 | logutil.BgLogger().Error("close backend conn error", zap.Error(errClose)) 143 | } 144 | } else { 145 | conn.PutBack() 146 | } 147 | }() 148 | 149 | if err = conn.UseDB(db); err != nil { 150 | return nil, err 151 | } 152 | 153 | var ret *gomysql.Result 154 | ret, err = conn.Execute(sql) 155 | return ret, err 156 | } 157 | 158 | func (f *BackendConnManager) queryInTxn(ctx context.Context, db, sql string) (*gomysql.Result, error) { 159 | if err := f.txnConn.UseDB(db); err != nil { 160 | return nil, err 161 | } 162 | return f.txnConn.Execute(sql) 163 | } 164 | 165 | func (f *BackendConnManager) releaseAttachedConn(err error) { 166 | if err != nil { 167 | errClosePooledBackendConn(f.txnConn, f.ns.Name()) 168 | } else { 169 | f.txnConn.PutBack() 170 | } 171 | f.unsetAttachedConn() 172 | } 173 | 174 | func (f *BackendConnManager) setAttachedConn(conn PooledBackendConn) { 175 | f.txnConn = conn 176 | metrics.QueryCtxAttachedConnGauge.WithLabelValues(f.ns.Name()).Inc() 177 | } 178 | 179 | func (f *BackendConnManager) unsetAttachedConn() { 180 | if f.txnConn != nil { 181 | metrics.QueryCtxAttachedConnGauge.WithLabelValues(f.ns.Name()).Dec() 182 | } 183 | f.txnConn = nil 184 | } 185 | 186 | func errClosePooledBackendConn(conn PooledBackendConn, ns string) { 187 | if err := conn.ErrorClose(); err != nil { 188 | logutil.BgLogger().Error("close backend conn error", zap.Error(err), zap.String("namespace", ns)) 189 | } 190 | } 191 | 192 | func isConnError(err error) bool { 193 | return utilerrors.Is(err, gomysql.ErrBadConn) || utilerrors.Is(err, driver.ErrBadConn) 194 | } 195 | -------------------------------------------------------------------------------- /pkg/util/sync2/atomic.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2019 The Vitess Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package sync2 18 | 19 | import ( 20 | "sync" 21 | "sync/atomic" 22 | "time" 23 | ) 24 | 25 | // AtomicInt32 is a wrapper with a simpler interface around atomic.(Add|Store|Load|CompareAndSwap)Int32 functions. 26 | type AtomicInt32 struct { 27 | int32 28 | } 29 | 30 | // NewAtomicInt32 initializes a new AtomicInt32 with a given value. 31 | func NewAtomicInt32(n int32) AtomicInt32 { 32 | return AtomicInt32{n} 33 | } 34 | 35 | // Add atomically adds n to the value. 36 | func (i *AtomicInt32) Add(n int32) int32 { 37 | return atomic.AddInt32(&i.int32, n) 38 | } 39 | 40 | // Set atomically sets n as new value. 41 | func (i *AtomicInt32) Set(n int32) { 42 | atomic.StoreInt32(&i.int32, n) 43 | } 44 | 45 | // Get atomically returns the current value. 46 | func (i *AtomicInt32) Get() int32 { 47 | return atomic.LoadInt32(&i.int32) 48 | } 49 | 50 | // CompareAndSwap automatically swaps the old with the new value. 51 | func (i *AtomicInt32) CompareAndSwap(oldval, newval int32) (swapped bool) { 52 | return atomic.CompareAndSwapInt32(&i.int32, oldval, newval) 53 | } 54 | 55 | // AtomicInt64 is a wrapper with a simpler interface around atomic.(Add|Store|Load|CompareAndSwap)Int64 functions. 56 | type AtomicInt64 struct { 57 | int64 58 | } 59 | 60 | // NewAtomicInt64 initializes a new AtomicInt64 with a given value. 61 | func NewAtomicInt64(n int64) AtomicInt64 { 62 | return AtomicInt64{n} 63 | } 64 | 65 | // Add atomically adds n to the value. 66 | func (i *AtomicInt64) Add(n int64) int64 { 67 | return atomic.AddInt64(&i.int64, n) 68 | } 69 | 70 | // Set atomically sets n as new value. 71 | func (i *AtomicInt64) Set(n int64) { 72 | atomic.StoreInt64(&i.int64, n) 73 | } 74 | 75 | // Get atomically returns the current value. 76 | func (i *AtomicInt64) Get() int64 { 77 | return atomic.LoadInt64(&i.int64) 78 | } 79 | 80 | // CompareAndSwap automatically swaps the old with the new value. 81 | func (i *AtomicInt64) CompareAndSwap(oldval, newval int64) (swapped bool) { 82 | return atomic.CompareAndSwapInt64(&i.int64, oldval, newval) 83 | } 84 | 85 | // AtomicDuration is a wrapper with a simpler interface around atomic.(Add|Store|Load|CompareAndSwap)Int64 functions. 86 | type AtomicDuration struct { 87 | int64 88 | } 89 | 90 | // NewAtomicDuration initializes a new AtomicDuration with a given value. 91 | func NewAtomicDuration(duration time.Duration) AtomicDuration { 92 | return AtomicDuration{int64(duration)} 93 | } 94 | 95 | // Add atomically adds duration to the value. 96 | func (d *AtomicDuration) Add(duration time.Duration) time.Duration { 97 | return time.Duration(atomic.AddInt64(&d.int64, int64(duration))) 98 | } 99 | 100 | // Set atomically sets duration as new value. 101 | func (d *AtomicDuration) Set(duration time.Duration) { 102 | atomic.StoreInt64(&d.int64, int64(duration)) 103 | } 104 | 105 | // Get atomically returns the current value. 106 | func (d *AtomicDuration) Get() time.Duration { 107 | return time.Duration(atomic.LoadInt64(&d.int64)) 108 | } 109 | 110 | // CompareAndSwap automatically swaps the old with the new value. 111 | func (d *AtomicDuration) CompareAndSwap(oldval, newval time.Duration) (swapped bool) { 112 | return atomic.CompareAndSwapInt64(&d.int64, int64(oldval), int64(newval)) 113 | } 114 | 115 | // AtomicBool gives an atomic boolean variable. 116 | type AtomicBool struct { 117 | int32 118 | } 119 | 120 | // NewAtomicBool initializes a new AtomicBool with a given value. 121 | func NewAtomicBool(n bool) AtomicBool { 122 | if n { 123 | return AtomicBool{1} 124 | } 125 | return AtomicBool{0} 126 | } 127 | 128 | // Set atomically sets n as new value. 129 | func (i *AtomicBool) Set(n bool) { 130 | if n { 131 | atomic.StoreInt32(&i.int32, 1) 132 | } else { 133 | atomic.StoreInt32(&i.int32, 0) 134 | } 135 | } 136 | 137 | // Get atomically returns the current value. 138 | func (i *AtomicBool) Get() bool { 139 | return atomic.LoadInt32(&i.int32) != 0 140 | } 141 | 142 | // CompareAndSwap automatically swaps the old with the new value. 143 | func (i *AtomicBool) CompareAndSwap(o, n bool) bool { 144 | var old, new int32 145 | if o { 146 | old = 1 147 | } 148 | if n { 149 | new = 1 150 | } 151 | return atomic.CompareAndSwapInt32(&i.int32, old, new) 152 | } 153 | 154 | // AtomicString gives you atomic-style APIs for string, but 155 | // it's only a convenience wrapper that uses a mutex. So, it's 156 | // not as efficient as the rest of the atomic types. 157 | type AtomicString struct { 158 | mu sync.Mutex 159 | str string 160 | } 161 | 162 | // Set atomically sets str as new value. 163 | func (s *AtomicString) Set(str string) { 164 | s.mu.Lock() 165 | s.str = str 166 | s.mu.Unlock() 167 | } 168 | 169 | // Get atomically returns the current value. 170 | func (s *AtomicString) Get() string { 171 | s.mu.Lock() 172 | str := s.str 173 | s.mu.Unlock() 174 | return str 175 | } 176 | 177 | // CompareAndSwap automatically swaps the old with the new value. 178 | func (s *AtomicString) CompareAndSwap(oldval, newval string) (swqpped bool) { 179 | s.mu.Lock() 180 | defer s.mu.Unlock() 181 | if s.str == oldval { 182 | s.str = newval 183 | return true 184 | } 185 | return false 186 | } 187 | -------------------------------------------------------------------------------- /pkg/proxy/backend/client/conn.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "net" 7 | "strings" 8 | "time" 9 | 10 | "github.com/pingcap/errors" 11 | . "github.com/siddontang/go-mysql/mysql" 12 | "github.com/siddontang/go-mysql/packet" 13 | "github.com/tidb-incubator/weir/pkg/proxy/constant" 14 | ) 15 | 16 | type Conn struct { 17 | *packet.Conn 18 | 19 | user string 20 | password string 21 | db string 22 | tlsConfig *tls.Config 23 | proto string 24 | 25 | capability uint32 26 | 27 | status uint16 28 | 29 | charset string 30 | 31 | salt []byte 32 | authPluginName string 33 | 34 | connectionID uint32 35 | } 36 | 37 | func getNetProto(addr string) string { 38 | proto := "tcp" 39 | if strings.Contains(addr, "/") { 40 | proto = "unix" 41 | } 42 | return proto 43 | } 44 | 45 | // Connect to a MySQL server, addr can be ip:port, or a unix socket domain like /var/sock. 46 | // Accepts a series of configuration functions as a variadic argument. 47 | func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { 48 | proto := getNetProto(addr) 49 | 50 | c := new(Conn) 51 | 52 | var err error 53 | conn, err := net.DialTimeout(proto, addr, 10*time.Second) 54 | if err != nil { 55 | return nil, errors.Trace(err) 56 | } 57 | 58 | if c.tlsConfig != nil { 59 | c.Conn = packet.NewTLSConn(conn) 60 | } else { 61 | c.Conn = packet.NewConn(conn) 62 | } 63 | 64 | c.user = user 65 | c.password = password 66 | c.db = dbName 67 | c.proto = proto 68 | 69 | //use default charset here 70 | c.charset = constant.DefaultCharset 71 | 72 | // Apply configuration functions. 73 | for i := range options { 74 | options[i](c) 75 | } 76 | 77 | if err = c.handshake(); err != nil { 78 | return nil, errors.Trace(err) 79 | } 80 | 81 | return c, nil 82 | } 83 | 84 | func (c *Conn) handshake() error { 85 | var err error 86 | if err = c.readInitialHandshake(); err != nil { 87 | c.Close() 88 | return errors.Trace(err) 89 | } 90 | 91 | if err := c.writeAuthHandshake(); err != nil { 92 | c.Close() 93 | 94 | return errors.Trace(err) 95 | } 96 | 97 | if err := c.handleAuthResult(); err != nil { 98 | c.Close() 99 | return errors.Trace(err) 100 | } 101 | 102 | return nil 103 | } 104 | 105 | func (c *Conn) Close() error { 106 | return c.Conn.Close() 107 | } 108 | 109 | func (c *Conn) Ping() error { 110 | if err := c.writeCommand(COM_PING); err != nil { 111 | return errors.Trace(err) 112 | } 113 | 114 | if _, err := c.readOK(); err != nil { 115 | return errors.Trace(err) 116 | } 117 | 118 | return nil 119 | } 120 | 121 | // UseSSL: use default SSL 122 | // pass to options when connect 123 | func (c *Conn) UseSSL(insecureSkipVerify bool) { 124 | c.tlsConfig = &tls.Config{InsecureSkipVerify: insecureSkipVerify} 125 | } 126 | 127 | // SetTLSConfig: use user-specified TLS config 128 | // pass to options when connect 129 | func (c *Conn) SetTLSConfig(config *tls.Config) { 130 | c.tlsConfig = config 131 | } 132 | 133 | func (c *Conn) UseDB(dbName string) error { 134 | if c.db == dbName { 135 | return nil 136 | } 137 | 138 | if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil { 139 | return errors.Trace(err) 140 | } 141 | 142 | if _, err := c.readOK(); err != nil { 143 | return errors.Trace(err) 144 | } 145 | 146 | c.db = dbName 147 | return nil 148 | } 149 | 150 | func (c *Conn) GetDB() string { 151 | return c.db 152 | } 153 | 154 | func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { 155 | if len(args) == 0 { 156 | return c.exec(command) 157 | } else { 158 | if s, err := c.Prepare(command); err != nil { 159 | return nil, errors.Trace(err) 160 | } else { 161 | var r *Result 162 | r, err = s.Execute(args...) 163 | s.Close() 164 | return r, err 165 | } 166 | } 167 | } 168 | 169 | func (c *Conn) Begin() error { 170 | _, err := c.exec("BEGIN") 171 | return errors.Trace(err) 172 | } 173 | 174 | func (c *Conn) Commit() error { 175 | _, err := c.exec("COMMIT") 176 | return errors.Trace(err) 177 | } 178 | 179 | func (c *Conn) Rollback() error { 180 | _, err := c.exec("ROLLBACK") 181 | return errors.Trace(err) 182 | } 183 | 184 | func (c *Conn) SetCharset(charset string) error { 185 | if c.charset == charset { 186 | return nil 187 | } 188 | 189 | if _, err := c.exec(fmt.Sprintf("SET NAMES %s", charset)); err != nil { 190 | return errors.Trace(err) 191 | } else { 192 | c.charset = charset 193 | return nil 194 | } 195 | } 196 | 197 | func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { 198 | if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil { 199 | return nil, errors.Trace(err) 200 | } 201 | fs := make([]*Field, 0, 4) 202 | for { 203 | data, err := c.ReadPacket() 204 | if err != nil { 205 | return nil, errors.Trace(err) 206 | } 207 | if data[0] == ERR_HEADER { 208 | return nil, c.handleErrorPacket(data) 209 | } 210 | 211 | // EOF Packet 212 | if c.isEOFPacket(data) { 213 | break 214 | } 215 | f, err := FieldData(data).Parse() 216 | if err != nil { 217 | return nil, errors.Trace(err) 218 | } 219 | fs = append(fs, f) 220 | } 221 | return fs, nil 222 | } 223 | 224 | func (c *Conn) SetAutoCommit(autocommit bool) error { 225 | if c.IsAutoCommit() == autocommit { 226 | return nil 227 | } 228 | 229 | if autocommit { 230 | if _, err := c.exec("SET AUTOCOMMIT = 1"); err != nil { 231 | return errors.Trace(err) 232 | } 233 | return nil 234 | } else { 235 | if _, err := c.exec("SET AUTOCOMMIT = 0"); err != nil { 236 | return errors.Trace(err) 237 | } 238 | return nil 239 | } 240 | } 241 | 242 | func (c *Conn) IsAutoCommit() bool { 243 | return c.status&SERVER_STATUS_AUTOCOMMIT > 0 244 | } 245 | 246 | func (c *Conn) IsInTransaction() bool { 247 | return c.status&SERVER_STATUS_IN_TRANS > 0 248 | } 249 | 250 | func (c *Conn) GetCharset() string { 251 | return c.charset 252 | } 253 | 254 | func (c *Conn) GetConnectionID() uint32 { 255 | return c.connectionID 256 | } 257 | 258 | func (c *Conn) GetStatus() uint16 { 259 | return c.status 260 | } 261 | 262 | func (c *Conn) HandleOKPacket(data []byte) *Result { 263 | r, _ := c.handleOKPacket(data) 264 | return r 265 | } 266 | 267 | func (c *Conn) HandleErrorPacket(data []byte) error { 268 | return c.handleErrorPacket(data) 269 | } 270 | 271 | func (c *Conn) ReadOKPacket() (*Result, error) { 272 | return c.readOK() 273 | } 274 | 275 | func (c *Conn) exec(query string) (*Result, error) { 276 | if err := c.writeCommandStr(COM_QUERY, query); err != nil { 277 | return nil, errors.Trace(err) 278 | } 279 | 280 | return c.readResult(false) 281 | } 282 | --------------------------------------------------------------------------------