├── LICENSE ├── README.md ├── broker ├── client.go ├── clients.go ├── config.go ├── memory_persistence.go ├── messageids.go ├── persistence.go ├── router.go ├── server.go ├── state.go ├── stats.go ├── subscriptions.go └── unit_router_test.go ├── config.go ├── main.go ├── packets ├── connack.go ├── connect.go ├── disconnect.go ├── packets.go ├── packets_test.go ├── pingreq.go ├── pingresp.go ├── puback.go ├── pubcomp.go ├── publish.go ├── pubrec.go ├── pubrel.go ├── suback.go ├── subscribe.go ├── unsuback.go └── unsubscribe.go ├── plugins ├── plugin.go ├── redirect_plugin.go └── twitter_plugin.go └── samples └── simple_broker.go /LICENSE: -------------------------------------------------------------------------------- 1 | Eclipse Public License - v 1.0 2 | 3 | THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC 4 | LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM 5 | CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 6 | 7 | 1. DEFINITIONS 8 | 9 | "Contribution" means: 10 | 11 | a) in the case of the initial Contributor, the initial code and documentation 12 | distributed under this Agreement, and 13 | b) in the case of each subsequent Contributor: 14 | i) changes to the Program, and 15 | ii) additions to the Program; 16 | 17 | where such changes and/or additions to the Program originate from and are 18 | distributed by that particular Contributor. A Contribution 'originates' 19 | from a Contributor if it was added to the Program by such Contributor 20 | itself or anyone acting on such Contributor's behalf. Contributions do not 21 | include additions to the Program which: (i) are separate modules of 22 | software distributed in conjunction with the Program under their own 23 | license agreement, and (ii) are not derivative works of the Program. 24 | 25 | "Contributor" means any person or entity that distributes the Program. 26 | 27 | "Licensed Patents" mean patent claims licensable by a Contributor which are 28 | necessarily infringed by the use or sale of its Contribution alone or when 29 | combined with the Program. 30 | 31 | "Program" means the Contributions distributed in accordance with this 32 | Agreement. 33 | 34 | "Recipient" means anyone who receives the Program under this Agreement, 35 | including all Contributors. 36 | 37 | 2. GRANT OF RIGHTS 38 | a) Subject to the terms of this Agreement, each Contributor hereby grants 39 | Recipient a non-exclusive, worldwide, royalty-free copyright license to 40 | reproduce, prepare derivative works of, publicly display, publicly 41 | perform, distribute and sublicense the Contribution of such Contributor, 42 | if any, and such derivative works, in source code and object code form. 43 | b) Subject to the terms of this Agreement, each Contributor hereby grants 44 | Recipient a non-exclusive, worldwide, royalty-free patent license under 45 | Licensed Patents to make, use, sell, offer to sell, import and otherwise 46 | transfer the Contribution of such Contributor, if any, in source code and 47 | object code form. This patent license shall apply to the combination of 48 | the Contribution and the Program if, at the time the Contribution is 49 | added by the Contributor, such addition of the Contribution causes such 50 | combination to be covered by the Licensed Patents. The patent license 51 | shall not apply to any other combinations which include the Contribution. 52 | No hardware per se is licensed hereunder. 53 | c) Recipient understands that although each Contributor grants the licenses 54 | to its Contributions set forth herein, no assurances are provided by any 55 | Contributor that the Program does not infringe the patent or other 56 | intellectual property rights of any other entity. Each Contributor 57 | disclaims any liability to Recipient for claims brought by any other 58 | entity based on infringement of intellectual property rights or 59 | otherwise. As a condition to exercising the rights and licenses granted 60 | hereunder, each Recipient hereby assumes sole responsibility to secure 61 | any other intellectual property rights needed, if any. For example, if a 62 | third party patent license is required to allow Recipient to distribute 63 | the Program, it is Recipient's responsibility to acquire that license 64 | before distributing the Program. 65 | d) Each Contributor represents that to its knowledge it has sufficient 66 | copyright rights in its Contribution, if any, to grant the copyright 67 | license set forth in this Agreement. 68 | 69 | 3. REQUIREMENTS 70 | 71 | A Contributor may choose to distribute the Program in object code form under 72 | its own license agreement, provided that: 73 | 74 | a) it complies with the terms and conditions of this Agreement; and 75 | b) its license agreement: 76 | i) effectively disclaims on behalf of all Contributors all warranties 77 | and conditions, express and implied, including warranties or 78 | conditions of title and non-infringement, and implied warranties or 79 | conditions of merchantability and fitness for a particular purpose; 80 | ii) effectively excludes on behalf of all Contributors all liability for 81 | damages, including direct, indirect, special, incidental and 82 | consequential damages, such as lost profits; 83 | iii) states that any provisions which differ from this Agreement are 84 | offered by that Contributor alone and not by any other party; and 85 | iv) states that source code for the Program is available from such 86 | Contributor, and informs licensees how to obtain it in a reasonable 87 | manner on or through a medium customarily used for software exchange. 88 | 89 | When the Program is made available in source code form: 90 | 91 | a) it must be made available under this Agreement; and 92 | b) a copy of this Agreement must be included with each copy of the Program. 93 | Contributors may not remove or alter any copyright notices contained 94 | within the Program. 95 | 96 | Each Contributor must identify itself as the originator of its Contribution, 97 | if 98 | any, in a manner that reasonably allows subsequent Recipients to identify the 99 | originator of the Contribution. 100 | 101 | 4. COMMERCIAL DISTRIBUTION 102 | 103 | Commercial distributors of software may accept certain responsibilities with 104 | respect to end users, business partners and the like. While this license is 105 | intended to facilitate the commercial use of the Program, the Contributor who 106 | includes the Program in a commercial product offering should do so in a manner 107 | which does not create potential liability for other Contributors. Therefore, 108 | if a Contributor includes the Program in a commercial product offering, such 109 | Contributor ("Commercial Contributor") hereby agrees to defend and indemnify 110 | every other Contributor ("Indemnified Contributor") against any losses, 111 | damages and costs (collectively "Losses") arising from claims, lawsuits and 112 | other legal actions brought by a third party against the Indemnified 113 | Contributor to the extent caused by the acts or omissions of such Commercial 114 | Contributor in connection with its distribution of the Program in a commercial 115 | product offering. The obligations in this section do not apply to any claims 116 | or Losses relating to any actual or alleged intellectual property 117 | infringement. In order to qualify, an Indemnified Contributor must: 118 | a) promptly notify the Commercial Contributor in writing of such claim, and 119 | b) allow the Commercial Contributor to control, and cooperate with the 120 | Commercial Contributor in, the defense and any related settlement 121 | negotiations. The Indemnified Contributor may participate in any such claim at 122 | its own expense. 123 | 124 | For example, a Contributor might include the Program in a commercial product 125 | offering, Product X. That Contributor is then a Commercial Contributor. If 126 | that Commercial Contributor then makes performance claims, or offers 127 | warranties related to Product X, those performance claims and warranties are 128 | such Commercial Contributor's responsibility alone. Under this section, the 129 | Commercial Contributor would have to defend claims against the other 130 | Contributors related to those performance claims and warranties, and if a 131 | court requires any other Contributor to pay any damages as a result, the 132 | Commercial Contributor must pay those damages. 133 | 134 | 5. NO WARRANTY 135 | 136 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON AN 137 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR 138 | IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, 139 | NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each 140 | Recipient is solely responsible for determining the appropriateness of using 141 | and distributing the Program and assumes all risks associated with its 142 | exercise of rights under this Agreement , including but not limited to the 143 | risks and costs of program errors, compliance with applicable laws, damage to 144 | or loss of data, programs or equipment, and unavailability or interruption of 145 | operations. 146 | 147 | 6. DISCLAIMER OF LIABILITY 148 | 149 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY 150 | CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, 151 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION 152 | LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 153 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 154 | ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE 155 | EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY 156 | OF SUCH DAMAGES. 157 | 158 | 7. GENERAL 159 | 160 | If any provision of this Agreement is invalid or unenforceable under 161 | applicable law, it shall not affect the validity or enforceability of the 162 | remainder of the terms of this Agreement, and without further action by the 163 | parties hereto, such provision shall be reformed to the minimum extent 164 | necessary to make such provision valid and enforceable. 165 | 166 | If Recipient institutes patent litigation against any entity (including a 167 | cross-claim or counterclaim in a lawsuit) alleging that the Program itself 168 | (excluding combinations of the Program with other software or hardware) 169 | infringes such Recipient's patent(s), then such Recipient's rights granted 170 | under Section 2(b) shall terminate as of the date such litigation is filed. 171 | 172 | All Recipient's rights under this Agreement shall terminate if it fails to 173 | comply with any of the material terms or conditions of this Agreement and does 174 | not cure such failure in a reasonable period of time after becoming aware of 175 | such noncompliance. If all Recipient's rights under this Agreement terminate, 176 | Recipient agrees to cease use and distribution of the Program as soon as 177 | reasonably practicable. However, Recipient's obligations under this Agreement 178 | and any licenses granted by Recipient relating to the Program shall continue 179 | and survive. 180 | 181 | Everyone is permitted to copy and distribute copies of this Agreement, but in 182 | order to avoid inconsistency the Agreement is copyrighted and may only be 183 | modified in the following manner. The Agreement Steward reserves the right to 184 | publish new versions (including revisions) of this Agreement from time to 185 | time. No one other than the Agreement Steward has the right to modify this 186 | Agreement. The Eclipse Foundation is the initial Agreement Steward. The 187 | Eclipse Foundation may assign the responsibility to serve as the Agreement 188 | Steward to a suitable separate entity. Each new version of the Agreement will 189 | be given a distinguishing version number. The Program (including 190 | Contributions) may always be distributed subject to the version of the 191 | Agreement under which it was received. In addition, after a new version of the 192 | Agreement is published, Contributor may elect to distribute the Program 193 | (including its Contributions) under the new version. Except as expressly 194 | stated in Sections 2(a) and 2(b) above, Recipient receives no rights or 195 | licenses to the intellectual property of any Contributor under this Agreement, 196 | whether expressly, by implication, estoppel or otherwise. All rights in the 197 | Program not expressly granted under this Agreement are reserved. 198 | 199 | This Agreement is governed by the laws of the State of New York and the 200 | intellectual property laws of the United States of America. No party to this 201 | Agreement will bring a legal action under this Agreement more than one year 202 | after the cause of action arose. Each party waives its rights to a jury trial in 203 | any resulting litigation. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Hrotti 2 | ====== 3 | 4 | Hrotti is both a library that provides an MQTT server and a wrapper program around that library that provides a standalone MQTT server. 5 | 6 | When used as a library you create a broker with the NewHrotti(maxQueueDepth int) function. This returns a broker with no listeners, the maxQueueDepth option is the number of messages that will be allowed to queue up for a client before any new messages that would be sent to that client are thrown away. 7 | 8 | To add a new listener to the broker you use AddListener(name string, config *ListenerConfig) :) which takes a pointer to a broker as the receiver. name is just a string to identify this listener, config is a pointer to a ListenerConfig currently the only important field in a ListenerConfig is URL which is a [url.URL](http://golang.org/pkg/net/url/#URL) 9 | 10 | To stop a listener use StopListener(name string) which again takes a broker as the receiver, name is the name of the listener as given in AddListener(), it returns a nil error on success, otherwise an error to indicate it could not find the named listener. Calling StopListener() will disconnect all clients currently connected to that listener. 11 | 12 | Here's a simple example that implements a tcp MQTT server on port 1883 13 | ``` 14 | package main 15 | 16 | import ( 17 | "github.com/alsm/hrotti/broker" 18 | "log" 19 | "os" 20 | "os/signal" 21 | "syscall" 22 | ) 23 | 24 | func main() { 25 | h := hrotti.NewHrotti(100) 26 | hrotti.INFO = log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime) 27 | h.AddListener("test", hrotti.NewListenerConfig("tcp://0.0.0.0:1883")) 28 | 29 | c := make(chan os.Signal, 1) 30 | signal.Notify(c, os.Interrupt, syscall.SIGTERM) 31 | <-c 32 | h.Stop() 33 | } 34 | ``` 35 | 36 | A slightly more extensive implementation is provided with this library, running go build in the project directory will produce a binary called hrotti which allows for configuration of multiple listeners with a json config file. If only a single listener is required though you can just set the HROTTI_URL environment variable. 37 | Only tcp and ws URL schemes are supported, eg: tcp://0.0.0.0:1883 or ws://0.0.0.0:1883/mqtt 38 | With a websocket URL if no path is specified it will automatically serve on / 39 | 40 | Alternatively a configuration file in json can be provided allowing the creation of multiple listeners, currently all listeners share the same root node in the topic tree. To pass a configuration file use the command line option "-conf", for example; 41 | ``` 42 | hrotti -conf config.json 43 | ``` 44 | The configuration expects an object called "listeners" which is a map of the listener name to a json representation of a ListenerConfig, currently only the url can be specified. 45 | 46 | A listener only listens via tcp or websockets and not both on the same port. 47 | 48 | An example configuration file is shown below 49 | ``` 50 | { 51 | "maxQueueDepth": 100, 52 | "listeners":{ 53 | "tcp":{ 54 | "url":"tcp://0.0.0.0:1883" 55 | }, 56 | "websockets":{ 57 | "url":"ws://0.0.0.0:2000/mqtt" 58 | } 59 | } 60 | } 61 | ``` 62 | 63 | The current persistence mechanism is in memory only. -------------------------------------------------------------------------------- /broker/client.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | //"errors" 5 | . "github.com/alsm/hrotti/packets" 6 | "github.com/google/uuid" 7 | //"io" 8 | "net" 9 | "sync" 10 | "time" 11 | // Plugins currently don't work (they create a cycle). We could break the cycle 12 | // by fudging things through main.go, but I think the real solution is to use RPC 13 | // and run plugins in a separate process 14 | // . "github.com/alsm/hrotti/plugins" 15 | ) 16 | 17 | type Client struct { 18 | sync.WaitGroup 19 | messageIDs 20 | clientID string 21 | conn net.Conn 22 | keepAlive uint16 23 | state State 24 | topicSpace string 25 | outboundMessages chan *PublishPacket 26 | outboundPriority chan ControlPacket 27 | stop chan struct{} 28 | stopOnce *sync.Once 29 | resetTimer chan bool 30 | cleanSession bool 31 | willMessage *PublishPacket 32 | takeOver bool 33 | } 34 | 35 | func newClient(conn net.Conn, clientID string, maxQDepth int) *Client { 36 | return &Client{ 37 | conn: conn, 38 | clientID: clientID, 39 | stop: make(chan struct{}), 40 | resetTimer: make(chan bool, 1), 41 | outboundMessages: make(chan *PublishPacket, maxQDepth), 42 | outboundPriority: make(chan ControlPacket, maxQDepth), 43 | stopOnce: new(sync.Once), 44 | messageIDs: messageIDs{ 45 | //idChan: make(chan uint16, 10), 46 | index: make(map[uint16]*uuid.UUID), 47 | }, 48 | } 49 | } 50 | 51 | func (c *Client) Connected() bool { 52 | return c.state.Value() == CONNECTED 53 | } 54 | 55 | func (c *Client) KeepAliveTimer(hrotti *Hrotti) { 56 | //this function is part of the client's waitgroup so call Done() when the function exits 57 | defer c.Done() 58 | //In a continuous loop create a Timer for 1.5 * the keepAlive setting 59 | for { 60 | t := time.NewTimer(time.Duration(float64(c.keepAlive)*1.5) * time.Second) 61 | //this select will block on all 3 cases until one of them is ready 62 | select { 63 | //if we get a value in on the resetTimer channel we drop out, stop the Timer then loop round again 64 | case <-c.resetTimer: 65 | DEBUG.Println(c.clientID, "resetting keepalive timer") 66 | //if the timer triggers then the client has failed to send us a packet in the keepAlive period so 67 | //must be disconnected, we call Stop() and the function returns. 68 | case <-t.C: 69 | ERROR.Println(c.clientID, "has timed out", c.keepAlive) 70 | go c.Stop(true, hrotti) 71 | return 72 | //the client sent a DISCONNECT or some error occurred that triggered the client to stop, so return. 73 | case <-c.stop: 74 | return 75 | } 76 | t.Stop() 77 | } 78 | } 79 | 80 | func (c *Client) StopForTakeover() { 81 | //close the stop channel, close the network connection, wait for all the goroutines in the waitgroup to 82 | //finish, set the conn and bufferedconn to nil 83 | c.takeOver = true 84 | c.stopOnce.Do(func() { 85 | INFO.Println("Closing Stop chan") 86 | close(c.stop) 87 | INFO.Println("Closing connection") 88 | c.conn.Close() 89 | c.Wait() 90 | c.conn = nil 91 | }) 92 | } 93 | 94 | func (c *Client) Stop(sendWill bool, hrotti *Hrotti) { 95 | //Its possible that error conditions with the network connection might cause both Send and Receive to 96 | //try and call Stop(), but we only want it to be called once, so using the sync.Once in the client we 97 | //run the embedded function, later calls with the same sync.Once will simply return. 98 | INFO.Println("Stopping client", c.clientID, c.conn.RemoteAddr()) 99 | if !c.takeOver { 100 | c.stopOnce.Do(func() { 101 | //close the stop channel, close the network connection, wait for all the goroutines in the waitgroup 102 | //set the state as disconnected, close the message channels. 103 | close(c.stop) 104 | c.conn.Close() 105 | c.Wait() 106 | c.state.SetValue(DISCONNECTED) 107 | close(c.outboundMessages) 108 | close(c.outboundPriority) 109 | //If we've stopped in a situation where the will message should be sent, and there is a will 110 | //message, then send it. 111 | if sendWill && c.willMessage != nil { 112 | INFO.Println("Sending will message for", c.clientID) 113 | go hrotti.DeliverMessage(c.willMessage.TopicName, c.willMessage) 114 | } 115 | //if this client connected with cleansession true it means it does not need its state (such as 116 | //subscriptions, unreceived messages etc) kept around 117 | if c.cleanSession { 118 | //so we lock the clients map, delete the clientid and *Client from the map, remove all subscriptions 119 | //associated with this client, from the normal tree and any plugins. Then close the persistence 120 | //store that it was using. 121 | hrotti.clients.Lock() 122 | delete(hrotti.clients.list, c.clientID) 123 | hrotti.clients.Unlock() 124 | hrotti.DeleteSubAll(c.clientID) 125 | hrotti.PersistStore.Close(c.clientID) 126 | } 127 | }) 128 | } 129 | } 130 | 131 | func (c *Client) Start(cp *ConnectPacket, hrotti *Hrotti) { 132 | //If cleansession was set to 1 in the CONNECT packet set as true in the client. 133 | c.cleanSession = cp.CleanSession 134 | //There is a will message in the connect packet, so construct the publish packet that will be sent if 135 | //the will is triggered. 136 | if cp.WillFlag { 137 | pp := NewControlPacket(PUBLISH).(*PublishPacket) 138 | pp.FixedHeader.Qos = cp.WillQos 139 | pp.FixedHeader.Retain = cp.WillRetain 140 | pp.TopicName = cp.WillTopic 141 | pp.Payload = cp.WillMessage 142 | 143 | c.willMessage = pp 144 | } else { 145 | c.willMessage = nil 146 | } 147 | c.keepAlive = cp.KeepaliveTimer 148 | 149 | //If cleansession true, or there doesn't already exist a persistence store for this client (ie a new 150 | //durable client), create the inbound and outbound persistence stores. 151 | if c.cleanSession || !hrotti.PersistStore.Exists(c.clientID) { 152 | hrotti.PersistStore.Open(c.clientID) 153 | } else { 154 | //we have an existing inbound and outbound persistence store for this client already, so lets 155 | //get any messages still in outbound and attempt to send them. 156 | INFO.Println("Getting unacknowledged messages from persistence") 157 | for _, msg := range hrotti.PersistStore.GetAll(c.clientID) { 158 | switch msg.(type) { 159 | //If the message in the store is a publish packet 160 | case *PublishPacket: 161 | //It's possible we already sent this message and didn't remove it from the store because we 162 | //didn't get an acknowledgement, so set the dup flag to 1. (only for QoS > 0) 163 | if msg.(*PublishPacket).Qos > 0 { 164 | msg.(*PublishPacket).Dup = true 165 | } 166 | c.outboundMessages <- msg.(*PublishPacket) 167 | //If it's something else like a PUBACK etc send it to the priority outbound channel 168 | default: 169 | c.outboundPriority <- msg 170 | } 171 | } 172 | } 173 | 174 | //Prepare and write the CONNACK packet. 175 | ca := NewControlPacket(CONNACK).(*ConnackPacket) 176 | ca.ReturnCode = CONN_ACCEPTED 177 | ca.Write(c.conn) 178 | //Receive and Send are part of this WaitGroup, so add 2 to the waitgroup and run the goroutines. 179 | c.Add(2) 180 | go c.Receive(hrotti) 181 | go c.Send(hrotti) 182 | c.state.SetValue(CONNECTED) 183 | //If keepalive value was set run the keepalive time and add 1 to the waitgroup. 184 | if c.keepAlive > 0 { 185 | c.Add(1) 186 | go c.KeepAliveTimer(hrotti) 187 | } 188 | } 189 | 190 | func validateclientID(clientID string) bool { 191 | return true 192 | } 193 | 194 | func (c *Client) ResetTimer() { 195 | //If we're using keepalive on this client attempt to reset the timer, if the channel blocks it's because 196 | //the timer is already being reset so we can safely drop the attempt here (the default case of the select) 197 | if c.keepAlive > 0 { 198 | select { 199 | case c.resetTimer <- true: 200 | default: 201 | } 202 | } 203 | } 204 | 205 | func (c *Client) Receive(hrotti *Hrotti) { 206 | //part of the client waitgroup so call Done() when the function returns. 207 | defer c.Done() 208 | //loop forever... 209 | for { 210 | select { 211 | //if called to stop then return 212 | case <-c.stop: 213 | return 214 | //otherwise... 215 | default: 216 | /*var cph FixedHeader 217 | var err error 218 | var body []byte 219 | //var typeByte byte 220 | //the msgType will always be the first byte read from the network. 221 | //typeByte, err = c.bufferedConn.Peek(1) 222 | //if there was an error reading from the network, print it and call stop 223 | //true here means send the will message, if there is one, and return. 224 | if err != nil { 225 | ERROR.Println(err.Error(), c.clientID, c.conn.RemoteAddr()) 226 | go c.Stop(true, hrotti) 227 | return 228 | } 229 | //we've received a message so reset the keepalive timer. 230 | c.ResetTimer() 231 | //unpack the first byte into the fixedHeader and read the remaining length 232 | cph.unpack(typeByte) 233 | cph.RemainingLength = decodeLength(c.bufferedConn) 234 | //if the remaining length is > 0 then there is more to read for this packet so 235 | //make the body slice the size of the remaining data. readfull will not return 236 | //until the target slice is full or there was an error 237 | if cph.remainingLength > 0 { 238 | body = make([]byte, cph.remainingLength) 239 | _, err = io.ReadFull(c.bufferedConn, body) 240 | //if there was an error (such as broken network), call Stop (send will message) 241 | //and return. 242 | if err != nil { 243 | go c.Stop(true, hrotti) 244 | return 245 | } 246 | } 247 | //MQTT allows large messages that could take a long time to receive, ideally here 248 | //we should pause the keepalive timer, for now we just reset the timer again once 249 | //we've recevied the message. 250 | c.ResetTimer() 251 | //switch on the type of message we've received*/ 252 | cp, err := ReadPacket(c.conn) 253 | if err != nil { 254 | ERROR.Println(err.Error(), c.clientID) 255 | go c.Stop(true, hrotti) 256 | return 257 | } 258 | 259 | // reset the keep alive timer. 260 | c.ResetTimer() 261 | 262 | switch cp.(type) { 263 | //a second CONNECT packet is a protocol violation, so Stop (send will) and return. 264 | case *ConnectPacket: 265 | ERROR.Println("Received second CONNECT from", c.clientID) 266 | go c.Stop(true, hrotti) 267 | return 268 | //client wishes to disconnect so Stop (don't send will) and return. 269 | case *DisconnectPacket: 270 | INFO.Println("Received DISCONNECT from", c.clientID) 271 | go c.Stop(false, hrotti) 272 | return 273 | //client has sent us a PUBLISH message, unpack it persist (if QoS > 0) in the inbound store 274 | case *PublishPacket: 275 | pp := cp.(*PublishPacket) 276 | PROTOCOL.Println("Received PUBLISH from", c.clientID, pp.TopicName) 277 | if pp.Qos > 0 { 278 | hrotti.PersistStore.Add(c.clientID, INBOUND, pp) 279 | } 280 | //if this message has the retained flag set then set as the retained message for the 281 | //appropriate node in the topic tree 282 | if pp.Retain { 283 | hrotti.subs.SetRetained(pp.TopicName, pp) 284 | } 285 | //go and deliver the message to any subscribers. 286 | go hrotti.DeliverMessage(pp.TopicName, pp) 287 | //if the message was QoS1 or QoS2 start the acknowledgement flows. 288 | switch pp.Qos { 289 | case 1: 290 | pa := NewControlPacket(PUBACK).(*PubackPacket) 291 | pa.MessageID = pp.MessageID 292 | c.HandleFlow(pa, hrotti) 293 | case 2: 294 | pr := NewControlPacket(PUBREC).(*PubrecPacket) 295 | pr.MessageID = pp.MessageID 296 | c.HandleFlow(pr, hrotti) 297 | } 298 | //We received a PUBACK acknowledging a QoS1 PUBLISH we sent to the client 299 | case *PubackPacket: 300 | pa := cp.(*PubackPacket) 301 | //Check that we also think this message id is in use, if it is remove the original 302 | //PUBLISH from the outbound persistence store and set the message id as free for reuse 303 | if c.inUse(pa.MessageID) { 304 | hrotti.PersistStore.Delete(c.clientID, OUTBOUND, pa.UUID()) 305 | c.freeID(pa.MessageID) 306 | } else { 307 | ERROR.Println("Received a PUBACK for unknown msgid", pa.MessageID, "from", c.clientID) 308 | } 309 | //We received a PUBREC for a QoS2 PUBLISH we sent to the client. 310 | case *PubrecPacket: 311 | pr := cp.(*PubrecPacket) 312 | //Check that we also think this message id is in use, if it is run the next stage of the 313 | //message flows for QoS2 messages. 314 | if c.inUse(pr.MessageID) { 315 | prel := NewControlPacket(PUBREL).(*PubrelPacket) 316 | prel.MessageID = pr.MessageID 317 | c.HandleFlow(prel, hrotti) 318 | } else { 319 | ERROR.Println("Received a PUBREC for unknown msgid", pr.MessageID, "from", c.clientID) 320 | } 321 | //We received a PUBREL for a QoS2 PUBLISH from the client, hrotti delivers on PUBLISH though 322 | //so we've already sent the original message to any subscribers, so just create a new 323 | //PUBCOMP message with the correct message id and pass it to the HandleFlow function. 324 | case *PubrelPacket: 325 | pr := cp.(*PubrelPacket) 326 | pc := NewControlPacket(PUBCOMP).(*PubcompPacket) 327 | pc.MessageID = pr.MessageID 328 | c.HandleFlow(pc, hrotti) 329 | //Received a PUBCOMP for a QoS2 PUBLISH we originally sent the client. Check the messageid is 330 | //one we think is in use, if so delete the original PUBLISH from the outbound persistence store 331 | //and free the message id for reuse 332 | case *PubcompPacket: 333 | pc := cp.(*PubcompPacket) 334 | if c.inUse(pc.MessageID) { 335 | //hrotti.PersistStore.Delete(c, OUTBOUND, pc.UUID) 336 | c.freeID(pc.MessageID) 337 | } else { 338 | ERROR.Println("Received a PUBCOMP for unknown msgid", pc.MessageID, "from", c.clientID) 339 | } 340 | //The client wishes to make a subscription, unpack the message and call AddSubscription with the 341 | //requested topics and QoS'. Create a new SUBACK message and put the granted QoS values in it 342 | //and send back to the client. 343 | case *SubscribePacket: 344 | PROTOCOL.Println("Received SUBSCRIBE from", c.clientID) 345 | sp := cp.(*SubscribePacket) 346 | rQos := hrotti.AddSubscription(c, sp.Topics, sp.Qoss) 347 | sa := NewControlPacket(SUBACK).(*SubackPacket) 348 | sa.MessageID = sp.MessageID 349 | sa.GrantedQoss = append(sa.GrantedQoss, rQos...) 350 | c.outboundPriority <- sa 351 | //The client wants to unsubscribe from a topic. 352 | case *UnsubscribePacket: 353 | PROTOCOL.Println("Received UNSUBSCRIBE from", c.clientID) 354 | up := cp.(*UnsubscribePacket) 355 | hrotti.RemoveSubscription(c, up.Topics[0]) 356 | ua := NewControlPacket(UNSUBACK).(*UnsubackPacket) 357 | ua.MessageID = up.MessageID 358 | c.outboundPriority <- ua 359 | //As part of the keepalive if the client doesn't have any messages to send us for as long as the 360 | //keepalive period it will send a ping request, so we send a ping response back 361 | case *PingreqPacket: 362 | presp := NewControlPacket(PINGRESP).(*PingrespPacket) 363 | c.outboundPriority <- presp 364 | } 365 | } 366 | } 367 | } 368 | 369 | func (c *Client) HandleFlow(msg ControlPacket, hrotti *Hrotti) { 370 | switch msg.(type) { 371 | case *PubrelPacket: 372 | hrotti.PersistStore.Replace(c.clientID, OUTBOUND, msg) 373 | case *PubackPacket, *PubcompPacket: 374 | hrotti.PersistStore.Delete(c.clientID, INBOUND, msg.UUID()) 375 | } 376 | //send to channel if open, silently drop if channel closed 377 | select { 378 | case c.outboundPriority <- msg: 379 | default: 380 | } 381 | } 382 | 383 | func (c *Client) Send(hrotti *Hrotti) { 384 | //Send is part of the client waitgroup so call Done when the function returns. 385 | defer c.Done() 386 | for { 387 | //3 way blocking select 388 | select { 389 | //the stop channel has been closed so we should return 390 | case <-c.stop: 391 | return 392 | //the two value receive from a channel tells us whether the channel is closed 393 | //as reading from a closed channel always returns the empty value for the channel 394 | //type. ok == false means the channel is closed and the msg will be nil 395 | case msg, ok := <-c.outboundPriority: 396 | if ok { 397 | //Message IDs are not assigned until we're ready to send the message 398 | switch msg.(type) { 399 | case *SubscribePacket: 400 | msg.(*SubscribePacket).MessageID = c.getMsgID(msg.UUID()) 401 | case *UnsubscribePacket: 402 | msg.(*UnsubscribePacket).MessageID = c.getMsgID(msg.UUID()) 403 | } 404 | msg.Write(c.conn) 405 | } 406 | case msg, ok := <-c.outboundMessages: 407 | //ok == false means we were triggered because the channel 408 | //is closed, and the msg will be nil 409 | if ok { 410 | switch msg.Details().Qos { 411 | case 1, 2: 412 | msg.MessageID = c.getMsgID(msg.UUID()) 413 | } 414 | msg.Write(c.conn) 415 | } 416 | } 417 | } 418 | } 419 | -------------------------------------------------------------------------------- /broker/clients.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // A map of clientid to Client pointer and a RW Mutex to protect access. 8 | type clients struct { 9 | sync.RWMutex 10 | list map[string]*Client 11 | } 12 | 13 | // Return empty Clients (value type) 14 | func newClients() *clients { 15 | c := &clients{ 16 | sync.RWMutex{}, 17 | make(map[string]*Client), 18 | } 19 | return c 20 | } 21 | -------------------------------------------------------------------------------- /broker/config.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "net/url" 7 | ) 8 | 9 | //loggers 10 | var ( 11 | INFO *log.Logger 12 | PROTOCOL *log.Logger 13 | ERROR *log.Logger 14 | DEBUG *log.Logger 15 | ) 16 | 17 | //The default output for all the loggers is set to ioutil.Discard 18 | func init() { 19 | INFO = log.New(ioutil.Discard, "", 0) 20 | PROTOCOL = log.New(ioutil.Discard, "", 0) 21 | ERROR = log.New(ioutil.Discard, "", 0) 22 | DEBUG = log.New(ioutil.Discard, "", 0) 23 | } 24 | 25 | //ListenerConfig is a struct containing a URL 26 | type ListenerConfig struct { 27 | URL *url.URL 28 | } 29 | 30 | //NewListenerConfig returns a pointer to a ListenerConfig prepared to listen 31 | //on the URL specified as rawURL 32 | func NewListenerConfig(rawURL string) *ListenerConfig { 33 | listenerURL, err := url.Parse(rawURL) 34 | if err != nil { 35 | return nil 36 | } 37 | l := &ListenerConfig{URL: listenerURL} 38 | return l 39 | } 40 | -------------------------------------------------------------------------------- /broker/memory_persistence.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | . "github.com/alsm/hrotti/packets" 5 | "github.com/google/uuid" 6 | "sync" 7 | ) 8 | 9 | //a persistence entry is a map of msgIds and ControlPackets 10 | type MemoryPersistenceEntry struct { 11 | sync.Mutex 12 | messages map[string]ControlPacket 13 | } 14 | 15 | //the MemoryPersistence struct is a map of client pointers to pointers 16 | //to a Persistence Entry. So each client has its own map of msgIds/packets. 17 | type MemoryPersistence struct { 18 | sync.RWMutex 19 | inbound map[string]*MemoryPersistenceEntry 20 | outbound map[string]*MemoryPersistenceEntry 21 | } 22 | 23 | func (p *MemoryPersistence) Init() error { 24 | //init the Memory persistence, we haven't created our persistenceentrys yet 25 | p.inbound = make(map[string]*MemoryPersistenceEntry) 26 | p.outbound = make(map[string]*MemoryPersistenceEntry) 27 | return nil 28 | } 29 | 30 | func (p *MemoryPersistence) Open(client string) { 31 | //lock the whole persistence store while we add a new client entry 32 | p.Lock() 33 | defer p.Unlock() 34 | DEBUG.Println("Opening memory persistence for", client) 35 | //init the MemoryPersistenceEntry for this client 36 | p.inbound[client] = &MemoryPersistenceEntry{messages: make(map[string]ControlPacket)} 37 | p.outbound[client] = &MemoryPersistenceEntry{messages: make(map[string]ControlPacket)} 38 | } 39 | 40 | func (p *MemoryPersistence) Close(client string) { 41 | //lock the whole persistence store while we delete a client from the map 42 | p.Lock() 43 | defer p.Unlock() 44 | delete(p.inbound, client) 45 | delete(p.outbound, client) 46 | } 47 | 48 | func (p *MemoryPersistence) Add(client string, direction dirFlag, message ControlPacket) bool { 49 | //only need to get a read lock on the persistence store, but lock the underlying 50 | //persistenceentry for the client we're working with. 51 | p.RLock() 52 | defer p.RUnlock() 53 | //the uuid is the key in the persistence entry 54 | id := message.UUID().String() 55 | switch direction { 56 | case INBOUND: 57 | p.inbound[client].Lock() 58 | defer p.inbound[client].Unlock() 59 | DEBUG.Println("Persisting inbound packet for", client, id) 60 | //if there is already an entry for this message id return false 61 | if _, ok := p.inbound[client].messages[id]; ok { 62 | return false 63 | } 64 | //otherwise insert this message into the map 65 | p.inbound[client].messages[id] = message 66 | case OUTBOUND: 67 | p.outbound[client].Lock() 68 | defer p.outbound[client].Unlock() 69 | DEBUG.Println("Persisting outbound packet for", client, id) 70 | //if there is already an entry for this message id return false 71 | if _, ok := p.outbound[client].messages[id]; ok { 72 | return false 73 | } 74 | //otherwise insert this message into the map 75 | p.outbound[client].messages[id] = message 76 | } 77 | return true 78 | } 79 | 80 | func (p *MemoryPersistence) Replace(client string, direction dirFlag, message ControlPacket) bool { 81 | //only need to get a read lock on the persistence store, but lock the underlying 82 | //persistenceentry for the client we're working with. 83 | p.RLock() 84 | defer p.RUnlock() 85 | 86 | id := message.UUID().String() 87 | switch direction { 88 | //For QoS2 flows we want to replace the original PUBLISH with the related PUBREL 89 | //as it maintains the same message id 90 | case INBOUND: 91 | p.inbound[client].Lock() 92 | defer p.inbound[client].Unlock() 93 | DEBUG.Println("Replacing persisted message for", client, id) 94 | //if there is already an entry for this message id return false 95 | if _, ok := p.inbound[client].messages[id]; ok { 96 | return false 97 | } 98 | //otherwise insert this message into the map 99 | p.inbound[client].messages[id] = message 100 | case OUTBOUND: 101 | p.outbound[client].Lock() 102 | defer p.outbound[client].Unlock() 103 | DEBUG.Println("Replacing persisted message for", client, id) 104 | //if there is already an entry for this message id return false 105 | if _, ok := p.outbound[client].messages[id]; ok { 106 | return false 107 | } 108 | //otherwise insert this message into the map 109 | p.outbound[client].messages[id] = message 110 | } 111 | return true 112 | } 113 | 114 | func (p *MemoryPersistence) AddBatch(batch map[string]*PublishPacket) { 115 | //adding messages to many different client entries at the same time, as we're doing 116 | //this grabbing a full lock on the whole persistence mechanism 117 | p.Lock() 118 | defer p.Unlock() 119 | //the batch is a map keyed by client and value is a pointer to a PUBLISH 120 | //for each create an appropriate entry 121 | for client, message := range batch { 122 | p.inbound[client].messages[message.UUID().String()] = message 123 | } 124 | } 125 | 126 | func (p *MemoryPersistence) Delete(client string, direction dirFlag, uid uuid.UUID) bool { 127 | //only need to get a read lock on the persistence store, but lock the underlying 128 | //persistenceentry for the client we're working with. 129 | p.RLock() 130 | defer p.RUnlock() 131 | //checks that there is actually an entry for the message id we're being asked to 132 | //delete, if there isn't return false, otherwise delete the entry. 133 | id := uid.String() 134 | DEBUG.Println("Removing persisted message for", client) 135 | switch direction { 136 | case INBOUND: 137 | p.inbound[client].Lock() 138 | defer p.inbound[client].Unlock() 139 | //if there is already an entry for this message id return false 140 | if _, ok := p.inbound[client].messages[id]; !ok { 141 | return false 142 | } 143 | delete(p.inbound[client].messages, id) 144 | case OUTBOUND: 145 | p.outbound[client].Lock() 146 | defer p.outbound[client].Unlock() 147 | //if there is already an entry for this message id return false 148 | if _, ok := p.outbound[client].messages[id]; !ok { 149 | return false 150 | } 151 | delete(p.outbound[client].messages, id) 152 | } 153 | return true 154 | } 155 | 156 | func (p *MemoryPersistence) GetAll(client string) (messages []ControlPacket) { 157 | //only need to get a read lock on the persistence store, but lock the underlying 158 | //persistenceentry for the client we're working with. 159 | p.RLock() 160 | p.outbound[client].Lock() 161 | defer p.outbound[client].Unlock() 162 | defer p.RUnlock() 163 | //Get every message in the persistence store for a given client, create a slice 164 | //of the ControlPackets (not just PUBLISHES in there) 165 | for _, message := range p.outbound[client].messages { 166 | messages = append(messages, message) 167 | } 168 | return messages 169 | } 170 | 171 | func (p *MemoryPersistence) Exists(client string) bool { 172 | //grab a read lock on the persistence and check if there is already an entry 173 | //for this client. 174 | p.RLock() 175 | defer p.RUnlock() 176 | _, okInbound := p.inbound[client] 177 | _, okOutbound := p.outbound[client] 178 | return okInbound && okOutbound 179 | } 180 | -------------------------------------------------------------------------------- /broker/messageids.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | "sync" 6 | ) 7 | 8 | type messageIDs struct { 9 | sync.RWMutex 10 | //idChan chan uint16 11 | index map[uint16]*uuid.UUID 12 | } 13 | 14 | const ( 15 | msgIDMax uint16 = 65535 16 | msgIDMin uint16 = 1 17 | ) 18 | 19 | /*func (c *Client) genMsgIDs() { 20 | defer c.Done() 21 | m := &c.messageIDs 22 | for { 23 | m.Lock() 24 | for i := msgIDMin; i < msgIDMax; i++ { 25 | if m.index[i] == nil { 26 | m.index[i] = 27 | m.Unlock() 28 | select { 29 | case m.idChan <- i: 30 | case <-c.stop: 31 | return 32 | } 33 | break 34 | } 35 | } 36 | } 37 | }*/ 38 | 39 | func (m *messageIDs) getMsgID(id uuid.UUID) uint16 { 40 | m.Lock() 41 | defer m.Unlock() 42 | for i := msgIDMin; i < msgIDMax; i++ { 43 | if m.index[i] == nil { 44 | m.index[i] = &id 45 | return i 46 | } 47 | } 48 | return 0 49 | } 50 | 51 | func (m *messageIDs) inUse(id uint16) bool { 52 | m.RLock() 53 | defer m.RUnlock() 54 | return m.index[id] != nil 55 | } 56 | 57 | func (m *messageIDs) freeID(id uint16) { 58 | m.Lock() 59 | defer m.Unlock() 60 | m.index[id] = nil 61 | } 62 | -------------------------------------------------------------------------------- /broker/persistence.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | . "github.com/alsm/hrotti/packets" 5 | "github.com/google/uuid" 6 | ) 7 | 8 | type dirFlag byte 9 | 10 | const ( 11 | INBOUND = 1 12 | OUTBOUND = 2 13 | ) 14 | 15 | type Persistence interface { 16 | Init() error 17 | Open(string) 18 | Close(string) 19 | Add(string, dirFlag, ControlPacket) bool 20 | Replace(string, dirFlag, ControlPacket) bool 21 | AddBatch(map[string]*PublishPacket) 22 | Delete(string, dirFlag, uuid.UUID) bool 23 | GetAll(string) []ControlPacket 24 | Exists(string) bool 25 | } 26 | -------------------------------------------------------------------------------- /broker/router.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | . "github.com/alsm/hrotti/packets" 5 | "strings" 6 | "sync" 7 | ) 8 | 9 | type subscriptionMap struct { 10 | subElements map[string][]string 11 | subMap map[string]map[string]byte 12 | subBitmap []map[string]map[string]bool 13 | retained map[string]*PublishPacket 14 | sync.RWMutex 15 | } 16 | 17 | var subs subscriptionMap 18 | 19 | func newSubMap() *subscriptionMap { 20 | s := &subscriptionMap{} 21 | s.subElements = make(map[string][]string) 22 | s.subMap = make(map[string]map[string]byte) 23 | s.subBitmap = make([]map[string]map[string]bool, 10) 24 | for i, _ := range s.subBitmap { 25 | s.subBitmap[i] = make(map[string]map[string]bool) 26 | } 27 | s.retained = make(map[string]*PublishPacket) 28 | 29 | return s 30 | } 31 | 32 | func (s *subscriptionMap) SetRetained(topic string, message *PublishPacket) { 33 | DEBUG.Println("Setting retained message for", topic) 34 | s.RLock() 35 | defer s.RUnlock() 36 | if len(message.Payload) == 0 { 37 | delete(s.retained, topic) 38 | } else { 39 | s.retained[topic] = message 40 | } 41 | } 42 | 43 | func match(route []string, topic []string) bool { 44 | if len(route) == 0 { 45 | if len(topic) == 0 { 46 | return true 47 | } 48 | return false 49 | } 50 | 51 | if len(topic) == 0 { 52 | if route[0] == "#" { 53 | return true 54 | } 55 | return false 56 | } 57 | 58 | if route[0] == "#" { 59 | return true 60 | } 61 | 62 | if (route[0] == "+") || (route[0] == topic[0]) { 63 | return match(route[1:], topic[1:]) 64 | } 65 | 66 | return false 67 | } 68 | 69 | func (h *Hrotti) FindRetained(id string, topic string, qos byte) { 70 | var deliverList []*PublishPacket 71 | client := h.getClient(id) 72 | if strings.ContainsAny(topic, "#+") { 73 | for rTopic, msg := range h.subs.retained { 74 | if match(strings.Split(topic, "/"), strings.Split(rTopic, "/")) { 75 | deliveryMsg := msg.Copy() 76 | deliveryMsg.Qos = calcMinQos(msg.Qos, qos) 77 | deliverList = append(deliverList, deliveryMsg) 78 | } 79 | } 80 | } else { 81 | if msg, ok := h.subs.retained[topic]; ok { 82 | deliveryMsg := msg.Copy() 83 | deliveryMsg.Qos = calcMinQos(msg.Qos, qos) 84 | deliverList = append(deliverList, deliveryMsg) 85 | } 86 | } 87 | if len(deliverList) > 0 { 88 | for _, msg := range deliverList { 89 | if msg.Qos > 0 { 90 | if client.Connected() { 91 | h.PersistStore.Add(id, OUTBOUND, msg) 92 | select { 93 | case client.outboundMessages <- msg: 94 | default: 95 | } 96 | } else { 97 | h.PersistStore.Add(id, OUTBOUND, msg) 98 | } 99 | } else if client.Connected() { 100 | select { 101 | case client.outboundMessages <- msg: 102 | default: 103 | } 104 | } 105 | } 106 | } 107 | } 108 | 109 | func (h *Hrotti) AddSub(client string, subscription string, qos byte) { 110 | h.subs.Lock() 111 | defer h.subs.Unlock() 112 | if _, ok := h.subs.subElements[subscription]; !ok { 113 | h.subs.subElements[subscription] = strings.Split(subscription, "/") 114 | } 115 | if _, ok := h.subs.subMap[subscription]; !ok { 116 | h.subs.subMap[subscription] = make(map[string]byte) 117 | } 118 | h.subs.subMap[subscription][client] = qos 119 | for i, element := range append(h.subs.subElements[subscription], "\u0000") { 120 | if _, ok := h.subs.subBitmap[i][element]; !ok { 121 | h.subs.subBitmap[i][element] = make(map[string]bool) 122 | } 123 | h.subs.subBitmap[i][element][subscription] = true 124 | } 125 | go h.FindRetained(client, subscription, qos) 126 | } 127 | 128 | func (h *Hrotti) DeleteSub(client string, subscription string) { 129 | h.subs.Lock() 130 | defer h.subs.Unlock() 131 | if _, ok := h.subs.subMap[subscription]; ok { 132 | delete(h.subs.subMap[subscription], client) 133 | } 134 | } 135 | 136 | func (h *Hrotti) DeleteSubAll(client string) { 137 | h.subs.Lock() 138 | defer h.subs.Unlock() 139 | for _, topic := range h.subs.subMap { 140 | if _, ok := topic[client]; ok { 141 | delete(topic, client) 142 | } 143 | } 144 | } 145 | 146 | func (h *Hrotti) DeliverMessage(topic string, message *PublishPacket) { 147 | h.subs.RLock() 148 | topicElements := strings.Split(topic, "/") 149 | var matches []string 150 | var hashMatches []string 151 | deliverList := make(map[string]byte) 152 | for i, element := range append(topicElements, "\u0000") { 153 | DEBUG.Println("Searching bitmap level", i, element) 154 | switch i { 155 | case 0: 156 | for sub, _ := range h.subs.subBitmap[i][element] { 157 | matches = append(matches, sub) 158 | } 159 | for sub, _ := range h.subs.subBitmap[i]["+"] { 160 | matches = append(matches, sub) 161 | } 162 | for sub, _ := range h.subs.subBitmap[i]["#"] { 163 | hashMatches = append(hashMatches, sub) 164 | } 165 | default: 166 | var tmpMatches []string 167 | for _, sub := range matches { 168 | switch element { 169 | case "\u0000": 170 | if h.subs.subBitmap[i][element][sub] { 171 | tmpMatches = append(tmpMatches, sub) 172 | } 173 | default: 174 | if h.subs.subBitmap[i][element][sub] || h.subs.subBitmap[i]["+"][sub] { 175 | tmpMatches = append(tmpMatches, sub) 176 | } else if h.subs.subBitmap[i]["#"][sub] { 177 | hashMatches = append(hashMatches, sub) 178 | } 179 | } 180 | } 181 | matches = tmpMatches 182 | } 183 | if len(matches) == 0 { 184 | break 185 | } 186 | } 187 | h.subs.RUnlock() 188 | 189 | zeroCopy := message.Copy() 190 | zeroCopy.Qos = 0 191 | 192 | for _, sub := range append(hashMatches, matches...) { 193 | for c, qos := range h.subs.subMap[sub] { 194 | if currQos, ok := deliverList[c]; ok { 195 | deliverList[c] = calcMinQos(calcMaxQos(currQos, qos), message.Qos) 196 | } else { 197 | deliverList[c] = calcMinQos(qos, message.Qos) 198 | } 199 | } 200 | } 201 | 202 | DEBUG.Println(deliverList) 203 | for cid, subQos := range deliverList { 204 | client := h.getClient(cid) 205 | if subQos > 0 { 206 | go func(c *Client, subQos byte) { 207 | deliveryMessage := message.Copy() 208 | deliveryMessage.Qos = subQos 209 | if c.Connected() { 210 | //deliveryMessage.MessageID = c.getMsgID(deliveryMessage.UUID()) 211 | h.PersistStore.Add(c.clientID, OUTBOUND, deliveryMessage) 212 | select { 213 | case c.outboundMessages <- deliveryMessage: 214 | default: 215 | } 216 | } else { 217 | h.PersistStore.Add(c.clientID, OUTBOUND, deliveryMessage) 218 | } 219 | }(client, subQos) 220 | } else if client.Connected() { 221 | select { 222 | case client.outboundMessages <- zeroCopy: 223 | default: 224 | } 225 | } 226 | } 227 | } 228 | 229 | func calcMinQos(a, b byte) byte { 230 | if a < b { 231 | return a 232 | } 233 | return b 234 | } 235 | 236 | func calcMaxQos(a, b byte) byte { 237 | if a > b { 238 | return a 239 | } 240 | return b 241 | } 242 | -------------------------------------------------------------------------------- /broker/server.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "net/http" 7 | "net/url" 8 | "sync" 9 | 10 | . "github.com/alsm/hrotti/packets" 11 | "github.com/google/uuid" 12 | "golang.org/x/net/websocket" 13 | ) 14 | 15 | type Hrotti struct { 16 | PersistStore Persistence 17 | listeners map[string]*internalListener 18 | listenersWaitGroup sync.WaitGroup 19 | maxQueueDepth int 20 | clients *clients 21 | subs *subscriptionMap 22 | } 23 | 24 | type internalListener struct { 25 | name string 26 | url url.URL 27 | connections []net.Conn 28 | stop chan struct{} 29 | } 30 | 31 | func NewHrotti(maxQueueDepth int, persistence Persistence) *Hrotti { 32 | h := &Hrotti{ 33 | PersistStore: persistence, 34 | listeners: make(map[string]*internalListener), 35 | maxQueueDepth: maxQueueDepth, 36 | clients: newClients(), 37 | subs: newSubMap(), 38 | } 39 | //start the goroutine that generates internal message ids for when clients receive messages 40 | //but are not connected. 41 | h.PersistStore.Init() 42 | return h 43 | } 44 | 45 | func (h *Hrotti) getClient(id string) *Client { 46 | h.clients.RLock() 47 | defer h.clients.RUnlock() 48 | return h.clients.list[id] 49 | } 50 | 51 | func (h *Hrotti) AddListener(name string, config *ListenerConfig) error { 52 | listener := &internalListener{name: name, url: *config.URL} 53 | listener.stop = make(chan struct{}) 54 | 55 | h.listeners[name] = listener 56 | 57 | ln, err := net.Listen("tcp", listener.url.Host) 58 | if err != nil { 59 | ERROR.Println(err.Error()) 60 | return err 61 | } 62 | 63 | if listener.url.Scheme == "ws" && len(listener.url.Path) == 0 { 64 | listener.url.Path = "/" 65 | } 66 | 67 | h.listenersWaitGroup.Add(1) 68 | INFO.Println("Starting MQTT listener on", listener.url.String()) 69 | 70 | go func() { 71 | <-listener.stop 72 | INFO.Println("Listener", name, "is stopping...") 73 | ln.Close() 74 | }() 75 | //if this is a WebSocket listener 76 | if listener.url.Scheme == "ws" { 77 | var server websocket.Server 78 | //override the Websocket handshake to accept any protocol name 79 | server.Handshake = func(c *websocket.Config, req *http.Request) error { 80 | c.Origin, _ = url.Parse(req.RemoteAddr) 81 | c.Protocol = []string{"mqtt"} 82 | return nil 83 | } 84 | //set up the ws connection handler, ie what we do when we get a new websocket connection 85 | server.Handler = func(ws *websocket.Conn) { 86 | ws.PayloadType = websocket.BinaryFrame 87 | INFO.Println("New incoming websocket connection", ws.RemoteAddr()) 88 | listener.connections = append(listener.connections, ws) 89 | h.InitClient(ws) 90 | } 91 | //set the path that the http server will recognise as related to this websocket 92 | //server, needs to be configurable really. 93 | http.Handle(listener.url.Path, server) 94 | //ListenAndServe loops forever receiving connections and initiating the handler 95 | //for each one. 96 | go func(ln net.Listener) { 97 | defer h.listenersWaitGroup.Done() 98 | err := http.Serve(ln, nil) 99 | if err != nil { 100 | ERROR.Println(err.Error()) 101 | return 102 | } 103 | }(ln) 104 | } else { 105 | //loop forever accepting connections and launch InitClient as a goroutine with the connection 106 | go func() { 107 | defer h.listenersWaitGroup.Done() 108 | for { 109 | conn, err := ln.Accept() 110 | if err != nil { 111 | ERROR.Println(err.Error()) 112 | return 113 | } 114 | INFO.Println("New incoming connection", conn.RemoteAddr()) 115 | listener.connections = append(listener.connections, conn) 116 | go h.InitClient(conn) 117 | } 118 | }() 119 | } 120 | return nil 121 | } 122 | 123 | func (h *Hrotti) StopListener(name string) error { 124 | if listener, ok := h.listeners[name]; ok { 125 | close(listener.stop) 126 | for _, conn := range listener.connections { 127 | conn.Close() 128 | } 129 | delete(h.listeners, name) 130 | return nil 131 | } 132 | return errors.New("Listener not found") 133 | } 134 | 135 | func (h *Hrotti) Stop() { 136 | INFO.Println("Exiting...") 137 | for _, listener := range h.listeners { 138 | close(listener.stop) 139 | } 140 | h.listenersWaitGroup.Wait() 141 | } 142 | 143 | func (h *Hrotti) InitClient(conn net.Conn) { 144 | var sendSessionID bool 145 | /*var cph fixedHeader 146 | 147 | //create a bufio conn from the network connection 148 | bufferedConn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) 149 | //first byte off the wire should be the msg type 150 | typeByte, _ := bufferedConn.ReadByte() 151 | //unpack the first byte into the fixed header 152 | cph.unpack(typeByte) 153 | 154 | if cph.messageType != CONNECT { 155 | //If the first packet isn't a CONNECT, it's not MQTT or not compliant, so kill the connection and we're done. 156 | conn.Close() 157 | return 158 | } 159 | 160 | //read the remaining length field from the network, this can be 1-3 bytes generally although in this case 161 | //it should always be 1 byte, but using the generic method. 162 | cph.remainingLength = decodeLength(bufferedConn) 163 | //a buffer to receive the rest of the connect packet 164 | body := make([]byte, cph.remainingLength) 165 | io.ReadFull(bufferedConn, body) 166 | //create a new empty CONNECT packet to unpack the body of the CONNECT into 167 | cp := newControlPacket(CONNECT).(*connectPacket) 168 | cp.fixedHeader = cph 169 | cp.unpack(body)*/ 170 | rp, _ := ReadPacket(conn) 171 | cp := rp.(*ConnectPacket) 172 | 173 | //Validate the CONNECT, check fields, values etc. 174 | rc := cp.Validate() 175 | //If it didn't validate... 176 | if rc != CONN_ACCEPTED { 177 | //and it wasn't because of a protocol violation... 178 | if rc != CONN_PROTOCOL_VIOLATION { 179 | //create and send a CONNACK with the correct rc in it. 180 | ca := NewControlPacket(CONNACK).(*ConnackPacket) 181 | ca.ReturnCode = rc 182 | ca.Write(conn) 183 | } 184 | //Put up a local message indicating an errored connection attempt and close the connection 185 | ERROR.Println(ConnackReturnCodes[rc], conn.RemoteAddr()) 186 | conn.Close() 187 | return 188 | } else { 189 | //Put up an INFO message with the client id and the address they're connecting from. 190 | INFO.Println(ConnackReturnCodes[rc], cp.ClientIdentifier, conn.RemoteAddr()) 191 | } 192 | 193 | //check for a zero length client id and if it exists create one from the UUID library and return 194 | //it on $SYS/session_identifier 195 | if len(cp.ClientIdentifier) == 0 { 196 | cp.ClientIdentifier = uuid.New().String() 197 | sendSessionID = true 198 | } 199 | //Lock the clients hashmap while we check if we already know this clientid. 200 | h.clients.Lock() 201 | c, ok := h.clients.list[cp.ClientIdentifier] 202 | if ok && cp.CleanSession { 203 | //and if we do, if the clientid is currently connected... 204 | if c.Connected() { 205 | INFO.Println("Clientid", c.clientID, "already connected, stopping first client") 206 | //stop the parts of it that need to stop before we can change the network connection it's using. 207 | c.StopForTakeover() 208 | } else { 209 | //if the clientid known but not connected, ie cleansession false 210 | INFO.Println("Durable client reconnecting", c.clientID) 211 | //disconnected client will no longer have the channels for messages 212 | c.outboundMessages = make(chan *PublishPacket, h.maxQueueDepth) 213 | c.outboundPriority = make(chan ControlPacket, h.maxQueueDepth) 214 | } 215 | //this function stays running until the client disconnects as the function called by an http 216 | //Handler has to remain running until its work is complete. So add one to the client waitgroup. 217 | c.Add(1) 218 | //create a new sync.Once for stopping with later, set the connections and create the stop channel. 219 | c.stopOnce = new(sync.Once) 220 | c.conn = conn 221 | //c.bufferedConn = bufferedConn 222 | c.stop = make(chan struct{}) 223 | //start the client. 224 | go c.Start(cp, h) 225 | } else { 226 | //This is a brand new client so create a NewClient and add to the clients map 227 | c = newClient(conn, cp.ClientIdentifier, h.maxQueueDepth) 228 | h.clients.list[cp.ClientIdentifier] = c 229 | if sendSessionID { 230 | go func() { 231 | sessionIDPacket := NewControlPacket(PUBLISH).(*PublishPacket) 232 | sessionIDPacket.TopicName = "$SYS/session_identifier" 233 | sessionIDPacket.Payload = []byte(cp.ClientIdentifier) 234 | sessionIDPacket.Qos = 1 235 | c.outboundMessages <- sessionIDPacket 236 | }() 237 | } 238 | //As before this function has to remain running but to avoid races we want to make sure its finished 239 | //before doing anything else so add it to the waitgroup so we can wait on it later 240 | c.Add(1) 241 | go c.Start(cp, h) 242 | } 243 | //finished with the clients hashmap 244 | h.clients.Unlock() 245 | //wait on the stop channel, we never actually send values down this channel but a closed channel with 246 | //return the default empty value for it's type without blocking. 247 | <-c.stop 248 | //call Done() on the client waitgroup. 249 | c.Done() 250 | } 251 | -------------------------------------------------------------------------------- /broker/state.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type State struct { 8 | sync.RWMutex 9 | value StateVal 10 | } 11 | 12 | type StateVal uint8 13 | 14 | const ( 15 | DISCONNECTED StateVal = 0x00 16 | CONNECTING StateVal = 0x01 17 | CONNECTED StateVal = 0x02 18 | DISCONNECTING StateVal = 0x03 19 | ) 20 | 21 | func (s *State) SetValue(value StateVal) { 22 | s.Lock() 23 | defer s.Unlock() 24 | s.value = value 25 | } 26 | 27 | func (s *State) Value() StateVal { 28 | s.RLock() 29 | defer s.RUnlock() 30 | return s.value 31 | } 32 | -------------------------------------------------------------------------------- /broker/stats.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "sync/atomic" 5 | ) 6 | 7 | type stat int64 8 | 9 | type BrokerStats struct { 10 | bytesReceived int64 11 | bytesSent int64 12 | clientsConnected int64 13 | clientsDisconnected int64 14 | clientsMaximum int64 15 | clientsTotal int64 16 | messagesInflight int64 17 | messagesReceived int64 18 | messagesSent int64 19 | messagesStored int64 20 | publishMessagesDropped int64 21 | publishMessagesReceived int64 22 | publishMessagesSent int64 23 | messagesRetained int64 24 | subscriptions int64 25 | brokerTime int64 26 | brokerUptime int64 27 | } 28 | 29 | func (b *BrokerStats) AddClient() { 30 | atomic.AddInt64(&b.clientsConnected, 1) 31 | } 32 | -------------------------------------------------------------------------------- /broker/subscriptions.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | //Add a subscription for a client, taking an array of topics to subscribe to and an associated 4 | //slice of QoS values for the topics, return a slice of byte values indicating the granted 5 | //QoS values in topics order. 6 | func (h *Hrotti) AddSubscription(c *Client, topics []string, qoss []byte) []byte { 7 | //this is the slice we'll return and needs to be the same length as the input QoS' slice 8 | rQos := make([]byte, len(qoss)) 9 | 10 | //for every topic in the topics slice, also get the index number of the topic... 11 | for i, topic := range topics { 12 | h.AddSub(c.clientID, topic, qoss[i]) 13 | rQos[i] = qoss[i] 14 | } 15 | //return the slice of granted QoS values. 16 | return rQos 17 | } 18 | 19 | func (h *Hrotti) RemoveSubscription(c *Client, topic string) bool { 20 | h.DeleteSub(c.clientID, topic) 21 | return true 22 | } 23 | -------------------------------------------------------------------------------- /broker/unit_router_test.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strconv" 7 | "strings" 8 | "sync" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | /*func Test_NewNode(t *testing.T) { 14 | rootNode := NewNode("test") 15 | 16 | if rootNode == nil { 17 | t.Fatalf("rootNode is nil") 18 | } 19 | 20 | if rootNode.Name != "test" { 21 | t.Fatalf("rootNode name is %s, not test", rootNode.Name) 22 | } 23 | }*/ 24 | 25 | func Test_AddSub(t *testing.T) { 26 | rootNode := NewNode("") 27 | rand.Seed(time.Now().UnixNano()) 28 | topics := [7]string{"a", "b", "c", "d", "e", "+", "#"} 29 | for i := 0; i < 20; i++ { 30 | c := newClient(nil, "testClientId"+strconv.Itoa(i), 100) 31 | var sub string 32 | r := rand.Intn(7) 33 | for j := 0; j <= r; j++ { 34 | char := topics[rand.Intn(7)] 35 | sub += char 36 | if char == "#" || j == r { 37 | break 38 | } 39 | sub += "/" 40 | } 41 | rootNode.AddSub(c, strings.Split(sub, "/"), 1) 42 | } 43 | } 44 | 45 | /*func Test_DeleteSub(t *testing.T) { 46 | rootNode := NewNode("") 47 | c := newClient(nil, "testClientId", 100) 48 | sub1 := strings.Split("test/test1/test2/test3", "/") 49 | sub2 := strings.Split("test/test1/test4/test5", "/") 50 | complete1 := make(chan byte) 51 | complete2 := make(chan byte) 52 | complete3 := make(chan bool) 53 | 54 | rootNode.AddSub(c, sub1, 1, complete1) 55 | <-complete1 56 | rootNode.AddSub(c, sub2, 2, complete2) 57 | <-complete2 58 | 59 | rootNode.Print("") 60 | 61 | rootNode.DeleteSub(c, sub2, complete3) 62 | <-complete3 63 | 64 | rootNode.Print("") 65 | close(complete1) 66 | close(complete2) 67 | close(complete3) 68 | } 69 | 70 | func Test_AddSub2(t *testing.T) { 71 | c := newClient(nil, "testClientId", 100) 72 | sub1 := "test/test1/test2/test3" 73 | sub2 := "test/test1/test4/test5" 74 | 75 | AddSub2(c, sub1, 1) 76 | AddSub2(c, sub2, 2) 77 | }*/ 78 | 79 | /*func BenchmarkFindrecip(b *testing.B) { 80 | rand.Seed(time.Now().UnixNano()) 81 | topics := [7]string{"a", "b", "c", "d", "e", "+", "#"} 82 | for i := 0; i < b.N; i++ { 83 | c := newClient(nil, "testClientId"+strconv.Itoa(i), 100) 84 | var sub string 85 | r := rand.Intn(7) 86 | for j := 0; j <= r; j++ { 87 | char := topics[rand.Intn(7)] 88 | sub += char 89 | if char == "#" || j == r { 90 | break 91 | } 92 | sub += "/" 93 | } 94 | AddSub2(c, sub, 1) 95 | } 96 | b.ResetTimer() 97 | match := FindRecipients2("a/b/c/d/e") 98 | fmt.Println("a/b/c/d/e", len(match)) 99 | }*/ 100 | 101 | func BenchmarkNormalRouter(b *testing.B) { 102 | rootNode := NewNode("") 103 | rand.Seed(time.Now().UnixNano()) 104 | topics := [7]string{"a", "b", "c", "d", "e", "+", "#"} 105 | for i := 0; i < b.N; i++ { 106 | c := newClient(nil, "testClientId"+strconv.Itoa(i), 100) 107 | var sub string 108 | r := rand.Intn(7) 109 | for j := 0; j <= r; j++ { 110 | char := topics[rand.Intn(7)] 111 | sub += char 112 | if char == "#" || j == r { 113 | break 114 | } 115 | sub += "/" 116 | } 117 | rootNode.AddSub(c, strings.Split(sub, "/"), 1) 118 | } 119 | var treeWorkers sync.WaitGroup 120 | recipients := make(chan *Entry) 121 | b.ResetTimer() 122 | treeWorkers.Add(1) 123 | rootNode.FindRecipients(strings.Split("a/b/c/d/e", "/"), recipients, &treeWorkers) 124 | treeWorkers.Wait() 125 | close(recipients) 126 | for { 127 | _, ok := <-recipients 128 | if !ok { 129 | break 130 | } 131 | } 132 | } 133 | 134 | func main() { 135 | br := testing.Benchmark(BenchmarkNormalRouter) 136 | fmt.Println(br) 137 | } 138 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "io/ioutil" 7 | "log" 8 | "net/url" 9 | "os" 10 | 11 | . "github.com/alsm/hrotti/broker" 12 | ) 13 | 14 | type ListenerEntry struct { 15 | URL string `json:"url"` 16 | } 17 | 18 | //Current configuration struct, maxQueueDepth sets the maximum number of unacknowledged mesages 19 | //for a client. Listeners is a slice of ListenerConfigs 20 | type BrokerConfig struct { 21 | MaxQueueDepth int `json:"maxQueueDepth"` 22 | ListenerEntries map[string]*ListenerEntry `json:"listeners"` 23 | Listeners map[string]*ListenerConfig 24 | Logging struct { 25 | Info string `json:"info"` 26 | Protocol string `json:"protocol"` 27 | Errlog string `json:"error"` 28 | Debug string `json:"debug"` 29 | } 30 | } 31 | 32 | var logTargets map[string]io.Writer = map[string]io.Writer{ 33 | "stdout": os.Stdout, 34 | "stderr": os.Stderr, 35 | "discard": ioutil.Discard, 36 | } 37 | 38 | func (c *BrokerConfig) SetLogTargets() { 39 | target, ok := logTargets[c.Logging.Info] 40 | if !ok { 41 | target = os.Stdout 42 | } 43 | INFO = log.New(target, "INFO: ", log.Ldate|log.Ltime) 44 | target, ok = logTargets[c.Logging.Protocol] 45 | if !ok { 46 | target = ioutil.Discard 47 | } 48 | PROTOCOL = log.New(target, "PROTOCOL: ", log.Ldate|log.Ltime) 49 | target, ok = logTargets[c.Logging.Errlog] 50 | if !ok { 51 | target = os.Stderr 52 | } 53 | ERROR = log.New(target, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) 54 | target, ok = logTargets[c.Logging.Debug] 55 | if !ok { 56 | target = ioutil.Discard 57 | } 58 | DEBUG = log.New(target, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile) 59 | } 60 | 61 | func ParseConfig(confFile string, confVar *BrokerConfig) error { 62 | file, err := os.Open(confFile) 63 | if err != nil { 64 | return err 65 | } 66 | decoder := json.NewDecoder(file) 67 | 68 | err = decoder.Decode(confVar) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | for name, entry := range confVar.ListenerEntries { 74 | url, err := url.Parse(entry.URL) 75 | if err != nil { 76 | return err 77 | } 78 | confVar.Listeners[name] = &ListenerConfig{URL: url} 79 | } 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | 10 | . "github.com/alsm/hrotti/broker" 11 | ) 12 | 13 | func createConfig() BrokerConfig { 14 | configFile := flag.String("conf", "", "A configuration file") 15 | 16 | flag.Parse() 17 | 18 | var config BrokerConfig 19 | config.ListenerEntries = make(map[string]*ListenerEntry) 20 | config.Listeners = make(map[string]*ListenerConfig) 21 | 22 | if *configFile == "" { 23 | listener := NewListenerConfig(os.Getenv("HROTTI_URL")) 24 | if listener.URL.Host == "" { 25 | listener = NewListenerConfig("tcp://0.0.0.0:1883") 26 | } 27 | config.Listeners["envconfig"] = listener 28 | config.MaxQueueDepth = 100 29 | } else { 30 | fmt.Println("Reading config file", *configFile) 31 | err := ParseConfig(*configFile, &config) 32 | if err != nil { 33 | os.Stderr.WriteString(fmt.Sprintf("%s\n", err.Error())) 34 | } 35 | } 36 | config.SetLogTargets() 37 | return config 38 | } 39 | 40 | func main() { 41 | config := createConfig() 42 | 43 | //r := &RedisPersistence{Server: ":6379"} 44 | r := &MemoryPersistence{} 45 | h := NewHrotti(config.MaxQueueDepth, r) 46 | 47 | for name, listener := range config.Listeners { 48 | h.AddListener(name, listener) 49 | } 50 | c := make(chan os.Signal, 1) 51 | signal.Notify(c, os.Interrupt, syscall.SIGTERM) 52 | <-c 53 | h.Stop() 54 | } 55 | -------------------------------------------------------------------------------- /packets/connack.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/google/uuid" 7 | "io" 8 | ) 9 | 10 | //CONNACK packet 11 | 12 | type ConnackPacket struct { 13 | FixedHeader 14 | TopicNameCompression byte 15 | ReturnCode byte 16 | uuid uuid.UUID 17 | } 18 | 19 | func (ca *ConnackPacket) String() string { 20 | str := fmt.Sprintf("%s\n", ca.FixedHeader) 21 | str += fmt.Sprintf("returncode: %d", ca.ReturnCode) 22 | return str 23 | } 24 | 25 | func (ca *ConnackPacket) Write(w io.Writer) error { 26 | var body bytes.Buffer 27 | var err error 28 | 29 | body.WriteByte(ca.TopicNameCompression) 30 | body.WriteByte(ca.ReturnCode) 31 | ca.FixedHeader.RemainingLength = 2 32 | packet := ca.FixedHeader.pack() 33 | packet.Write(body.Bytes()) 34 | _, err = packet.WriteTo(w) 35 | 36 | return err 37 | } 38 | 39 | func (ca *ConnackPacket) Unpack(b io.Reader) { 40 | ca.TopicNameCompression = decodeByte(b) 41 | ca.ReturnCode = decodeByte(b) 42 | } 43 | 44 | func (ca *ConnackPacket) Details() Details { 45 | return Details{Qos: 0, MessageID: 0} 46 | } 47 | 48 | func (ca *ConnackPacket) UUID() uuid.UUID { 49 | return ca.uuid 50 | } 51 | -------------------------------------------------------------------------------- /packets/connect.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/google/uuid" 7 | "io" 8 | ) 9 | 10 | //CONNECT packet 11 | 12 | type ConnectPacket struct { 13 | FixedHeader 14 | ProtocolName string 15 | ProtocolVersion byte 16 | CleanSession bool 17 | WillFlag bool 18 | WillQos byte 19 | WillRetain bool 20 | UsernameFlag bool 21 | PasswordFlag bool 22 | ReservedBit byte 23 | KeepaliveTimer uint16 24 | 25 | ClientIdentifier string 26 | WillTopic string 27 | WillMessage []byte 28 | Username string 29 | Password []byte 30 | uuid uuid.UUID 31 | } 32 | 33 | func (c *ConnectPacket) String() string { 34 | str := fmt.Sprintf("%s\n", c.FixedHeader) 35 | str += fmt.Sprintf("protocolversion: %d protocolname: %s cleansession: %t willflag: %t WillQos: %d WillRetain: %t Usernameflag: %t Passwordflag: %t keepalivetimer: %d\nclientId: %s\nwilltopic: %s\nwillmessage: %s\nUsername: %s\nPassword: %s\n", c.ProtocolVersion, c.ProtocolName, c.CleanSession, c.WillFlag, c.WillQos, c.WillRetain, c.UsernameFlag, c.PasswordFlag, c.KeepaliveTimer, c.ClientIdentifier, c.WillTopic, c.WillMessage, c.Username, c.Password) 36 | return str 37 | } 38 | 39 | func (c *ConnectPacket) Write(w io.Writer) error { 40 | var body bytes.Buffer 41 | var err error 42 | 43 | body.Write(encodeString(c.ProtocolName)) 44 | body.WriteByte(c.ProtocolVersion) 45 | body.WriteByte(boolToByte(c.CleanSession)<<1 | boolToByte(c.WillFlag)<<2 | c.WillQos<<3 | boolToByte(c.WillRetain)<<5 | boolToByte(c.PasswordFlag)<<6 | boolToByte(c.UsernameFlag)<<7) 46 | body.Write(encodeUint16(c.KeepaliveTimer)) 47 | body.Write(encodeString(c.ClientIdentifier)) 48 | if c.WillFlag { 49 | body.Write(encodeString(c.WillTopic)) 50 | body.Write(encodeBytes(c.WillMessage)) 51 | } 52 | if c.UsernameFlag { 53 | body.Write(encodeString(c.Username)) 54 | } 55 | if c.PasswordFlag { 56 | body.Write(encodeBytes(c.Password)) 57 | } 58 | c.FixedHeader.RemainingLength = body.Len() 59 | packet := c.FixedHeader.pack() 60 | packet.Write(body.Bytes()) 61 | _, err = packet.WriteTo(w) 62 | 63 | return err 64 | } 65 | 66 | func (c *ConnectPacket) Unpack(b io.Reader) { 67 | c.ProtocolName = decodeString(b) 68 | c.ProtocolVersion = decodeByte(b) 69 | options := decodeByte(b) 70 | c.ReservedBit = 1 & options 71 | c.CleanSession = 1&(options>>1) > 0 72 | c.WillFlag = 1&(options>>2) > 0 73 | c.WillQos = 3 & (options >> 3) 74 | c.WillRetain = 1&(options>>5) > 0 75 | c.PasswordFlag = 1&(options>>6) > 0 76 | c.UsernameFlag = 1&(options>>7) > 0 77 | c.KeepaliveTimer = decodeUint16(b) 78 | c.ClientIdentifier = decodeString(b) 79 | if c.WillFlag { 80 | c.WillTopic = decodeString(b) 81 | c.WillMessage = decodeBytes(b) 82 | } 83 | if c.UsernameFlag { 84 | c.Username = decodeString(b) 85 | } 86 | if c.PasswordFlag { 87 | c.Password = decodeBytes(b) 88 | } 89 | } 90 | 91 | func (c *ConnectPacket) Validate() byte { 92 | if c.PasswordFlag && !c.UsernameFlag { 93 | return CONN_REF_BAD_USER_PASS 94 | } 95 | if c.ReservedBit != 0 { 96 | fmt.Println("Bad reserved bit") 97 | return CONN_PROTOCOL_VIOLATION 98 | } 99 | if (c.ProtocolName == "MQIsdp" && c.ProtocolVersion != 3) || (c.ProtocolName == "MQTT" && c.ProtocolVersion != 4) { 100 | return CONN_REF_BAD_PROTO_VER 101 | } 102 | if c.ProtocolName != "MQIsdp" && c.ProtocolName != "MQTT" { 103 | fmt.Println("Bad protocol name") 104 | return CONN_PROTOCOL_VIOLATION 105 | } 106 | if len(c.ClientIdentifier) > 65535 || len(c.Username) > 65535 || len(c.Password) > 65535 { 107 | fmt.Println("Bad size field") 108 | return CONN_PROTOCOL_VIOLATION 109 | } 110 | return CONN_ACCEPTED 111 | } 112 | 113 | func (c *ConnectPacket) Details() Details { 114 | return Details{Qos: 0, MessageID: 0} 115 | } 116 | 117 | func (c *ConnectPacket) UUID() uuid.UUID { 118 | return c.uuid 119 | } 120 | -------------------------------------------------------------------------------- /packets/disconnect.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //DISCONNECT packet 10 | 11 | type DisconnectPacket struct { 12 | FixedHeader 13 | uuid uuid.UUID 14 | } 15 | 16 | func (d *DisconnectPacket) String() string { 17 | str := fmt.Sprintf("%s\n", d.FixedHeader) 18 | return str 19 | } 20 | 21 | func (d *DisconnectPacket) Write(w io.Writer) error { 22 | packet := d.FixedHeader.pack() 23 | _, err := packet.WriteTo(w) 24 | 25 | return err 26 | } 27 | 28 | func (d *DisconnectPacket) Unpack(b io.Reader) { 29 | } 30 | 31 | func (d *DisconnectPacket) Details() Details { 32 | return Details{Qos: 0, MessageID: 0} 33 | } 34 | 35 | func (d *DisconnectPacket) UUID() uuid.UUID { 36 | return d.uuid 37 | } 38 | -------------------------------------------------------------------------------- /packets/packets.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "github.com/google/uuid" 9 | "io" 10 | ) 11 | 12 | type ControlPacket interface { 13 | Write(io.Writer) error 14 | Unpack(io.Reader) 15 | String() string 16 | Details() Details 17 | UUID() uuid.UUID 18 | } 19 | 20 | var PacketNames = map[uint8]string{ 21 | 1: "CONNECT", 22 | 2: "CONNACK", 23 | 3: "PUBLISH", 24 | 4: "PUBACK", 25 | 5: "PUBREC", 26 | 6: "PUBREL", 27 | 7: "PUBCOMP", 28 | 8: "SUBSCRIBE", 29 | 9: "SUBACK", 30 | 10: "UNSUBSCRIBE", 31 | 11: "UNSUBACK", 32 | 12: "PINGREQ", 33 | 13: "PINGRESP", 34 | 14: "DISCONNECT", 35 | } 36 | 37 | const ( 38 | CONNECT = 1 39 | CONNACK = 2 40 | PUBLISH = 3 41 | PUBACK = 4 42 | PUBREC = 5 43 | PUBREL = 6 44 | PUBCOMP = 7 45 | SUBSCRIBE = 8 46 | SUBACK = 9 47 | UNSUBSCRIBE = 10 48 | UNSUBACK = 11 49 | PINGREQ = 12 50 | PINGRESP = 13 51 | DISCONNECT = 14 52 | ) 53 | 54 | const ( 55 | CONN_ACCEPTED = 0x00 56 | CONN_REF_BAD_PROTO_VER = 0x01 57 | CONN_REF_ID_REJ = 0x02 58 | CONN_REF_SERV_UNAVAIL = 0x03 59 | CONN_REF_BAD_USER_PASS = 0x04 60 | CONN_REF_NOT_AUTH = 0x05 61 | CONN_NETWORK_ERROR = 0xFE 62 | CONN_PROTOCOL_VIOLATION = 0xFF 63 | ) 64 | 65 | var ConnackReturnCodes = map[uint8]string{ 66 | 0: "Connection Accepted", 67 | 1: "Connection Refused: Bad Protocol Version", 68 | 2: "Connection Refused: Client Identifier Rejected", 69 | 3: "Connection Refused: Server Unavailable", 70 | 4: "Connection Refused: Username or Password in unknown format", 71 | 5: "Connection Refused: Not Authorised", 72 | 254: "Connection Error", 73 | 255: "Connection Refused: Protocol Violation", 74 | } 75 | 76 | func ReadPacket(r io.Reader) (cp ControlPacket, err error) { 77 | var fh FixedHeader 78 | b := make([]byte, 1) 79 | 80 | _, err = io.ReadFull(r, b) 81 | if err != nil { 82 | return nil, err 83 | } 84 | fh.unpack(b[0], r) 85 | cp = NewControlPacketWithHeader(fh) 86 | if cp == nil { 87 | return nil, errors.New("Bad data from client") 88 | } 89 | packetBytes := make([]byte, fh.RemainingLength) 90 | _, err = io.ReadFull(r, packetBytes) 91 | if err != nil { 92 | return nil, err 93 | } 94 | cp.Unpack(bytes.NewBuffer(packetBytes)) 95 | return cp, nil 96 | } 97 | 98 | func NewControlPacket(packetType byte) (cp ControlPacket) { 99 | switch packetType { 100 | case CONNECT: 101 | cp = &ConnectPacket{FixedHeader: FixedHeader{MessageType: CONNECT}, uuid: uuid.New()} 102 | case CONNACK: 103 | cp = &ConnackPacket{FixedHeader: FixedHeader{MessageType: CONNACK}, uuid: uuid.New()} 104 | case DISCONNECT: 105 | cp = &DisconnectPacket{FixedHeader: FixedHeader{MessageType: DISCONNECT}, uuid: uuid.New()} 106 | case PUBLISH: 107 | cp = &PublishPacket{FixedHeader: FixedHeader{MessageType: PUBLISH}, uuid: uuid.New()} 108 | case PUBACK: 109 | cp = &PubackPacket{FixedHeader: FixedHeader{MessageType: PUBACK}, uuid: uuid.New()} 110 | case PUBREC: 111 | cp = &PubrecPacket{FixedHeader: FixedHeader{MessageType: PUBREC}, uuid: uuid.New()} 112 | case PUBREL: 113 | cp = &PubrelPacket{FixedHeader: FixedHeader{MessageType: PUBREL, Qos: 1}, uuid: uuid.New()} 114 | case PUBCOMP: 115 | cp = &PubcompPacket{FixedHeader: FixedHeader{MessageType: PUBCOMP}, uuid: uuid.New()} 116 | case SUBSCRIBE: 117 | cp = &SubscribePacket{FixedHeader: FixedHeader{MessageType: SUBSCRIBE, Qos: 1}, uuid: uuid.New()} 118 | case SUBACK: 119 | cp = &SubackPacket{FixedHeader: FixedHeader{MessageType: SUBACK}, uuid: uuid.New()} 120 | case UNSUBSCRIBE: 121 | cp = &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: UNSUBSCRIBE}, uuid: uuid.New()} 122 | case UNSUBACK: 123 | cp = &UnsubackPacket{FixedHeader: FixedHeader{MessageType: UNSUBACK}, uuid: uuid.New()} 124 | case PINGREQ: 125 | cp = &PingreqPacket{FixedHeader: FixedHeader{MessageType: PINGREQ}, uuid: uuid.New()} 126 | case PINGRESP: 127 | cp = &PingrespPacket{FixedHeader: FixedHeader{MessageType: PINGRESP}, uuid: uuid.New()} 128 | default: 129 | return nil 130 | } 131 | return cp 132 | } 133 | 134 | func NewControlPacketWithHeader(fh FixedHeader) (cp ControlPacket) { 135 | switch fh.MessageType { 136 | case CONNECT: 137 | cp = &ConnectPacket{FixedHeader: fh, uuid: uuid.New()} 138 | case CONNACK: 139 | cp = &ConnackPacket{FixedHeader: fh, uuid: uuid.New()} 140 | case DISCONNECT: 141 | cp = &DisconnectPacket{FixedHeader: fh, uuid: uuid.New()} 142 | case PUBLISH: 143 | cp = &PublishPacket{FixedHeader: fh, uuid: uuid.New()} 144 | case PUBACK: 145 | cp = &PubackPacket{FixedHeader: fh, uuid: uuid.New()} 146 | case PUBREC: 147 | cp = &PubrecPacket{FixedHeader: fh, uuid: uuid.New()} 148 | case PUBREL: 149 | cp = &PubrelPacket{FixedHeader: fh, uuid: uuid.New()} 150 | case PUBCOMP: 151 | cp = &PubcompPacket{FixedHeader: fh, uuid: uuid.New()} 152 | case SUBSCRIBE: 153 | cp = &SubscribePacket{FixedHeader: fh, uuid: uuid.New()} 154 | case SUBACK: 155 | cp = &SubackPacket{FixedHeader: fh, uuid: uuid.New()} 156 | case UNSUBSCRIBE: 157 | cp = &UnsubscribePacket{FixedHeader: fh, uuid: uuid.New()} 158 | case UNSUBACK: 159 | cp = &UnsubackPacket{FixedHeader: fh, uuid: uuid.New()} 160 | case PINGREQ: 161 | cp = &PingreqPacket{FixedHeader: fh, uuid: uuid.New()} 162 | case PINGRESP: 163 | cp = &PingrespPacket{FixedHeader: fh, uuid: uuid.New()} 164 | default: 165 | return nil 166 | } 167 | return cp 168 | } 169 | 170 | type Details struct { 171 | Qos byte 172 | MessageID uint16 173 | } 174 | 175 | type FixedHeader struct { 176 | MessageType byte 177 | Dup bool 178 | Qos byte 179 | Retain bool 180 | RemainingLength int 181 | } 182 | 183 | func (fh FixedHeader) String() string { 184 | return fmt.Sprintf("%s: dup: %t qos: %d retain: %t rLength: %d", PacketNames[fh.MessageType], fh.Dup, fh.Qos, fh.Retain, fh.RemainingLength) 185 | } 186 | 187 | func boolToByte(b bool) byte { 188 | switch b { 189 | case true: 190 | return 1 191 | default: 192 | return 0 193 | } 194 | } 195 | 196 | func (fh *FixedHeader) pack() bytes.Buffer { 197 | var header bytes.Buffer 198 | header.WriteByte(fh.MessageType<<4 | boolToByte(fh.Dup)<<3 | fh.Qos<<1 | boolToByte(fh.Retain)) 199 | header.Write(encodeLength(fh.RemainingLength)) 200 | return header 201 | } 202 | 203 | func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) { 204 | fh.MessageType = typeAndFlags >> 4 205 | fh.Dup = (typeAndFlags>>3)&0x01 > 0 206 | fh.Qos = (typeAndFlags >> 1) & 0x03 207 | fh.Retain = typeAndFlags&0x01 > 0 208 | fh.RemainingLength = decodeLength(r) 209 | } 210 | 211 | func decodeByte(b io.Reader) byte { 212 | num := make([]byte, 1) 213 | b.Read(num) 214 | return num[0] 215 | } 216 | 217 | func decodeUint16(b io.Reader) uint16 { 218 | num := make([]byte, 2) 219 | b.Read(num) 220 | return binary.BigEndian.Uint16(num) 221 | } 222 | 223 | func encodeUint16(num uint16) []byte { 224 | bytes := make([]byte, 2) 225 | binary.BigEndian.PutUint16(bytes, num) 226 | return bytes 227 | } 228 | 229 | func encodeString(field string) []byte { 230 | fieldLength := make([]byte, 2) 231 | binary.BigEndian.PutUint16(fieldLength, uint16(len(field))) 232 | return append(fieldLength, []byte(field)...) 233 | } 234 | 235 | func decodeString(b io.Reader) string { 236 | fieldLength := decodeUint16(b) 237 | field := make([]byte, fieldLength) 238 | b.Read(field) 239 | return string(field) 240 | } 241 | 242 | func decodeBytes(b io.Reader) []byte { 243 | fieldLength := decodeUint16(b) 244 | field := make([]byte, fieldLength) 245 | b.Read(field) 246 | return field 247 | } 248 | 249 | func encodeBytes(field []byte) []byte { 250 | fieldLength := make([]byte, 2) 251 | binary.BigEndian.PutUint16(fieldLength, uint16(len(field))) 252 | return append(fieldLength, field...) 253 | } 254 | 255 | func encodeLength(length int) []byte { 256 | var encLength []byte 257 | for { 258 | digit := byte(length % 128) 259 | length /= 128 260 | if length > 0 { 261 | digit |= 0x80 262 | } 263 | encLength = append(encLength, digit) 264 | if length == 0 { 265 | break 266 | } 267 | } 268 | return encLength 269 | } 270 | 271 | func decodeLength(r io.Reader) int { 272 | var rLength uint32 273 | var multiplier uint32 = 0 274 | b := make([]byte, 1) 275 | for { 276 | io.ReadFull(r, b) 277 | digit := b[0] 278 | rLength |= uint32(digit&127) << multiplier 279 | if (digit & 128) == 0 { 280 | break 281 | } 282 | multiplier += 7 283 | } 284 | return int(rLength) 285 | } 286 | -------------------------------------------------------------------------------- /packets/packets_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestPacketNames(t *testing.T) { 9 | if PacketNames[1] != "CONNECT" { 10 | t.Errorf("PacketNames[1] is %s, should be %s", PacketNames[1], "CONNECT") 11 | } 12 | if PacketNames[2] != "CONNACK" { 13 | t.Errorf("PacketNames[2] is %s, should be %s", PacketNames[2], "CONNACK") 14 | } 15 | if PacketNames[3] != "PUBLISH" { 16 | t.Errorf("PacketNames[3] is %s, should be %s", PacketNames[3], "PUBLISH") 17 | } 18 | if PacketNames[4] != "PUBACK" { 19 | t.Errorf("PacketNames[4] is %s, should be %s", PacketNames[4], "PUBACK") 20 | } 21 | if PacketNames[5] != "PUBREC" { 22 | t.Errorf("PacketNames[5] is %s, should be %s", PacketNames[5], "PUBREC") 23 | } 24 | if PacketNames[6] != "PUBREL" { 25 | t.Errorf("PacketNames[6] is %s, should be %s", PacketNames[6], "PUBREL") 26 | } 27 | if PacketNames[7] != "PUBCOMP" { 28 | t.Errorf("PacketNames[7] is %s, should be %s", PacketNames[7], "PUBCOMP") 29 | } 30 | if PacketNames[8] != "SUBSCRIBE" { 31 | t.Errorf("PacketNames[8] is %s, should be %s", PacketNames[8], "SUBSCRIBE") 32 | } 33 | if PacketNames[9] != "SUBACK" { 34 | t.Errorf("PacketNames[9] is %s, should be %s", PacketNames[9], "SUBACK") 35 | } 36 | if PacketNames[10] != "UNSUBSCRIBE" { 37 | t.Errorf("PacketNames[10] is %s, should be %s", PacketNames[10], "UNSUBSCRIBE") 38 | } 39 | if PacketNames[11] != "UNSUBACK" { 40 | t.Errorf("PacketNames[11] is %s, should be %s", PacketNames[11], "UNSUBACK") 41 | } 42 | if PacketNames[12] != "PINGREQ" { 43 | t.Errorf("PacketNames[12] is %s, should be %s", PacketNames[12], "PINGREQ") 44 | } 45 | if PacketNames[13] != "PINGRESP" { 46 | t.Errorf("PacketNames[13] is %s, should be %s", PacketNames[13], "PINGRESP") 47 | } 48 | if PacketNames[14] != "DISCONNECT" { 49 | t.Errorf("PacketNames[14] is %s, should be %s", PacketNames[14], "DISCONNECT") 50 | } 51 | } 52 | 53 | func TestPacketConsts(t *testing.T) { 54 | if CONNECT != 1 { 55 | t.Errorf("Const for CONNECT is %d, should be %d", CONNECT, 1) 56 | } 57 | if CONNACK != 2 { 58 | t.Errorf("Const for CONNACK is %d, should be %d", CONNACK, 2) 59 | } 60 | if PUBLISH != 3 { 61 | t.Errorf("Const for PUBLISH is %d, should be %d", PUBLISH, 3) 62 | } 63 | if PUBACK != 4 { 64 | t.Errorf("Const for PUBACK is %d, should be %d", PUBACK, 4) 65 | } 66 | if PUBREC != 5 { 67 | t.Errorf("Const for PUBREC is %d, should be %d", PUBREC, 5) 68 | } 69 | if PUBREL != 6 { 70 | t.Errorf("Const for PUBREL is %d, should be %d", PUBREL, 6) 71 | } 72 | if PUBCOMP != 7 { 73 | t.Errorf("Const for PUBCOMP is %d, should be %d", PUBCOMP, 7) 74 | } 75 | if SUBSCRIBE != 8 { 76 | t.Errorf("Const for SUBSCRIBE is %d, should be %d", SUBSCRIBE, 8) 77 | } 78 | if SUBACK != 9 { 79 | t.Errorf("Const for SUBACK is %d, should be %d", SUBACK, 9) 80 | } 81 | if UNSUBSCRIBE != 10 { 82 | t.Errorf("Const for UNSUBSCRIBE is %d, should be %d", UNSUBSCRIBE, 10) 83 | } 84 | if UNSUBACK != 11 { 85 | t.Errorf("Const for UNSUBACK is %d, should be %d", UNSUBACK, 11) 86 | } 87 | if PINGREQ != 12 { 88 | t.Errorf("Const for PINGREQ is %d, should be %d", PINGREQ, 12) 89 | } 90 | if PINGRESP != 13 { 91 | t.Errorf("Const for PINGRESP is %d, should be %d", PINGRESP, 13) 92 | } 93 | if DISCONNECT != 14 { 94 | t.Errorf("Const for DISCONNECT is %d, should be %d", DISCONNECT, 14) 95 | } 96 | } 97 | 98 | func TestConnackConsts(t *testing.T) { 99 | if CONN_ACCEPTED != 0x00 { 100 | t.Errorf("Const for CONN_ACCEPTED is %d, should be %d", CONN_ACCEPTED, 0) 101 | } 102 | if CONN_REF_BAD_PROTO_VER != 0x01 { 103 | t.Errorf("Const for CONN_REF_BAD_PROTO_VER is %d, should be %d", CONN_REF_BAD_PROTO_VER, 1) 104 | } 105 | if CONN_REF_ID_REJ != 0x02 { 106 | t.Errorf("Const for CONN_REF_ID_REJ is %d, should be %d", CONN_REF_ID_REJ, 2) 107 | } 108 | if CONN_REF_SERV_UNAVAIL != 0x03 { 109 | t.Errorf("Const for CONN_REF_SERV_UNAVAIL is %d, should be %d", CONN_REF_SERV_UNAVAIL, 3) 110 | } 111 | if CONN_REF_BAD_USER_PASS != 0x04 { 112 | t.Errorf("Const for CONN_REF_BAD_USER_PASS is %d, should be %d", CONN_REF_BAD_USER_PASS, 4) 113 | } 114 | if CONN_REF_NOT_AUTH != 0x05 { 115 | t.Errorf("Const for CONN_REF_NOT_AUTH is %d, should be %d", CONN_REF_NOT_AUTH, 5) 116 | } 117 | } 118 | 119 | func TestConnectPacket(t *testing.T) { 120 | connectPacketBytes := bytes.NewBuffer([]byte{16, 52, 0, 4, 77, 81, 84, 84, 4, 204, 0, 0, 0, 0, 0, 4, 116, 101, 115, 116, 0, 12, 84, 101, 115, 116, 32, 80, 97, 121, 108, 111, 97, 100, 0, 8, 116, 101, 115, 116, 117, 115, 101, 114, 0, 8, 116, 101, 115, 116, 112, 97, 115, 115}) 121 | packet, err := ReadPacket(connectPacketBytes) 122 | if err != nil { 123 | t.Fatalf("Error reading packet: %s", err.Error()) 124 | } 125 | cp := packet.(*ConnectPacket) 126 | if cp.ProtocolName != "MQTT" { 127 | t.Errorf("Connect Packet ProtocolName is %s, should be %s", cp.ProtocolName, "MQTT") 128 | } 129 | if cp.ProtocolVersion != 4 { 130 | t.Errorf("Connect Packet ProtocolVersion is %d, should be %d", cp.ProtocolVersion, 4) 131 | } 132 | if cp.UsernameFlag != true { 133 | t.Errorf("Connect Packet UsernameFlag is %t, should be %t", cp.UsernameFlag, true) 134 | } 135 | if cp.Username != "testuser" { 136 | t.Errorf("Connect Packet Username is %s, should be %s", cp.Username, "testuser") 137 | } 138 | if cp.PasswordFlag != true { 139 | t.Errorf("Connect Packet PasswordFlag is %t, should be %t", cp.PasswordFlag, true) 140 | } 141 | if string(cp.Password) != "testpass" { 142 | t.Errorf("Connect Packet Password is %s, should be %s", string(cp.Password), "testpass") 143 | } 144 | if cp.WillFlag != true { 145 | t.Errorf("Connect Packet WillFlag is %t, should be %t", cp.WillFlag, true) 146 | } 147 | if cp.WillTopic != "test" { 148 | t.Errorf("Connect Packet WillTopic is %s, should be %s", cp.WillTopic, "test") 149 | } 150 | if cp.WillQos != 1 { 151 | t.Errorf("Connect Packet WillQos is %d, should be %d", cp.WillQos, 1) 152 | } 153 | if cp.WillRetain != false { 154 | t.Errorf("Connect Packet WillRetain is %t, should be %t", cp.WillRetain, false) 155 | } 156 | if string(cp.WillMessage) != "Test Payload" { 157 | t.Errorf("Connect Packet WillMessage is %s, should be %s", string(cp.WillMessage), "Test Payload") 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /packets/pingreq.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //PINGREQ packet 10 | 11 | type PingreqPacket struct { 12 | FixedHeader 13 | uuid uuid.UUID 14 | } 15 | 16 | func (pr *PingreqPacket) String() string { 17 | str := fmt.Sprintf("%s", pr.FixedHeader) 18 | return str 19 | } 20 | 21 | func (pr *PingreqPacket) Write(w io.Writer) error { 22 | packet := pr.FixedHeader.pack() 23 | _, err := packet.WriteTo(w) 24 | 25 | return err 26 | } 27 | 28 | func (pr *PingreqPacket) Unpack(b io.Reader) { 29 | } 30 | 31 | func (pr *PingreqPacket) Details() Details { 32 | return Details{Qos: 0, MessageID: 0} 33 | } 34 | 35 | func (pr *PingreqPacket) UUID() uuid.UUID { 36 | return pr.uuid 37 | } 38 | -------------------------------------------------------------------------------- /packets/pingresp.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //PINGRESP packet 10 | 11 | type PingrespPacket struct { 12 | FixedHeader 13 | uuid uuid.UUID 14 | } 15 | 16 | func (pr *PingrespPacket) String() string { 17 | str := fmt.Sprintf("%s", pr.FixedHeader) 18 | return str 19 | } 20 | 21 | func (pr *PingrespPacket) Write(w io.Writer) error { 22 | packet := pr.FixedHeader.pack() 23 | _, err := packet.WriteTo(w) 24 | 25 | return err 26 | } 27 | 28 | func (pr *PingrespPacket) Unpack(b io.Reader) { 29 | } 30 | 31 | func (pr *PingrespPacket) Details() Details { 32 | return Details{Qos: 0, MessageID: 0} 33 | } 34 | 35 | func (pr *PingrespPacket) UUID() uuid.UUID { 36 | return pr.uuid 37 | } 38 | -------------------------------------------------------------------------------- /packets/puback.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //PUBACK packet 10 | 11 | type PubackPacket struct { 12 | FixedHeader 13 | MessageID uint16 14 | uuid uuid.UUID 15 | } 16 | 17 | func (pa *PubackPacket) String() string { 18 | str := fmt.Sprintf("%s\n", pa.FixedHeader) 19 | str += fmt.Sprintf("messageID: %d", pa.MessageID) 20 | return str 21 | } 22 | 23 | func (pa *PubackPacket) Write(w io.Writer) error { 24 | var err error 25 | pa.FixedHeader.RemainingLength = 2 26 | packet := pa.FixedHeader.pack() 27 | packet.Write(encodeUint16(pa.MessageID)) 28 | _, err = packet.WriteTo(w) 29 | 30 | return err 31 | } 32 | 33 | func (pa *PubackPacket) Unpack(b io.Reader) { 34 | pa.MessageID = decodeUint16(b) 35 | } 36 | 37 | func (pa *PubackPacket) Details() Details { 38 | return Details{Qos: pa.Qos, MessageID: pa.MessageID} 39 | } 40 | 41 | func (pa *PubackPacket) UUID() uuid.UUID { 42 | return pa.uuid 43 | } 44 | -------------------------------------------------------------------------------- /packets/pubcomp.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //PUBCOMP packet 10 | 11 | type PubcompPacket struct { 12 | FixedHeader 13 | MessageID uint16 14 | uuid uuid.UUID 15 | } 16 | 17 | func (pc *PubcompPacket) String() string { 18 | str := fmt.Sprintf("%s\n", pc.FixedHeader) 19 | str += fmt.Sprintf("MessageID: %d", pc.MessageID) 20 | return str 21 | } 22 | 23 | func (pc *PubcompPacket) Write(w io.Writer) error { 24 | var err error 25 | pc.FixedHeader.RemainingLength = 2 26 | packet := pc.FixedHeader.pack() 27 | packet.Write(encodeUint16(pc.MessageID)) 28 | _, err = packet.WriteTo(w) 29 | 30 | return err 31 | } 32 | 33 | func (pc *PubcompPacket) Unpack(b io.Reader) { 34 | pc.MessageID = decodeUint16(b) 35 | } 36 | 37 | func (pc *PubcompPacket) Details() Details { 38 | return Details{Qos: pc.Qos, MessageID: pc.MessageID} 39 | } 40 | 41 | func (pc *PubcompPacket) UUID() uuid.UUID { 42 | return pc.uuid 43 | } 44 | -------------------------------------------------------------------------------- /packets/publish.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/google/uuid" 7 | "io" 8 | ) 9 | 10 | //PUBLISH packet 11 | 12 | type PublishPacket struct { 13 | FixedHeader 14 | TopicName string 15 | MessageID uint16 16 | Payload []byte 17 | uuid uuid.UUID 18 | } 19 | 20 | func (p *PublishPacket) String() string { 21 | str := fmt.Sprintf("%s\n", p.FixedHeader) 22 | str += fmt.Sprintf("topicName: %s MessageID: %d\n", p.TopicName, p.MessageID) 23 | str += fmt.Sprintf("payload: %s\n", string(p.Payload)) 24 | return str 25 | } 26 | 27 | func (p *PublishPacket) Write(w io.Writer) error { 28 | var body bytes.Buffer 29 | var err error 30 | 31 | body.Write(encodeString(p.TopicName)) 32 | if p.Qos > 0 { 33 | body.Write(encodeUint16(p.MessageID)) 34 | } 35 | p.FixedHeader.RemainingLength = body.Len() + len(p.Payload) 36 | packet := p.FixedHeader.pack() 37 | packet.Write(body.Bytes()) 38 | packet.Write(p.Payload) 39 | _, err = w.Write(packet.Bytes()) 40 | 41 | return err 42 | } 43 | 44 | func (p *PublishPacket) Unpack(b io.Reader) { 45 | var payloadLength = p.FixedHeader.RemainingLength 46 | p.TopicName = decodeString(b) 47 | if p.Qos > 0 { 48 | p.MessageID = decodeUint16(b) 49 | payloadLength -= len(p.TopicName) + 4 50 | } else { 51 | payloadLength -= len(p.TopicName) + 2 52 | } 53 | p.Payload = make([]byte, payloadLength) 54 | b.Read(p.Payload) 55 | } 56 | 57 | func (p *PublishPacket) Copy() *PublishPacket { 58 | newP := NewControlPacket(PUBLISH).(*PublishPacket) 59 | newP.TopicName = p.TopicName 60 | newP.Payload = p.Payload 61 | 62 | return newP 63 | } 64 | 65 | func (p *PublishPacket) Details() Details { 66 | return Details{Qos: p.Qos, MessageID: p.MessageID} 67 | } 68 | 69 | func (p *PublishPacket) UUID() uuid.UUID { 70 | return p.uuid 71 | } 72 | -------------------------------------------------------------------------------- /packets/pubrec.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //PUBREC packet 10 | 11 | type PubrecPacket struct { 12 | FixedHeader 13 | MessageID uint16 14 | uuid uuid.UUID 15 | } 16 | 17 | func (pr *PubrecPacket) String() string { 18 | str := fmt.Sprintf("%s\n", pr.FixedHeader) 19 | str += fmt.Sprintf("MessageID: %d", pr.MessageID) 20 | return str 21 | } 22 | 23 | func (pr *PubrecPacket) Write(w io.Writer) error { 24 | var err error 25 | pr.FixedHeader.RemainingLength = 2 26 | packet := pr.FixedHeader.pack() 27 | packet.Write(encodeUint16(pr.MessageID)) 28 | _, err = packet.WriteTo(w) 29 | 30 | return err 31 | } 32 | 33 | func (pr *PubrecPacket) Unpack(b io.Reader) { 34 | pr.MessageID = decodeUint16(b) 35 | } 36 | 37 | func (pr *PubrecPacket) Details() Details { 38 | return Details{Qos: pr.Qos, MessageID: pr.MessageID} 39 | } 40 | 41 | func (pr *PubrecPacket) UUID() uuid.UUID { 42 | return pr.uuid 43 | } 44 | -------------------------------------------------------------------------------- /packets/pubrel.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //PUBREL packet 10 | 11 | type PubrelPacket struct { 12 | FixedHeader 13 | MessageID uint16 14 | uuid uuid.UUID 15 | } 16 | 17 | func (pr *PubrelPacket) String() string { 18 | str := fmt.Sprintf("%s\n", pr.FixedHeader) 19 | str += fmt.Sprintf("MessageID: %d", pr.MessageID) 20 | return str 21 | } 22 | 23 | func (pr *PubrelPacket) Write(w io.Writer) error { 24 | var err error 25 | pr.FixedHeader.RemainingLength = 2 26 | packet := pr.FixedHeader.pack() 27 | packet.Write(encodeUint16(pr.MessageID)) 28 | _, err = packet.WriteTo(w) 29 | 30 | return err 31 | } 32 | 33 | func (pr *PubrelPacket) Unpack(b io.Reader) { 34 | pr.MessageID = decodeUint16(b) 35 | } 36 | 37 | func (pr *PubrelPacket) Details() Details { 38 | return Details{Qos: pr.Qos, MessageID: pr.MessageID} 39 | } 40 | 41 | func (pr *PubrelPacket) UUID() uuid.UUID { 42 | return pr.uuid 43 | } 44 | -------------------------------------------------------------------------------- /packets/suback.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/google/uuid" 7 | "io" 8 | ) 9 | 10 | //SUBACK packet 11 | 12 | type SubackPacket struct { 13 | FixedHeader 14 | MessageID uint16 15 | GrantedQoss []byte 16 | uuid uuid.UUID 17 | } 18 | 19 | func (sa *SubackPacket) String() string { 20 | str := fmt.Sprintf("%s\n", sa.FixedHeader) 21 | str += fmt.Sprintf("MessageID: %d", sa.MessageID) 22 | return str 23 | } 24 | 25 | func (sa *SubackPacket) Write(w io.Writer) error { 26 | var body bytes.Buffer 27 | var err error 28 | body.Write(encodeUint16(sa.MessageID)) 29 | body.Write(sa.GrantedQoss) 30 | sa.FixedHeader.RemainingLength = body.Len() 31 | packet := sa.FixedHeader.pack() 32 | packet.Write(body.Bytes()) 33 | _, err = packet.WriteTo(w) 34 | 35 | return err 36 | } 37 | 38 | func (sa *SubackPacket) Unpack(b io.Reader) { 39 | var qosBuffer bytes.Buffer 40 | sa.MessageID = decodeUint16(b) 41 | qosBuffer.ReadFrom(b) 42 | sa.GrantedQoss = qosBuffer.Bytes() 43 | } 44 | 45 | func (sa *SubackPacket) Details() Details { 46 | return Details{Qos: 0, MessageID: sa.MessageID} 47 | } 48 | 49 | func (sa *SubackPacket) UUID() uuid.UUID { 50 | return sa.uuid 51 | } 52 | -------------------------------------------------------------------------------- /packets/subscribe.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/google/uuid" 7 | "io" 8 | ) 9 | 10 | //SUBSCRIBE packet 11 | 12 | type SubscribePacket struct { 13 | FixedHeader 14 | MessageID uint16 15 | Topics []string 16 | Qoss []byte 17 | uuid uuid.UUID 18 | } 19 | 20 | func (s *SubscribePacket) String() string { 21 | str := fmt.Sprintf("%s\n", s.FixedHeader) 22 | str += fmt.Sprintf("MessageID: %d topics: %s", s.MessageID, s.Topics) 23 | return str 24 | } 25 | 26 | func (s *SubscribePacket) Write(w io.Writer) error { 27 | var body bytes.Buffer 28 | var err error 29 | 30 | body.Write(encodeUint16(s.MessageID)) 31 | for i, topic := range s.Topics { 32 | body.Write(encodeString(topic)) 33 | body.WriteByte(s.Qoss[i]) 34 | } 35 | s.FixedHeader.RemainingLength = body.Len() 36 | packet := s.FixedHeader.pack() 37 | packet.Write(body.Bytes()) 38 | _, err = packet.WriteTo(w) 39 | 40 | return err 41 | } 42 | 43 | func (s *SubscribePacket) Unpack(b io.Reader) { 44 | s.MessageID = decodeUint16(b) 45 | payloadLength := s.FixedHeader.RemainingLength - 2 46 | for payloadLength > 0 { 47 | topic := decodeString(b) 48 | s.Topics = append(s.Topics, topic) 49 | qos := decodeByte(b) 50 | s.Qoss = append(s.Qoss, qos) 51 | payloadLength -= 2 + len(topic) + 1 //2 bytes of string length, plus string, plus 1 byte for Qos 52 | } 53 | } 54 | 55 | func (s *SubscribePacket) Details() Details { 56 | return Details{Qos: 1, MessageID: s.MessageID} 57 | } 58 | 59 | func (s *SubscribePacket) UUID() uuid.UUID { 60 | return s.uuid 61 | } 62 | -------------------------------------------------------------------------------- /packets/unsuback.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "github.com/google/uuid" 6 | "io" 7 | ) 8 | 9 | //UNSUBACK packet 10 | 11 | type UnsubackPacket struct { 12 | FixedHeader 13 | MessageID uint16 14 | uuid uuid.UUID 15 | } 16 | 17 | func (ua *UnsubackPacket) String() string { 18 | str := fmt.Sprintf("%s\n", ua.FixedHeader) 19 | str += fmt.Sprintf("MessageID: %d", ua.MessageID) 20 | return str 21 | } 22 | 23 | func (ua *UnsubackPacket) Write(w io.Writer) error { 24 | var err error 25 | ua.FixedHeader.RemainingLength = 2 26 | packet := ua.FixedHeader.pack() 27 | packet.Write(encodeUint16(ua.MessageID)) 28 | _, err = packet.WriteTo(w) 29 | 30 | return err 31 | } 32 | 33 | func (ua *UnsubackPacket) Unpack(b io.Reader) { 34 | ua.MessageID = decodeUint16(b) 35 | } 36 | 37 | func (ua *UnsubackPacket) Details() Details { 38 | return Details{Qos: 0, MessageID: ua.MessageID} 39 | } 40 | 41 | func (ua *UnsubackPacket) UUID() uuid.UUID { 42 | return ua.uuid 43 | } 44 | -------------------------------------------------------------------------------- /packets/unsubscribe.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/google/uuid" 7 | "io" 8 | ) 9 | 10 | //UNSUBSCRIBE packet 11 | 12 | type UnsubscribePacket struct { 13 | FixedHeader 14 | MessageID uint16 15 | Topics []string 16 | uuid uuid.UUID 17 | } 18 | 19 | func (u *UnsubscribePacket) String() string { 20 | str := fmt.Sprintf("%s\n", u.FixedHeader) 21 | str += fmt.Sprintf("MessageID: %d", u.MessageID) 22 | return str 23 | } 24 | 25 | func (u *UnsubscribePacket) Write(w io.Writer) error { 26 | var body bytes.Buffer 27 | var err error 28 | body.Write(encodeUint16(u.MessageID)) 29 | for _, topic := range u.Topics { 30 | body.Write(encodeString(topic)) 31 | } 32 | u.FixedHeader.RemainingLength = body.Len() 33 | packet := u.FixedHeader.pack() 34 | packet.Write(body.Bytes()) 35 | _, err = packet.WriteTo(w) 36 | 37 | return err 38 | } 39 | 40 | func (u *UnsubscribePacket) Unpack(b io.Reader) { 41 | u.MessageID = decodeUint16(b) 42 | var topic string 43 | for topic = decodeString(b); topic != ""; topic = decodeString(b) { 44 | u.Topics = append(u.Topics, topic) 45 | } 46 | } 47 | 48 | func (u *UnsubscribePacket) Details() Details { 49 | return Details{Qos: 1, MessageID: u.MessageID} 50 | } 51 | 52 | func (u *UnsubscribePacket) UUID() uuid.UUID { 53 | return u.uuid 54 | } 55 | -------------------------------------------------------------------------------- /plugins/plugin.go: -------------------------------------------------------------------------------- 1 | package hrotti 2 | 3 | import ( 4 | "encoding/json" 5 | "os" 6 | "sync" 7 | ) 8 | 9 | var pluginNodes map[string]Plugin 10 | var pluginMutex sync.Mutex 11 | 12 | //define the interface for a Plugin, any struct with these methods can be a plugin 13 | type Plugin interface { 14 | Initialise() error 15 | AddSub(*Client, []string, byte, chan byte) 16 | DeleteSub(*Client, []string, chan bool) 17 | } 18 | 19 | func init() { 20 | //set up the plugin map if it's nil. This code should also be in an init() for 21 | //every plugin that's written as the init function for a plugin is where it 22 | //registers itself and we can't guarantee the order init functions are called in 23 | pluginMutex.Lock() 24 | if pluginNodes == nil { 25 | pluginNodes = make(map[string]Plugin) 26 | } 27 | pluginMutex.Unlock() 28 | } 29 | 30 | func StartPlugins() { 31 | pluginMutex.Lock() 32 | defer pluginMutex.Unlock() 33 | //plugins have already registered themselves as part of their init functions 34 | //so range on map and call Initialise() for each plugin. 35 | for topic, plugin := range pluginNodes { 36 | err := plugin.Initialise() 37 | if err != nil { 38 | ERROR.Println("Failed to initialise plugin for", topic) 39 | delete(pluginNodes, topic) 40 | } else { 41 | INFO.Println("Initialised plugin for", topic) 42 | } 43 | } 44 | } 45 | 46 | //when a client disconnects and is cleansession true we want to remove all 47 | //subscriptions that client held in all plugins. 48 | func DeleteSubAllPlugins(client *Client) { 49 | complete := make(chan bool, 1) 50 | defer close(complete) 51 | for _, plugin := range pluginNodes { 52 | plugin.DeleteSub(client, nil, complete) 53 | <-complete 54 | } 55 | } 56 | 57 | func ReadPluginConfig(confFile string, result interface{}) error { 58 | file, err := os.Open(confFile) 59 | if err != nil { 60 | return err 61 | } 62 | decoder := json.NewDecoder(file) 63 | 64 | err = decoder.Decode(result) 65 | if err != nil { 66 | return err 67 | } 68 | return nil 69 | } 70 | -------------------------------------------------------------------------------- /plugins/redirect_plugin.go: -------------------------------------------------------------------------------- 1 | package plugins 2 | 3 | import ( 4 | "strings" 5 | "sync" 6 | ) 7 | 8 | type RedirectPlugin struct { 9 | sync.RWMutex 10 | Redirects map[string]string 11 | fauxClient *Client 12 | stop chan struct{} 13 | } 14 | 15 | func init() { 16 | pluginMutex.Lock() 17 | if pluginNodes == nil { 18 | pluginNodes = make(map[string]Plugin) 19 | } 20 | pluginNodes["$redirect"] = &RedirectPlugin{} 21 | pluginMutex.Unlock() 22 | } 23 | 24 | func (rp *RedirectPlugin) Initialise() error { 25 | rp.Redirects = make(map[string]string) 26 | if err := ReadPluginConfig("redirect_plugin_config.json", &rp.Redirects); err != nil { 27 | return err 28 | } 29 | rp.fauxClient = NewClient(nil, nil, "$redirectpluginclient") 30 | INFO.Println("Redirects:", rp.Redirects) 31 | for source, _ := range rp.Redirects { 32 | rp.fauxClient.AddSubscription([]string{source}, []byte{0}) 33 | } 34 | go rp.Run() 35 | return nil 36 | } 37 | 38 | func (rp *RedirectPlugin) AddSub(client *Client, topic []string, qos byte, complete chan byte) { 39 | rp.Lock() 40 | defer rp.Unlock() 41 | sourceAndDest := strings.Split(strings.Join(topic[1:], ""), ",") 42 | rp.Redirects[sourceAndDest[0]] = sourceAndDest[1] 43 | rp.fauxClient.AddSubscription([]string{sourceAndDest[0]}, []byte{0}) 44 | complete <- 0 45 | } 46 | 47 | func (rp *RedirectPlugin) DeleteSub(client *Client, topic []string, complete chan bool) { 48 | rp.Lock() 49 | defer rp.Unlock() 50 | if topic != nil { 51 | source := strings.Join(topic[1:], "") 52 | rp.fauxClient.RemoveSubscription(source) 53 | delete(rp.Redirects, source) 54 | } 55 | complete <- true 56 | } 57 | 58 | func (rp *RedirectPlugin) Run() { 59 | for { 60 | select { 61 | case <-rp.stop: 62 | return 63 | case msg := <-rp.fauxClient.outboundMessages: 64 | rp.RLock() 65 | if dest, ok := rp.Redirects[msg.topicName]; ok { 66 | rp.fauxClient.rootNode.DeliverMessage(strings.Split(dest, "/"), msg) 67 | } 68 | rp.RUnlock() 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /plugins/twitter_plugin.go: -------------------------------------------------------------------------------- 1 | package plugins 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | 7 | "github.com/darkhelmet/twitterstream" 8 | ) 9 | 10 | type Secrets struct { 11 | ConsumerKey string `json:"consumerKey,omitempty"` 12 | ConsumerSecret string `json:"consumerSecret,omitempty"` 13 | AccessToken string `json:"accessToken,omitempty"` 14 | AccessSecret string `json:"accessSecret,omitempty"` 15 | } 16 | 17 | type TwitterPlugin struct { 18 | sync.RWMutex 19 | client *twitterstream.Client 20 | conn *twitterstream.Connection 21 | config *Secrets 22 | Subscribed map[*Client]byte 23 | stop chan struct{} 24 | filter string 25 | } 26 | 27 | func init() { 28 | pluginMutex.Lock() 29 | if pluginNodes == nil { 30 | pluginNodes = make(map[string]Plugin) 31 | } 32 | pluginNodes["$twitter"] = &TwitterPlugin{} 33 | pluginMutex.Unlock() 34 | } 35 | 36 | func (tp *TwitterPlugin) Initialise() error { 37 | if err := ReadPluginConfig("twitter_plugin_config.json", &tp.config); err != nil { 38 | return err 39 | } 40 | if tp.config.ConsumerKey == "" || tp.config.ConsumerSecret == "" || tp.config.AccessToken == "" || tp.config.AccessSecret == "" { 41 | return errors.New("Not all twitter secrets defined") 42 | } 43 | tp.Subscribed = make(map[*Client]byte) 44 | tp.client = twitterstream.NewClient(tp.config.ConsumerKey, tp.config.ConsumerSecret, tp.config.AccessToken, tp.config.AccessSecret) 45 | return nil 46 | } 47 | 48 | func (tp *TwitterPlugin) AddSub(client *Client, subscription []string, qos byte, complete chan byte) { 49 | tp.Lock() 50 | defer func() { 51 | complete <- 0 52 | tp.Subscribed[client] = qos 53 | tp.Unlock() 54 | }() 55 | var err error 56 | INFO.Println("Adding $twitter sub for", subscription[1], client.clientId) 57 | if subscription[1] != tp.filter { 58 | if tp.conn != nil { 59 | close(tp.stop) 60 | tp.conn.Close() 61 | } 62 | tp.stop = make(chan struct{}) 63 | tp.conn, err = tp.client.Track(subscription[1]) 64 | if err != nil { 65 | ERROR.Println(err.Error()) 66 | return 67 | } 68 | tp.filter = subscription[1] 69 | go tp.Run() 70 | } 71 | } 72 | 73 | func (tp *TwitterPlugin) DeleteSub(client *Client, topic []string, complete chan bool) { 74 | tp.Lock() 75 | defer tp.Unlock() 76 | delete(tp.Subscribed, client) 77 | if len(tp.Subscribed) == 0 && tp.conn != nil { 78 | INFO.Println("All subscriptions gone, closing twitter connection") 79 | close(tp.stop) 80 | tp.conn.Close() 81 | tp.conn = nil 82 | tp.filter = "" 83 | } 84 | complete <- true 85 | } 86 | 87 | func (tp *TwitterPlugin) Run() { 88 | tweetChan := make(chan *twitterstream.Tweet, 2) 89 | go func() { 90 | for { 91 | tweet, err := tp.conn.Next() 92 | if err != nil { 93 | ERROR.Println("Twitter receive error") 94 | return 95 | } 96 | tweetChan <- tweet 97 | select { 98 | case <-tp.stop: 99 | return 100 | default: 101 | } 102 | } 103 | }() 104 | 105 | for { 106 | select { 107 | case tweet := <-tweetChan: 108 | tp.RLock() 109 | message := New(PUBLISH).(*publishPacket) 110 | message.Qos = 0 111 | message.topicName = "$twitter/" + tweet.User.ScreenName 112 | message.payload = []byte(tweet.Text) 113 | for client := range tp.Subscribed { 114 | select { 115 | case client.outboundMessages <- message: 116 | default: 117 | } 118 | } 119 | tp.RUnlock() 120 | case <-tp.stop: 121 | return 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /samples/simple_broker.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "os/signal" 6 | "syscall" 7 | 8 | "github.com/alsm/hrotti/broker" 9 | ) 10 | 11 | func main() { 12 | h := hrotti.NewHrotti(100, &hrotti.MemoryPersistence{}) 13 | hrotti.INFO.SetOutput(os.Stdout) 14 | hrotti.DEBUG.SetOutput(os.Stdout) 15 | hrotti.ERROR.SetOutput(os.Stdout) 16 | h.AddListener("test", hrotti.NewListenerConfig("tcp://0.0.0.0:1883")) 17 | 18 | c := make(chan os.Signal, 1) 19 | signal.Notify(c, os.Interrupt, syscall.SIGTERM) 20 | <-c 21 | h.Stop() 22 | } 23 | --------------------------------------------------------------------------------