├── .gitignore ├── Dockerfile ├── Dockerrun.aws.json.template ├── Godeps ├── Godeps.json └── Readme ├── LICENSE ├── Makefile ├── README.md ├── bin └── .keep ├── conf ├── conf.go ├── config_test.go └── config_test.toml ├── etc ├── example.config.toml └── supervisord.conf ├── health └── http.go ├── main.go ├── metrics ├── console.go ├── librato.go ├── proxy.go └── runtime.go ├── proxy └── proxy.go ├── rewrite ├── auth_rewriter.go ├── auth_rewriter_test.go ├── message_rewriter.go ├── message_rewriter_test.go ├── topic_rewriter.go └── topic_rewriter_test.go ├── ssl ├── cert.pem └── key.pem ├── store ├── mysql.go └── store.go ├── tcp ├── conn.go └── server.go ├── util └── util.go └── version.go /.gitignore: -------------------------------------------------------------------------------- 1 | config.toml 2 | mqtt-proxy 3 | pkg 4 | .gopath 5 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.4 2 | MAINTAINER Ninja Blocks 3 | 4 | RUN apt-get -qy update && apt-get -qy install vim-common gcc mercurial bzr supervisor 5 | RUN mkdir -p /var/log/supervisor 6 | RUN mkdir -p /etc/mqtt-proxy 7 | 8 | COPY etc/supervisord.conf /etc/supervisor/conf.d/supervisord.conf 9 | COPY etc/example.config.toml /etc/mqtt-proxy/config.toml 10 | 11 | COPY build/mqtt-proxy /app/ 12 | WORKDIR /app 13 | 14 | EXPOSE 6300 15 | CMD ["/usr/bin/supervisord"] -------------------------------------------------------------------------------- /Dockerrun.aws.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "AWSEBDockerrunVersion": "1", 3 | "Authentication": { 4 | "Bucket": "ninjablocks-sphere-docker", 5 | "Key": "dockercfg" 6 | }, 7 | "Image": { 8 | "Name": "ninjablocks/mqtt-proxy:", 9 | "Update": "true" 10 | }, 11 | "Ports": [ 12 | { 13 | "ContainerPort": "6300" 14 | } 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /Godeps/Godeps.json: -------------------------------------------------------------------------------- 1 | { 2 | "ImportPath": "github.com/ninjablocks/mqtt-proxy", 3 | "GoVersion": "go1.6", 4 | "GodepVersion": "v60", 5 | "Deps": [ 6 | { 7 | "ImportPath": "github.com/BurntSushi/toml", 8 | "Comment": "v0.1.0-21-g056c9bc", 9 | "Rev": "056c9bc7be7190eaa7715723883caffa5f8fa3e4" 10 | }, 11 | { 12 | "ImportPath": "github.com/cloudfoundry/gosigar", 13 | "Comment": "scotty_09012012-21-gd906efd", 14 | "Rev": "d906efd1da51405714ee9b67c79cf14cdb58fd29" 15 | }, 16 | { 17 | "ImportPath": "github.com/davecgh/go-spew/spew", 18 | "Rev": "2df174808ee097f90d259e432cc04442cf60be21" 19 | }, 20 | { 21 | "ImportPath": "github.com/go-sql-driver/mysql", 22 | "Comment": "v1.2-171-g267b128", 23 | "Rev": "267b128680c46286b9ca13475c3cca5de8f79bd7" 24 | }, 25 | { 26 | "ImportPath": "github.com/rcrowley/go-metrics", 27 | "Rev": "3e5e593311103d49927c8d2b0fd93ccdfe4a525c" 28 | }, 29 | { 30 | "ImportPath": "github.com/rcrowley/go-metrics/librato", 31 | "Rev": "3e5e593311103d49927c8d2b0fd93ccdfe4a525c" 32 | }, 33 | { 34 | "ImportPath": "github.com/wolfeidau/mqtt", 35 | "Rev": "b3e33afe097268cf25bc43dc257c5042f44560ab" 36 | } 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /Godeps/Readme: -------------------------------------------------------------------------------- 1 | This directory tree is generated automatically by godep. 2 | 3 | Please do not edit. 4 | 5 | See https://github.com/tools/godep for more information. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014-2015 Ninja Blocks Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT ?= mqtt-proxy 2 | EB_BUCKET ?= ninjablocks-sphere-docker 3 | 4 | APP_NAME ?= mqtt-proxy 5 | APP_ENV ?= mqtt-proxy-prod 6 | 7 | SHA1 := $(shell git rev-parse --short HEAD | tr -d "\n") 8 | 9 | DOCKERRUN_FILE := Dockerrun.aws.json 10 | APP_FILE := ${SHA1}.zip 11 | 12 | build: binary 13 | docker build -t "ninjablocks/${PROJECT}:${SHA1}" . 14 | 15 | push: 16 | docker push "ninjablocks/${PROJECT}:${SHA1}" 17 | 18 | services: 19 | docker run --name ninja-rabbit -p 5672:5672 -p 15672:15672 -d mikaelhg/docker-rabbitmq 20 | 21 | local: 22 | docker run -t -i --rm --link ninja-rabbit:rabbit -e "DEBUG=true" \ 23 | -p 6300:6300 -t "ninjablocks/${PROJECT}:${SHA1}" 24 | 25 | binary: 26 | godep restore 27 | GOOS=linux GOARCH=amd64 go build -o build/mqtt-proxy -ldflags "\ 28 | -X main.buildVersion=$$(grep "const Version " version.go | sed -E 's/.*"(.+)"$$/\1/' ) \ 29 | -X main.buildRevision=$$(git rev-parse --short HEAD) \ 30 | -X main.buildBranch=$$(git rev-parse --abbrev-ref HEAD) \ 31 | -X main.buildDate=$$(date +%Y%m%d-%H:%M:%S)" 32 | 33 | deploy: 34 | sed "s//${SHA1}/" < Dockerrun.aws.json.template > ${DOCKERRUN_FILE} 35 | zip -r ${APP_FILE} ${DOCKERRUN_FILE} .ebextensions 36 | 37 | aws s3 cp ${APP_FILE} s3://${EB_BUCKET}/${APP_ENV}/${APP_FILE} 38 | 39 | aws elasticbeanstalk create-application-version --application-name ${APP_NAME} \ 40 | --version-label ${SHA1} --source-bundle S3Bucket=${EB_BUCKET},S3Key=${APP_ENV}/${APP_FILE} 41 | 42 | # # Update Elastic Beanstalk environment to new version 43 | aws elasticbeanstalk update-environment --environment-name ${APP_ENV} \ 44 | --version-label ${SHA1} 45 | 46 | clean: 47 | rm *.zip || true 48 | rm ${DOCKERRUN_FILE} || true 49 | rm -rf build || true 50 | 51 | .PHONY: all build push local services deploy clean 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mqtt-proxy 2 | 3 | This service acts as a front end for mqtt servers peforming preauthentication, load balancing and rate limiting. 4 | 5 | # setup 6 | 7 | Create a table with tokens in it. 8 | 9 | ```sql 10 | CREATE TABLE legends ( 11 | uid int(11) NOT NULL AUTO_INCREMENT, 12 | mqtt_id varchar(128) COLLATE utf8_bin NOT NULL, 13 | PRIMARY KEY (uid), 14 | UNIQUE KEY mqtt_id_UNIQUE (mqtt_id) 15 | ) DEFAULT CHARSET=utf8; 16 | 17 | CREATE TABLE tokens ( 18 | token_id int(11) NOT NULL AUTO_INCREMENT, 19 | uid int(11) NOT NULL, 20 | token varchar(64) COLLATE utf8_bin NOT NULL, 21 | PRIMARY KEY (token_id), 22 | UNIQUE KEY token_UNIQUE (token) 23 | ) DEFAULT CHARSET=utf8; 24 | ``` 25 | 26 | Configure rabbitmq as a backend for mqtt-proxy by editing or creating `/usr/local/etc/rabbitmq/rabbitmq.config`. 27 | 28 | ``` 29 | [{rabbit, [{tcp_listeners, {"0.0.0.0", 5672}}]}, 30 | {rabbitmq_mqtt, [{default_user, <<"guest">>}, 31 | {default_pass, <<"guest">>}, 32 | {allow_anonymous, true}, 33 | {vhost, <<"/">>}, 34 | {exchange, <<"amq.topic">>}, 35 | {subscription_ttl, 1800000}, 36 | {prefetch, 10}, 37 | {ssl_listeners, []}, 38 | {tcp_listeners, [2883]}, 39 | {tcp_listen_options, [binary, 40 | {packet, raw}, 41 | {reuseaddr, true}, 42 | {backlog, 128}, 43 | {nodelay, true}]}]} 44 | ]. 45 | ``` 46 | 47 | Enable plugins by modifying running the following commands. 48 | 49 | ``` 50 | rabbitmq-plugins enable rabbitmq_management 51 | rabbitmq-plugins enable rabbitmq_mqtt 52 | rabbitmq-plugins enable rabbitmq_tracing 53 | ``` 54 | 55 | Restart rabbitmq. 56 | 57 | ``` 58 | rabbitmqctl stop 59 | rabbitmq-server -detached 60 | ``` 61 | 62 | # Port redirection for 443 -> WS port 63 | 64 | iptables -A PREROUTING -t nat -i eth0 -p tcp --dport 443 -j REDIRECT --to-port 9000 65 | 66 | # status 67 | 68 | * Preauthentication supports MySQL at the moment 69 | 70 | # Licensing 71 | 72 | mqtt-proxy is licensed under the MIT License. See LICENSE for the full license text. 73 | -------------------------------------------------------------------------------- /bin/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninjasphere/mqtt-proxy/f7ff9f29f8b15f8b204329d69f83c90c2ba92579/bin/.keep -------------------------------------------------------------------------------- /conf/conf.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "errors" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "time" 9 | 10 | "github.com/BurntSushi/toml" 11 | "github.com/davecgh/go-spew/spew" 12 | ) 13 | 14 | type MysqlConfiguration struct { 15 | ConnectionString string `toml:"connection-string"` 16 | Select string `toml:"select"` 17 | } 18 | 19 | type MqttConfiguration struct { 20 | ListenAddress string `toml:"listen-address"` 21 | Cert string `toml:"cert"` 22 | Key string `toml:"key"` 23 | } 24 | 25 | type InfluxConfiguration struct { 26 | Host string `toml:"host"` 27 | User string `toml:"user"` 28 | Pass string `toml:"pass"` 29 | Database string `toml:"database"` 30 | } 31 | 32 | type LibratoConfiguration struct { 33 | Email string `toml:"email"` 34 | Token string `toml:"token"` 35 | } 36 | 37 | type Configuration struct { 38 | BackendServers []string `toml:"backend-servers"` 39 | User string `toml:"user"` 40 | Pass string `toml:"pass"` 41 | 42 | // typically us-west | us-east 43 | // prepended to metrics 44 | Region string `toml:"region"` 45 | 46 | // typically develop | beta | prod 47 | // prepended to metrics 48 | Environment string `toml:"env"` 49 | 50 | ReadTimeout int `toml:"read-timeout"` 51 | 52 | MqttStoreMysql MysqlConfiguration `toml:"mqtt-store"` 53 | Mqtt MqttConfiguration `toml:"mqtt"` 54 | Influx InfluxConfiguration `toml:"influx"` 55 | Librato LibratoConfiguration `toml:"librato"` 56 | } 57 | 58 | func (c *Configuration) GetReadTimeout() time.Duration { 59 | return time.Second * time.Duration(c.ReadTimeout) 60 | } 61 | 62 | func (c *Configuration) envOverrides() { 63 | 64 | if backendUser := os.Getenv("BACKEND_USER"); backendUser != "" { 65 | c.User = backendUser 66 | } 67 | 68 | if backendPass := os.Getenv("BACKEND_PASS"); backendPass != "" { 69 | c.Pass = backendPass 70 | } 71 | 72 | } 73 | 74 | func (c *Configuration) validate() error { 75 | 76 | if len(c.BackendServers) == 0 { 77 | return errors.New("At least one backend servers required.") 78 | } 79 | 80 | return nil 81 | } 82 | 83 | func (c *Configuration) assignDefaults() { 84 | 85 | if c.Region == "" { 86 | c.Region = "us-east" 87 | } 88 | 89 | if c.Environment == "" { 90 | c.Environment = "develop" 91 | } 92 | 93 | if c.Mqtt.ListenAddress == "" { 94 | c.Mqtt.ListenAddress = ":1883" 95 | } 96 | 97 | if c.User == "" { 98 | c.User = "guest" 99 | } 100 | 101 | if c.Pass == "" { 102 | c.Pass = "guest" 103 | } 104 | 105 | // need a way to merge defaults.. 106 | if c.MqttStoreMysql.ConnectionString == "" { 107 | c.MqttStoreMysql.ConnectionString = "root:@tcp(127.0.0.1:3306)/mqtt" 108 | } 109 | 110 | if c.MqttStoreMysql.Select == "" { 111 | c.MqttStoreMysql.Select = "select uid, mqtt_id from users where mqtt_id = ?" 112 | } 113 | 114 | } 115 | 116 | func LoadConfiguration(fileName string) *Configuration { 117 | config, err := parseTomlConfiguration(fileName) 118 | if err != nil { 119 | log.Println("Couldn't parse configuration file: " + fileName) 120 | panic(err) 121 | } 122 | return config 123 | } 124 | 125 | func parseTomlConfiguration(filename string) (*Configuration, error) { 126 | body, err := ioutil.ReadFile(filename) 127 | if err != nil { 128 | return nil, err 129 | } 130 | tomlConfiguration := &Configuration{} 131 | _, err = toml.Decode(string(body), tomlConfiguration) 132 | if err != nil { 133 | return nil, err 134 | } 135 | log.Println(spew.Sprintf("sql = %v", tomlConfiguration)) 136 | 137 | tomlConfiguration.assignDefaults() 138 | tomlConfiguration.envOverrides() 139 | 140 | err = tomlConfiguration.validate() 141 | 142 | if err != nil { 143 | return nil, err 144 | } 145 | 146 | return tomlConfiguration, nil 147 | } 148 | -------------------------------------------------------------------------------- /conf/config_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | . "launchpad.net/gocheck" 5 | 6 | "testing" 7 | ) 8 | 9 | func Test(t *testing.T) { 10 | TestingT(t) 11 | } 12 | 13 | type LoadConfigurationSuite struct{} 14 | 15 | var _ = Suite(&LoadConfigurationSuite{}) 16 | 17 | func (self *LoadConfigurationSuite) TestConfig(c *C) { 18 | config := LoadConfiguration("config_test.toml") 19 | c.Assert(config.BackendServers, DeepEquals, []string{"hosta:8090", "hostb:8090"}) 20 | // check mqtt defaults 21 | c.Assert(config.Mqtt.ListenAddress, Equals, ":1883") 22 | c.Assert(config.User, Equals, "guest") 23 | c.Assert(config.Pass, Equals, "guest") 24 | } 25 | -------------------------------------------------------------------------------- /conf/config_test.toml: -------------------------------------------------------------------------------- 1 | backend-servers = ["hosta:8090", "hostb:8090"] 2 | 3 | [http] 4 | listen-address = ":9000" 5 | -------------------------------------------------------------------------------- /etc/example.config.toml: -------------------------------------------------------------------------------- 1 | # 2 | # Used for docker and development 3 | # 4 | # globals 5 | backend-servers = ["rabbit:1883"] 6 | user = "guest" 7 | pass = "guest" 8 | 9 | # read timeout in seconds 10 | read-timeout = 30 11 | 12 | [mqtt-store] 13 | connection-string = "douitsu:douitsu@tcp(mysql:3306)/douitsu" 14 | select = "select at.userID, at.mqtt_client_id from accesstoken at, accesstoken_scope ats where at.id = ats.accesstoken and ats.scope_domain in ('mqtt', '*') and ats.scope_item = '*' and at.id = ?" 15 | 16 | [mqtt] 17 | listen-address = ":6300" 18 | -------------------------------------------------------------------------------- /etc/supervisord.conf: -------------------------------------------------------------------------------- 1 | [supervisord] 2 | nodaemon=true 3 | loglevel=debug 4 | 5 | [program:mqtt-proxy] 6 | command=/app/mqtt-proxy -config=/etc/mqtt-proxy/config.toml 7 | redirect_stderr=true 8 | -------------------------------------------------------------------------------- /health/http.go: -------------------------------------------------------------------------------- 1 | package health 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | 8 | "github.com/ninjablocks/mqtt-proxy/conf" 9 | ) 10 | 11 | type HeathServer struct { 12 | } 13 | 14 | func StartHealthServer(conf *conf.Configuration) { 15 | http.HandleFunc("/health", HomeHandler) 16 | log.Printf("[health] listening %s", ":1880") 17 | log.Fatal(http.ListenAndServe(":1880", nil)) 18 | } 19 | 20 | func HomeHandler(w http.ResponseWriter, r *http.Request) { 21 | fmt.Fprintf(w, "OK") 22 | } 23 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "os/signal" 9 | 10 | "github.com/ninjablocks/mqtt-proxy/conf" 11 | "github.com/ninjablocks/mqtt-proxy/health" 12 | "github.com/ninjablocks/mqtt-proxy/metrics" 13 | "github.com/ninjablocks/mqtt-proxy/proxy" 14 | "github.com/ninjablocks/mqtt-proxy/tcp" 15 | ) 16 | 17 | var configFile = flag.String("config", "config.toml", "configuration file") 18 | var debug = flag.Bool("debug", false, "enable debugging") 19 | var version = flag.Bool("version", false, "show version") 20 | 21 | func main() { 22 | log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) 23 | 24 | flag.Parse() 25 | 26 | if *version { 27 | fmt.Printf("Version: %s\n", Version) 28 | os.Exit(0) 29 | } 30 | 31 | conf := conf.LoadConfiguration(*configFile) 32 | 33 | if *debug { 34 | log.Printf("[main] conf %+v", conf) 35 | } 36 | 37 | p := proxy.CreateMQTTProxy(conf) 38 | 39 | // assign the servers 40 | tcpServer := tcp.CreateTcpServer(p) 41 | 42 | go tcpServer.StartServer(&conf.Mqtt) 43 | go health.StartHealthServer(conf) 44 | 45 | metrics.StartMetricsJobs(conf) 46 | 47 | c := make(chan os.Signal, 1) 48 | signal.Notify(c, os.Interrupt, os.Kill) 49 | 50 | // Block until a signal is received. 51 | log.Printf("Got signal %s, exiting now", <-c) 52 | } 53 | -------------------------------------------------------------------------------- /metrics/console.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "log" 5 | "time" 6 | 7 | gmetrics "github.com/rcrowley/go-metrics" 8 | ) 9 | 10 | func logForever(r gmetrics.Registry, d time.Duration) { 11 | for { 12 | r.Each(func(name string, i interface{}) { 13 | switch m := i.(type) { 14 | case gmetrics.Counter: 15 | log.Printf("counter %s\n", name) 16 | log.Printf(" count: %9d\n", m.Count()) 17 | case gmetrics.Gauge: 18 | log.Printf("gauge %s\n", name) 19 | log.Printf(" value: %9d\n", m.Value()) 20 | case gmetrics.Healthcheck: 21 | m.Check() 22 | log.Printf("healthcheck %s\n", name) 23 | log.Printf(" error: %v\n", m.Error()) 24 | case gmetrics.Histogram: 25 | ps := m.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999}) 26 | log.Printf("histogram %s\n", name) 27 | log.Printf(" count: %9d\n", m.Count()) 28 | log.Printf(" min: %9d\n", m.Min()) 29 | log.Printf(" max: %9d\n", m.Max()) 30 | log.Printf(" mean: %12.2f\n", m.Mean()) 31 | log.Printf(" stddev: %12.2f\n", m.StdDev()) 32 | log.Printf(" median: %12.2f\n", ps[0]) 33 | log.Printf(" 75%%: %12.2f\n", ps[1]) 34 | log.Printf(" 95%%: %12.2f\n", ps[2]) 35 | log.Printf(" 99%%: %12.2f\n", ps[3]) 36 | log.Printf(" 99.9%%: %12.2f\n", ps[4]) 37 | case gmetrics.Meter: 38 | log.Printf("meter %s\n", name) 39 | log.Printf(" count: %9d\n", m.Count()) 40 | log.Printf(" 1-min rate: %12.2f\n", m.Rate1()) 41 | log.Printf(" 5-min rate: %12.2f\n", m.Rate5()) 42 | log.Printf(" 15-min rate: %12.2f\n", m.Rate15()) 43 | log.Printf(" mean rate: %12.2f\n", m.RateMean()) 44 | case gmetrics.Timer: 45 | ps := m.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999}) 46 | log.Printf("timer %s\n", name) 47 | log.Printf(" count: %9d\n", m.Count()) 48 | log.Printf(" min: %9d\n", m.Min()) 49 | log.Printf(" max: %9d\n", m.Max()) 50 | log.Printf(" mean: %12.2f\n", m.Mean()) 51 | log.Printf(" stddev: %12.2f\n", m.StdDev()) 52 | log.Printf(" median: %12.2f\n", ps[0]) 53 | log.Printf(" 75%%: %12.2f\n", ps[1]) 54 | log.Printf(" 95%%: %12.2f\n", ps[2]) 55 | log.Printf(" 99%%: %12.2f\n", ps[3]) 56 | log.Printf(" 99.9%%: %12.2f\n", ps[4]) 57 | log.Printf(" 1-min rate: %12.2f\n", m.Rate1()) 58 | log.Printf(" 5-min rate: %12.2f\n", m.Rate5()) 59 | log.Printf(" 15-min rate: %12.2f\n", m.Rate15()) 60 | log.Printf(" mean rate: %12.2f\n", m.RateMean()) 61 | } 62 | }) 63 | time.Sleep(d) 64 | } 65 | } 66 | 67 | func ConsoleOutput() { 68 | log.Println("starting metrics job") 69 | go logForever(gmetrics.DefaultRegistry, 10e9) 70 | } 71 | -------------------------------------------------------------------------------- /metrics/librato.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "time" 7 | 8 | "github.com/ninjablocks/mqtt-proxy/conf" 9 | "github.com/rcrowley/go-metrics" 10 | "github.com/rcrowley/go-metrics/librato" 11 | ) 12 | 13 | func UploadToLibrato(config *conf.LibratoConfiguration) { 14 | if config.Email != "" { 15 | hostname, err := os.Hostname() 16 | if err != nil { 17 | log.Fatalf("Unable to retrieve a hostname %s", err) 18 | } 19 | 20 | go librato.Librato(metrics.DefaultRegistry, 21 | 30e9, // interval 22 | config.Email, // account email addres 23 | config.Token, // auth token 24 | hostname, // source 25 | []float64{95}, // precentiles to send 26 | time.Millisecond) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /metrics/proxy.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/ninjablocks/mqtt-proxy/conf" 7 | gmetrics "github.com/rcrowley/go-metrics" 8 | ) 9 | 10 | // Record proxy related meters to enable monitoring 11 | // of throughput and volume. 12 | type ProxyMetrics struct { 13 | Msgs gmetrics.Meter 14 | MsgReply gmetrics.Meter 15 | MsgForward gmetrics.Meter 16 | MsgBodySize gmetrics.Histogram 17 | Connects gmetrics.Meter 18 | Connections gmetrics.Gauge 19 | } 20 | 21 | // conf.Environment, conf.Region 22 | func NewProxyMetrics(env string, region string) ProxyMetrics { 23 | 24 | prefix := buildPrefix(env, region) 25 | 26 | pm := ProxyMetrics{ 27 | Msgs: gmetrics.NewMeter(), 28 | MsgReply: gmetrics.NewMeter(), 29 | MsgForward: gmetrics.NewMeter(), 30 | MsgBodySize: gmetrics.NewHistogram(gmetrics.NewExpDecaySample(1028, 0.015)), 31 | Connects: gmetrics.NewMeter(), 32 | Connections: gmetrics.NewGauge(), 33 | } 34 | 35 | gmetrics.Register(prefix+".proxy.msgs", pm.Msgs) 36 | gmetrics.Register(prefix+".proxy.msg_reply", pm.Msgs) 37 | gmetrics.Register(prefix+".proxy.msg_forward", pm.Msgs) 38 | gmetrics.Register(prefix+".proxy.msg_body_size", pm.MsgBodySize) 39 | gmetrics.Register(prefix+".proxy.connects", pm.Connects) 40 | gmetrics.Register(prefix+".proxy.connections", pm.Connections) 41 | 42 | return pm 43 | } 44 | 45 | func buildPrefix(env string, region string) string { 46 | return strings.Join([]string{region, env, "mqtt-proxy"}, ".") 47 | } 48 | 49 | func StartMetricsJobs(config *conf.Configuration) { 50 | StartRuntimeMetricsJob(config.Environment, config.Region) 51 | UploadToLibrato(&config.Librato) 52 | } 53 | -------------------------------------------------------------------------------- /metrics/runtime.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "os" 5 | "runtime" 6 | "time" 7 | 8 | sigar "github.com/cloudfoundry/gosigar" 9 | gmetrics "github.com/rcrowley/go-metrics" 10 | ) 11 | 12 | type MetricsGroup interface { 13 | Update() 14 | } 15 | 16 | type RuntimeMetrics struct { 17 | Internals, Memory, Cpu, Load MetricsGroup 18 | } 19 | 20 | func NewRuntimeMetrics(prefix string) *RuntimeMetrics { 21 | return &RuntimeMetrics{ 22 | Internals: NewGoInternalMetrics(prefix), 23 | Memory: NewProcessMemoryMetrics(prefix), 24 | Cpu: NewProcessCpuMetrics(prefix), 25 | Load: NewLoadMetrics(prefix), 26 | } 27 | } 28 | 29 | func (rm *RuntimeMetrics) Update() { 30 | rm.Internals.Update() 31 | rm.Memory.Update() 32 | rm.Cpu.Update() 33 | rm.Load.Update() 34 | } 35 | 36 | // system load 37 | type LoadMetrics struct { 38 | One, Five, Fifteen gmetrics.GaugeFloat64 39 | } 40 | 41 | func NewLoadMetrics(prefix string) *LoadMetrics { 42 | 43 | load := &LoadMetrics{ 44 | One: gmetrics.NewGaugeFloat64(), 45 | Five: gmetrics.NewGaugeFloat64(), 46 | Fifteen: gmetrics.NewGaugeFloat64(), 47 | } 48 | 49 | gmetrics.Register(prefix+".load.1min", load.One) 50 | gmetrics.Register(prefix+".load.5min", load.Five) 51 | gmetrics.Register(prefix+".load.15min", load.Fifteen) 52 | 53 | return load 54 | } 55 | 56 | func (lm *LoadMetrics) Update() { 57 | 58 | load := sigar.LoadAverage{} 59 | 60 | err := load.Get() 61 | 62 | if err == nil { 63 | lm.One.Update(load.One) 64 | lm.Five.Update(load.Five) 65 | lm.Fifteen.Update(load.Fifteen) 66 | } 67 | 68 | } 69 | 70 | // process memory 71 | type ProcessMemoryMetrics struct { 72 | Resident, Shared gmetrics.Gauge 73 | PageFaults gmetrics.Meter 74 | } 75 | 76 | func NewProcessMemoryMetrics(prefix string) *ProcessMemoryMetrics { 77 | 78 | mem := &ProcessMemoryMetrics{ 79 | Resident: gmetrics.NewGauge(), 80 | Shared: gmetrics.NewGauge(), 81 | PageFaults: gmetrics.NewMeter(), 82 | } 83 | 84 | gmetrics.Register(prefix+".mem.resident", mem.Resident) 85 | gmetrics.Register(prefix+".mem.shared", mem.Shared) 86 | gmetrics.Register(prefix+".mem.pagefaults", mem.PageFaults) 87 | 88 | return mem 89 | } 90 | func (pmm *ProcessMemoryMetrics) Update() { 91 | pid := os.Getpid() 92 | 93 | mem := sigar.ProcMem{} 94 | 95 | err := mem.Get(pid) 96 | 97 | if err == nil { 98 | pmm.Resident.Update(int64(mem.Resident)) 99 | pmm.Shared.Update(int64(mem.Share)) 100 | 101 | updateMeter(pmm.PageFaults, mem.PageFaults) 102 | } 103 | } 104 | 105 | // process cpu 106 | type ProcessCpuMetrics struct { 107 | User, Sys, Total gmetrics.Meter 108 | } 109 | 110 | func NewProcessCpuMetrics(prefix string) *ProcessCpuMetrics { 111 | cpu := &ProcessCpuMetrics{ 112 | User: gmetrics.NewMeter(), 113 | Sys: gmetrics.NewMeter(), 114 | Total: gmetrics.NewMeter(), 115 | } 116 | 117 | gmetrics.Register(prefix+".cpu.user", cpu.User) 118 | gmetrics.Register(prefix+".cpu.sys", cpu.Sys) 119 | gmetrics.Register(prefix+".cpu.total", cpu.Total) 120 | 121 | return cpu 122 | } 123 | 124 | func (pcm *ProcessCpuMetrics) Update() { 125 | pid := os.Getpid() 126 | cpu := sigar.ProcTime{} 127 | 128 | err := cpu.Get(pid) 129 | 130 | if err == nil { 131 | updateMeter(pcm.User, cpu.User) 132 | updateMeter(pcm.Sys, cpu.Sys) 133 | updateMeter(pcm.Total, cpu.Total) 134 | } 135 | 136 | } 137 | 138 | // golang internals 139 | type GoInternalMetrics struct { 140 | Alloc, TotalAlloc, NumGoroutine gmetrics.Gauge 141 | } 142 | 143 | func NewGoInternalMetrics(prefix string) *GoInternalMetrics { 144 | goint := &GoInternalMetrics{ 145 | Alloc: gmetrics.NewGauge(), 146 | TotalAlloc: gmetrics.NewGauge(), 147 | NumGoroutine: gmetrics.NewGauge(), 148 | } 149 | 150 | gmetrics.Register(prefix+".goint.alloc", goint.Alloc) 151 | gmetrics.Register(prefix+".goint.total_alloc", goint.TotalAlloc) 152 | gmetrics.Register(prefix+".goint.go_routines", goint.NumGoroutine) 153 | 154 | return goint 155 | } 156 | 157 | func (gim *GoInternalMetrics) Update() { 158 | 159 | ms := &runtime.MemStats{} 160 | runtime.ReadMemStats(ms) 161 | 162 | gim.Alloc.Update(int64(ms.Alloc)) 163 | gim.TotalAlloc.Update(int64(ms.TotalAlloc)) 164 | gim.NumGoroutine.Update(int64(runtime.NumGoroutine())) 165 | 166 | } 167 | 168 | func updateMeter(meter gmetrics.Meter, newValue uint64) { 169 | va := int64(newValue) - meter.Count() 170 | meter.Mark(int64(va)) 171 | } 172 | 173 | func StartRuntimeMetricsJob(env string, region string) { 174 | 175 | prefix := buildPrefix(env, region) 176 | 177 | rm := NewRuntimeMetrics(prefix) 178 | 179 | ticker := time.NewTicker(time.Second * 2) 180 | go func() { 181 | for _ = range ticker.C { 182 | rm.Update() 183 | } 184 | }() 185 | } 186 | -------------------------------------------------------------------------------- /proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/ninjablocks/mqtt-proxy/conf" 8 | "github.com/ninjablocks/mqtt-proxy/metrics" 9 | "github.com/ninjablocks/mqtt-proxy/rewrite" 10 | "github.com/ninjablocks/mqtt-proxy/store" 11 | ) 12 | 13 | type ProxyConn interface { 14 | Id() string 15 | Close() 16 | } 17 | 18 | type MQTTProxy struct { 19 | Conf *conf.Configuration 20 | connections map[string]net.Conn 21 | Metrics metrics.ProxyMetrics 22 | } 23 | 24 | func CreateMQTTProxy(conf *conf.Configuration) *MQTTProxy { 25 | p := &MQTTProxy{ 26 | Conf: conf, 27 | Metrics: metrics.NewProxyMetrics(conf.Environment, conf.Region), 28 | connections: make(map[string]net.Conn), 29 | } 30 | 31 | return p 32 | } 33 | 34 | func (p *MQTTProxy) RegisterSession(conn net.Conn) { 35 | id := fmt.Sprintf("%s %s", conn.RemoteAddr(), conn.LocalAddr()) 36 | p.connections[id] = conn 37 | p.Metrics.Connections.Update(int64(len(p.connections))) 38 | } 39 | 40 | func (p *MQTTProxy) UnRegisterSession(conn net.Conn) { 41 | id := fmt.Sprintf("%s %s", conn.RemoteAddr(), conn.LocalAddr()) 42 | delete(p.connections, id) 43 | p.Metrics.Connections.Update(int64(len(p.connections))) 44 | } 45 | 46 | func (p *MQTTProxy) mqttCredentialsRewriter(user *store.User) rewrite.CredentialsRewriter { 47 | return rewrite.NewCredentialsReplaceRewriter(p.Conf.User, p.Conf.Pass, user.UserId, user.MqttId) 48 | } 49 | 50 | func (p *MQTTProxy) mqttTopicRewriter(mqttId string, direction int) rewrite.TopicRewriter { 51 | return rewrite.NewTopicPartRewriter(mqttId, direction) 52 | } 53 | 54 | func (p *MQTTProxy) MqttMsgRewriter(user *store.User) *rewrite.MsgRewriter { 55 | return rewrite.CreatMsgRewriter(p.mqttCredentialsRewriter(user), p.mqttTopicRewriter(user.UserId, rewrite.INGRESS), p.mqttTopicRewriter(user.UserId, rewrite.EGRESS)) 56 | } 57 | -------------------------------------------------------------------------------- /rewrite/auth_rewriter.go: -------------------------------------------------------------------------------- 1 | package rewrite 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/wolfeidau/mqtt" 8 | ) 9 | 10 | type CredentialsRewriter interface { 11 | RewriteCredentials(msg *mqtt.Connect) *mqtt.Connect 12 | } 13 | 14 | type CredentialsReplaceRewriter struct { 15 | User string 16 | Pass string 17 | UserId string 18 | MqttId string 19 | } 20 | 21 | func NewCredentialsReplaceRewriter(user string, pass string, uid string, mqttId string) *CredentialsReplaceRewriter { 22 | return &CredentialsReplaceRewriter{ 23 | User: user, 24 | Pass: pass, 25 | UserId: uid, 26 | MqttId: mqttId, 27 | } 28 | } 29 | 30 | func (crr *CredentialsReplaceRewriter) RewriteCredentials(msg *mqtt.Connect) *mqtt.Connect { 31 | 32 | if crr.User != "" { 33 | msg.UsernameFlag = true 34 | msg.Username = crr.User 35 | } 36 | 37 | if crr.Pass != "" { 38 | msg.PasswordFlag = true 39 | msg.Password = crr.Pass 40 | } 41 | 42 | msg.ClientId = fmt.Sprintf("%s", crr.MqttId) 43 | 44 | log.Printf("[creds] connecting ClientId %s", msg.ClientId) 45 | 46 | return msg 47 | } 48 | -------------------------------------------------------------------------------- /rewrite/auth_rewriter_test.go: -------------------------------------------------------------------------------- 1 | package rewrite 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/wolfeidau/mqtt" 7 | 8 | . "launchpad.net/gocheck" 9 | ) 10 | 11 | func Test4(t *testing.T) { TestingT(t) } 12 | 13 | type CredentialsRewriterSuite struct { 14 | credentialsRewriter CredentialsRewriter 15 | } 16 | 17 | var _ = Suite(&CredentialsRewriterSuite{}) 18 | 19 | func (s *CredentialsRewriterSuite) SetUpTest(c *C) { 20 | s.credentialsRewriter = NewCredentialsReplaceRewriter("user", "pass", "1", "1") 21 | } 22 | 23 | func (s *CredentialsRewriterSuite) TestCredsRewrite(c *C) { 24 | 25 | // connection request message 26 | connect := createConnectMessage("bob", "11223344", true, true, "abc") 27 | expectedConnect := createConnectMessage("user", "pass", true, true, "1") 28 | 29 | modConnect := s.credentialsRewriter.RewriteCredentials(connect) 30 | c.Assert(modConnect, DeepEquals, expectedConnect) 31 | 32 | } 33 | 34 | func (s *CredentialsRewriterSuite) TestCredsRewriteJustUser(c *C) { 35 | 36 | // connection request message 37 | connect := createConnectMessage("bob", "", true, false, "abc") 38 | expectedConnect := createConnectMessage("user", "pass", true, true, "1") 39 | 40 | modConnect := s.credentialsRewriter.RewriteCredentials(connect) 41 | c.Assert(modConnect, DeepEquals, expectedConnect) 42 | 43 | } 44 | 45 | func createConnectMessage(user string, pass string, userFlag bool, passFlag bool, clientId string) *mqtt.Connect { 46 | return &mqtt.Connect{ 47 | ProtocolName: "MQIsdp", 48 | ProtocolVersion: 3, 49 | UsernameFlag: userFlag, 50 | PasswordFlag: passFlag, 51 | WillRetain: false, 52 | WillQos: 1, 53 | WillFlag: true, 54 | CleanSession: true, 55 | KeepAliveTimer: 10, 56 | ClientId: clientId, 57 | WillTopic: "topic", 58 | WillMessage: "message", 59 | Username: user, 60 | Password: pass, 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /rewrite/message_rewriter.go: -------------------------------------------------------------------------------- 1 | package rewrite 2 | 3 | import "github.com/wolfeidau/mqtt" 4 | 5 | type MsgRewriter struct { 6 | CredentialsRewriter CredentialsRewriter 7 | IngressRewriter TopicRewriter 8 | EgressRewriter TopicRewriter 9 | } 10 | 11 | func CreatMsgRewriter(credentialsRewriter CredentialsRewriter, ingressRewriter TopicRewriter, egressRewriter TopicRewriter) *MsgRewriter { 12 | return &MsgRewriter{ 13 | CredentialsRewriter: credentialsRewriter, 14 | IngressRewriter: ingressRewriter, 15 | EgressRewriter: egressRewriter, 16 | } 17 | } 18 | 19 | func (mr *MsgRewriter) RewriteIngress(msg mqtt.Message) mqtt.Message { 20 | 21 | // log.Printf("[ingress] msg: %s %v", reflect.TypeOf(msg), msg) 22 | 23 | switch msg := msg.(type) { 24 | case *mqtt.Connect: 25 | msg = mr.CredentialsRewriter.RewriteCredentials(msg) 26 | case *mqtt.Publish: 27 | msg.TopicName = mr.IngressRewriter.RewriteTopicName(msg.TopicName) 28 | case *mqtt.Subscribe: 29 | msg.Topics = mr.IngressRewriter.RewriteTopics(msg.Topics) 30 | case *mqtt.Unsubscribe: 31 | msg.Topics = mr.IngressRewriter.RenameTopicNames(msg.Topics) 32 | } 33 | return msg 34 | } 35 | 36 | func (mr *MsgRewriter) RewriteEgress(msg mqtt.Message) mqtt.Message { 37 | 38 | // log.Printf("[egress] msg: %s %v", reflect.TypeOf(msg), msg) 39 | 40 | switch msg := msg.(type) { 41 | case *mqtt.Publish: 42 | msg.TopicName = mr.EgressRewriter.RewriteTopicName(msg.TopicName) 43 | } 44 | return msg 45 | } 46 | -------------------------------------------------------------------------------- /rewrite/message_rewriter_test.go: -------------------------------------------------------------------------------- 1 | package rewrite 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/wolfeidau/mqtt" 7 | . "launchpad.net/gocheck" 8 | ) 9 | 10 | func Test2(t *testing.T) { TestingT(t) } 11 | 12 | type MessagRewriterSuite struct { 13 | msgRewriter *MsgRewriter 14 | } 15 | 16 | var _ = Suite(&MessagRewriterSuite{}) 17 | 18 | func (s *MessagRewriterSuite) SetUpTest(c *C) { 19 | s.msgRewriter = &MsgRewriter{ 20 | CredentialsRewriter: &CredentialsReplaceRewriter{ 21 | User: "guest", 22 | Pass: "123", 23 | UserId: "1", 24 | MqttId: "1", 25 | }, 26 | IngressRewriter: &TopicPartRewriter{ 27 | Token: "123", 28 | Direction: INGRESS, 29 | }, 30 | EgressRewriter: &TopicPartRewriter{ 31 | Token: "123", 32 | Direction: EGRESS, 33 | }, 34 | } 35 | } 36 | 37 | func (s *MessagRewriterSuite) TestIngressMsgPublish(c *C) { 38 | 39 | // client publish a message to a topic 40 | pub := createPublish("$cloud/456/789") 41 | expectedPub := createPublish("123/$cloud/456/789") 42 | 43 | modPub := s.msgRewriter.RewriteIngress(pub) 44 | c.Assert(modPub, DeepEquals, expectedPub) 45 | 46 | } 47 | 48 | func (s *MessagRewriterSuite) TestIngressMsgSubscribe(c *C) { 49 | 50 | // client subscribe a message to a topic 51 | sub := createSubscribe("$cloud/456/789") 52 | expectedSub := createSubscribe("123/$cloud/456/789") 53 | 54 | modSub := s.msgRewriter.RewriteIngress(sub) 55 | c.Assert(modSub, DeepEquals, expectedSub) 56 | 57 | } 58 | 59 | func (s *MessagRewriterSuite) TestConnect(c *C) { 60 | 61 | // connection request message 62 | connect := createConnect("bob", "11223344", "abc") 63 | expectedConnect := createConnect("guest", "123", "1") 64 | 65 | modConnect := s.msgRewriter.RewriteIngress(connect) 66 | c.Assert(modConnect, DeepEquals, expectedConnect) 67 | } 68 | 69 | func (s *MessagRewriterSuite) TestIngressMsgUnsubscribe(c *C) { 70 | 71 | // client unsubscribe to a topic 72 | unsub := createUnsubscribe("$cloud/456/789") 73 | expectedUnsub := createUnsubscribe("123/$cloud/456/789") 74 | 75 | modUnsub := s.msgRewriter.RewriteIngress(unsub) 76 | c.Assert(modUnsub, DeepEquals, expectedUnsub) 77 | } 78 | 79 | func (s *MessagRewriterSuite) TestEgressMsgPublish(c *C) { 80 | 81 | // client publish a message to a topic 82 | pub := createPublish("123/$block/456/789") 83 | expectedPub := createPublish("$block/456/789") 84 | 85 | modPub := s.msgRewriter.RewriteEgress(pub) 86 | c.Assert(modPub, DeepEquals, expectedPub) 87 | 88 | } 89 | 90 | func createConnect(user string, pass string, clientId string) mqtt.Message { 91 | return &mqtt.Connect{ 92 | ProtocolName: "MQIsdp", 93 | ProtocolVersion: 3, 94 | UsernameFlag: true, 95 | PasswordFlag: true, 96 | WillRetain: false, 97 | WillQos: 1, 98 | WillFlag: true, 99 | CleanSession: true, 100 | KeepAliveTimer: 10, 101 | ClientId: clientId, 102 | WillTopic: "topic", 103 | WillMessage: "message", 104 | Username: user, 105 | Password: pass, 106 | } 107 | } 108 | 109 | func createPublish(topic string) mqtt.Message { 110 | return &mqtt.Publish{ 111 | Header: mqtt.Header{ 112 | DupFlag: false, 113 | QosLevel: mqtt.QosAtMostOnce, 114 | Retain: false, 115 | }, 116 | TopicName: topic, 117 | Payload: mqtt.BytesPayload{1, 2, 3}, 118 | } 119 | } 120 | 121 | func createSubscribe(topic string) mqtt.Message { 122 | return &mqtt.Subscribe{ 123 | Header: mqtt.Header{ 124 | DupFlag: false, 125 | QosLevel: mqtt.QosAtLeastOnce, 126 | }, 127 | MessageId: 0x4321, 128 | Topics: []mqtt.TopicQos{ 129 | {topic, mqtt.QosExactlyOnce}, 130 | }, 131 | } 132 | } 133 | 134 | func createUnsubscribe(topic string) mqtt.Message { 135 | return &mqtt.Unsubscribe{ 136 | Header: mqtt.Header{ 137 | DupFlag: false, 138 | QosLevel: mqtt.QosAtLeastOnce, 139 | }, 140 | MessageId: 0x4321, 141 | Topics: []string{topic}, 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /rewrite/topic_rewriter.go: -------------------------------------------------------------------------------- 1 | package rewrite 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | 8 | "github.com/wolfeidau/mqtt" 9 | ) 10 | 11 | var INGRESS = 1 12 | var EGRESS = 2 13 | 14 | type TopicRewriter interface { 15 | RewriteTopicName(topic string) string 16 | RenameTopicNames(topicNames []string) []string 17 | RewriteTopics(topics []mqtt.TopicQos) []mqtt.TopicQos 18 | } 19 | 20 | // rewriter which inserts a partition after the first token in the topic. 21 | type TopicPartRewriter struct { 22 | Token string 23 | Direction int 24 | } 25 | 26 | func NewTopicPartRewriter(token string, direction int) *TopicPartRewriter { 27 | return &TopicPartRewriter{ 28 | Token: token, 29 | Direction: direction, 30 | } 31 | } 32 | 33 | func (tppw *TopicPartRewriter) RewriteTopicName(topic string) string { 34 | switch tppw.Direction { 35 | case INGRESS: 36 | return insertToken(topic, tppw.Token) 37 | case EGRESS: 38 | return removeToken(topic, tppw.Token) 39 | } 40 | return topic 41 | } 42 | 43 | func (tppw *TopicPartRewriter) RenameTopicNames(topicNames []string) []string { 44 | for i := range topicNames { 45 | topicNames[i] = tppw.RewriteTopicName(topicNames[i]) 46 | } 47 | return topicNames 48 | } 49 | 50 | func (tppw *TopicPartRewriter) RewriteTopics(topics []mqtt.TopicQos) []mqtt.TopicQos { 51 | for i := range topics { 52 | topics[i].Topic = tppw.RewriteTopicName(topics[i].Topic) 53 | } 54 | return topics 55 | } 56 | 57 | func insertToken(topic string, token string) string { 58 | return fmt.Sprintf("%s/%s", token, topic) 59 | } 60 | 61 | func removeToken(topic string, token string) string { 62 | tokens := strings.Split(topic, "/") 63 | 64 | if tokens[0] == token { 65 | return strings.Join(tokens[1:], "/") 66 | } 67 | 68 | log.Printf("[topic] token not found %d %s", topic, token) 69 | return topic 70 | } 71 | -------------------------------------------------------------------------------- /rewrite/topic_rewriter_test.go: -------------------------------------------------------------------------------- 1 | package rewrite 2 | 3 | import ( 4 | "testing" 5 | 6 | . "launchpad.net/gocheck" 7 | ) 8 | 9 | func Test3(t *testing.T) { TestingT(t) } 10 | 11 | type TopicRewriterSuite struct { 12 | egressPartRewriter TopicRewriter 13 | ingressPartRewriter TopicRewriter 14 | } 15 | 16 | var _ = Suite(&TopicRewriterSuite{}) 17 | 18 | func (s *TopicRewriterSuite) SetUpTest(c *C) { 19 | 20 | s.egressPartRewriter = &TopicPartRewriter{ 21 | Token: "123", 22 | Direction: EGRESS, 23 | } 24 | 25 | s.ingressPartRewriter = &TopicPartRewriter{ 26 | Token: "123", 27 | Direction: INGRESS, 28 | } 29 | } 30 | 31 | func (s *TopicRewriterSuite) TestPartTopicName(c *C) { 32 | 33 | topicName := s.ingressPartRewriter.RewriteTopicName("$cloud/test/123") 34 | c.Assert(topicName, Equals, "123/$cloud/test/123") 35 | 36 | topicName = s.ingressPartRewriter.RewriteTopicName("cloud") 37 | c.Assert(topicName, Equals, "123/cloud") 38 | } 39 | -------------------------------------------------------------------------------- /ssl/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC6TCCAdOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD 3 | bzAeFw0xNDAyMDYyMzU0MTNaFw0xNTAyMDYyMzU0MTNaMBIxEDAOBgNVBAoTB0Fj 4 | bWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC191oEsvb1pCsu 5 | yBWj1X62rs3gjKDj74HucpyqS3m/KkhltyVujVapPUrEJNXlwqZOBMsTFMN5XKjK 6 | cmaTDD7VlBatmAXNRgxeaWUUxY8MHXWyxvbKz0NXWtrytWye+iFZVxUNfly0t9hi 7 | ir6BcxDvv2lffDLKbaHsikkrQ2KclWRTOPwWxnURjE3deGQMQjSiHCg5/nk6asW+ 8 | jGf5tXipOGxIiyLbILDZFKkryREPCFvSp5pGq1cEvyowhRPuPC/Gk4NLzskFeKu8 9 | YbH/aeUPCjrXjsVqnHWBp65H26XvssvU69P2LDa8NrhpI5ZYGeAeO/VoRCfQ3nwa 10 | X3npX+MZAgMBAAGjTjBMMA4GA1UdDwEB/wQEAwIAoDATBgNVHSUEDDAKBggrBgEF 11 | BQcDATAMBgNVHRMBAf8EAjAAMBcGA1UdEQQQMA6CDGZsaXBweS5sb2NhbDALBgkq 12 | hkiG9w0BAQUDggEBAI+k0XHyIxlk8oWfJeKY0scmBLjBlsEnxmmRjt0nrBckDk0T 13 | 3vm2zqpKNBu+POpk7rUv/cCbGjLwmYU+xc15ESzY4Im1+xb6LufhYMos8eUSlRUa 14 | 0qNpt2Hs3I3gZYalw1qgnwqvr8qlUH+RsIbr7IFmcaYNCtym1UHG5bZBjuAot4em 15 | gRGLYaPOJgxa9Pov3tfEi5CIu0kdp4aTRfUt/rT9dITSbceBh4HFYFulKuNWhPCQ 16 | eecuguaFfpBr425U4U3eEee38egN9P0I9ZPKiZKY5l/a1MeaZcrgEfNZTSH52Y93 17 | Ys7igm+6NIQT4+2pOsFXQ8R043FYrw3OAhqqmVk= 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /ssl/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAtfdaBLL29aQrLsgVo9V+tq7N4Iyg4++B7nKcqkt5vypIZbcl 3 | bo1WqT1KxCTV5cKmTgTLExTDeVyoynJmkww+1ZQWrZgFzUYMXmllFMWPDB11ssb2 4 | ys9DV1ra8rVsnvohWVcVDX5ctLfYYoq+gXMQ779pX3wyym2h7IpJK0NinJVkUzj8 5 | FsZ1EYxN3XhkDEI0ohwoOf55OmrFvoxn+bV4qThsSIsi2yCw2RSpK8kRDwhb0qea 6 | RqtXBL8qMIUT7jwvxpODS87JBXirvGGx/2nlDwo6147Fapx1gaeuR9ul77LL1OvT 7 | 9iw2vDa4aSOWWBngHjv1aEQn0N58Gl956V/jGQIDAQABAoIBAQCWcEcl92ehMfbZ 8 | mGX7qzg1hFOFP/6MM6kyH+NSD1A6MZTlMwNpYMSy6o9zlhiY+dJQUjoqLlJldau+ 9 | o9IV5FvWa7ZMEpFJYo47R9tfzu0y0PBLu56xkaVVBTJa0o9Y5+bGW+5113CBoDTv 10 | U1Go6B6qd4+Ad8ft/7GNQ862S+GtHZdGHPnw6g2pXUGd8f3XCD4hRXGUe7PT3I6L 11 | kSUpa4+KbKAFOClpZ652XP9sRDh/X3HgrPbat77vX2P23F9hRbQ83lKPzSqJd+4Y 12 | jsL+mZ0qUkq6qhWNCPd+JIuWqqp6aX4m4tytMFxdR57W/7zGM03QUkMTlEGR77LN 13 | 5YwLotEBAoGBAOsA7lrfIkDDSuzAk46n+Yd1FkFUTSU3losxM07KZIaPEQ8EsN8w 14 | yXzoBmnAk7uYScCiePlusJuEweD/vuGGnbbT0cVrQYdYzKm7F8+/wGBm12cxDPw5 15 | iWHitLwGJsHheKvlUfVRUK36z1ppcDZ37iPfoXgdhBAL5P+iHvQKOE2pAoGBAMY5 16 | VXfdxBgXgmNgBqLLJvbRtC1waECwSN+BHLsN6vOHdPZs6Q8405Rf+l3g4Hj30T6S 17 | YPWSEXXeqeo4dIXH0hRHVBAA2hcED1XBPCVpBTWGruoUTtBGDN2hQSmFl3C8kqpG 18 | 2qr7iyhhYG46+bpk3LoVFNiaw3hrW5Hk2uIJru/xAoGAVAiaMx51Nilfgnd+jFWe 19 | kgSZd7T4fSV6jL2ENlmDRuaj1/X6dWURt7uUh35YlY1oWhz/G1qshoAbgCSTkju7 20 | 6+OksG6hGQ/054DCjARqe05rGjhdB1hfuwQBUvb0JwJET1uKSinQqtX0DcWEXcXW 21 | /zb5m2Uak05djdfgL63z4jkCgYAl5TYauVHQzUXHG8eI/c+QJh0NBs1XeJwl5ngI 22 | pquBLSdGKSIRH+sLFaI2qlQfrDjbfn581BT0dMIFHg/gt4fJCq3edVs8RTFtUoje 23 | Ggq95eawp3s9w/aXtElR19FQ4ywi03Lgd0BuUtdtm2a8pKWyCW+3zTaLYfLanGbg 24 | Cbvg8QKBgQCueiu7ZI4UokmnlJ8HFVcJ1NWXd1JpH5JxtS6zgR6DH8o7VY2vB3z9 25 | 1fZ2tryuwJwQL/Opx38+UP+YELB/B45Oq5UhHl1SdAv4cHJlV47r38jqZ9+VFEny 26 | DclrGNUsWzUIYd1lfYXXMbLDckKARlcz0OilEGIfs3rvLUDdThU3Dw== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /store/mysql.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import ( 4 | "database/sql" 5 | "log" 6 | 7 | _ "github.com/go-sql-driver/mysql" 8 | "github.com/ninjablocks/mqtt-proxy/conf" 9 | ) 10 | 11 | type MysqlStore struct { 12 | db *sql.DB 13 | conf *conf.MysqlConfiguration 14 | } 15 | 16 | func NewMysqlStore(conf *conf.MysqlConfiguration) *MysqlStore { 17 | db, err := sql.Open("mysql", conf.ConnectionString) 18 | if err != nil { 19 | log.Fatal(err) 20 | } 21 | err = db.Ping() 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | // defer db.Close() 26 | return &MysqlStore{ 27 | db: db, 28 | conf: conf, 29 | } 30 | } 31 | 32 | // Sends a PING request to Redis. 33 | func (s *MysqlStore) Health() bool { 34 | // defer s.db.Close() 35 | err := s.db.Ping() 36 | if err != nil { 37 | return false 38 | } 39 | return true 40 | } 41 | 42 | // Validates the credentials against MySQL. 43 | func (s *MysqlStore) FindUser(token string) (*User, error) { 44 | 45 | var uid string 46 | var mqttId string 47 | err := s.db.QueryRow(s.conf.Select, token).Scan(&uid, &mqttId) 48 | if err != nil { 49 | if err == sql.ErrNoRows { 50 | return nil, ErrUserNotFound 51 | } 52 | return nil, err 53 | } 54 | return &User{ 55 | UserId: uid, 56 | MqttId: mqttId, 57 | }, nil 58 | } 59 | 60 | // Sends a PING request to Redis. 61 | func (s *MysqlStore) Close() { 62 | s.db.Close() 63 | } 64 | -------------------------------------------------------------------------------- /store/store.go: -------------------------------------------------------------------------------- 1 | package store 2 | 3 | import "errors" 4 | 5 | var ErrUserNotFound = errors.New("User not found in store") 6 | 7 | type Store interface { 8 | Health() bool 9 | FindUser(token string) (*User, error) 10 | Close() 11 | } 12 | 13 | type User struct { 14 | UserId string `json:"uid"` 15 | MqttId string `json:"mqttId"` 16 | } 17 | -------------------------------------------------------------------------------- /tcp/conn.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "net" 8 | 9 | "github.com/ninjablocks/mqtt-proxy/rewrite" 10 | ) 11 | 12 | type TcpProxyConn struct { 13 | 14 | // proxy connection 15 | pConn net.Conn 16 | 17 | // client connection 18 | cConn net.Conn 19 | 20 | id string 21 | 22 | rewriter *rewrite.MsgRewriter 23 | } 24 | 25 | func CreateTcpProxyConn(conn net.Conn, backend string) (*TcpProxyConn, error) { 26 | 27 | addr, err := net.ResolveTCPAddr("tcp", backend) 28 | 29 | if err != nil { 30 | return nil, errors.New(fmt.Sprintf("[serv] Error resolving upstream: %s", err)) 31 | } 32 | log.Printf("Opening connection to %s", addr) 33 | tcpconn, err := net.DialTCP("tcp", nil, addr) 34 | 35 | if err != nil { 36 | return nil, errors.New(fmt.Sprintf("[serv] Error connecting to upstream: %s", err)) 37 | } 38 | 39 | return &TcpProxyConn{cConn: conn, pConn: tcpconn, id: fmt.Sprintf("%s %s", conn.RemoteAddr(), conn.LocalAddr())}, nil 40 | 41 | } 42 | 43 | func (c *TcpProxyConn) Id() string { 44 | return c.id 45 | } 46 | 47 | func (c *TcpProxyConn) Close() { 48 | c.pConn.Close() 49 | } 50 | -------------------------------------------------------------------------------- /tcp/server.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net" 10 | "reflect" 11 | 12 | "github.com/davecgh/go-spew/spew" 13 | "github.com/ninjablocks/mqtt-proxy/conf" 14 | "github.com/ninjablocks/mqtt-proxy/proxy" 15 | "github.com/ninjablocks/mqtt-proxy/store" 16 | "github.com/ninjablocks/mqtt-proxy/util" 17 | "github.com/wolfeidau/mqtt" 18 | ) 19 | 20 | type TcpServer struct { 21 | proxy *proxy.MQTTProxy 22 | store store.Store 23 | } 24 | 25 | func CreateTcpServer(proxy *proxy.MQTTProxy) *TcpServer { 26 | 27 | store := store.NewMysqlStore(&proxy.Conf.MqttStoreMysql) 28 | 29 | return &TcpServer{ 30 | proxy: proxy, 31 | store: store, 32 | } 33 | } 34 | 35 | func (t *TcpServer) StartServer(conf *conf.MqttConfiguration) { 36 | 37 | log.Printf("[tcp] listening on %s", conf.ListenAddress) 38 | 39 | listener, err := t.startListener(conf) 40 | 41 | if err != nil { 42 | log.Fatalln("error listening:", err.Error()) 43 | } 44 | 45 | for { 46 | conn, err := listener.Accept() 47 | if err != nil { 48 | log.Printf("Client error: %s", err) 49 | } else { 50 | go t.clientHandler(conn) 51 | } 52 | } 53 | 54 | } 55 | 56 | func (t *TcpServer) startListener(conf *conf.MqttConfiguration) (net.Listener, error) { 57 | if conf.Cert != "" { 58 | cert, err := tls.LoadX509KeyPair(conf.Cert, conf.Key) 59 | 60 | if err != nil { 61 | log.Fatalf("server: loadkeys: %s", err) 62 | } 63 | log.Println("[serv] Starting tls listener") 64 | 65 | config := tls.Config{Certificates: []tls.Certificate{cert}} 66 | 67 | return tls.Listen("tcp", conf.ListenAddress, &config) 68 | } else { 69 | log.Println("[serv] Starting tcp listener") 70 | return net.Listen("tcp", conf.ListenAddress) 71 | } 72 | 73 | } 74 | 75 | func (t *TcpServer) clientHandler(conn net.Conn) { 76 | 77 | log.Printf("[serv] client connection opened - %s", conn.RemoteAddr()) 78 | 79 | defer conn.Close() 80 | 81 | t.proxy.RegisterSession(conn) 82 | defer t.proxy.UnRegisterSession(conn) 83 | 84 | // create channels for the return messages from the client 85 | cmr := util.CreateMqttTcpMessageReader(conn, t.proxy.Conf.GetReadTimeout()) 86 | 87 | go cmr.ReadMqttMessages() 88 | 89 | // This needs to be distributed across all servers 90 | backend := t.proxy.Conf.BackendServers[0] 91 | 92 | p, err := CreateTcpProxyConn(conn, backend) 93 | 94 | if err != nil { 95 | log.Printf("[serv] Error creating proxy connection - %s", err) 96 | sendServerUnavailable(conn) 97 | return 98 | } 99 | 100 | defer p.Close() 101 | 102 | t.proxy.Metrics.Connects.Mark(1) 103 | 104 | // do the authentication up front before going into normal operation 105 | if err = t.handleAuth(cmr, p); err != nil { 106 | log.Printf("[serv] Error authenticating connection - %s", err) 107 | // be very careful and clear on the error type as we are saying 108 | // for sure these credentials are not valid. 109 | if err == store.ErrUserNotFound { 110 | sendBadUsernameOrPassword(p.cConn) 111 | } else { 112 | sendServerUnavailable(conn) 113 | } 114 | return 115 | } 116 | 117 | // create channels for the return messages from the backend 118 | pmr := util.CreateMqttTcpMessageReader(p.pConn, t.proxy.Conf.GetReadTimeout()) 119 | 120 | go pmr.ReadMqttMessages() 121 | 122 | Loop: 123 | for { 124 | 125 | select { 126 | 127 | case msg := <-cmr.InMsgs: 128 | 129 | //util.DebugMQTT("client", conn, msg) 130 | msg = p.rewriter.RewriteIngress(msg) 131 | 132 | t.updateMsgCount(msg) 133 | 134 | // write to the proxy connection 135 | len, err := msg.Encode(p.pConn) 136 | 137 | if err != nil { 138 | log.Printf("[serv] proxy connection error - %s", err) 139 | break Loop 140 | } 141 | t.updateMsgBodySize(len) 142 | 143 | case err := <-cmr.InErrors: 144 | if err == io.EOF { 145 | log.Printf("[serv] client closed connection") 146 | } else { 147 | log.Printf("[serv] client connection read error - %s", err) 148 | } 149 | break Loop 150 | 151 | case msg := <-pmr.InMsgs: 152 | 153 | //util.DebugMQTT("proxy", conn, msg) 154 | msg = p.rewriter.RewriteEgress(msg) 155 | 156 | switch msg := msg.(type) { 157 | case *mqtt.ConnAck: 158 | log.Printf("[serv] got connack for %s", conn.RemoteAddr()) 159 | log.Printf("[serv] connack %+v", msg) 160 | case *mqtt.Disconnect: 161 | log.Printf("[serv] got disconnect for %s", conn.RemoteAddr()) 162 | log.Printf("[serv] disconnect %+v", msg) 163 | } 164 | t.proxy.Metrics.MsgReply.Mark(1) 165 | 166 | // write to the client connection 167 | len, err := msg.Encode(p.cConn) 168 | if err != nil { 169 | log.Printf("[serv] proxy connection error - %s", err) 170 | break Loop 171 | } 172 | t.updateMsgBodySize(len) 173 | 174 | case err := <-pmr.InErrors: 175 | if err == io.EOF { 176 | log.Printf("[serv] proxy connection closed by backend server") 177 | } else { 178 | log.Printf("[serv] proxy connection read error - %s", err) 179 | 180 | } 181 | break Loop 182 | } 183 | 184 | } 185 | 186 | } 187 | 188 | func (t *TcpServer) handleAuth(cmr *util.MqttTcpMessageReader, proxyConn *TcpProxyConn) error { 189 | 190 | select { 191 | case msg := <-cmr.InMsgs: 192 | 193 | //util.DebugMQTT("auth", proxyConn.cConn, msg) 194 | t.updateMsgCount(msg) 195 | 196 | switch cmsg := msg.(type) { 197 | case *mqtt.Connect: 198 | 199 | authUser, err := t.store.FindUser(cmsg.Username) 200 | 201 | if err != nil { 202 | log.Printf("[serv] authentication failed for %s - %s", authUser, err) 203 | return err 204 | } 205 | 206 | proxyConn.rewriter = t.proxy.MqttMsgRewriter(authUser) 207 | 208 | msg = proxyConn.rewriter.RewriteIngress(msg) 209 | 210 | len, err := msg.Encode(proxyConn.pConn) 211 | 212 | if err != nil { 213 | log.Printf("[serv] proxy connection error - %s", err) 214 | log.Println(spew.Sprintf("msg %v", msg)) 215 | return err 216 | } 217 | 218 | t.updateMsgBodySize(len) 219 | 220 | return nil 221 | 222 | } 223 | // anything else is bad 224 | return errors.New(fmt.Sprintf("expected connect got - %v", reflect.TypeOf(msg))) 225 | 226 | case err := <-cmr.InErrors: 227 | log.Printf("connection error ocurred during authentication - %s", err) 228 | return err 229 | } 230 | 231 | } 232 | 233 | func (t *TcpServer) updateMsgCount(msg mqtt.Message) { 234 | t.proxy.Metrics.Msgs.Mark(1) 235 | } 236 | 237 | func (t *TcpServer) updateMsgBodySize(len int) { 238 | t.proxy.Metrics.MsgBodySize.Update(int64(len)) 239 | } 240 | 241 | func sendBadUsernameOrPassword(conn net.Conn) { 242 | log.Printf("[serv] bad username / password %s %s", conn.LocalAddr(), conn.RemoteAddr()) 243 | connAck := &mqtt.ConnAck{ 244 | ReturnCode: mqtt.RetCodeBadUsernameOrPassword, 245 | } 246 | connAck.Encode(conn) 247 | } 248 | 249 | func sendServerUnavailable(conn net.Conn) { 250 | log.Printf("[serv] server unavailable %s %s", conn.LocalAddr(), conn.RemoteAddr()) 251 | connAck := &mqtt.ConnAck{ 252 | ReturnCode: mqtt.RetCodeServerUnavailable, 253 | } 254 | connAck.Encode(conn) 255 | } 256 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "log" 5 | "net" 6 | "reflect" 7 | "time" 8 | 9 | "github.com/ninjablocks/mqtt-proxy/proxy" 10 | "github.com/wolfeidau/mqtt" 11 | ) 12 | 13 | type MqttTcpMessageReader struct { 14 | Tcpconn net.Conn 15 | InMsgs chan mqtt.Message 16 | InErrors chan error 17 | ReadTimeout time.Duration 18 | } 19 | 20 | func CreateMqttTcpMessageReader(tcpconn net.Conn, readTimeout time.Duration) *MqttTcpMessageReader { 21 | return &MqttTcpMessageReader{ 22 | Tcpconn: tcpconn, 23 | InMsgs: make(chan mqtt.Message, 1), 24 | InErrors: make(chan error, 1), 25 | ReadTimeout: readTimeout, 26 | } 27 | } 28 | 29 | func (m *MqttTcpMessageReader) ReadMqttMessages() { 30 | 31 | defer log.Println("[serv] Reader done -", m.Tcpconn.RemoteAddr()) 32 | 33 | for { 34 | 35 | // we only want to configure this if it is greater than zero 36 | if m.ReadTimeout > 0 { 37 | m.Tcpconn.SetDeadline(time.Now().Add(m.ReadTimeout)) 38 | m.Tcpconn.SetWriteDeadline(time.Now().Add(m.ReadTimeout)) 39 | } 40 | 41 | msg, err := mqtt.DecodeOneMessage(m.Tcpconn, nil) 42 | 43 | if err != nil { 44 | m.InErrors <- err 45 | break 46 | } else { 47 | m.InMsgs <- msg 48 | } 49 | } 50 | 51 | } 52 | 53 | func IsMqttDisconnect(msg mqtt.Message) bool { 54 | return reflect.TypeOf(msg) == reflect.TypeOf(mqtt.MsgDisconnect) 55 | } 56 | 57 | func DebugMQTTMsg(tag string, c proxy.ProxyConn, msg mqtt.Message) { 58 | log.Printf("[%s] (%s) %s", tag, c.Id(), reflect.TypeOf(msg)) 59 | } 60 | 61 | func DebugMQTT(tag string, c net.Conn, msg mqtt.Message) { 62 | log.Printf("[%s] (%s) %s", tag, c.RemoteAddr(), reflect.TypeOf(msg)) 63 | } 64 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // The git commit that was compiled. This will be filled in by the compiler. 4 | var GitCommit string 5 | 6 | // The main version number that is being run at the moment. 7 | const Version = "1.1.1" 8 | 9 | // A pre-release marker for the version. If this is "" (empty string) 10 | // then it means that it is a final release. Otherwise, this is a pre-release 11 | // such as "dev" (in development), "beta", "rc1", etc. 12 | const VersionPrerelease = "" 13 | --------------------------------------------------------------------------------