├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── config ├── config.go ├── parse.go └── parse_test.go ├── dbinit └── mysql.sql ├── evthandler ├── evthandler.go └── webhook │ └── webhook.go ├── examples ├── printer │ └── main.go └── testclient │ └── main.go ├── http.go ├── main.go ├── msgcache ├── cache.go ├── cache_test.go ├── mysqlcache.go ├── mysqlcache_test.go ├── rediscache.go └── rediscache_test.go ├── msgcenter ├── connmap.go ├── connmap_test.go ├── connset.go ├── msgcenter.go ├── msgcenter_test.go └── srvcenter.go ├── proto ├── client │ ├── conn.go │ ├── dial.go │ ├── digestproc.go │ └── redirproc.go ├── cmd.go ├── cmd_test.go ├── cmdio.go ├── cmdio_test.go ├── const.go ├── keyex.go ├── keyex_test.go ├── keyset.go ├── server │ ├── auth.go │ ├── auth_test.go │ ├── conn.go │ ├── conn_test.go │ ├── digest_test.go │ ├── fwd_test.go │ ├── fwdproc.go │ ├── msgretriever.go │ ├── redir_test.go │ ├── retrieveall.go │ ├── retrieveall_test.go │ ├── settingsproc.go │ ├── sub_test.go │ ├── subproc.go │ ├── vis_test.go │ └── visproc.go ├── utils.go └── utils_test.go ├── push └── push.go ├── rpc ├── fwdreq.go ├── mc.go ├── msg.go ├── multipeer.go ├── peer.go ├── redirreq.go ├── result.go ├── sendreq.go ├── subreq.go └── usrstatus.go └── tools └── connect-cluster └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.1 4 | services: 5 | - redis-server 6 | before_script: 7 | - go get labix.org/v2/mgo/bson 8 | - mysql -e 'create database uniqush;' 9 | script: 10 | - go test -v -race github.com/uniqush/uniqush-conn/msgcache 11 | - go test -v -race github.com/uniqush/uniqush-conn/proto 12 | - go test -v -race github.com/uniqush/uniqush-conn/proto/server 13 | - go test -v -race github.com/uniqush/uniqush-conn/msgcenter 14 | - go test -v -race github.com/uniqush/uniqush-conn/config 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | uniqush-conn 2 | ============ 3 | 4 | [![Build Status](https://travis-ci.org/uniqush/uniqush-conn.png?branch=master)](https://travis-ci.org/uniqush/uniqush-conn) 5 | 6 | Server side program maintains communication channel between server and mobile devices. 7 | 8 | For more details on this program, please read [this](http://blog.uniqush.org/uniqush-after-go1.html). 9 | 10 | This program is under construction. **Please do not use it right now.** 11 | 12 | For early birds: Please check out our [Wiki pages](https://github.com/uniqush/uniqush-conn/wiki/_pages) 13 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package config 19 | 20 | import ( 21 | "fmt" 22 | "github.com/uniqush/uniqush-conn/evthandler" 23 | "github.com/uniqush/uniqush-conn/msgcache" 24 | 25 | "github.com/uniqush/uniqush-conn/push" 26 | "github.com/uniqush/uniqush-conn/rpc" 27 | "net" 28 | "time" 29 | ) 30 | 31 | type Config struct { 32 | HandshakeTimeout time.Duration 33 | HttpAddr string 34 | Auth evthandler.Authenticator 35 | ErrorHandler evthandler.ErrorHandler 36 | filename string 37 | srvConfig map[string]*ServiceConfig 38 | defaultConfig *ServiceConfig 39 | } 40 | 41 | func (self *Config) OnError(addr net.Addr, err error) { 42 | if self == nil || self.ErrorHandler == nil { 43 | return 44 | } 45 | go self.ErrorHandler.OnError("", "", "", addr.String(), err) 46 | return 47 | } 48 | 49 | func (self *Config) AllServices() []string { 50 | ret := make([]string, 0, len(self.srvConfig)) 51 | for srv, _ := range self.srvConfig { 52 | ret = append(ret, srv) 53 | } 54 | return ret 55 | } 56 | 57 | func (self *Config) ReadConfig(srv string) *ServiceConfig { 58 | if ret, ok := self.srvConfig[srv]; ok { 59 | return ret 60 | } 61 | return self.defaultConfig 62 | } 63 | 64 | func (self *Config) Authenticate(srv, usr, connId, token, addr string) (bool, []string, error) { 65 | if self == nil || self.Auth == nil { 66 | return false, nil, nil 67 | } 68 | return self.Auth.Authenticate(srv, usr, connId, token, addr) 69 | } 70 | 71 | type ServiceConfig struct { 72 | ServiceName string 73 | MaxNrConns int 74 | MaxNrUsers int 75 | MaxNrConnsPerUser int 76 | 77 | MsgCache msgcache.Cache 78 | 79 | LoginHandler evthandler.LoginHandler 80 | LogoutHandler evthandler.LogoutHandler 81 | MessageHandler evthandler.MessageHandler 82 | ForwardRequestHandler evthandler.ForwardRequestHandler 83 | ErrorHandler evthandler.ErrorHandler 84 | 85 | // Push related web hooks 86 | SubscribeHandler evthandler.SubscribeHandler 87 | UnsubscribeHandler evthandler.UnsubscribeHandler 88 | 89 | PushService push.Push 90 | } 91 | 92 | func (self *ServiceConfig) clone(srv string, dst *ServiceConfig) *ServiceConfig { 93 | if self == nil { 94 | dst = new(ServiceConfig) 95 | return dst 96 | } 97 | if dst == nil { 98 | dst = new(ServiceConfig) 99 | } 100 | dst.ServiceName = srv 101 | dst.MaxNrConns = self.MaxNrConns 102 | dst.MaxNrUsers = self.MaxNrUsers 103 | dst.MaxNrConnsPerUser = self.MaxNrConnsPerUser 104 | 105 | dst.MsgCache = self.MsgCache 106 | 107 | dst.LoginHandler = self.LoginHandler 108 | dst.LogoutHandler = self.LogoutHandler 109 | dst.MessageHandler = self.MessageHandler 110 | dst.ForwardRequestHandler = self.ForwardRequestHandler 111 | dst.ErrorHandler = self.ErrorHandler 112 | 113 | // Push related web hooks 114 | dst.SubscribeHandler = self.SubscribeHandler 115 | dst.UnsubscribeHandler = self.UnsubscribeHandler 116 | 117 | dst.PushService = self.PushService 118 | return dst 119 | } 120 | 121 | func (self *ServiceConfig) Cache() msgcache.Cache { 122 | if self == nil { 123 | return nil 124 | } 125 | return self.MsgCache 126 | } 127 | func (self *ServiceConfig) Subscribe(req *rpc.SubscribeRequest) { 128 | if req.Subscribe { 129 | self.subscribe(req.Username, req.Params) 130 | } else { 131 | self.unsubscribe(req.Username, req.Params) 132 | } 133 | } 134 | 135 | func (self *ServiceConfig) subscribe(username string, info map[string]string) { 136 | if self == nil || self.PushService == nil { 137 | return 138 | } 139 | go func() { 140 | if self.shouldSubscribe(self.ServiceName, username, info) { 141 | self.PushService.Subscribe(self.ServiceName, username, info) 142 | } 143 | }() 144 | return 145 | } 146 | 147 | func (self *ServiceConfig) unsubscribe(username string, info map[string]string) { 148 | if self == nil || self.PushService == nil { 149 | return 150 | } 151 | go self.PushService.Unsubscribe(self.ServiceName, username, info) 152 | return 153 | } 154 | 155 | type connDescriptor interface { 156 | RemoteAddr() net.Addr 157 | Service() string 158 | Username() string 159 | UniqId() string 160 | } 161 | 162 | func (self *ServiceConfig) OnError(c connDescriptor, err error) { 163 | if self == nil || self.ErrorHandler == nil { 164 | return 165 | } 166 | go self.ErrorHandler.OnError(c.Service(), c.Username(), c.UniqId(), c.RemoteAddr().String(), err) 167 | return 168 | } 169 | 170 | func (self *ServiceConfig) OnLogin(c connDescriptor) { 171 | if self == nil || self.LoginHandler == nil { 172 | return 173 | } 174 | go self.LoginHandler.OnLogin(c.Service(), c.Username(), c.UniqId(), c.RemoteAddr().String()) 175 | return 176 | } 177 | 178 | func (self *ServiceConfig) OnLogout(c connDescriptor, reason error) { 179 | if self == nil || self.LogoutHandler == nil { 180 | return 181 | } 182 | go self.LogoutHandler.OnLogout(c.Service(), c.Username(), c.UniqId(), c.RemoteAddr().String(), reason) 183 | return 184 | } 185 | 186 | func (self *ServiceConfig) OnMessage(c connDescriptor, msg *rpc.Message) { 187 | if self == nil || self.MessageHandler == nil { 188 | return 189 | } 190 | go self.MessageHandler.OnMessage(c.Service(), c.Username(), c.UniqId(), msg) 191 | return 192 | } 193 | 194 | func (self *ServiceConfig) CacheMessage(username string, mc *rpc.MessageContainer, ttl time.Duration) (string, error) { 195 | if self == nil || self.MsgCache == nil { 196 | return "", fmt.Errorf("no cache available") 197 | } 198 | return self.MsgCache.CacheMessage(self.ServiceName, username, mc, ttl) 199 | } 200 | 201 | func (self *ServiceConfig) Push(username, senderService, senderName string, info map[string]string, msgId string, size int) { 202 | if self == nil || self.PushService == nil { 203 | return 204 | } 205 | self.PushService.Push(self.ServiceName, username, senderService, senderName, info, msgId, size) 206 | } 207 | 208 | func (self *ServiceConfig) ShouldForward(fwdreq *rpc.ForwardRequest) (shouldForward, shouldPush bool, pushInfo map[string]string) { 209 | if self == nil || self.ForwardRequestHandler == nil { 210 | return false, false, nil 211 | } 212 | return self.ForwardRequestHandler.ShouldForward(fwdreq.SenderService, fwdreq.Sender, fwdreq.ReceiverService, fwdreq.Receivers, fwdreq.TTL, fwdreq.Message) 213 | } 214 | 215 | func (self *ServiceConfig) shouldSubscribe(srv, usr string, info map[string]string) bool { 216 | if self == nil || self.SubscribeHandler == nil { 217 | return false 218 | } 219 | return self.SubscribeHandler.ShouldSubscribe(srv, usr, info) 220 | } 221 | -------------------------------------------------------------------------------- /config/parse_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package config 19 | 20 | import ( 21 | "bytes" 22 | "os" 23 | "testing" 24 | ) 25 | 26 | var configFileContent string = ` 27 | # This is a comment 28 | http-addr: 127.0.0.1:8088 29 | handshake-timeout: 10s 30 | auth: 31 | default: disallow 32 | url: http://localhost:8080/auth 33 | timeout: 3s 34 | err: 35 | url: http://localhost:8080/err 36 | timeout: 3s 37 | default: 38 | msg: 39 | # A web hook takes either a url or a list of url 40 | url: 41 | - http://localhost:8080/msg 42 | - http://localhost:8081/msg 43 | timeout: 3s 44 | err: 45 | url: http://localhost:8080/err 46 | timeout: 3s 47 | login: 48 | url: http://localhost:8080/login 49 | timeout: 3s 50 | logout: 51 | url: http://localhost:8080/logout 52 | timeout: 3s 53 | fwd: 54 | default: allow 55 | url: http://localhost:8080/fwd 56 | timeout: 3s 57 | max-ttl: 36h 58 | subscribe: 59 | default: allow 60 | url: http://localhost:8080/subscribe 61 | timeout: 3s 62 | unsubscribe: 63 | default: allow 64 | url: http://localhost:8080/unsubscribe 65 | timeout: 3s 66 | uniqush-push: 67 | addr: localhost:9898 68 | timeout: 3s 69 | max-conns: 2048 70 | max-online-users: 2048 71 | max-conns-per-user: 10 72 | db: 73 | engine: redis 74 | host: 127.0.0.1 75 | port: 6379 76 | database: 1 77 | ` 78 | 79 | func writeConfigFile(filename string) { 80 | file, _ := os.Create(filename) 81 | file.WriteString(configFileContent) 82 | file.Close() 83 | } 84 | 85 | func deleteConfigFile(filename string) { 86 | os.Remove(filename) 87 | } 88 | 89 | func TestParseFile(t *testing.T) { 90 | filename := "config.yaml" 91 | writeConfigFile(filename) 92 | defer deleteConfigFile(filename) 93 | _, err := ParseFile(filename) 94 | if err != nil { 95 | t.Errorf("Error: %v\n", err) 96 | } 97 | } 98 | 99 | func TestParseReader(t *testing.T) { 100 | reader := bytes.NewBufferString(configFileContent) 101 | _, err := Parse(reader) 102 | if err != nil { 103 | t.Errorf("Error: %v\n", err) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /dbinit/mysql.sql: -------------------------------------------------------------------------------- 1 | -- I know it is not normalized. But it is easier. 2 | CREATE TABLE IF NOT EXISTS messages 3 | ( 4 | mid CHAR(255) NOT NULL PRIMARY KEY, 5 | 6 | owner_service CHAR(255) NOT NULL, 7 | owner_name CHAR(255) NOT NULL, 8 | 9 | sender_service CHAR(255), 10 | sender_name CHAR(255), 11 | 12 | create_time BIGINT, 13 | deadline BIGINT, 14 | content BLOB 15 | ); 16 | 17 | CREATE INDEX idx_owner_time ON messages (owner_service, owner_name, create_time, deadline); 18 | 19 | -- SELECT * FROM messages 20 | -- WHERE mid = ? 21 | 22 | -- SELECT * FROM messages 23 | -- WHERE owner_service = ? 24 | -- AND owner_name = ? 25 | -- AND create_time > ? 26 | -- AND deadline > NOW(); 27 | -------------------------------------------------------------------------------- /evthandler/evthandler.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package evthandler 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/rpc" 22 | "time" 23 | ) 24 | 25 | type Authenticator interface { 26 | Authenticate(srv, usr, connId, token, addr string) (bool, []string, error) 27 | } 28 | 29 | type LoginHandler interface { 30 | OnLogin(service, username, connId, addr string) 31 | } 32 | 33 | type LogoutHandler interface { 34 | OnLogout(service, username, connId, addr string, reason error) 35 | } 36 | 37 | type MessageHandler interface { 38 | OnMessage(service, username, connId string, msg *rpc.Message) 39 | } 40 | 41 | type ForwardRequestHandler interface { 42 | ShouldForward(senderService, sender, receiverService string, receivers []string, ttl time.Duration, msg *rpc.Message) (shouldForward bool, shouldPush bool, pushInfo map[string]string) 43 | MaxTTL() time.Duration 44 | } 45 | 46 | type ErrorHandler interface { 47 | OnError(service, username, connId, addr string, err error) 48 | } 49 | 50 | type SubscribeHandler interface { 51 | ShouldSubscribe(service, username string, info map[string]string) bool 52 | } 53 | 54 | type UnsubscribeHandler interface { 55 | OnUnsubscribe(service, username string, info map[string]string) 56 | } 57 | -------------------------------------------------------------------------------- /evthandler/webhook/webhook.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package webhook 19 | 20 | import ( 21 | "bytes" 22 | "encoding/json" 23 | "github.com/uniqush/uniqush-conn/rpc" 24 | "net" 25 | "net/http" 26 | "time" 27 | ) 28 | 29 | type WebHook interface { 30 | SetURL(urls ...string) 31 | SetTimeout(timeout time.Duration) 32 | SetDefault(d int) 33 | } 34 | 35 | type webHook struct { 36 | URLs []string 37 | Timeout time.Duration 38 | Default int 39 | } 40 | 41 | func (self *webHook) SetURL(urls ...string) { 42 | self.URLs = urls 43 | } 44 | 45 | func (self *webHook) SetTimeout(timeout time.Duration) { 46 | self.Timeout = timeout 47 | } 48 | 49 | func (self *webHook) SetDefault(d int) { 50 | self.Default = d 51 | } 52 | 53 | func timeoutDialler(ns time.Duration) func(net, addr string) (c net.Conn, err error) { 54 | return func(netw, addr string) (net.Conn, error) { 55 | c, err := net.Dial(netw, addr) 56 | if err != nil { 57 | return nil, err 58 | } 59 | if ns.Seconds() > 0.0 { 60 | c.SetDeadline(time.Now().Add(ns)) 61 | } 62 | return c, nil 63 | } 64 | } 65 | 66 | func (self *webHook) post(data interface{}, out interface{}, requireOut bool) int { 67 | ret := self.Default 68 | for _, url := range self.URLs { 69 | status := self.postSingle(url, data, out, requireOut) 70 | if status == 200 { 71 | if out != nil { 72 | return status 73 | } 74 | ret = status 75 | } 76 | } 77 | return ret 78 | } 79 | 80 | func (self *webHook) postSingle(url string, data interface{}, out interface{}, requireOut bool) int { 81 | if len(url) == 0 || url == "none" { 82 | return self.Default 83 | } 84 | jdata, err := json.Marshal(data) 85 | if err != nil { 86 | return self.Default 87 | } 88 | c := http.Client{ 89 | Transport: &http.Transport{ 90 | Dial: timeoutDialler(self.Timeout), 91 | }, 92 | } 93 | resp, err := c.Post(url, "application/json", bytes.NewReader(jdata)) 94 | if err != nil { 95 | return self.Default 96 | } 97 | defer resp.Body.Close() 98 | 99 | if out != nil { 100 | e := json.NewDecoder(resp.Body) 101 | err = e.Decode(out) 102 | if err != nil && requireOut { 103 | return self.Default 104 | } 105 | } 106 | return resp.StatusCode 107 | } 108 | 109 | type loginEvent struct { 110 | Service string `json:"service"` 111 | Username string `json:"username"` 112 | ConnID string `json:"connId"` 113 | Addr string `json:"addr"` 114 | } 115 | 116 | type LoginHandler struct { 117 | webHook 118 | } 119 | 120 | func (self *LoginHandler) OnLogin(service, username, connId, addr string) { 121 | self.post(&loginEvent{service, username, connId, addr}, nil, false) 122 | } 123 | 124 | type logoutEvent struct { 125 | Service string `json:"service"` 126 | Username string `json:"username"` 127 | ConnID string `json:"connId"` 128 | Addr string `json:"addr"` 129 | Reason string `json:"reason,omitempty"` 130 | } 131 | 132 | type LogoutHandler struct { 133 | webHook 134 | } 135 | 136 | func (self *LogoutHandler) OnLogout(service, username, connId, addr string, reason error) { 137 | evt := &logoutEvent{ 138 | Service: service, 139 | Username: username, 140 | ConnID: connId, 141 | Addr: addr, 142 | } 143 | 144 | if reason != nil { 145 | evt.Reason = reason.Error() 146 | } 147 | self.post(evt, nil, false) 148 | } 149 | 150 | type messageEvent struct { 151 | ConnID string `json:"connId"` 152 | Msg *rpc.Message `json:"msg"` 153 | Service string `json:"service"` 154 | Username string `json:"username"` 155 | } 156 | 157 | type MessageHandler struct { 158 | webHook 159 | } 160 | 161 | func (self *MessageHandler) OnMessage(service, username, connId string, msg *rpc.Message) { 162 | evt := &messageEvent{ 163 | Service: service, 164 | Username: username, 165 | ConnID: connId, 166 | Msg: msg, 167 | } 168 | self.post(evt, nil, false) 169 | } 170 | 171 | type errorEvent struct { 172 | Service string `json:"service"` 173 | Username string `json:"username"` 174 | ConnID string `json:"connId"` 175 | Addr string `json:"addr"` 176 | Reason string `json:"reason"` 177 | } 178 | 179 | type ErrorHandler struct { 180 | webHook 181 | } 182 | 183 | func (self *ErrorHandler) OnError(service, username, connId, addr string, reason error) { 184 | self.post(&errorEvent{service, username, connId, addr, reason.Error()}, nil, false) 185 | } 186 | 187 | type ForwardRequestHandler struct { 188 | webHook 189 | maxTTL time.Duration 190 | } 191 | 192 | type forwardEvent struct { 193 | SenderService string `json:"sender-service"` 194 | Sender string `json:"sender"` 195 | ReceiverService string `json:"receiver-service"` 196 | Receivers []string `json:"receivers"` 197 | Message *rpc.Message `json:"msg"` 198 | TTL time.Duration `json:"ttl"` 199 | } 200 | 201 | type forwardDecision struct { 202 | ShouldForward bool `json:"should-forward"` 203 | ShouldPush bool `json:"should-push"` 204 | PushInfo map[string]string `json:"push-info"` 205 | } 206 | 207 | func (self *ForwardRequestHandler) ShouldForward(senderService, sender, receiverService string, receivers []string, 208 | ttl time.Duration, msg *rpc.Message) (shouldForward, shouldPush bool, pushInfo map[string]string) { 209 | fwd := &forwardEvent{ 210 | Sender: sender, 211 | SenderService: senderService, 212 | Receivers: receivers, 213 | ReceiverService: receiverService, 214 | TTL: ttl, 215 | Message: msg, 216 | } 217 | 218 | res := &forwardDecision{ 219 | ShouldForward: true, 220 | ShouldPush: true, 221 | PushInfo: make(map[string]string, 10), 222 | } 223 | 224 | status := self.post(fwd, res, false) 225 | if status != 200 { 226 | res.ShouldForward = false 227 | res.ShouldPush = false 228 | } 229 | return res.ShouldForward, res.ShouldPush, res.PushInfo 230 | } 231 | 232 | func (self *ForwardRequestHandler) SetMaxTTL(ttl time.Duration) { 233 | self.maxTTL = ttl 234 | } 235 | 236 | func (self *ForwardRequestHandler) MaxTTL() time.Duration { 237 | return self.maxTTL 238 | } 239 | 240 | type authEvent struct { 241 | Service string `json:"service"` 242 | Username string `json:"username"` 243 | ConnId string `json:"conn-id"` 244 | Token string `json:"token"` 245 | Addr string `json:"addr"` 246 | } 247 | 248 | type AuthHandler struct { 249 | webHook 250 | } 251 | 252 | func (self *AuthHandler) Authenticate(srv, usr, connId, token, addr string) (pass bool, redir []string, err error) { 253 | evt := new(authEvent) 254 | evt.Service = srv 255 | evt.Username = usr 256 | evt.ConnId = connId 257 | evt.Token = token 258 | evt.Addr = addr 259 | pass = self.post(evt, redir, false) == 200 260 | return 261 | } 262 | 263 | type pushRelatedEvent struct { 264 | Service string `json:"service"` 265 | Username string `json:"username"` 266 | Info map[string]string `json:"info"` 267 | } 268 | 269 | type SubscribeHandler struct { 270 | webHook 271 | } 272 | 273 | func (self *SubscribeHandler) ShouldSubscribe(service, username string, info map[string]string) bool { 274 | evt := new(pushRelatedEvent) 275 | evt.Service = service 276 | evt.Username = username 277 | evt.Info = info 278 | return self.post(evt, nil, false) == 200 279 | } 280 | 281 | type UnsubscribeHandler struct { 282 | webHook 283 | } 284 | 285 | func (self *UnsubscribeHandler) OnUnsubscribe(service, username string, info map[string]string) { 286 | evt := &pushRelatedEvent{} 287 | evt.Service = service 288 | evt.Username = username 289 | evt.Info = info 290 | self.post(evt, nil, false) 291 | return 292 | } 293 | -------------------------------------------------------------------------------- /examples/printer/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package main 19 | 20 | import ( 21 | "net/http" 22 | "fmt" 23 | "bufio" 24 | ) 25 | 26 | func PrintData(w http.ResponseWriter, r *http.Request) { 27 | defer r.Body.Close() 28 | fmt.Printf("%v:\n", r.URL.Path) 29 | 30 | reader := bufio.NewReader(r.Body) 31 | for { 32 | line, _, err := reader.ReadLine() 33 | if err != nil { 34 | return 35 | } 36 | fmt.Printf("\t%v\n", string(line)) 37 | } 38 | w.WriteHeader(200) 39 | return 40 | } 41 | 42 | func main() { 43 | http.HandleFunc("/", PrintData) 44 | err := http.ListenAndServe(":8080", nil) 45 | if err != nil { 46 | fmt.Printf("Error: %v\n", err) 47 | } 48 | return 49 | } 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /examples/testclient/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/json" 8 | "encoding/pem" 9 | "flag" 10 | "fmt" 11 | "github.com/uniqush/uniqush-conn/rpc" 12 | 13 | "github.com/uniqush/uniqush-conn/proto/client" 14 | "io" 15 | "io/ioutil" 16 | "net" 17 | "os" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | func loadRSAPublicKey(keyFileName string) (rsapub *rsa.PublicKey, err error) { 23 | keyData, err := ioutil.ReadFile(keyFileName) 24 | if err != nil { 25 | fmt.Printf("Error: %v\n", err) 26 | return 27 | } 28 | b, _ := pem.Decode(keyData) 29 | if b == nil { 30 | err = fmt.Errorf("No key in the file") 31 | return 32 | } 33 | key, err := x509.ParsePKIXPublicKey(b.Bytes) 34 | if err != nil { 35 | fmt.Printf("Error: %v\n", err) 36 | return 37 | } 38 | rsapub, ok := key.(*rsa.PublicKey) 39 | 40 | if !ok { 41 | err = fmt.Errorf("Not an RSA public key") 42 | return 43 | } 44 | return 45 | } 46 | 47 | var argvPubKey = flag.String("key", "pub.pem", "public key file") 48 | var argvService = flag.String("s", "service", "service") 49 | var argvUsername = flag.String("u", "username", "username") 50 | var argvPassword = flag.String("p", "", "password") 51 | var argvDigestThrd = flag.Int("d", 512, "digest threshold") 52 | var argvCompressThrd = flag.Int("c", 1024, "compress threshold") 53 | 54 | func messagePrinter(conn client.Conn, msgChan <-chan *rpc.MessageContainer, digestChan <-chan *client.Digest) { 55 | encoder := json.NewEncoder(os.Stdout) 56 | for { 57 | select { 58 | case msg := <-msgChan: 59 | if msg == nil { 60 | return 61 | } 62 | encoder.Encode(msg) 63 | if msg.Message.Body != nil { 64 | fmt.Printf("\n%v: %v", msg.Sender, string(msg.Message.Body)) 65 | } else { 66 | fmt.Printf("\n") 67 | } 68 | case digest := <-digestChan: 69 | if digest == nil { 70 | return 71 | } 72 | encoder.Encode(digest) 73 | fmt.Printf("\n") 74 | conn.RequestMessage(digest.MsgId) 75 | } 76 | } 77 | } 78 | 79 | func messageReceiver(conn client.Conn, msgChan chan<- *rpc.MessageContainer) { 80 | defer conn.Close() 81 | for { 82 | msg, err := conn.ReceiveMessage() 83 | if err != nil { 84 | if err != io.EOF { 85 | fmt.Fprintf(os.Stderr, "%v\n", err) 86 | } 87 | return 88 | } 89 | msgChan <- msg 90 | } 91 | } 92 | 93 | func messageSender(conn client.Conn) { 94 | stdin := bufio.NewReader(os.Stdin) 95 | for { 96 | line, err := stdin.ReadString('\n') 97 | if err != nil { 98 | if err != io.EOF { 99 | fmt.Fprintf(os.Stderr, "%v\n", err) 100 | } 101 | return 102 | } 103 | msg := &rpc.Message{} 104 | 105 | elems := strings.SplitN(line, ":", 2) 106 | if len(elems) == 2 { 107 | msg.Body = []byte(elems[1]) 108 | 109 | recvers := strings.Split(elems[0], ",") 110 | err = conn.SendMessageToUsers(msg, 1*time.Hour, conn.Service(), recvers...) 111 | } else { 112 | msg.Body = []byte(line) 113 | err = conn.SendMessageToServer(msg) 114 | } 115 | if err != nil { 116 | if err != io.EOF { 117 | fmt.Fprintf(os.Stderr, "%v\n", err) 118 | } 119 | return 120 | } 121 | } 122 | } 123 | 124 | func main() { 125 | flag.Parse() 126 | pk, err := loadRSAPublicKey(*argvPubKey) 127 | if err != nil { 128 | fmt.Fprintf(os.Stderr, "%v\n", err) 129 | return 130 | } 131 | addr := "127.0.0.1:8964" 132 | if flag.NArg() > 0 { 133 | addr = flag.Arg(0) 134 | _, err := net.ResolveTCPAddr("tcp", addr) 135 | if err != nil { 136 | fmt.Fprintf(os.Stderr, "Invalid address: %v\n", err) 137 | return 138 | } 139 | } 140 | 141 | c, err := net.Dial("tcp", addr) 142 | if err != nil { 143 | fmt.Fprintf(os.Stderr, "%v\n", err) 144 | return 145 | } 146 | conn, err := client.Dial(c, pk, *argvService, *argvUsername, *argvPassword, 3*time.Second) 147 | if err != nil { 148 | fmt.Fprintf(os.Stderr, "Login Error: %v\n", err) 149 | return 150 | } 151 | err = conn.Config(*argvDigestThrd, *argvCompressThrd) 152 | if err != nil { 153 | fmt.Fprintf(os.Stderr, "Config Error: %v\n", err) 154 | return 155 | } 156 | 157 | msgChan := make(chan *rpc.MessageContainer) 158 | digestChan := make(chan *client.Digest) 159 | conn.SetDigestChannel(digestChan) 160 | go messageReceiver(conn, msgChan) 161 | go messagePrinter(conn, msgChan, digestChan) 162 | messageSender(conn) 163 | } 164 | -------------------------------------------------------------------------------- /http.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package main 19 | 20 | import ( 21 | "bytes" 22 | "crypto/rand" 23 | "encoding/base64" 24 | "encoding/json" 25 | "fmt" 26 | "github.com/uniqush/uniqush-conn/msgcenter" 27 | "io" 28 | "io/ioutil" 29 | "net/url" 30 | 31 | "strings" 32 | "time" 33 | 34 | "github.com/uniqush/uniqush-conn/rpc" 35 | 36 | "net/http" 37 | ) 38 | 39 | func (self *HttpRequestProcessor) processJsonRequest(w http.ResponseWriter, 40 | r *http.Request, req interface{}, 41 | center *msgcenter.MessageCenter, 42 | proc func(p *HttpRequestProcessor, center *msgcenter.MessageCenter, req interface{}) *rpc.Result) { 43 | if center == nil || req == nil { 44 | return 45 | } 46 | decoder := json.NewDecoder(r.Body) 47 | encoder := json.NewEncoder(w) 48 | err := decoder.Decode(req) 49 | 50 | if err != nil { 51 | res := &rpc.Result{ 52 | Error: err.Error(), 53 | } 54 | encoder.Encode(res) 55 | return 56 | } 57 | 58 | res := proc(self, center, req) 59 | encoder.Encode(res) 60 | return 61 | } 62 | 63 | func send(proc *HttpRequestProcessor, center *msgcenter.MessageCenter, req interface{}) *rpc.Result { 64 | if r, ok := req.(*rpc.SendRequest); ok { 65 | return center.Send(r) 66 | } 67 | return &rpc.Result{Error: "invalid req type"} 68 | } 69 | 70 | func forward(proc *HttpRequestProcessor, center *msgcenter.MessageCenter, req interface{}) *rpc.Result { 71 | if r, ok := req.(*rpc.ForwardRequest); ok { 72 | return center.Forward(r) 73 | } 74 | return &rpc.Result{Error: "invalid req type"} 75 | } 76 | 77 | func redirect(proc *HttpRequestProcessor, center *msgcenter.MessageCenter, req interface{}) *rpc.Result { 78 | if r, ok := req.(*rpc.RedirectRequest); ok { 79 | return center.Redirect(r) 80 | } 81 | return &rpc.Result{Error: "invalid req type"} 82 | } 83 | 84 | func checkUserStatus(proc *HttpRequestProcessor, center *msgcenter.MessageCenter, req interface{}) *rpc.Result { 85 | if r, ok := req.(*rpc.UserStatusQuery); ok { 86 | return center.CheckUserStatus(r) 87 | } 88 | return &rpc.Result{Error: "invalid req type"} 89 | } 90 | 91 | func addInstance(proc *HttpRequestProcessor, center *msgcenter.MessageCenter, req interface{}) *rpc.Result { 92 | if r, ok := req.(*rpc.UniqushConnInstance); ok { 93 | u, err := url.Parse(r.Addr) 94 | if err != nil { 95 | return &rpc.Result{Error: err.Error()} 96 | } 97 | 98 | isme, err := proc.isMyself(u.String()) 99 | if err != nil { 100 | return &rpc.Result{Error: err.Error()} 101 | } 102 | if isme { 103 | return &rpc.Result{Error: "This is me"} 104 | } 105 | instance, err := rpc.NewUniqushConnInstance(u, r.Timeout) 106 | if err != nil { 107 | return &rpc.Result{Error: err.Error()} 108 | } 109 | 110 | center.AddPeer(instance) 111 | return &rpc.Result{Error: "Success"} 112 | } 113 | return &rpc.Result{Error: "invalid req type"} 114 | } 115 | 116 | type HttpRequestProcessor struct { 117 | center *msgcenter.MessageCenter 118 | addr string 119 | myId string 120 | } 121 | 122 | func (self *HttpRequestProcessor) isMyself(addr string) (bool, error) { 123 | resp, err := http.Get(addr + "/id") 124 | if err != nil { 125 | return false, err 126 | } 127 | defer resp.Body.Close() 128 | body, err := ioutil.ReadAll(resp.Body) 129 | if err != nil { 130 | return false, err 131 | } 132 | 133 | peerId := string(body) 134 | 135 | if peerId == self.myId { 136 | return true, nil 137 | } 138 | return false, nil 139 | } 140 | 141 | func (self *HttpRequestProcessor) ServeHTTP(w http.ResponseWriter, r *http.Request) { 142 | defer r.Body.Close() 143 | 144 | upath := strings.TrimSpace(r.URL.Path) 145 | switch upath { 146 | case rpc.SEND_MESSAGE_PATH: 147 | sendReq := &rpc.SendRequest{} 148 | self.processJsonRequest(w, r, sendReq, self.center, send) 149 | case rpc.FORWARD_MESSAGE_PATH: 150 | fwdReq := &rpc.ForwardRequest{} 151 | self.processJsonRequest(w, r, fwdReq, self.center, forward) 152 | case rpc.USER_STATUS_QUERY_PATH: 153 | usrStatusQuery := &rpc.UserStatusQuery{} 154 | self.processJsonRequest(w, r, usrStatusQuery, self.center, checkUserStatus) 155 | case rpc.REDIRECT_CLIENT_PATH: 156 | redirReq := &rpc.RedirectRequest{} 157 | self.processJsonRequest(w, r, redirReq, self.center, redirect) 158 | case "/join.json": 159 | instance := &rpc.UniqushConnInstance{} 160 | self.processJsonRequest(w, r, instance, self.center, addInstance) 161 | case "/id": 162 | fmt.Fprintf(w, "%v\r\n", self.myId) 163 | case "/services": 164 | if r.Method == "GET" { 165 | srvs, err := json.Marshal(self.center.AllServices()) 166 | if err != nil { 167 | fmt.Fprintf(w, "[]\r\n") 168 | return 169 | } 170 | fmt.Fprintf(w, "%v", string(srvs)) 171 | } 172 | default: 173 | pathb := bytes.Trim([]byte(upath), "/") 174 | elems := bytes.Split(pathb, []byte("/")) 175 | var action string 176 | var srv string 177 | if len(elems) == 2 { 178 | action = string(elems[0]) 179 | srv = string(elems[1]) 180 | } 181 | switch action { 182 | case "nr-conns": 183 | fmt.Fprintf(w, "%v\r\n", self.center.NrConns(srv)) 184 | case "nr-users": 185 | fmt.Fprintf(w, "%v\r\n", self.center.NrUsers(srv)) 186 | case "all-users": 187 | usrs, err := json.Marshal(self.center.AllUsernames(srv)) 188 | if err != nil { 189 | fmt.Fprintf(w, "[]\r\n") 190 | return 191 | } 192 | fmt.Fprintf(w, "%v", string(usrs)) 193 | } 194 | } 195 | } 196 | 197 | func (self *HttpRequestProcessor) Start() error { 198 | http.Handle("/", self) 199 | err := http.ListenAndServe(self.addr, nil) 200 | return err 201 | } 202 | 203 | func NewHttpRequestProcessor(addr string, center *msgcenter.MessageCenter) *HttpRequestProcessor { 204 | ret := new(HttpRequestProcessor) 205 | ret.addr = addr 206 | ret.center = center 207 | var d [16]byte 208 | io.ReadFull(rand.Reader, d[:]) 209 | ret.myId = fmt.Sprintf("%x-%v", time.Now().Unix(), base64.URLEncoding.EncodeToString(d[:])) 210 | return ret 211 | } 212 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package main 19 | 20 | import ( 21 | "crypto/rsa" 22 | "crypto/x509" 23 | "encoding/pem" 24 | "flag" 25 | "fmt" 26 | "github.com/uniqush/uniqush-conn/config" 27 | "github.com/uniqush/uniqush-conn/msgcenter" 28 | "io/ioutil" 29 | "net" 30 | "os" 31 | ) 32 | 33 | func readPrivateKey(keyFileName string) (priv *rsa.PrivateKey, err error) { 34 | keyData, err := ioutil.ReadFile(keyFileName) 35 | if err != nil { 36 | return 37 | } 38 | 39 | b, _ := pem.Decode(keyData) 40 | priv, err = x509.ParsePKCS1PrivateKey(b.Bytes) 41 | if err != nil { 42 | return 43 | } 44 | return 45 | } 46 | 47 | var argvKeyFile = flag.String("key", "key.pem", "private key") 48 | var argvConfigFile = flag.String("config", "config.yaml", "config file path") 49 | 50 | // In memory of the blood on the square. 51 | var argvPort = flag.Int("port", 0x2304, "port number") 52 | 53 | func main() { 54 | flag.Parse() 55 | ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%v", *argvPort)) 56 | if err != nil { 57 | fmt.Fprintf(os.Stderr, "Network error: %v\n", err) 58 | return 59 | } 60 | 61 | privkey, err := readPrivateKey(*argvKeyFile) 62 | if err != nil { 63 | fmt.Fprintf(os.Stderr, "Key error: %v\n", err) 64 | return 65 | } 66 | config, err := config.ParseFile(*argvConfigFile) 67 | if err != nil { 68 | fmt.Fprintf(os.Stderr, "Config error: %v\n", err) 69 | return 70 | } 71 | if config.Auth == nil { 72 | fmt.Fprintf(os.Stderr, "Config error: You should provide the auth url\n") 73 | return 74 | } 75 | 76 | center := msgcenter.NewMessageCenter(ln, privkey, config) 77 | proc := NewHttpRequestProcessor(config.HttpAddr, center) 78 | go center.Start() 79 | err = proc.Start() 80 | if err != nil { 81 | fmt.Fprintf(os.Stderr, "%v", err) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /msgcache/cache.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcache 19 | 20 | import ( 21 | "crypto/rand" 22 | "encoding/base64" 23 | 24 | "fmt" 25 | "github.com/uniqush/uniqush-conn/rpc" 26 | "io" 27 | "sync" 28 | "time" 29 | ) 30 | 31 | func init() { 32 | Register(&mysqlCacheManager{}) 33 | Register(&redisCacheManager{}) 34 | } 35 | 36 | type Cache interface { 37 | CacheMessage(service, username string, msg *rpc.MessageContainer, ttl time.Duration) (id string, err error) 38 | // XXX Is there any better way to support retrieve all feature? 39 | Get(service, username, id string) (msg *rpc.MessageContainer, err error) 40 | RetrieveAllSince(service, username string, since time.Time) (msgs []*rpc.MessageContainer, err error) 41 | Close() error 42 | } 43 | 44 | type CacheManager interface { 45 | GetCache(host, username, password, database string, port int) (Cache, error) 46 | Engine() string 47 | } 48 | 49 | var cacheEngineMapLock sync.Mutex 50 | var cacheEngineMap map[string]CacheManager 51 | 52 | func Register(cm CacheManager) { 53 | cacheEngineMapLock.Lock() 54 | defer cacheEngineMapLock.Unlock() 55 | if cacheEngineMap == nil { 56 | cacheEngineMap = make(map[string]CacheManager, 10) 57 | } 58 | cacheEngineMap[cm.Engine()] = cm 59 | } 60 | 61 | func GetCache(engine, host, username, password, database string, port int) (Cache, error) { 62 | cacheEngineMapLock.Lock() 63 | defer cacheEngineMapLock.Unlock() 64 | if c, ok := cacheEngineMap[engine]; ok { 65 | return c.GetCache(host, username, password, database, port) 66 | } 67 | return nil, fmt.Errorf("%v is not supported", engine) 68 | } 69 | 70 | func randomId() string { 71 | var d [8]byte 72 | io.ReadFull(rand.Reader, d[:]) 73 | return fmt.Sprintf("%x-%v", time.Now().Unix(), base64.URLEncoding.EncodeToString(d[:])) 74 | } 75 | -------------------------------------------------------------------------------- /msgcache/mysqlcache.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcache 19 | 20 | import ( 21 | "database/sql" 22 | "encoding/json" 23 | "fmt" 24 | _ "github.com/go-sql-driver/mysql" 25 | "github.com/uniqush/uniqush-conn/rpc" 26 | "net" 27 | "sync" 28 | "time" 29 | ) 30 | 31 | const ( 32 | maxMessageIdLength = 255 33 | maxUsernameLength = 255 34 | maxServicenameLength = 255 35 | ) 36 | 37 | type mysqlCacheManager struct { 38 | } 39 | 40 | func (self *mysqlCacheManager) Engine() string { 41 | return "mysql" 42 | } 43 | 44 | func (self *mysqlCacheManager) GetCache(host, username, password, database string, port int) (Cache, error) { 45 | if len(host) == 0 { 46 | host = "127.0.0.1" 47 | } 48 | if port <= 0 { 49 | port = 3306 50 | } 51 | addr := net.JoinHostPort(host, fmt.Sprintf("%v", port)) 52 | return newMySQLMessageCache(username, password, addr, database) 53 | } 54 | 55 | type mysqlMessageCache struct { 56 | dsn string 57 | lock sync.RWMutex 58 | db *sql.DB 59 | cacheStmt *sql.Stmt 60 | getMultiMsgStmt *sql.Stmt 61 | getSingleMsgStmt *sql.Stmt 62 | } 63 | 64 | func (self *mysqlMessageCache) init() error { 65 | createTblStmt := `CREATE TABLE IF NOT EXISTS uniqush_messages 66 | ( 67 | id CHAR(255) NOT NULL PRIMARY KEY, 68 | mid CHAR(255), 69 | 70 | owner_service CHAR(255) NOT NULL, 71 | owner_name CHAR(255) NOT NULL, 72 | 73 | sender_service CHAR(255), 74 | sender_name CHAR(255), 75 | 76 | create_time BIGINT, 77 | deadline BIGINT, 78 | content BLOB 79 | );` 80 | createIdxStmt := `CREATE INDEX idx_owner_time ON uniqush_messages (owner_service, owner_name, create_time, deadline);` 81 | 82 | tx, err := self.db.Begin() 83 | if err != nil { 84 | return err 85 | } 86 | 87 | _, err = tx.Exec(createTblStmt) 88 | if err != nil { 89 | return err 90 | } 91 | // XXX we will ignore the error. 92 | // Because it will always be an error if there exists 93 | // an index with same name. 94 | // Is there any way to only create index if it does not exist? 95 | tx.Exec(createIdxStmt) 96 | err = tx.Commit() 97 | if err != nil { 98 | return err 99 | } 100 | return nil 101 | } 102 | 103 | func (self *mysqlMessageCache) Close() error { 104 | self.lock.RLock() 105 | defer self.lock.RUnlock() 106 | if self.cacheStmt != nil { 107 | self.cacheStmt.Close() 108 | } 109 | if self.getMultiMsgStmt != nil { 110 | self.getMultiMsgStmt.Close() 111 | } 112 | if self.getSingleMsgStmt != nil { 113 | self.getSingleMsgStmt.Close() 114 | } 115 | if self.db != nil { 116 | self.db.Close() 117 | } 118 | return nil 119 | } 120 | 121 | func (self *mysqlMessageCache) reconnect() error { 122 | self.Close() 123 | self.lock.Lock() 124 | defer self.lock.Unlock() 125 | if len(self.dsn) == 0 { 126 | return fmt.Errorf("No DSN") 127 | } 128 | db, err := sql.Open("mysql", self.dsn) 129 | if err != nil { 130 | return fmt.Errorf("Data base error: %v", err) 131 | } 132 | self.db = db 133 | err = self.init() 134 | if err != nil { 135 | return fmt.Errorf("Data base init error: %v", err) 136 | } 137 | 138 | stmt, err := db.Prepare(`INSERT INTO uniqush_messages 139 | (id, mid, owner_service, owner_name, sender_service, sender_name, create_time, deadline, content) 140 | VALUES 141 | (?, ?, ?, ?, ?, ?, ?, ?, ?) 142 | `) 143 | if err != nil { 144 | return fmt.Errorf("Data base prepare statement error: %v; insert stmt", err) 145 | } 146 | 147 | self.cacheStmt = stmt 148 | 149 | stmt, err = db.Prepare(`SELECT mid, sender_service, sender_name, create_time, content 150 | FROM uniqush_messages 151 | WHERE owner_service=? AND owner_name=? AND create_time>=? AND (deadline>=? OR deadline<=0) ORDER BY create_time; 152 | `) 153 | if err != nil { 154 | return fmt.Errorf("Data base prepare error: %v; select multi stmt", err) 155 | } 156 | self.getMultiMsgStmt = stmt 157 | stmt, err = db.Prepare(`SELECT mid, sender_service, sender_name, create_time, content 158 | FROM uniqush_messages 159 | WHERE id=? AND (deadline>? OR deadline<=0); 160 | `) 161 | if err != nil { 162 | return fmt.Errorf("Data base prepare error: %v; select single stmt", err) 163 | } 164 | self.getSingleMsgStmt = stmt 165 | return nil 166 | } 167 | 168 | func newMySQLMessageCache(username, password, address, dbname string) (c *mysqlMessageCache, err error) { 169 | c = new(mysqlMessageCache) 170 | if len(address) == 0 { 171 | address = "127.0.0.1:3306" 172 | } 173 | c.dsn = fmt.Sprintf("%v:%v@tcp(%v)/%v", username, password, address, dbname) 174 | err = c.reconnect() 175 | return 176 | } 177 | 178 | func getUniqMessageId(service, username, id string) string { 179 | return fmt.Sprintf("%v,%v,%v", service, username, id) 180 | } 181 | 182 | func (self *mysqlMessageCache) expBackoffRetry(N int, initWaitSec int, f func() error) error { 183 | s := initWaitSec 184 | var err error 185 | for i := 0; i < N; i++ { 186 | if err == nil { 187 | self.lock.RLock() 188 | err = f() 189 | self.lock.RUnlock() 190 | if err == nil { 191 | return nil 192 | } 193 | } 194 | time.Sleep(time.Duration(s) * time.Second) 195 | s = s * s 196 | err = self.reconnect() 197 | } 198 | return err 199 | } 200 | 201 | func (self *mysqlMessageCache) CacheMessage(service, username string, mc *rpc.MessageContainer, ttl time.Duration) (id string, err error) { 202 | data, err := json.Marshal(mc.Message) 203 | if err != nil { 204 | return 205 | } 206 | 207 | if mc.Id == "" { 208 | id = randomId() 209 | mc.Id = id 210 | } else { 211 | id = mc.Id 212 | } 213 | 214 | uniqid := getUniqMessageId(service, username, id) 215 | if len(uniqid) > maxMessageIdLength { 216 | err = fmt.Errorf("message id length is greater than %v characters", maxMessageIdLength) 217 | return 218 | } 219 | 220 | if len(username) > maxUsernameLength { 221 | err = fmt.Errorf("user %v's name is too long", username) 222 | return 223 | } 224 | if len(mc.Sender) > maxUsernameLength { 225 | err = fmt.Errorf("user %v's name is too long", mc.Sender) 226 | return 227 | } 228 | if len(service) > maxServicenameLength { 229 | err = fmt.Errorf("service %v's name is too long", service) 230 | return 231 | } 232 | if len(mc.SenderService) > maxServicenameLength { 233 | err = fmt.Errorf("service %v's name is too long", mc.SenderService) 234 | return 235 | } 236 | 237 | now := time.Now() 238 | mc.Birthday = now 239 | deadline := now.Add(ttl) 240 | if ttl < 1*time.Second { 241 | // max possible value for int64 242 | deadline = time.Unix(0, 0) 243 | } 244 | var result sql.Result 245 | 246 | err = self.expBackoffRetry(3, 2, func() error { 247 | result, err = self.cacheStmt.Exec(uniqid, id, service, username, mc.SenderService, mc.Sender, now.Unix(), deadline.Unix(), data) 248 | return err 249 | }) 250 | if err != nil { 251 | err = fmt.Errorf("Data base error: %v; insert error", err) 252 | return 253 | } 254 | n, err := result.RowsAffected() 255 | if err != nil { 256 | return 257 | } 258 | if n != 1 { 259 | err = fmt.Errorf("affected %v rows, which is weird", n) 260 | return 261 | } 262 | return 263 | } 264 | 265 | func (self *mysqlMessageCache) Get(service, username, id string) (mc *rpc.MessageContainer, err error) { 266 | err = self.expBackoffRetry(3, 2, func() error { 267 | mc, err = self.getOnce(service, username, id) 268 | return err 269 | }) 270 | return 271 | } 272 | 273 | func (self *mysqlMessageCache) getOnce(service, username, id string) (mc *rpc.MessageContainer, err error) { 274 | uniqid := getUniqMessageId(service, username, id) 275 | row := self.getSingleMsgStmt.QueryRow(uniqid, time.Now().Unix()) 276 | if err != nil { 277 | err = fmt.Errorf("Data base error: %v; query error", err) 278 | return 279 | } 280 | 281 | mc = new(rpc.MessageContainer) 282 | var data []byte 283 | var createTime int64 284 | err = row.Scan(&mc.Id, &mc.SenderService, &mc.Sender, &createTime, &data) 285 | if err != nil { 286 | if err == sql.ErrNoRows { 287 | err = nil 288 | mc = nil 289 | return 290 | } 291 | return 292 | } 293 | mc.Message = new(rpc.Message) 294 | err = json.Unmarshal(data, mc.Message) 295 | if err != nil { 296 | return 297 | } 298 | mc.Birthday = time.Unix(createTime, 0) 299 | return 300 | } 301 | 302 | func (self *mysqlMessageCache) RetrieveAllSince(service, username string, since time.Time) (msgs []*rpc.MessageContainer, err error) { 303 | err = self.expBackoffRetry(3, 2, func() error { 304 | msgs, err = self.retrieveOnce(service, username, since) 305 | return err 306 | }) 307 | return 308 | } 309 | 310 | func (self *mysqlMessageCache) retrieveOnce(service, username string, since time.Time) (msgs []*rpc.MessageContainer, err error) { 311 | self.lock.RLock() 312 | defer self.lock.RUnlock() 313 | rows, err := self.getMultiMsgStmt.Query(service, username, since.Unix(), time.Now().Unix()) 314 | if err != nil { 315 | err = fmt.Errorf("Data base error: %v; query multi-msg error", err) 316 | return 317 | } 318 | defer rows.Close() 319 | 320 | msgs = make([]*rpc.MessageContainer, 0, 128) 321 | for rows.Next() { 322 | mc := new(rpc.MessageContainer) 323 | var data []byte 324 | var createTime int64 325 | err = rows.Scan(&mc.Id, &mc.SenderService, &mc.Sender, &createTime, &data) 326 | if err != nil { 327 | return 328 | } 329 | mc.Message = new(rpc.Message) 330 | err = json.Unmarshal(data, mc.Message) 331 | if err != nil { 332 | return 333 | } 334 | mc.Birthday = time.Unix(createTime, 0) 335 | msgs = append(msgs, mc) 336 | } 337 | return 338 | } 339 | -------------------------------------------------------------------------------- /msgcache/mysqlcache_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcache 19 | 20 | import "testing" 21 | 22 | type mysqlTestCacheManager struct { 23 | } 24 | 25 | func (self *mysqlTestCacheManager) Name() string { 26 | return "mysql" 27 | } 28 | 29 | func (self *mysqlTestCacheManager) GetCache() (Cache, error) { 30 | return GetCache("mysql", "127.0.0.1", "travis", "", "uniqush", 3306) 31 | } 32 | 33 | func (self *mysqlTestCacheManager) ClearCache(c Cache) { 34 | if cache, ok := c.(*mysqlMessageCache); ok { 35 | cache.db.Exec("DELETE FROM uniqush_messages") 36 | } 37 | } 38 | 39 | func TestMysqlCache(t *testing.T) { 40 | testCacheImpl(&mysqlTestCacheManager{}, t) 41 | } 42 | -------------------------------------------------------------------------------- /msgcache/rediscache.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcache 19 | 20 | import ( 21 | "encoding/json" 22 | "fmt" 23 | "github.com/garyburd/redigo/redis" 24 | "github.com/uniqush/uniqush-conn/rpc" 25 | "net" 26 | 27 | "strconv" 28 | "time" 29 | ) 30 | 31 | type redisCacheManager struct { 32 | } 33 | 34 | func (self *redisCacheManager) GetCache(host, username, password, database string, port int) (Cache, error) { 35 | db := 0 36 | if len(database) > 0 { 37 | var err error 38 | db, err = strconv.Atoi(database) 39 | if err != nil { 40 | return nil, fmt.Errorf("bad database %v: %v", database, err) 41 | } 42 | } 43 | 44 | return newRedisMessageCache(host, password, port, db), nil 45 | } 46 | 47 | func (self *redisCacheManager) Engine() string { 48 | return "redis" 49 | } 50 | 51 | type redisMessageCache struct { 52 | pool *redis.Pool 53 | } 54 | 55 | func newRedisMessageCache(host, password string, port, db int) Cache { 56 | if len(host) == 0 { 57 | host = "localhost" 58 | } 59 | if port <= 0 { 60 | port = 6379 61 | } 62 | if db < 0 { 63 | db = 0 64 | } 65 | 66 | addr := net.JoinHostPort(host, fmt.Sprintf("%v", port)) 67 | dial := func() (redis.Conn, error) { 68 | c, err := redis.Dial("tcp", addr) 69 | if err != nil { 70 | return nil, err 71 | } 72 | if len(password) > 0 { 73 | if _, err := c.Do("AUTH", password); err != nil { 74 | c.Close() 75 | return nil, err 76 | } 77 | } 78 | if _, err := c.Do("SELECT", db); err != nil { 79 | c.Close() 80 | return nil, err 81 | } 82 | return c, err 83 | } 84 | testOnBorrow := func(c redis.Conn, t time.Time) error { 85 | _, err := c.Do("PING") 86 | return err 87 | } 88 | 89 | pool := &redis.Pool{ 90 | MaxIdle: 3, 91 | IdleTimeout: 240 * time.Second, 92 | Dial: dial, 93 | TestOnBorrow: testOnBorrow, 94 | } 95 | 96 | ret := new(redisMessageCache) 97 | ret.pool = pool 98 | return ret 99 | } 100 | 101 | func (self *redisMessageCache) CacheMessage(service, username string, msg *rpc.MessageContainer, ttl time.Duration) (id string, err error) { 102 | if msg.Id == "" { 103 | id = randomId() 104 | } else { 105 | id = msg.Id 106 | } 107 | err = self.set(service, username, id, msg, ttl) 108 | if err != nil { 109 | id = "" 110 | return 111 | } 112 | return 113 | } 114 | 115 | func msgKey(service, username, id string) string { 116 | return fmt.Sprintf("mcache:%v:%v:%v", service, username, id) 117 | } 118 | 119 | func msgKeyPattern(service, username string) string { 120 | return fmt.Sprintf("mcache:%v:%v:*", service, username) 121 | } 122 | 123 | func msgQueueKey(service, username string) string { 124 | return fmt.Sprintf("mqueue:%v:%v", service, username) 125 | } 126 | 127 | func msgWeightKey(service, username, id string) string { 128 | return fmt.Sprintf("w_mcache:%v:%v:%v", service, username, id) 129 | } 130 | 131 | func msgWeightPattern(service, username string) string { 132 | return fmt.Sprintf("w_mcache:%v:%v:*", service, username) 133 | } 134 | 135 | func counterKey(service, username string) string { 136 | return "msgCounter" 137 | } 138 | 139 | func msgMarshal(msg *rpc.MessageContainer) (data []byte, err error) { 140 | data, err = json.Marshal(msg) 141 | return 142 | } 143 | 144 | func msgUnmarshal(data []byte) (msg *rpc.MessageContainer, err error) { 145 | msg = new(rpc.MessageContainer) 146 | err = json.Unmarshal(data, msg) 147 | if err != nil { 148 | msg = nil 149 | return 150 | } 151 | return 152 | } 153 | 154 | func (self *redisMessageCache) set(service, username, id string, msg *rpc.MessageContainer, ttl time.Duration) error { 155 | msg.Id = id 156 | msg.Birthday = time.Now() 157 | key := msgKey(service, username, id) 158 | conn := self.pool.Get() 159 | defer conn.Close() 160 | 161 | data, err := msgMarshal(msg) 162 | if err != nil { 163 | return err 164 | } 165 | 166 | /* 167 | reply, err := conn.Do("INCR", counterKey(service, username)) 168 | if err != nil { 169 | return err 170 | } 171 | 172 | weight, err := redis.Int64(reply, err) 173 | if err != nil { 174 | return err 175 | } 176 | */ 177 | weight := time.Now().Unix() 178 | wkey := msgWeightKey(service, username, id) 179 | 180 | err = conn.Send("MULTI") 181 | if err != nil { 182 | return err 183 | } 184 | 185 | if ttl.Seconds() <= 0.0 { 186 | err = conn.Send("SET", key, data) 187 | if err != nil { 188 | conn.Do("DISCARD") 189 | return err 190 | } 191 | err = conn.Send("SET", wkey, weight) 192 | } else { 193 | err = conn.Send("SETEX", key, int64(ttl.Seconds()), data) 194 | if err != nil { 195 | conn.Do("DISCARD") 196 | return err 197 | } 198 | err = conn.Send("SETEX", wkey, int64(ttl.Seconds()), weight) 199 | } 200 | if err != nil { 201 | conn.Do("DISCARD") 202 | return err 203 | } 204 | msgQK := msgQueueKey(service, username) 205 | err = conn.Send("SADD", msgQK, id) 206 | if err != nil { 207 | conn.Do("DISCARD") 208 | return err 209 | } 210 | _, err = conn.Do("EXEC") 211 | if err != nil { 212 | return err 213 | } 214 | return nil 215 | } 216 | 217 | func (self *redisMessageCache) Get(service, username, id string) (msg *rpc.MessageContainer, err error) { 218 | key := msgKey(service, username, id) 219 | conn := self.pool.Get() 220 | defer conn.Close() 221 | 222 | reply, err := conn.Do("GET", key) 223 | if err != nil { 224 | return 225 | } 226 | if reply == nil { 227 | return 228 | } 229 | data, err := redis.Bytes(reply, err) 230 | if err != nil { 231 | return 232 | } 233 | msg, err = msgUnmarshal(data) 234 | return 235 | } 236 | 237 | /* 238 | * We may not need Delete 239 | func (self *redisMessageCache) Del(service, username, id string) error { 240 | key := msgKey(service, username, id) 241 | wkey := msgWeightKey(service, username, id) 242 | conn := self.pool.Get() 243 | defer conn.Close() 244 | 245 | err := conn.Send("MULTI") 246 | if err != nil { 247 | return err 248 | } 249 | err = conn.Send("DEL", key) 250 | if err != nil { 251 | conn.Do("DISCARD") 252 | return err 253 | } 254 | err = conn.Send("DEL", wkey) 255 | if err != nil { 256 | conn.Do("DISCARD") 257 | return err 258 | } 259 | msgQK := msgQueueKey(service, username) 260 | err = conn.Send("SREM", msgQK, id) 261 | if err != nil { 262 | conn.Do("DISCARD") 263 | return err 264 | } 265 | _, err = conn.Do("EXEC") 266 | if err != nil { 267 | return err 268 | } 269 | return nil 270 | } 271 | */ 272 | 273 | /* 274 | * We may not need get then delete. 275 | func (self *redisMessageCache) GetThenDel(service, username, id string) (msg *proto.Message, err error) { 276 | key := msgKey(service, username, id) 277 | wkey := msgWeightKey(service, username, id) 278 | conn := self.pool.Get() 279 | defer conn.Close() 280 | 281 | err = conn.Send("MULTI") 282 | if err != nil { 283 | return 284 | } 285 | err = conn.Send("GET", key) 286 | if err != nil { 287 | conn.Do("DISCARD") 288 | return 289 | } 290 | err = conn.Send("DEL", key) 291 | if err != nil { 292 | conn.Do("DISCARD") 293 | return 294 | } 295 | err = conn.Send("DEL", wkey) 296 | if err != nil { 297 | conn.Do("DISCARD") 298 | return 299 | } 300 | msgQK := msgQueueKey(service, username) 301 | err = conn.Send("SREM", msgQK, id) 302 | if err != nil { 303 | conn.Do("DISCARD") 304 | return 305 | } 306 | reply, err := conn.Do("EXEC") 307 | if err != nil { 308 | return 309 | } 310 | 311 | bulkReply, err := redis.Values(reply, err) 312 | if err != nil { 313 | return 314 | } 315 | if len(bulkReply) != 4 { 316 | return 317 | } 318 | if bulkReply[0] == nil { 319 | return 320 | } 321 | data, err := redis.Bytes(bulkReply[0], err) 322 | if err != nil { 323 | return 324 | } 325 | if len(data) == 0 { 326 | return 327 | } 328 | msg, err = msgUnmarshal(data) 329 | return 330 | } 331 | */ 332 | 333 | func (self *redisMessageCache) RetrieveAllSince(service, username string, since time.Time) (msgs []*rpc.MessageContainer, err error) { 334 | msgQK := msgQueueKey(service, username) 335 | conn := self.pool.Get() 336 | defer conn.Close() 337 | 338 | err = conn.Send("MULTI") 339 | if err != nil { 340 | return 341 | } 342 | err = conn.Send("SORT", msgQK, 343 | "BY", 344 | msgWeightPattern(service, username), 345 | "GET", 346 | msgKeyPattern(service, username)) 347 | if err != nil { 348 | conn.Do("DISCARD") 349 | return 350 | } 351 | err = conn.Send("SORT", msgQK, 352 | "BY", 353 | msgWeightPattern(service, username)) 354 | if err != nil { 355 | conn.Do("DISCARD") 356 | return 357 | } 358 | 359 | reply, err := conn.Do("EXEC") 360 | if err != nil { 361 | return 362 | } 363 | bulkReply, err := redis.Values(reply, err) 364 | if err != nil { 365 | return 366 | } 367 | if len(bulkReply) != 2 { 368 | return 369 | } 370 | 371 | msgObjs, err := redis.Values(bulkReply[0], nil) 372 | if err != nil { 373 | return 374 | } 375 | msgIds, err := redis.Values(bulkReply[1], nil) 376 | if err != nil { 377 | return 378 | } 379 | n := len(msgObjs) 380 | if n == 0 { 381 | return 382 | } 383 | msgShadow := make([]*rpc.MessageContainer, 0, n) 384 | removed := make([]interface{}, 1, n+1) 385 | removed[0] = msgQK 386 | 387 | for i, reply := range msgObjs { 388 | var data []byte 389 | var msg *rpc.MessageContainer 390 | if reply == nil { 391 | id, err := redis.String(msgIds[i], nil) 392 | if err == nil { 393 | removed = append(removed, id) 394 | } 395 | continue 396 | } 397 | data, err = redis.Bytes(reply, err) 398 | if err != nil { 399 | return 400 | } 401 | if len(data) == 0 { 402 | id, err := redis.String(msgIds[i], nil) 403 | if err == nil { 404 | removed = append(removed, id) 405 | } 406 | continue 407 | } 408 | msg, err = msgUnmarshal(data) 409 | 410 | if since.Before(msg.Birthday) { 411 | msgShadow = append(msgShadow, msg) 412 | } 413 | } 414 | 415 | if len(removed) > 1 { 416 | _, err = conn.Do("SREM", removed...) 417 | if err != nil { 418 | return 419 | } 420 | } 421 | msgs = msgShadow 422 | return 423 | } 424 | 425 | func (self *redisMessageCache) Close() error { 426 | return nil 427 | } 428 | -------------------------------------------------------------------------------- /msgcache/rediscache_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcache 19 | 20 | import ( 21 | "fmt" 22 | "testing" 23 | ) 24 | 25 | type redisTestCacheManager struct { 26 | db int 27 | } 28 | 29 | func (self *redisTestCacheManager) Name() string { 30 | return "redis" 31 | } 32 | 33 | func (self *redisTestCacheManager) ClearCache(c Cache) { 34 | if cache, ok := c.(*redisMessageCache); ok { 35 | cache.pool.Get().Do("SELECT", self.db) 36 | cache.pool.Get().Do("FLUSHDB") 37 | } 38 | } 39 | 40 | func (self *redisTestCacheManager) GetCache() (Cache, error) { 41 | return GetCache("redis", "", "", "", fmt.Sprintf("%v", self.db), 0) 42 | } 43 | 44 | func TestRedisCache(t *testing.T) { 45 | testCacheImpl(&redisTestCacheManager{1}, t) 46 | } 47 | 48 | /* 49 | func TestGetSetMessage(t *testing.T) { 50 | N := 10 51 | msgs := multiRandomMessage(N) 52 | cache := getCache() 53 | defer clearDb() 54 | srv := "srv" 55 | usr := "usr" 56 | 57 | ids := make([]string, N) 58 | 59 | for i, msg := range msgs { 60 | id, err := cache.CacheMessage(srv, usr, msg, 0*time.Second) 61 | if err != nil { 62 | t.Errorf("Set error: %v", err) 63 | return 64 | } 65 | ids[i] = id 66 | } 67 | for i, msg := range msgs { 68 | m, err := cache.Get(srv, usr, ids[i]) 69 | if err != nil { 70 | t.Errorf("Get error: %v", err) 71 | return 72 | } 73 | if !m.Message.Eq(msg.Message) { 74 | t.Errorf("%vth message does not same", i) 75 | } 76 | } 77 | } 78 | 79 | func TestGetSetMessageTTL(t *testing.T) { 80 | N := 10 81 | msgs := multiRandomMessage(N) 82 | cache := getCache() 83 | defer clearDb() 84 | srv := "srv" 85 | usr := "usr" 86 | 87 | ids := make([]string, N) 88 | 89 | for i, msg := range msgs { 90 | id, err := cache.CacheMessage(srv, usr, msg, 1*time.Second) 91 | if err != nil { 92 | t.Errorf("Set error: %v", err) 93 | return 94 | } 95 | ids[i] = id 96 | } 97 | time.Sleep(2 * time.Second) 98 | for i, id := range ids { 99 | m, err := cache.Get(srv, usr, id) 100 | if err != nil { 101 | t.Errorf("Get error: %v", err) 102 | return 103 | } 104 | if m != nil { 105 | t.Errorf("%vth message should be deleted", i) 106 | } 107 | } 108 | } 109 | 110 | func TestCacheThenRetrieveAll(t *testing.T) { 111 | N := 10 112 | msgs := multiRandomMessage(N) 113 | cache := getCache() 114 | defer clearDb() 115 | srv := "srv" 116 | usr := "usr" 117 | 118 | ids := make([]string, N) 119 | 120 | for i, msg := range msgs { 121 | id, err := cache.CacheMessage(srv, usr, msg, 0*time.Second) 122 | if err != nil { 123 | t.Errorf("Set error: %v", err) 124 | return 125 | } 126 | ids[i] = id 127 | } 128 | 129 | retrievedMsgs, err := cache.GetCachedMessages(srv, usr) 130 | if err != nil { 131 | t.Errorf("Set error: %v", err) 132 | return 133 | } 134 | for i, id := range ids { 135 | if retrievedMsgs[i].Id != id { 136 | t.Errorf("retrieved different ids: %v != %v", retrievedMsgs, ids) 137 | return 138 | } 139 | } 140 | } 141 | 142 | func TestGetNonExistMsg(t *testing.T) { 143 | cache := getCache() 144 | defer clearDb() 145 | srv := "srv" 146 | usr := "usr" 147 | 148 | msg, err := cache.Get(srv, usr, "wont-be-a-good-message-id") 149 | if err != nil { 150 | t.Errorf("%v", err) 151 | return 152 | } 153 | if msg != nil { 154 | t.Errorf("should be nil message") 155 | return 156 | } 157 | } 158 | 159 | func TestCacheThenRetrieveAllWithTTL(t *testing.T) { 160 | N := 10 161 | msgs := multiRandomMessage(N) 162 | cache := getCache() 163 | defer clearDb() 164 | srv := "srv" 165 | usr := "usr" 166 | 167 | ids := make([]string, N) 168 | 169 | ttl := 0 170 | nrDead := 2 171 | for i, msg := range msgs { 172 | if i == len(msgs)-nrDead { 173 | ttl = 1 174 | } 175 | id, err := cache.CacheMessage(srv, usr, msg, time.Duration(ttl)*time.Second) 176 | if err != nil { 177 | t.Errorf("Set error: %v", err) 178 | return 179 | } 180 | ids[i] = id 181 | } 182 | time.Sleep(2 * time.Second) 183 | retrievedMsgs, err := cache.GetCachedMessages(srv, usr) 184 | if err != nil { 185 | t.Errorf("Set error: %v", err) 186 | return 187 | } 188 | if len(retrievedMsgs) != len(msgs)-nrDead { 189 | t.Errorf("retrieved %v objects", len(retrievedMsgs)) 190 | return 191 | } 192 | for i, id := range ids[:len(msgs)-nrDead] { 193 | if retrievedMsgs[i].Id != id { 194 | t.Errorf("retrieved different ids: %v != %v", retrievedMsgs, ids) 195 | return 196 | } 197 | } 198 | } 199 | */ 200 | -------------------------------------------------------------------------------- /msgcenter/connmap.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcenter 19 | 20 | import ( 21 | "errors" 22 | 23 | "github.com/petar/GoLLRB/llrb" 24 | "github.com/uniqush/uniqush-conn/proto/server" 25 | "sync" 26 | ) 27 | 28 | type connMap interface { 29 | NrConns() int 30 | NrUsers() int 31 | AllUsernames() []string 32 | AddConn(conn server.Conn) error 33 | GetConn(username string) ConnSet 34 | DelConn(conn server.Conn) server.Conn 35 | CloseAll() 36 | } 37 | type treeBasedConnMap struct { 38 | tree *llrb.LLRB 39 | maxNrConn int 40 | maxNrUsers int 41 | maxNrConnsPerUser int 42 | 43 | nrConn int 44 | lock sync.RWMutex 45 | } 46 | 47 | func (self *treeBasedConnMap) AllUsernames() []string { 48 | self.lock.RLock() 49 | defer self.lock.RUnlock() 50 | ret := make([]string, 0, self.tree.Len()) 51 | 52 | self.tree.AscendGreaterOrEqual(nil, func(item llrb.Item) bool { 53 | if cs, ok := item.(*connSet); ok { 54 | ret = append(ret, cs.username()) 55 | } 56 | return true 57 | }) 58 | return ret 59 | } 60 | 61 | func (self *treeBasedConnMap) NrConns() int { 62 | self.lock.RLock() 63 | defer self.lock.RUnlock() 64 | return self.nrConn 65 | } 66 | 67 | func (self *treeBasedConnMap) NrUsers() int { 68 | self.lock.RLock() 69 | defer self.lock.RUnlock() 70 | return self.tree.Len() 71 | } 72 | 73 | func (self *treeBasedConnMap) GetConn(user string) ConnSet { 74 | self.lock.RLock() 75 | defer self.lock.RUnlock() 76 | cset := self.getConn(user) 77 | return cset 78 | } 79 | 80 | func (self *treeBasedConnMap) getConn(user string) *connSet { 81 | key := &connSet{name: user, list: nil} 82 | clif := self.tree.Get(key) 83 | if clif == nil { 84 | return nil 85 | } 86 | cl, ok := clif.(*connSet) 87 | if !ok || cl == nil { 88 | return nil 89 | } 90 | return cl 91 | } 92 | 93 | var ErrTooManyUsers = errors.New("too many users") 94 | var ErrTooManyConnForThisUser = errors.New("too many connections under this user") 95 | var ErrTooManyConns = errors.New("too many connections") 96 | 97 | // There's no way back! 98 | func (self *treeBasedConnMap) CloseAll() { 99 | self.lock.Lock() 100 | var nilcs *connSet 101 | self.tree.AscendGreaterOrEqual(nilcs, func(i llrb.Item) bool { 102 | if cs, ok := i.(*connSet); ok { 103 | cs.CloseAll() 104 | } 105 | return true 106 | }) 107 | self.tree = llrb.New() 108 | } 109 | 110 | func (self *treeBasedConnMap) AddConn(conn server.Conn) error { 111 | self.lock.Lock() 112 | defer self.lock.Unlock() 113 | if conn == nil { 114 | return nil 115 | } 116 | 117 | if self.maxNrConn > 0 && self.nrConn >= self.maxNrConn { 118 | return ErrTooManyConns 119 | } 120 | if self.maxNrUsers > 0 && self.tree.Len() >= self.maxNrUsers { 121 | return ErrTooManyUsers 122 | } 123 | cset := self.getConn(connKey(conn)) 124 | 125 | if cset == nil { 126 | cset = &connSet{name: connKey(conn), list: make([]server.Conn, 0, 3)} 127 | } 128 | 129 | cset.lock() 130 | defer cset.unlock() 131 | 132 | err := cset.add(conn, self.maxNrConnsPerUser) 133 | if err != nil { 134 | return err 135 | } 136 | 137 | self.nrConn++ 138 | self.tree.ReplaceOrInsert(cset) 139 | return nil 140 | } 141 | 142 | func (self *treeBasedConnMap) DelConn(conn server.Conn) server.Conn { 143 | self.lock.Lock() 144 | defer self.lock.Unlock() 145 | if conn == nil { 146 | return nil 147 | } 148 | cset := self.getConn(connKey(conn)) 149 | if cset == nil { 150 | return nil 151 | } 152 | 153 | cset.lock() 154 | defer cset.unlock() 155 | 156 | ret := cset.del(conn.UniqId()) 157 | 158 | if ret == nil { 159 | return ret 160 | } 161 | self.nrConn-- 162 | if cset.nrConn() == 0 { 163 | self.tree.Delete(cset) 164 | } else { 165 | self.tree.ReplaceOrInsert(cset) 166 | } 167 | 168 | return ret 169 | } 170 | 171 | func newTreeBasedConnMap(maxNrConn, maxNrUsers, maxNrConnsPerUser int) connMap { 172 | ret := new(treeBasedConnMap) 173 | ret.tree = llrb.New() 174 | ret.maxNrConn = maxNrConn 175 | ret.maxNrUsers = maxNrUsers 176 | ret.maxNrConnsPerUser = maxNrConnsPerUser 177 | return ret 178 | } 179 | -------------------------------------------------------------------------------- /msgcenter/connset.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcenter 19 | 20 | import ( 21 | "github.com/petar/GoLLRB/llrb" 22 | "github.com/uniqush/uniqush-conn/proto/server" 23 | "sync" 24 | ) 25 | 26 | type ConnSet interface { 27 | NrConn() int 28 | Traverse(f func(c server.Conn) error) error 29 | CloseAll() 30 | } 31 | 32 | type connSet struct { 33 | mutex sync.RWMutex 34 | name string 35 | list []server.Conn 36 | } 37 | 38 | func (self *connSet) CloseAll() { 39 | if self == nil { 40 | return 41 | } 42 | self.mutex.Lock() 43 | for _, c := range self.list { 44 | c.Close() 45 | } 46 | self.list = nil 47 | } 48 | 49 | func (self *connSet) unlock() { 50 | self.mutex.Unlock() 51 | } 52 | 53 | // Never manipulate another connSet inside the function f. 54 | // It may lead to deadlock. 55 | // Only perform read-only operation (to the connSet, not the Conn) in the function f(). 56 | func (self *connSet) Traverse(f func(c server.Conn) error) error { 57 | if self == nil { 58 | return nil 59 | } 60 | self.mutex.RLock() 61 | defer self.mutex.RUnlock() 62 | 63 | if self == nil { 64 | return nil 65 | } 66 | 67 | for _, c := range self.list { 68 | err := f(c) 69 | if err != nil { 70 | return err 71 | } 72 | } 73 | return nil 74 | } 75 | 76 | func (self *connSet) lock() { 77 | self.mutex.Lock() 78 | } 79 | 80 | func (self *connSet) NrConn() int { 81 | if self == nil { 82 | return 0 83 | } 84 | self.mutex.RLock() 85 | defer self.mutex.RUnlock() 86 | 87 | return self.nrConn() 88 | } 89 | 90 | func (self *connSet) nrConn() int { 91 | if self == nil { 92 | return 0 93 | } 94 | return len(self.list) 95 | } 96 | 97 | func (self *connSet) del(key string) server.Conn { 98 | i := -1 99 | for j, c := range self.list { 100 | if c.UniqId() == key { 101 | i = j 102 | break 103 | } 104 | } 105 | if i < 0 || i >= len(self.list) { 106 | return nil 107 | } 108 | c := self.list[i] 109 | self.list[i] = self.list[len(self.list)-1] 110 | self.list = self.list[:len(self.list)-1] 111 | 112 | return c 113 | } 114 | 115 | func (self *connSet) add(conn server.Conn, max int) error { 116 | for _, c := range self.list { 117 | if c.UniqId() == conn.UniqId() { 118 | return nil 119 | } 120 | } 121 | if max > 0 && len(self.list) >= max { 122 | return ErrTooManyConnForThisUser 123 | } 124 | self.list = append(self.list, conn) 125 | return nil 126 | } 127 | 128 | func (self *connSet) username() string { 129 | if len(self.list) == 0 { 130 | return self.name 131 | } 132 | return self.list[0].Username() 133 | } 134 | 135 | func (self *connSet) key() string { 136 | if len(self.list) == 0 { 137 | return self.name 138 | } 139 | return connKey(self.list[0]) 140 | } 141 | 142 | func (self *connSet) Less(than llrb.Item) bool { 143 | if self == nil { 144 | return true 145 | } 146 | if than == nil { 147 | return false 148 | } 149 | thanCs := than.(*connSet) 150 | if thanCs == nil { 151 | return false 152 | } 153 | selfKey := llrb.String(self.key()) 154 | thanKey := llrb.String(than.(*connSet).key()) 155 | return selfKey.Less(thanKey) 156 | } 157 | 158 | func connKey(conn server.Conn) string { 159 | return conn.Username() 160 | } 161 | -------------------------------------------------------------------------------- /msgcenter/msgcenter.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcenter 19 | 20 | import ( 21 | "crypto/rsa" 22 | "fmt" 23 | "github.com/uniqush/uniqush-conn/config" 24 | "github.com/uniqush/uniqush-conn/proto/server" 25 | "github.com/uniqush/uniqush-conn/rpc" 26 | "net" 27 | "sync" 28 | ) 29 | 30 | type MessageCenter struct { 31 | ln net.Listener 32 | privkey *rsa.PrivateKey 33 | config *config.Config 34 | 35 | srvCentersLock sync.Mutex 36 | serviceCenterMap map[string]*serviceCenter 37 | fwdChan chan *rpc.ForwardRequest 38 | peers *rpc.MultiPeer 39 | } 40 | 41 | func (self *MessageCenter) processForwardRequest() { 42 | for req := range self.fwdChan { 43 | if req == nil { 44 | return 45 | } 46 | if len(req.ReceiverService) == 0 { 47 | continue 48 | } 49 | center := self.getServiceCenter(req.ReceiverService) 50 | if center != nil { 51 | go center.Forward(req) 52 | } 53 | } 54 | } 55 | 56 | func NewMessageCenter(ln net.Listener, privkey *rsa.PrivateKey, conf *config.Config) *MessageCenter { 57 | ret := new(MessageCenter) 58 | ret.ln = ln 59 | ret.privkey = privkey 60 | ret.config = conf 61 | 62 | ret.peers = rpc.NewMultiPeer() 63 | ret.fwdChan = make(chan *rpc.ForwardRequest) 64 | ret.serviceCenterMap = make(map[string]*serviceCenter, 10) 65 | 66 | go ret.processForwardRequest() 67 | return ret 68 | } 69 | 70 | func (self *MessageCenter) ServiceNames() []string { 71 | self.srvCentersLock.Lock() 72 | defer self.srvCentersLock.Unlock() 73 | 74 | ret := make([]string, 0, len(self.serviceCenterMap)) 75 | 76 | for srv, _ := range self.serviceCenterMap { 77 | ret = append(ret, srv) 78 | } 79 | return ret 80 | } 81 | 82 | func (self *MessageCenter) getServiceCenter(srv string) *serviceCenter { 83 | self.srvCentersLock.Lock() 84 | defer self.srvCentersLock.Unlock() 85 | 86 | center, ok := self.serviceCenterMap[srv] 87 | if !ok { 88 | conf := self.config.ReadConfig(srv) 89 | if conf == nil { 90 | return nil 91 | } 92 | center = newServiceCenter(conf, self.fwdChan, self.peers) 93 | if center != nil { 94 | self.serviceCenterMap[srv] = center 95 | } 96 | } 97 | return center 98 | } 99 | 100 | func (self *MessageCenter) serveConn(c net.Conn) { 101 | if tcpConn, ok := c.(*net.TCPConn); ok { 102 | // Rather than keeping an application leve heart beat, 103 | // we rely on TCP-level keep-alive. 104 | // XXX Is this a good idea? 105 | tcpConn.SetKeepAlive(true) 106 | 107 | // Use Nagle. 108 | tcpConn.SetNoDelay(false) 109 | } 110 | conn, err := server.AuthConn(c, self.privkey, self.config, self.config.HandshakeTimeout) 111 | if err != nil { 112 | if err != server.ErrAuthFail { 113 | self.config.OnError(c.RemoteAddr(), err) 114 | } 115 | c.Close() 116 | return 117 | } 118 | srv := conn.Service() 119 | 120 | center := self.getServiceCenter(srv) 121 | if center == nil { 122 | self.config.OnError(c.RemoteAddr(), fmt.Errorf("unknown service: %v", srv)) 123 | c.Close() 124 | return 125 | } 126 | center.NewConn(conn) 127 | return 128 | } 129 | 130 | func (self *MessageCenter) Start() { 131 | for { 132 | conn, err := self.ln.Accept() 133 | if err != nil { 134 | if ne, ok := err.(net.Error); ok { 135 | // It's a temporary error. 136 | if ne.Temporary() { 137 | continue 138 | } 139 | } 140 | self.config.OnError(self.ln.Addr(), err) 141 | return 142 | } 143 | go self.serveConn(conn) 144 | } 145 | } 146 | 147 | // NOTE: you cannot restart it! 148 | func (self *MessageCenter) Stop() { 149 | self.srvCentersLock.Lock() 150 | 151 | for _, center := range self.serviceCenterMap { 152 | center.Stop() 153 | } 154 | self.ln.Close() 155 | close(self.fwdChan) 156 | } 157 | 158 | func (self *MessageCenter) AllServices() []string { 159 | self.srvCentersLock.Lock() 160 | defer self.srvCentersLock.Lock() 161 | ret := make([]string, 0, len(self.serviceCenterMap)) 162 | for srv, _ := range self.serviceCenterMap { 163 | ret = append(ret, srv) 164 | } 165 | return ret 166 | } 167 | 168 | func (self *MessageCenter) AllUsernames(srv string) []string { 169 | center := self.getServiceCenter(srv) 170 | if center == nil { 171 | return nil 172 | } 173 | return center.AllUsernames() 174 | } 175 | 176 | func (self *MessageCenter) NrConns(srv string) int { 177 | center := self.getServiceCenter(srv) 178 | if center == nil { 179 | return 0 180 | } 181 | return center.NrConns() 182 | } 183 | 184 | func (self *MessageCenter) NrUsers(srv string) int { 185 | center := self.getServiceCenter(srv) 186 | if center == nil { 187 | return 0 188 | } 189 | return center.NrUsers() 190 | } 191 | 192 | func (self *MessageCenter) do(srv string, f func(center *serviceCenter) *rpc.Result) *rpc.Result { 193 | center := self.getServiceCenter(srv) 194 | if center == nil { 195 | ret := new(rpc.Result) 196 | //ret.SetError(fmt.Errorf("unknown service: %v", srv)) 197 | return ret 198 | } 199 | return f(center) 200 | } 201 | 202 | func (self *MessageCenter) Send(req *rpc.SendRequest) *rpc.Result { 203 | return self.do(req.ReceiverService, func(center *serviceCenter) *rpc.Result { 204 | return center.Send(req) 205 | }) 206 | } 207 | 208 | func (self *MessageCenter) Forward(req *rpc.ForwardRequest) *rpc.Result { 209 | return self.do(req.ReceiverService, func(center *serviceCenter) *rpc.Result { 210 | return center.Forward(req) 211 | }) 212 | } 213 | 214 | func (self *MessageCenter) Redirect(req *rpc.RedirectRequest) *rpc.Result { 215 | return self.do(req.ReceiverService, func(center *serviceCenter) *rpc.Result { 216 | return center.Redirect(req) 217 | }) 218 | } 219 | 220 | func (self *MessageCenter) CheckUserStatus(req *rpc.UserStatusQuery) *rpc.Result { 221 | return self.do(req.Service, func(center *serviceCenter) *rpc.Result { 222 | return center.CheckUserStatus(req) 223 | }) 224 | } 225 | 226 | func (self *MessageCenter) AddPeer(peer rpc.UniqushConnPeer) { 227 | self.peers.AddPeer(peer) 228 | } 229 | -------------------------------------------------------------------------------- /msgcenter/srvcenter.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package msgcenter 19 | 20 | import ( 21 | "errors" 22 | "fmt" 23 | "github.com/uniqush/uniqush-conn/config" 24 | "github.com/uniqush/uniqush-conn/proto/server" 25 | "github.com/uniqush/uniqush-conn/rpc" 26 | "io" 27 | "strings" 28 | ) 29 | 30 | type serviceCenter struct { 31 | config *config.ServiceConfig 32 | fwdChan chan<- *rpc.ForwardRequest 33 | subReqChan chan *rpc.SubscribeRequest 34 | conns connMap 35 | peer rpc.UniqushConnPeer 36 | } 37 | 38 | func (self *serviceCenter) Stop() { 39 | self.conns.CloseAll() 40 | close(self.subReqChan) 41 | } 42 | 43 | func (self *serviceCenter) serveConn(conn server.Conn) { 44 | var reason error 45 | defer func() { 46 | self.config.OnLogout(conn, reason) 47 | self.conns.DelConn(conn) 48 | conn.Close() 49 | }() 50 | for { 51 | msg, err := conn.ReceiveMessage() 52 | if err != nil { 53 | if err != io.EOF { 54 | self.config.OnError(conn, err) 55 | reason = err 56 | } 57 | return 58 | } 59 | if msg != nil { 60 | self.config.OnMessage(conn, msg) 61 | } 62 | } 63 | } 64 | 65 | func (self *serviceCenter) NewConn(conn server.Conn) { 66 | if conn == nil { 67 | //self.config.OnError(conn, fmt.Errorf("Nil conn") 68 | return 69 | } 70 | usr := conn.Username() 71 | if len(usr) == 0 || strings.Contains(usr, ":") || strings.Contains(usr, "\n") { 72 | self.config.OnError(conn, fmt.Errorf("invalid username")) 73 | conn.Close() 74 | return 75 | } 76 | conn.SetMessageCache(self.config.Cache()) 77 | conn.SetForwardRequestChannel(self.fwdChan) 78 | conn.SetSubscribeRequestChan(self.subReqChan) 79 | err := self.conns.AddConn(conn) 80 | if err != nil { 81 | self.config.OnError(conn, err) 82 | conn.Close() 83 | return 84 | } 85 | 86 | self.config.OnLogin(conn) 87 | go self.serveConn(conn) 88 | return 89 | } 90 | 91 | func (self *serviceCenter) Send(req *rpc.SendRequest) *rpc.Result { 92 | ret := new(rpc.Result) 93 | 94 | if req == nil { 95 | ret.SetError(fmt.Errorf("invalid request")) 96 | return ret 97 | } 98 | if req.Message == nil || req.Message.IsEmpty() { 99 | ret.SetError(fmt.Errorf("invalid request: empty message")) 100 | return ret 101 | } 102 | if len(req.Receivers) == 0 { 103 | ret.SetError(fmt.Errorf("invalid request: no receiver")) 104 | return ret 105 | } 106 | 107 | shouldPush := !req.DontPush 108 | shouldCache := !req.DontCache 109 | shouldPropagate := !req.DontPropagate 110 | receivers := req.Receivers 111 | 112 | for _, recver := range receivers { 113 | mid := req.Id 114 | msg := req.Message 115 | 116 | if shouldCache { 117 | mc := &rpc.MessageContainer{ 118 | Sender: "", 119 | SenderService: "", 120 | Message: msg, 121 | } 122 | var err error 123 | mid, err = self.config.CacheMessage(recver, mc, req.TTL) 124 | if err != nil { 125 | ret.SetError(err) 126 | return ret 127 | } 128 | } 129 | 130 | if len(mid) == 0 { 131 | ret.SetError(fmt.Errorf("undefined message Id")) 132 | return ret 133 | } 134 | 135 | n := 0 136 | 137 | conns := self.conns.GetConn(recver) 138 | conns.Traverse(func(conn server.Conn) error { 139 | err := conn.SendMessage(msg, mid, nil, !req.NeverDigest) 140 | ret.Append(conn, err) 141 | if err != nil { 142 | conn.Close() 143 | // We won't delete this connection here. 144 | // Instead, we close it and let the reader 145 | // goroutine detect the error and close it. 146 | } else { 147 | n++ 148 | } 149 | // Instead of returning an error, 150 | // we wourld rather let the Traverse() move forward. 151 | return nil 152 | }) 153 | 154 | // Don't propagate this request to other instances in the cluster. 155 | req.DontPropagate = true 156 | // Don't push the message. We will push it on this node. 157 | req.DontPush = true 158 | // Don't cache the message. We have already cached it. 159 | req.DontCache = true 160 | req.Id = mid 161 | req.Receivers = []string{recver} 162 | 163 | if shouldPropagate { 164 | r := self.peer.Send(req) 165 | n += r.NrSuccess() 166 | ret.Join(r) 167 | } 168 | 169 | if n == 0 && shouldPush { 170 | self.config.Push(recver, "", "", req.PushInfo, mid, msg.Size()) 171 | } 172 | } 173 | return ret 174 | } 175 | 176 | func (self *serviceCenter) Forward(req *rpc.ForwardRequest) *rpc.Result { 177 | ret := new(rpc.Result) 178 | 179 | if req == nil { 180 | ret.SetError(fmt.Errorf("invalid request")) 181 | return ret 182 | } 183 | if req.Message == nil || req.Message.IsEmpty() { 184 | ret.SetError(fmt.Errorf("invalid request: empty message")) 185 | return ret 186 | } 187 | if len(req.Receivers) == 0 { 188 | ret.SetError(fmt.Errorf("invalid request: no receiver")) 189 | return ret 190 | } 191 | 192 | mid := req.Id 193 | msg := req.Message 194 | mc := &req.MessageContainer 195 | 196 | var pushInfo map[string]string 197 | var shouldForward bool 198 | shouldPush := !req.DontPush 199 | shouldCache := !req.DontCache 200 | shouldPropagate := !req.DontPropagate 201 | 202 | if !req.DontAsk { 203 | // We need to ask for permission to forward this message. 204 | // This means the forward request is generated directly from a user, 205 | // not from a uniqush-conn node in a cluster. 206 | 207 | mc.Id = "" 208 | shouldForward, shouldPush, pushInfo = self.config.ShouldForward(req) 209 | 210 | if !shouldForward { 211 | return nil 212 | } 213 | } 214 | 215 | receivers := req.Receivers 216 | 217 | for _, recver := range receivers { 218 | if shouldCache { 219 | var err error 220 | mid, err = self.config.CacheMessage(recver, mc, req.TTL) 221 | if err != nil { 222 | ret.SetError(err) 223 | return ret 224 | } 225 | } 226 | 227 | if len(mid) == 0 { 228 | ret.SetError(fmt.Errorf("undefined message Id")) 229 | return ret 230 | } 231 | 232 | n := 0 233 | 234 | conns := self.conns.GetConn(recver) 235 | conns.Traverse(func(conn server.Conn) error { 236 | err := conn.ForwardMessage(req.Sender, req.SenderService, msg, mid, !req.NeverDigest) 237 | ret.Append(conn, err) 238 | if err != nil { 239 | conn.Close() 240 | // We won't delete this connection here. 241 | // Instead, we close it and let the reader 242 | // goroutine detect the error and close it. 243 | } else { 244 | n++ 245 | } 246 | // Instead of returning an error, 247 | // we wourld rather let the Traverse() move forward. 248 | return nil 249 | }) 250 | 251 | // forward the message if possible, 252 | // Don't propagate this request to other instances in the cluster. 253 | req.DontPropagate = true 254 | // Don't ask the permission to forward (we have already got the permission) 255 | req.DontAsk = true 256 | // And don't push the message. We will push it on this node. 257 | req.DontPush = true 258 | // Dont' cache it 259 | req.DontCache = true 260 | req.Id = mid 261 | req.Receivers = []string{recver} 262 | 263 | if shouldPropagate { 264 | r := self.peer.Forward(req) 265 | n += r.NrSuccess() 266 | ret.Join(r) 267 | } 268 | 269 | if n == 0 && shouldPush { 270 | self.config.Push(recver, req.SenderService, req.Sender, pushInfo, mid, msg.Size()) 271 | } 272 | } 273 | return ret 274 | } 275 | 276 | func (self *serviceCenter) Redirect(req *rpc.RedirectRequest) *rpc.Result { 277 | conns := self.conns.GetConn(req.Receiver) 278 | var sc server.Conn 279 | result := new(rpc.Result) 280 | conns.Traverse(func(conn server.Conn) error { 281 | if len(req.ConnId) == 0 || conn.UniqId() == req.ConnId { 282 | sc = conn 283 | result.Append(sc, nil) 284 | return errors.New("done") 285 | } 286 | return nil 287 | }) 288 | if sc != nil { 289 | // self.conns.DelConn(sc) 290 | sc.Redirect(req.CandidateSersers...) 291 | sc.Close() 292 | return result 293 | } 294 | 295 | if req.DontPropagate { 296 | return result 297 | } 298 | req.DontPropagate = true 299 | return self.peer.Redirect(req) 300 | } 301 | 302 | func (self *serviceCenter) CheckUserStatus(req *rpc.UserStatusQuery) *rpc.Result { 303 | conns := self.conns.GetConn(req.Username) 304 | result := new(rpc.Result) 305 | conns.Traverse(func(conn server.Conn) error { 306 | result.Append(conn, nil) 307 | return nil 308 | }) 309 | if req.DontPropagate { 310 | return result 311 | } 312 | req.DontPropagate = true 313 | r := self.peer.CheckUserStatus(req) 314 | result.Join(r) 315 | return result 316 | } 317 | 318 | func (self *serviceCenter) processSubscription() { 319 | for req := range self.subReqChan { 320 | if req == nil { 321 | return 322 | } 323 | go self.config.Subscribe(req) 324 | } 325 | } 326 | 327 | func (self *serviceCenter) NrConns() int { 328 | return self.conns.NrConns() 329 | } 330 | 331 | func (self *serviceCenter) NrUsers() int { 332 | return self.conns.NrUsers() 333 | } 334 | 335 | func (self *serviceCenter) AllUsernames() []string { 336 | return self.conns.AllUsernames() 337 | } 338 | 339 | func newServiceCenter(conf *config.ServiceConfig, fwdChan chan<- *rpc.ForwardRequest, peer rpc.UniqushConnPeer) *serviceCenter { 340 | if conf == nil || fwdChan == nil { 341 | return nil 342 | } 343 | ret := new(serviceCenter) 344 | ret.config = conf 345 | ret.conns = newTreeBasedConnMap(conf.MaxNrConns, conf.MaxNrUsers, conf.MaxNrConnsPerUser) 346 | ret.fwdChan = fwdChan 347 | ret.subReqChan = make(chan *rpc.SubscribeRequest) 348 | ret.peer = peer 349 | 350 | go ret.processSubscription() 351 | return ret 352 | } 353 | -------------------------------------------------------------------------------- /proto/client/conn.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package client 19 | 20 | import ( 21 | "fmt" 22 | "github.com/uniqush/uniqush-conn/proto" 23 | "github.com/uniqush/uniqush-conn/rpc" 24 | "io" 25 | "math/rand" 26 | "net" 27 | "strings" 28 | "sync/atomic" 29 | "time" 30 | ) 31 | 32 | type Conn interface { 33 | Close() error 34 | Service() string 35 | Username() string 36 | UniqId() string 37 | 38 | SendMessageToUsers(msg *rpc.Message, ttl time.Duration, service string, receiver ...string) error 39 | SendMessageToServer(msg *rpc.Message) error 40 | ReceiveMessage() (mc *rpc.MessageContainer, err error) 41 | 42 | Config(digestThreshold, compressThreshold int, digestFields ...string) error 43 | SetDigestChannel(digestChan chan<- *Digest) 44 | SetRedirectChannel(redirChan chan<- *RedirectRequest) 45 | RequestMessage(id string) error 46 | SetVisibility(v bool) error 47 | Subscribe(params map[string]string) error 48 | Unsubscribe(params map[string]string) error 49 | RequestAllCachedMessages(since time.Time) error 50 | } 51 | 52 | type CommandProcessor interface { 53 | ProcessCommand(cmd *proto.Command) (mc *rpc.MessageContainer, err error) 54 | } 55 | 56 | type clientConn struct { 57 | cmdio *proto.CommandIO 58 | conn net.Conn 59 | compressThreshold int32 60 | digestThreshold int32 61 | service string 62 | username string 63 | connId string 64 | cmdProcs []CommandProcessor 65 | } 66 | 67 | func (self *clientConn) Service() string { 68 | return self.service 69 | } 70 | 71 | func (self *clientConn) Username() string { 72 | return self.username 73 | } 74 | 75 | func (self *clientConn) UniqId() string { 76 | return self.connId 77 | } 78 | 79 | func (self *clientConn) Close() error { 80 | return self.conn.Close() 81 | } 82 | 83 | func (self *clientConn) shouldCompress(size int) bool { 84 | t := int(atomic.LoadInt32(&self.compressThreshold)) 85 | if t > 0 && t < size { 86 | return true 87 | } 88 | return false 89 | } 90 | 91 | func (self *clientConn) SendMessageToServer(msg *rpc.Message) error { 92 | compress := self.shouldCompress(msg.Size()) 93 | 94 | cmd := &proto.Command{} 95 | cmd.Message = msg 96 | cmd.Type = proto.CMD_DATA 97 | err := self.cmdio.WriteCommand(cmd, compress) 98 | return err 99 | } 100 | 101 | func (self *clientConn) SendMessageToUsers(msg *rpc.Message, ttl time.Duration, service string, receiver ...string) error { 102 | if len(receiver) == 0 { 103 | return nil 104 | } 105 | cmd := &proto.Command{} 106 | cmd.Type = proto.CMD_FWD_REQ 107 | cmd.Params = make([]string, 2, 3) 108 | cmd.Params[0] = fmt.Sprintf("%v", ttl) 109 | cmd.Params[1] = strings.Join(receiver, ",") 110 | if len(service) > 0 && service != self.Service() { 111 | cmd.Params = append(cmd.Params, service) 112 | } 113 | cmd.Message = msg 114 | compress := self.shouldCompress(msg.Size()) 115 | return self.cmdio.WriteCommand(cmd, compress) 116 | } 117 | 118 | func (self *clientConn) processCommand(cmd *proto.Command) (mc *rpc.MessageContainer, err error) { 119 | if cmd == nil { 120 | return 121 | } 122 | 123 | t := int(cmd.Type) 124 | if t > len(self.cmdProcs) { 125 | return 126 | } 127 | proc := self.cmdProcs[t] 128 | if proc != nil { 129 | mc, err = proc.ProcessCommand(cmd) 130 | } 131 | return 132 | } 133 | 134 | func (self *clientConn) ReceiveMessage() (mc *rpc.MessageContainer, err error) { 135 | var cmd *proto.Command 136 | for { 137 | cmd, err = self.cmdio.ReadCommand() 138 | if err != nil { 139 | return 140 | } 141 | switch cmd.Type { 142 | case proto.CMD_DATA: 143 | mc = new(rpc.MessageContainer) 144 | mc.Message = cmd.Message 145 | if len(cmd.Params[0]) > 0 { 146 | mc.Id = cmd.Params[0] 147 | } 148 | return 149 | case proto.CMD_FWD: 150 | if len(cmd.Params) < 1 { 151 | err = proto.ErrBadPeerImpl 152 | return 153 | } 154 | mc = new(rpc.MessageContainer) 155 | mc.Message = cmd.Message 156 | mc.Sender = cmd.Params[0] 157 | if len(cmd.Params) > 1 { 158 | mc.SenderService = cmd.Params[1] 159 | } else { 160 | mc.SenderService = self.Service() 161 | } 162 | if len(cmd.Params) > 2 { 163 | mc.Id = cmd.Params[2] 164 | } 165 | return 166 | case proto.CMD_BYE: 167 | err = io.EOF 168 | return 169 | default: 170 | mc, err = self.processCommand(cmd) 171 | if err != nil || mc != nil { 172 | return 173 | } 174 | } 175 | } 176 | return 177 | } 178 | 179 | func (self *clientConn) setCommandProcessor(cmdType uint8, proc CommandProcessor) { 180 | if cmdType >= proto.CMD_NR_CMDS { 181 | return 182 | } 183 | if len(self.cmdProcs) <= int(cmdType) { 184 | self.cmdProcs = make([]CommandProcessor, proto.CMD_NR_CMDS) 185 | } 186 | self.cmdProcs[cmdType] = proc 187 | } 188 | 189 | func (self *clientConn) SetRedirectChannel(redirChan chan<- *RedirectRequest) { 190 | if redirChan == nil { 191 | return 192 | } 193 | 194 | proc := &redirectProcessor{ 195 | redirChan: redirChan, 196 | } 197 | 198 | self.setCommandProcessor(proto.CMD_REDIRECT, proc) 199 | } 200 | 201 | func (self *clientConn) SetDigestChannel(digestChan chan<- *Digest) { 202 | if digestChan == nil { 203 | return 204 | } 205 | proc := &digestProcessor{} 206 | proc.digestChan = digestChan 207 | proc.service = self.Service() 208 | self.setCommandProcessor(proto.CMD_DIGEST, proc) 209 | } 210 | 211 | func (self *clientConn) Config(digestThreshold, compressThreshold int, digestFields ...string) error { 212 | self.digestThreshold = int32(digestThreshold) 213 | self.compressThreshold = int32(compressThreshold) 214 | cmd := &proto.Command{} 215 | cmd.Type = proto.CMD_SETTING 216 | cmd.Params = make([]string, 2, 2+len(digestFields)) 217 | cmd.Params[0] = fmt.Sprintf("%v", self.digestThreshold) 218 | cmd.Params[1] = fmt.Sprintf("%v", self.compressThreshold) 219 | for _, f := range digestFields { 220 | cmd.Params = append(cmd.Params, f) 221 | } 222 | err := self.cmdio.WriteCommand(cmd, false) 223 | return err 224 | } 225 | 226 | func (self *clientConn) RequestMessage(id string) error { 227 | cmd := &proto.Command{ 228 | Type: proto.CMD_MSG_RETRIEVE, 229 | Params: []string{id}, 230 | } 231 | return self.cmdio.WriteCommand(cmd, false) 232 | } 233 | 234 | func (self *clientConn) SetVisibility(v bool) error { 235 | cmd := &proto.Command{ 236 | Type: proto.CMD_SET_VISIBILITY, 237 | } 238 | if v { 239 | cmd.Params = []string{"1"} 240 | } else { 241 | cmd.Params = []string{"0"} 242 | } 243 | return self.cmdio.WriteCommand(cmd, false) 244 | } 245 | 246 | func (self *clientConn) subscribe(params map[string]string, sub bool) error { 247 | cmd := &proto.Command{} 248 | cmd.Type = proto.CMD_SUBSCRIPTION 249 | if sub { 250 | cmd.Params = []string{"1"} 251 | } else { 252 | cmd.Params = []string{"0"} 253 | } 254 | cmd.Message = &rpc.Message{} 255 | cmd.Message.Header = params 256 | return self.cmdio.WriteCommand(cmd, false) 257 | } 258 | 259 | func (self *clientConn) Subscribe(params map[string]string) error { 260 | return self.subscribe(params, true) 261 | } 262 | 263 | func (self *clientConn) Unsubscribe(params map[string]string) error { 264 | return self.subscribe(params, false) 265 | } 266 | 267 | func (self *clientConn) RequestAllCachedMessages(since time.Time) error { 268 | cmd := &proto.Command{} 269 | cmd.Type = proto.CMD_REQ_ALL_CACHED 270 | if !since.IsZero() { 271 | cmd.Params = []string{fmt.Sprintf("%v", since.Unix())} 272 | } 273 | return self.cmdio.WriteCommand(cmd, false) 274 | } 275 | 276 | func NewConn(cmdio *proto.CommandIO, service, username string, conn net.Conn) Conn { 277 | ret := new(clientConn) 278 | ret.conn = conn 279 | ret.cmdio = cmdio 280 | ret.service = service 281 | ret.username = username 282 | ret.connId = fmt.Sprintf("%x-%x", time.Now().UnixNano(), rand.Int63()) 283 | 284 | ret.cmdProcs = make([]CommandProcessor, proto.CMD_NR_CMDS) 285 | return ret 286 | } 287 | -------------------------------------------------------------------------------- /proto/client/dial.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package client 19 | 20 | import ( 21 | "crypto/rsa" 22 | "errors" 23 | "github.com/uniqush/uniqush-conn/proto" 24 | 25 | "net" 26 | "strings" 27 | "time" 28 | ) 29 | 30 | var ErrBadServiceOrUserName = errors.New("service name or user name should not contain '\\n' or ':'") 31 | 32 | // The conn will be closed if any error occur 33 | func Dial(conn net.Conn, pubkey *rsa.PublicKey, service, username, token string, timeout time.Duration) (c Conn, err error) { 34 | if strings.Contains(service, "\n") || strings.Contains(username, "\n") || 35 | strings.Contains(service, ":") || strings.Contains(username, ":") || 36 | strings.Contains(service, ",") || strings.Contains(username, ",") { 37 | err = ErrBadServiceOrUserName 38 | return 39 | } 40 | conn.SetDeadline(time.Now().Add(timeout)) 41 | defer func() { 42 | conn.SetDeadline(time.Time{}) 43 | if err != nil { 44 | conn.Close() 45 | } 46 | }() 47 | 48 | ks, err := proto.ClientKeyExchange(pubkey, conn) 49 | if err != nil { 50 | return 51 | } 52 | cmdio := ks.ClientCommandIO(conn) 53 | 54 | cmd := new(proto.Command) 55 | cmd.Type = proto.CMD_AUTH 56 | cmd.Params = make([]string, 3) 57 | cmd.Params[0] = service 58 | cmd.Params[1] = username 59 | cmd.Params[2] = token 60 | 61 | // don't compress, but encrypt it 62 | cmdio.WriteCommand(cmd, false) 63 | 64 | cmd, err = cmdio.ReadCommand() 65 | if err != nil { 66 | return 67 | } 68 | if cmd.Type == proto.CMD_REDIRECT { 69 | r := new(RedirectRequest) 70 | r.Addresses = cmd.Params 71 | return nil, r 72 | } 73 | if cmd.Type != proto.CMD_AUTHOK { 74 | return 75 | } 76 | c = NewConn(cmdio, service, username, conn) 77 | err = nil 78 | return 79 | } 80 | -------------------------------------------------------------------------------- /proto/client/digestproc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package client 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/proto" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | "strconv" 24 | ) 25 | 26 | type Digest struct { 27 | MsgId string 28 | Sender string 29 | SenderService string 30 | Size int 31 | Info map[string]string 32 | } 33 | 34 | type digestProcessor struct { 35 | digestChan chan<- *Digest 36 | service string 37 | } 38 | 39 | func (self *digestProcessor) ProcessCommand(cmd *proto.Command) (mc *rpc.MessageContainer, err error) { 40 | if cmd.Type != proto.CMD_DIGEST || self.digestChan == nil { 41 | return 42 | } 43 | if len(cmd.Params) < 2 { 44 | err = proto.ErrBadPeerImpl 45 | return 46 | } 47 | digest := new(Digest) 48 | digest.Size, err = strconv.Atoi(cmd.Params[0]) 49 | if err != nil { 50 | err = proto.ErrBadPeerImpl 51 | return 52 | } 53 | digest.MsgId = cmd.Params[1] 54 | if cmd.Message != nil { 55 | digest.Info = cmd.Message.Header 56 | } 57 | if len(cmd.Params) > 2 { 58 | digest.Sender = cmd.Params[2] 59 | if len(cmd.Params) > 3 { 60 | digest.SenderService = cmd.Params[3] 61 | } else { 62 | digest.SenderService = self.service 63 | } 64 | } 65 | self.digestChan <- digest 66 | 67 | return 68 | } 69 | -------------------------------------------------------------------------------- /proto/client/redirproc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package client 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/proto" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | ) 24 | 25 | type RedirectRequest struct { 26 | Addresses []string 27 | } 28 | 29 | func (self *RedirectRequest) Error() string { 30 | return "redirect" 31 | } 32 | 33 | type redirectProcessor struct { 34 | redirChan chan<- *RedirectRequest 35 | } 36 | 37 | func (self *redirectProcessor) ProcessCommand(cmd *proto.Command) (mc *rpc.MessageContainer, err error) { 38 | if cmd.Type != proto.CMD_REDIRECT || self.redirChan == nil { 39 | return 40 | } 41 | 42 | if len(cmd.Params) == 0 { 43 | return 44 | } 45 | 46 | req := new(RedirectRequest) 47 | req.Addresses = cmd.Params 48 | self.redirChan <- req 49 | return 50 | } 51 | -------------------------------------------------------------------------------- /proto/cmd_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "fmt" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | "labix.org/v2/mgo/bson" 24 | "testing" 25 | ) 26 | 27 | func marshalUnmarshal(cmd *Command) error { 28 | data, err := cmd.Marshal() 29 | if err != nil { 30 | return err 31 | } 32 | 33 | c, err := UnmarshalCommand(data) 34 | if !c.eq(cmd) { 35 | return fmt.Errorf("Not same") 36 | } 37 | return nil 38 | } 39 | 40 | func TestCommandMarshalNoParamsNoMessage(t *testing.T) { 41 | cmd := new(Command) 42 | cmd.Type = 1 43 | marshalUnmarshal(cmd) 44 | } 45 | 46 | func TestCommandMarshalNoParams(t *testing.T) { 47 | cmd := new(Command) 48 | cmd.Type = 1 49 | cmd.Params = make([]string, 2) 50 | cmd.Params[0] = "hello" 51 | cmd.Params[1] = "" 52 | marshalUnmarshal(cmd) 53 | } 54 | 55 | func TestCommandMarshalNoBody(t *testing.T) { 56 | cmd := new(Command) 57 | cmd.Type = 1 58 | cmd.Params = make([]string, 2) 59 | cmd.Params[0] = "hello" 60 | cmd.Params[1] = "" 61 | cmd.Message = new(rpc.Message) 62 | cmd.Message.Header = make(map[string]string, 3) 63 | cmd.Message.Header["a"] = "h" 64 | cmd.Message.Header["b"] = "i" 65 | cmd.Message.Header["b"] = "j" 66 | marshalUnmarshal(cmd) 67 | } 68 | 69 | func TestCommandMarshal(t *testing.T) { 70 | cmd := new(Command) 71 | cmd.Type = 1 72 | cmd.Params = make([]string, 2) 73 | cmd.Params[0] = "hello" 74 | cmd.Params[1] = "new" 75 | cmd.Message = new(rpc.Message) 76 | cmd.Message.Header = make(map[string]string, 3) 77 | cmd.Message.Header["a"] = "h" 78 | cmd.Message.Header["b"] = "i" 79 | cmd.Message.Header["b"] = "j" 80 | cmd.Message.Body = []byte{1, 2, 3, 3} 81 | marshalUnmarshal(cmd) 82 | } 83 | 84 | func TestRandomize(t *testing.T) { 85 | cmd := new(Command) 86 | cmd.Type = 1 87 | cmd.Params = make([]string, 2) 88 | cmd.Params[0] = "hello" 89 | cmd.Params[1] = "new" 90 | cmd.Message = new(rpc.Message) 91 | cmd.Message.Header = make(map[string]string, 3) 92 | cmd.Message.Header["a"] = "h" 93 | cmd.Message.Header["b"] = "i" 94 | cmd.Message.Header["b"] = "j" 95 | cmd.Message.Body = []byte{1, 2, 3, 3} 96 | marshalUnmarshal(cmd) 97 | cmd.Randomize() 98 | } 99 | 100 | func BenchmarkCommandMarshalUnmarshal(b *testing.B) { 101 | b.StopTimer() 102 | cmds := make([]*Command, b.N) 103 | for i, _ := range cmds { 104 | cmd := randomCommand() 105 | cmds[i] = cmd 106 | } 107 | b.StartTimer() 108 | for _, cmd := range cmds { 109 | data, _ := cmd.Marshal() 110 | UnmarshalCommand(data) 111 | } 112 | } 113 | 114 | func BenchmarkCommandMarshalUnmarshalBson(b *testing.B) { 115 | b.StopTimer() 116 | cmds := make([]*Command, b.N) 117 | for i, _ := range cmds { 118 | cmd := randomCommand() 119 | cmds[i] = cmd 120 | } 121 | b.StartTimer() 122 | for _, cmd := range cmds { 123 | data, _ := bson.Marshal(cmd) 124 | c := new(Command) 125 | bson.Unmarshal(data, c) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /proto/cmdio.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "code.google.com/p/snappy-go/snappy" 22 | "crypto/aes" 23 | "crypto/cipher" 24 | "crypto/hmac" 25 | "crypto/sha256" 26 | "encoding/binary" 27 | "hash" 28 | "io" 29 | "sync" 30 | ) 31 | 32 | type CommandIO struct { 33 | writeAuth hash.Hash 34 | cryptWriter io.Writer 35 | readAuth hash.Hash 36 | cryptReader io.Reader 37 | conn io.ReadWriter 38 | 39 | writeLock *sync.Mutex 40 | } 41 | 42 | func (self *CommandIO) writeThenHmac(data []byte) (mac []byte, err error) { 43 | writer := self.cryptWriter 44 | self.writeAuth.Reset() 45 | var datalen uint16 46 | datalen = uint16(len(data)) 47 | err = binary.Write(self.writeAuth, binary.LittleEndian, datalen) 48 | if err != nil { 49 | return 50 | } 51 | err = writen(writer, data) 52 | if err != nil { 53 | return 54 | } 55 | mac = self.writeAuth.Sum(nil) 56 | return 57 | } 58 | 59 | func (self *CommandIO) readThenHmac(data []byte) (mac []byte, err error) { 60 | reader := self.cryptReader 61 | self.readAuth.Reset() 62 | 63 | var datalen uint16 64 | datalen = uint16(len(data)) 65 | err = binary.Write(self.readAuth, binary.LittleEndian, datalen) 66 | if err != nil { 67 | return 68 | } 69 | n, err := io.ReadFull(reader, data) 70 | if err != nil { 71 | return 72 | } 73 | if n != len(data) { 74 | err = io.EOF 75 | return 76 | } 77 | mac = self.readAuth.Sum(nil) 78 | return 79 | } 80 | 81 | func (self *CommandIO) writeHmac(mac []byte) error { 82 | if len(mac) == 0 { 83 | return nil 84 | } 85 | return writen(self.conn, mac) 86 | } 87 | 88 | func (self *CommandIO) readAndCmpHmac(mac []byte) error { 89 | if len(mac) == 0 { 90 | return nil 91 | } 92 | macRecved := make([]byte, self.readAuth.Size()) 93 | n, err := io.ReadFull(self.conn, macRecved) 94 | if err != nil { 95 | return err 96 | } 97 | if n != len(macRecved) { 98 | return ErrCorruptedData 99 | } 100 | if !xorBytesEq(mac, macRecved) { 101 | return ErrCorruptedData 102 | } 103 | return nil 104 | } 105 | 106 | func (self *CommandIO) decodeCommand(data []byte) (cmd *Command, err error) { 107 | // Flag: 8 bit 108 | // Most significant 5 bits: number of bytes of padding 109 | // Least significant bit: compress bit 110 | compress := ((data[0] & cmdflag_COMPRESS) != 0) 111 | var npadding int 112 | npadding = int(data[0] >> 3) 113 | data = data[1 : len(data)-npadding] 114 | decoded := data 115 | if compress { 116 | decoded, err = snappy.Decode(nil, data) 117 | if err != nil { 118 | return 119 | } 120 | } 121 | cmd, err = UnmarshalCommand(decoded) 122 | if err != nil { 123 | return 124 | } 125 | return 126 | } 127 | 128 | func (self *CommandIO) encodeCommand(cmd *Command, compress bool) (data []byte, err error) { 129 | bsonEncoded, err := cmd.Marshal() 130 | if err != nil { 131 | return 132 | } 133 | 134 | data = bsonEncoded 135 | if compress { 136 | data, err = snappy.Encode(nil, bsonEncoded) 137 | if err != nil { 138 | return 139 | } 140 | } 141 | var flag byte 142 | if compress { 143 | flag |= cmdflag_COMPRESS 144 | } 145 | // one byte flag 146 | nrBlk := (len(data) + blkLen) / blkLen 147 | npadding := (nrBlk * blkLen) - (len(data) + 1) 148 | flag |= byte((npadding & 0xFF) << 3) 149 | 150 | data = append(data, 0) 151 | copy(data[1:], data[:len(data)-1]) 152 | data[0] = flag 153 | 154 | data = append(data, make([]byte, npadding)...) 155 | return 156 | } 157 | 158 | // WriteCommand() is goroutine-safe. i.e. Multiple goroutine could write concurrently. 159 | func (self *CommandIO) WriteCommand(cmd *Command, compress bool) error { 160 | data, err := self.encodeCommand(cmd, compress) 161 | if err != nil { 162 | return err 163 | } 164 | var cmdLen uint16 165 | cmdLen = uint16(len(data)) 166 | if cmdLen == 0 { 167 | return nil 168 | } 169 | self.writeLock.Lock() 170 | defer self.writeLock.Unlock() 171 | err = binary.Write(self.conn, binary.LittleEndian, cmdLen) 172 | if err != nil { 173 | return err 174 | } 175 | mac, err := self.writeThenHmac(data) 176 | if err != nil { 177 | return err 178 | } 179 | err = self.writeHmac(mac) 180 | if err != nil { 181 | return err 182 | } 183 | return nil 184 | } 185 | 186 | // ReadCommand() is not goroutine-safe. 187 | func (self *CommandIO) ReadCommand() (cmd *Command, err error) { 188 | var cmdLen uint16 189 | err = binary.Read(self.conn, binary.LittleEndian, &cmdLen) 190 | if err != nil { 191 | return 192 | } 193 | 194 | data := make([]byte, int(cmdLen)) 195 | mac, err := self.readThenHmac(data) 196 | if err != nil { 197 | return 198 | } 199 | err = self.readAndCmpHmac(mac) 200 | if err != nil { 201 | return 202 | } 203 | cmd, err = self.decodeCommand(data) 204 | if cmd.Type == CMD_BYE { 205 | err = io.EOF 206 | return 207 | } 208 | return 209 | } 210 | 211 | func NewCommandIO(writeKey, writeAuthKey, readKey, readAuthKey []byte, conn io.ReadWriter) *CommandIO { 212 | ret := new(CommandIO) 213 | ret.writeAuth = hmac.New(sha256.New, writeAuthKey) 214 | ret.readAuth = hmac.New(sha256.New, readAuthKey) 215 | ret.conn = conn 216 | ret.writeLock = new(sync.Mutex) 217 | 218 | writeBlkCipher, _ := aes.NewCipher(writeKey) 219 | readBlkCipher, _ := aes.NewCipher(readKey) 220 | 221 | // IV: 0 for all. Since we change keys for each connection, letting IV=0 won't hurt. 222 | writeIV := make([]byte, writeBlkCipher.BlockSize()) 223 | readIV := make([]byte, readBlkCipher.BlockSize()) 224 | 225 | writeStream := cipher.NewCTR(writeBlkCipher, writeIV) 226 | readStream := cipher.NewCTR(readBlkCipher, readIV) 227 | 228 | // Then for each encrypted bit, 229 | // it will be written to both the connection and the hmac 230 | // We use encrypt-then-hmac scheme. 231 | mwriter := io.MultiWriter(conn, ret.writeAuth) 232 | swriter := new(cipher.StreamWriter) 233 | swriter.S = writeStream 234 | swriter.W = mwriter 235 | ret.cryptWriter = swriter 236 | 237 | // Similarly, for each bit read from the connection, 238 | // it will be written to the hmac as well. 239 | tee := io.TeeReader(conn, ret.readAuth) 240 | sreader := new(cipher.StreamReader) 241 | sreader.S = readStream 242 | sreader.R = tee 243 | ret.cryptReader = sreader 244 | return ret 245 | } 246 | -------------------------------------------------------------------------------- /proto/cmdio_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "bytes" 22 | "crypto/hmac" 23 | "crypto/rand" 24 | "crypto/sha256" 25 | "fmt" 26 | "github.com/uniqush/uniqush-conn/rpc" 27 | "io" 28 | "testing" 29 | ) 30 | 31 | type opBetweenWriteAndRead interface { 32 | Op() 33 | } 34 | 35 | func testSendingCommands(t *testing.T, op opBetweenWriteAndRead, compress, encrypt bool, from, to *CommandIO, cmds ...*Command) { 36 | errCh := make(chan error) 37 | startRead := make(chan bool) 38 | go func() { 39 | defer close(errCh) 40 | if op != nil { 41 | <-startRead 42 | } 43 | for i, cmd := range cmds { 44 | recved, err := to.ReadCommand() 45 | if err != nil { 46 | errCh <- err 47 | return 48 | } 49 | if !cmd.eq(recved) { 50 | errCh <- fmt.Errorf("%vth command does not equal", i) 51 | } 52 | } 53 | }() 54 | 55 | for _, cmd := range cmds { 56 | err := from.WriteCommand(cmd, compress) 57 | if err != nil { 58 | t.Errorf("Error on write: %v", err) 59 | } 60 | } 61 | if op != nil { 62 | op.Op() 63 | startRead <- true 64 | } 65 | 66 | for err := range errCh { 67 | if err != nil { 68 | t.Errorf("Error on read: %v", err) 69 | } 70 | } 71 | } 72 | 73 | func getBufferCommandIOs(t *testing.T) (io1, io2 *CommandIO, buffer *bytes.Buffer, ks *keySet) { 74 | keybuf := make([]byte, 2*(authKeyLen+encrKeyLen)) 75 | io.ReadFull(rand.Reader, keybuf) 76 | sen := keybuf[:encrKeyLen] 77 | keybuf = keybuf[encrKeyLen:] 78 | sau := keybuf[:authKeyLen] 79 | keybuf = keybuf[authKeyLen:] 80 | cen := keybuf[:encrKeyLen] 81 | keybuf = keybuf[encrKeyLen:] 82 | cau := keybuf[:authKeyLen] 83 | keybuf = keybuf[authKeyLen:] 84 | 85 | buffer = new(bytes.Buffer) 86 | ks = newKeySet(sen, sau, cen, cau) 87 | scmdio := ks.ServerCommandIO(buffer) 88 | ccmdio := ks.ClientCommandIO(buffer) 89 | io1 = scmdio 90 | io2 = ccmdio 91 | return 92 | } 93 | 94 | func getNetworkCommandIOs(t *testing.T) (io1, io2 *CommandIO) { 95 | sks, cks, s2c, c2s := exchangeKeysOrReport(t, true) 96 | if sks == nil || cks == nil || s2c == nil || c2s == nil { 97 | return 98 | } 99 | 100 | scmdio := sks.ServerCommandIO(s2c) 101 | ccmdio := cks.ClientCommandIO(c2s) 102 | io1 = scmdio 103 | io2 = ccmdio 104 | return 105 | } 106 | 107 | func TestExchangingFullCommandNoCompressNoEncrypt(t *testing.T) { 108 | cmd := randomCommand() 109 | compress := false 110 | encrypt := false 111 | io1, io2 := getNetworkCommandIOs(t) 112 | testSendingCommands(t, nil, compress, encrypt, io1, io2, cmd) 113 | testSendingCommands(t, nil, compress, encrypt, io2, io1, cmd) 114 | } 115 | 116 | func TestExchangingFullCommandNoCompress(t *testing.T) { 117 | cmd := randomCommand() 118 | compress := false 119 | encrypt := true 120 | io1, io2 := getNetworkCommandIOs(t) 121 | testSendingCommands(t, nil, compress, encrypt, io1, io2, cmd) 122 | testSendingCommands(t, nil, compress, encrypt, io2, io1, cmd) 123 | } 124 | 125 | func TestExchangingFullCommandNoEncrypt(t *testing.T) { 126 | cmd := randomCommand() 127 | compress := true 128 | encrypt := false 129 | io1, io2 := getNetworkCommandIOs(t) 130 | testSendingCommands(t, nil, compress, encrypt, io1, io2, cmd) 131 | testSendingCommands(t, nil, compress, encrypt, io2, io1, cmd) 132 | } 133 | 134 | type bufPrinter struct { 135 | buf *bytes.Buffer 136 | authKey []byte 137 | } 138 | 139 | func (self *bufPrinter) Op() { 140 | fmt.Printf("--------------\n") 141 | fmt.Printf("Data in buffer: %v\n", self.buf.Bytes()) 142 | 143 | data := self.buf.Bytes() 144 | data = data[16:] 145 | 146 | hash := hmac.New(sha256.New, self.authKey) 147 | hash.Write(data) 148 | 149 | fmt.Printf("HMAC: %v\n", hash.Sum(nil)) 150 | fmt.Printf("--------------\n") 151 | } 152 | 153 | func TestExchangingFullCommandOverNetwork(t *testing.T) { 154 | cmd := randomCommand() 155 | compress := true 156 | encrypt := true 157 | io1, io2 := getNetworkCommandIOs(t) 158 | testSendingCommands(t, nil, compress, encrypt, io1, io2, cmd) 159 | testSendingCommands(t, nil, compress, encrypt, io2, io1, cmd) 160 | } 161 | 162 | func TestExchangingFullCommandInBuffer(t *testing.T) { 163 | cmd := randomCommand() 164 | compress := true 165 | encrypt := true 166 | io1, io2 := getNetworkCommandIOs(t) 167 | io1, io2, buffer, ks := getBufferCommandIOs(t) 168 | op := &bufPrinter{buffer, ks.serverAuthKey} 169 | testSendingCommands(t, op, compress, encrypt, io1, io2, cmd) 170 | 171 | op = &bufPrinter{buffer, ks.clientAuthKey} 172 | testSendingCommands(t, op, compress, encrypt, io2, io1, cmd) 173 | } 174 | 175 | func randomCommand() *Command { 176 | cmd := new(Command) 177 | cmd.Type = 1 178 | cmd.Params = make([]string, 2) 179 | cmd.Params[0] = "123" 180 | cmd.Params[1] = "223" 181 | cmd.Message = new(rpc.Message) 182 | cmd.Message.Header = make(map[string]string, 2) 183 | cmd.Message.Header["a"] = "hello" 184 | cmd.Message.Header["b"] = "hell" 185 | cmd.Message.Body = make([]byte, 10) 186 | io.ReadFull(rand.Reader, cmd.Message.Body) 187 | return cmd 188 | } 189 | 190 | func TestExchangingMultiFullCommandOverNetwork(t *testing.T) { 191 | cmds := make([]*Command, 100) 192 | for i, _ := range cmds { 193 | cmd := randomCommand() 194 | cmds[i] = cmd 195 | } 196 | 197 | compress := true 198 | encrypt := true 199 | io1, io2 := getNetworkCommandIOs(t) 200 | testSendingCommands(t, nil, compress, encrypt, io1, io2, cmds...) 201 | testSendingCommands(t, nil, compress, encrypt, io2, io1, cmds...) 202 | } 203 | 204 | func TestConcurrentWrite(t *testing.T) { 205 | N := 100 206 | cmds := make([]*Command, N) 207 | for i, _ := range cmds { 208 | cmd := randomCommand() 209 | cmds[i] = cmd 210 | } 211 | io1, io2 := getNetworkCommandIOs(nil) 212 | done := make(chan bool) 213 | go func() { 214 | defer close(done) 215 | for i := 0; i < N; i++ { 216 | io2.ReadCommand() 217 | } 218 | }() 219 | 220 | for _, cmd := range cmds { 221 | go io1.WriteCommand(cmd, true) 222 | } 223 | <-done 224 | } 225 | 226 | func BenchmarkExchangingMultiFullCommandOverNetwork(b *testing.B) { 227 | b.StopTimer() 228 | cmds := make([]*Command, b.N) 229 | for i, _ := range cmds { 230 | cmd := randomCommand() 231 | cmds[i] = cmd 232 | } 233 | io1, io2 := getNetworkCommandIOs(nil) 234 | done := make(chan bool) 235 | go func() { 236 | defer close(done) 237 | for i := 0; i < b.N; i++ { 238 | io2.ReadCommand() 239 | } 240 | }() 241 | 242 | b.StartTimer() 243 | for _, cmd := range cmds { 244 | io1.WriteCommand(cmd, true) 245 | } 246 | <-done 247 | } 248 | -------------------------------------------------------------------------------- /proto/const.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | const ( 21 | encrKeyLen int = 32 22 | authKeyLen int = 32 23 | ivLen int = 16 24 | blkLen int = 16 25 | hmacLen int = 32 26 | pssSaltLen int = 32 27 | dhGroupID int = 0 28 | dhPubkeyLen int = 256 29 | nonceLen int = 32 30 | ) 31 | -------------------------------------------------------------------------------- /proto/keyex.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "crypto" 22 | "crypto/rand" 23 | "crypto/rsa" 24 | "crypto/sha256" 25 | "github.com/monnand/dhkx" 26 | pss "github.com/monnand/rsa" 27 | "io" 28 | "net" 29 | ) 30 | 31 | const currentProtocolVersion byte = 1 32 | 33 | // The authentication here is quite similar with, if not same as, tarsnap's auth algorithm. 34 | // 35 | // First, server generate a Diffie-Hellman public key, dhpub1, sign it with 36 | // server's private key using RSASSA-PSS signing algorithm. 37 | // Send dhpub1, its signature and a nonce to client. 38 | // An nonce is just a sequence of random bytes. 39 | // 40 | // Server -- dhpub1 + sign(dhpub1) + nonce --> Client 41 | // 42 | // Then client generate its own Diffie-Hellman key. It now can calculate 43 | // a key, K, using its own Diffie-Hellman key and server's DH public key. 44 | // (According to DH key exchange algorithm) 45 | // 46 | // Now, we can use K to derive any key we need on server and client side. 47 | // master key, mkey = MGF1(nonce || K, 48) 48 | func ServerKeyExchange(privKey *rsa.PrivateKey, conn net.Conn) (ks *keySet, err error) { 49 | group, _ := dhkx.GetGroup(dhGroupID) 50 | priv, _ := group.GeneratePrivateKey(nil) 51 | 52 | mypub := priv.Bytes() 53 | mypub = leftPaddingZero(mypub, dhPubkeyLen) 54 | 55 | salt := make([]byte, pssSaltLen) 56 | n, err := io.ReadFull(rand.Reader, salt) 57 | if err != nil || n != len(salt) { 58 | err = ErrZeroEntropy 59 | return 60 | } 61 | 62 | sha := sha256.New() 63 | hashed := make([]byte, sha.Size()) 64 | sha.Write([]byte{currentProtocolVersion}) 65 | sha.Write(mypub) 66 | hashed = sha.Sum(hashed[:0]) 67 | 68 | sig, err := pss.SignPSS(rand.Reader, privKey, crypto.SHA256, hashed, salt) 69 | if err != nil { 70 | return 71 | } 72 | 73 | siglen := (privKey.N.BitLen() + 7) / 8 74 | keyExPkt := make([]byte, dhPubkeyLen+siglen+nonceLen+1) 75 | keyExPkt[0] = currentProtocolVersion 76 | copy(keyExPkt[1:], mypub) 77 | copy(keyExPkt[dhPubkeyLen+1:], sig) 78 | nonce := keyExPkt[dhPubkeyLen+siglen+1:] 79 | n, err = io.ReadFull(rand.Reader, nonce) 80 | if err != nil || n != len(nonce) { 81 | err = ErrZeroEntropy 82 | return 83 | } 84 | 85 | // Send to client: 86 | // - Server's version (1 byte) 87 | // - DH public key: g ^ x 88 | // - Signature of DH public key RSASSA-PSS(version || g ^ x) 89 | // - nonce 90 | err = writen(conn, keyExPkt) 91 | if err != nil { 92 | return 93 | } 94 | 95 | // Receive from client: 96 | // - Client's version (1 byte) 97 | // - Client's DH public key: g ^ y 98 | // - HMAC of client's DH public key: HMAC(version || g ^ y, clientAuthKey) 99 | keyExPkt = keyExPkt[:1+dhPubkeyLen+authKeyLen] 100 | 101 | // Receive the data from client 102 | n, err = io.ReadFull(conn, keyExPkt) 103 | if err != nil { 104 | return 105 | } 106 | if n != len(keyExPkt) { 107 | err = ErrBadKeyExchangePacket 108 | return 109 | } 110 | 111 | version := keyExPkt[0] 112 | if version > currentProtocolVersion { 113 | err = ErrImcompatibleProtocol 114 | return 115 | } 116 | // First, recover client's DH public key 117 | clientpub := dhkx.NewPublicKey(keyExPkt[1 : dhPubkeyLen+1]) 118 | 119 | // Compute a shared key K. 120 | K, err := group.ComputeKey(clientpub, priv) 121 | if err != nil { 122 | return 123 | } 124 | 125 | // Generate keys from the shared key 126 | ks, err = generateKeys(K.Bytes(), nonce) 127 | if err != nil { 128 | return 129 | } 130 | 131 | // Check client's hmac 132 | err = ks.checkClientHMAC(keyExPkt[:dhPubkeyLen+1], keyExPkt[dhPubkeyLen+1:]) 133 | if err != nil { 134 | return 135 | } 136 | return 137 | } 138 | 139 | func ClientKeyExchange(pubKey *rsa.PublicKey, conn net.Conn) (ks *keySet, err error) { 140 | // Receive the data from server, which contains: 141 | // - version 142 | // - Server's DH public key: g ^ x 143 | // - Signature of server's DH public key RSASSA-PSS(g ^ x) 144 | // - nonce 145 | siglen := (pubKey.N.BitLen() + 7) / 8 146 | keyExPkt := make([]byte, dhPubkeyLen+siglen+nonceLen+1) 147 | n, err := io.ReadFull(conn, keyExPkt) 148 | if err != nil { 149 | return 150 | } 151 | if n != len(keyExPkt) { 152 | err = ErrBadKeyExchangePacket 153 | return 154 | } 155 | 156 | version := keyExPkt[0] 157 | if version != currentProtocolVersion { 158 | err = ErrImcompatibleProtocol 159 | return 160 | } 161 | 162 | serverPubData := keyExPkt[1 : dhPubkeyLen+1] 163 | signature := keyExPkt[dhPubkeyLen+1 : dhPubkeyLen+siglen+1] 164 | nonce := keyExPkt[dhPubkeyLen+siglen+1:] 165 | 166 | sha := sha256.New() 167 | hashed := make([]byte, sha.Size()) 168 | sha.Write(keyExPkt[:dhPubkeyLen+1]) 169 | hashed = sha.Sum(hashed[:0]) 170 | 171 | // Verify the signature 172 | err = pss.VerifyPSS(pubKey, crypto.SHA256, hashed, signature, pssSaltLen) 173 | 174 | if err != nil { 175 | return 176 | } 177 | 178 | // Generate a DH key 179 | group, _ := dhkx.GetGroup(dhGroupID) 180 | priv, _ := group.GeneratePrivateKey(nil) 181 | mypub := leftPaddingZero(priv.Bytes(), dhPubkeyLen) 182 | 183 | // Generate the shared key from server's DH public key and client DH private key 184 | serverpub := dhkx.NewPublicKey(serverPubData) 185 | K, err := group.ComputeKey(serverpub, priv) 186 | if err != nil { 187 | return 188 | } 189 | 190 | ks, err = generateKeys(K.Bytes(), nonce) 191 | if err != nil { 192 | return 193 | } 194 | 195 | keyExPkt = keyExPkt[:1+dhPubkeyLen+authKeyLen] 196 | keyExPkt[0] = currentProtocolVersion 197 | copy(keyExPkt[1:], mypub) 198 | err = ks.clientHMAC(keyExPkt[:dhPubkeyLen+1], keyExPkt[dhPubkeyLen+1:]) 199 | if err != nil { 200 | return 201 | } 202 | 203 | // Send the client message to server, which contains: 204 | // - Protocol version (1 byte) 205 | // - Client's DH public key: g ^ y 206 | // - HMAC of client's DH public key: HMAC(g ^ y, clientAuthKey) 207 | err = writen(conn, keyExPkt) 208 | 209 | return 210 | } 211 | -------------------------------------------------------------------------------- /proto/keyex_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "crypto/rand" 22 | "crypto/rsa" 23 | //"fmt" 24 | "net" 25 | "testing" 26 | "time" 27 | ) 28 | 29 | func serverGetOneClient(addr string) (conn net.Conn, err error) { 30 | ln, err := net.Listen("tcp", addr) 31 | if err != nil { 32 | return 33 | } 34 | defer ln.Close() 35 | conn, err = ln.Accept() 36 | if err != nil { 37 | return 38 | } 39 | return 40 | } 41 | 42 | func clientConnectServer(addr string) (conn net.Conn, err error) { 43 | conn, err = net.Dial("tcp", addr) 44 | if err != nil { 45 | return 46 | } 47 | return 48 | } 49 | 50 | func buildServerClient(addr string) (server net.Conn, client net.Conn, err error) { 51 | ch := make(chan error) 52 | go func() { 53 | var e error 54 | client, e = serverGetOneClient(addr) 55 | ch <- e 56 | }() 57 | // It is enough to setup a server for a test. 58 | time.Sleep(1 * time.Second) 59 | server, err = clientConnectServer(addr) 60 | if err != nil { 61 | return 62 | } 63 | err = <-ch 64 | if err != nil { 65 | return 66 | } 67 | return 68 | } 69 | 70 | func exchangeKeysOrReport(t *testing.T, succ bool) (serverKeySet, clientKeySet *keySet, server2client, client2server net.Conn) { 71 | addr := "127.0.0.1:8080" 72 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 73 | if err != nil { 74 | if t != nil { 75 | t.Errorf("Error: %v", err) 76 | } 77 | return 78 | } 79 | pub := &priv.PublicKey 80 | if !succ { 81 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 82 | if err != nil { 83 | if t != nil { 84 | t.Errorf("Error: %v", err) 85 | } 86 | return 87 | } 88 | pub = &priv.PublicKey 89 | } 90 | server, client, err := buildServerClient(addr) 91 | if err != nil { 92 | if t != nil { 93 | t.Errorf("Error: %v", err) 94 | } 95 | return 96 | } 97 | if server == nil || client == nil { 98 | if t != nil { 99 | t.Errorf("Nil pointer: server=%v; client=%v", server, client) 100 | } 101 | return 102 | } 103 | server2client = client 104 | client2server = server 105 | var es error 106 | var ec error 107 | ch := make(chan bool) 108 | go func() { 109 | //start := time.Now() 110 | serverKeySet, es = ServerKeyExchange(priv, client) 111 | //delta := time.Since(start) 112 | //fmt.Printf("Key exchange: Server used %v\n", delta) 113 | ch <- true 114 | }() 115 | go func() { 116 | //start := time.Now() 117 | clientKeySet, ec = ClientKeyExchange(pub, server) 118 | //delta := time.Since(start) 119 | //fmt.Printf("Key exchange: Client used %v\n", delta) 120 | if ec != nil { 121 | server.Close() 122 | } 123 | ch <- true 124 | }() 125 | <-ch 126 | <-ch 127 | if es == nil && !succ { 128 | if t != nil { 129 | t.Errorf("Should be failed. Run again") 130 | return 131 | } 132 | return 133 | } 134 | if ec == nil && !succ { 135 | if t != nil { 136 | t.Errorf("Should be failed. Run again") 137 | return 138 | } 139 | return 140 | } 141 | if !succ { 142 | return 143 | } 144 | if es != nil && succ { 145 | serverKeySet = nil 146 | clientKeySet = nil 147 | if t != nil { 148 | t.Errorf("Error from server: %v", es) 149 | return 150 | } 151 | } 152 | if ec != nil && succ { 153 | serverKeySet = nil 154 | clientKeySet = nil 155 | if t != nil { 156 | t.Errorf("Error from client: %v", ec) 157 | return 158 | } 159 | } 160 | if !serverKeySet.eq(clientKeySet) { 161 | serverKeySet = nil 162 | clientKeySet = nil 163 | 164 | if t != nil { 165 | t.Errorf("Key set Not equal") 166 | return 167 | } 168 | } 169 | return 170 | } 171 | 172 | func TestKeyExchange(t *testing.T) { 173 | exchangeKeysOrReport(t, true) 174 | } 175 | 176 | func TestKeyExchangeFail(t *testing.T) { 177 | exchangeKeysOrReport(t, false) 178 | } 179 | -------------------------------------------------------------------------------- /proto/keyset.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "bytes" 22 | "crypto/hmac" 23 | "crypto/sha256" 24 | "fmt" 25 | "io" 26 | ) 27 | 28 | type keySet struct { 29 | serverEncrKey []byte 30 | serverAuthKey []byte 31 | clientEncrKey []byte 32 | clientAuthKey []byte 33 | } 34 | 35 | func (self *keySet) String() string { 36 | return fmt.Sprintf("serverEncr: %v; serverAuth: %v\nclientEncr: %v; clientAuth: %v", self.serverEncrKey, self.serverAuthKey, self.clientEncrKey, self.clientAuthKey) 37 | } 38 | 39 | func (self *keySet) eq(ks *keySet) bool { 40 | if !bytes.Equal(self.serverEncrKey, ks.serverEncrKey) { 41 | return false 42 | } 43 | if !bytes.Equal(self.serverAuthKey, ks.serverAuthKey) { 44 | return false 45 | } 46 | if !bytes.Equal(self.clientEncrKey, ks.clientEncrKey) { 47 | return false 48 | } 49 | if !bytes.Equal(self.clientAuthKey, ks.clientAuthKey) { 50 | return false 51 | } 52 | return true 53 | } 54 | 55 | func newKeySet(serverEncrKey, serverAuthKey, clientEncrKey, clientAuthKey []byte) *keySet { 56 | result := new(keySet) 57 | 58 | result.serverEncrKey = serverEncrKey 59 | result.serverAuthKey = serverAuthKey 60 | result.clientEncrKey = clientEncrKey 61 | result.clientAuthKey = clientAuthKey 62 | 63 | return result 64 | } 65 | 66 | func (self *keySet) serverHMAC(data, mac []byte) error { 67 | hash := hmac.New(sha256.New, self.serverAuthKey) 68 | err := writen(hash, data) 69 | if err != nil { 70 | return err 71 | } 72 | mac = hash.Sum(mac[:0]) 73 | return nil 74 | } 75 | 76 | func (self *keySet) checkServerHMAC(data, mac []byte) error { 77 | if len(mac) != authKeyLen { 78 | return ErrCorruptedData 79 | } 80 | hmac := make([]byte, len(mac)) 81 | err := self.serverHMAC(data, hmac) 82 | if err != nil { 83 | return err 84 | } 85 | if !xorBytesEq(hmac, mac) { 86 | return ErrCorruptedData 87 | } 88 | return nil 89 | } 90 | 91 | func (self *keySet) ClientCommandIO(conn io.ReadWriter) *CommandIO { 92 | ret := NewCommandIO(self.clientEncrKey, self.clientAuthKey, self.serverEncrKey, self.serverAuthKey, conn) 93 | return ret 94 | } 95 | 96 | func (self *keySet) ServerCommandIO(conn io.ReadWriter) *CommandIO { 97 | ret := NewCommandIO(self.serverEncrKey, self.serverAuthKey, self.clientEncrKey, self.clientAuthKey, conn) 98 | return ret 99 | } 100 | 101 | func (self *keySet) clientHMAC(data, mac []byte) error { 102 | hash := hmac.New(sha256.New, self.clientAuthKey) 103 | err := writen(hash, data) 104 | if err != nil { 105 | return err 106 | } 107 | mac = hash.Sum(mac[:0]) 108 | return nil 109 | } 110 | 111 | func (self *keySet) checkClientHMAC(data, mac []byte) error { 112 | if len(mac) != authKeyLen { 113 | return ErrCorruptedData 114 | } 115 | hmac := make([]byte, len(mac)) 116 | err := self.clientHMAC(data, hmac) 117 | if err != nil { 118 | return err 119 | } 120 | if !xorBytesEq(hmac, mac) { 121 | return ErrCorruptedData 122 | } 123 | return nil 124 | } 125 | 126 | func generateKeys(k, nonce []byte) (ks *keySet, err error) { 127 | mkey := make([]byte, 48) 128 | mgf1XOR(mkey, sha256.New(), append(k, nonce...)) 129 | 130 | h := hmac.New(sha256.New, mkey) 131 | 132 | serverEncrKey := make([]byte, encrKeyLen) 133 | h.Write([]byte("ServerEncr")) 134 | serverEncrKey = h.Sum(serverEncrKey[:0]) 135 | h.Reset() 136 | 137 | serverAuthKey := make([]byte, authKeyLen) 138 | h.Write([]byte("ServerAuth")) 139 | serverAuthKey = h.Sum(serverAuthKey[:0]) 140 | h.Reset() 141 | 142 | clientEncrKey := make([]byte, encrKeyLen) 143 | h.Write([]byte("ClientEncr")) 144 | clientEncrKey = h.Sum(clientEncrKey[:0]) 145 | h.Reset() 146 | 147 | clientAuthKey := make([]byte, authKeyLen) 148 | h.Write([]byte("ClientAuth")) 149 | clientAuthKey = h.Sum(clientAuthKey[:0]) 150 | h.Reset() 151 | 152 | ks = newKeySet(serverEncrKey, serverAuthKey, clientEncrKey, clientAuthKey) 153 | return 154 | } 155 | -------------------------------------------------------------------------------- /proto/server/auth.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "crypto/rand" 22 | "crypto/rsa" 23 | "encoding/base64" 24 | "errors" 25 | "fmt" 26 | . "github.com/uniqush/uniqush-conn/evthandler" 27 | "github.com/uniqush/uniqush-conn/proto" 28 | "io" 29 | "net" 30 | "strings" 31 | "time" 32 | ) 33 | 34 | var ErrAuthFail = errors.New("authentication failed") 35 | 36 | // The conn will be closed if any error occur 37 | func AuthConn(conn net.Conn, privkey *rsa.PrivateKey, auth Authenticator, timeout time.Duration) (c Conn, err error) { 38 | conn.SetDeadline(time.Now().Add(timeout)) 39 | defer func() { 40 | if err == nil { 41 | err = conn.SetDeadline(time.Time{}) 42 | if err != nil { 43 | conn.Close() 44 | } 45 | } 46 | }() 47 | 48 | ks, err := proto.ServerKeyExchange(privkey, conn) 49 | if err != nil { 50 | conn.Close() 51 | return 52 | } 53 | cmdio := ks.ServerCommandIO(conn) 54 | cmd, err := cmdio.ReadCommand() 55 | if err != nil { 56 | return 57 | } 58 | if cmd.Type != proto.CMD_AUTH { 59 | err = fmt.Errorf("invalid command type") 60 | return 61 | } 62 | if len(cmd.Params) != 3 { 63 | err = fmt.Errorf("invalid parameters") 64 | return 65 | } 66 | service := cmd.Params[0] 67 | username := cmd.Params[1] 68 | token := cmd.Params[2] 69 | 70 | // Username and service should not contain "\n", ":", "," 71 | if strings.Contains(service, "\n") || strings.Contains(username, "\n") || 72 | strings.Contains(service, ":") || strings.Contains(username, ":") || 73 | strings.Contains(service, ",") || strings.Contains(username, ",") { 74 | err = fmt.Errorf("invalid service name or username") 75 | return 76 | } 77 | 78 | var d [16]byte 79 | io.ReadFull(rand.Reader, d[:]) 80 | connId := fmt.Sprintf("%x-%v", time.Now().Unix(), base64.URLEncoding.EncodeToString(d[:])) 81 | ok, redir, err := auth.Authenticate(service, username, connId, token, conn.RemoteAddr().String()) 82 | if err != nil { 83 | return 84 | } 85 | if len(redir) > 0 { 86 | cmd.Type = proto.CMD_REDIRECT 87 | cmd.Params = redir 88 | cmd.Message = nil 89 | } else if !ok { 90 | err = ErrAuthFail 91 | return 92 | } else { 93 | cmd.Type = proto.CMD_AUTHOK 94 | cmd.Params = nil 95 | cmd.Message = nil 96 | cmd.Randomize() 97 | } 98 | err = cmdio.WriteCommand(cmd, false) 99 | if err != nil { 100 | return 101 | } 102 | 103 | if len(redir) > 0 { 104 | c = nil 105 | err = io.EOF 106 | } else { 107 | c = NewConn(cmdio, service, username, connId, conn) 108 | err = nil 109 | } 110 | return 111 | } 112 | -------------------------------------------------------------------------------- /proto/server/auth_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "crypto/rand" 22 | "crypto/rsa" 23 | . "github.com/uniqush/uniqush-conn/evthandler" 24 | "github.com/uniqush/uniqush-conn/proto/client" 25 | "net" 26 | "sync" 27 | "testing" 28 | "time" 29 | ) 30 | 31 | type singleUserAuth struct { 32 | service, username, token string 33 | } 34 | 35 | func (self *singleUserAuth) Authenticate(srv, usr, connId, token, addr string) (bool, []string, error) { 36 | if self.service == srv && self.username == usr && self.token == token { 37 | return true, nil, nil 38 | } 39 | return false, nil, nil 40 | } 41 | 42 | func getClient(addr string, priv *rsa.PrivateKey, auth Authenticator, timeout time.Duration) (conn Conn, err error) { 43 | ln, err := net.Listen("tcp", addr) 44 | if err != nil { 45 | return 46 | } 47 | c, err := ln.Accept() 48 | if err != nil { 49 | return 50 | } 51 | ln.Close() 52 | conn, err = AuthConn(c, priv, auth, timeout) 53 | return 54 | } 55 | 56 | func connectServer(addr string, pub *rsa.PublicKey, service, username, token string, timeout time.Duration) (conn client.Conn, err error) { 57 | c, err := net.Dial("tcp", addr) 58 | if err != nil { 59 | return 60 | } 61 | conn, err = client.Dial(c, pub, service, username, token, timeout) 62 | return 63 | } 64 | 65 | func buildServerClientConns(addr string, token string, timeout time.Duration) (servConn Conn, cliConn client.Conn, err error) { 66 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 67 | if err != nil { 68 | return 69 | } 70 | pub := &priv.PublicKey 71 | 72 | auth := new(singleUserAuth) 73 | auth.service = "service" 74 | auth.username = "username" 75 | auth.token = "token" 76 | 77 | wg := new(sync.WaitGroup) 78 | wg.Add(2) 79 | 80 | var ec error 81 | var es error 82 | go func() { 83 | servConn, es = getClient(addr, priv, auth, timeout) 84 | wg.Done() 85 | }() 86 | 87 | time.Sleep(1 * time.Second) 88 | 89 | go func() { 90 | cliConn, ec = connectServer(addr, pub, auth.service, auth.username, token, timeout) 91 | wg.Done() 92 | }() 93 | wg.Wait() 94 | if es != nil { 95 | err = es 96 | return 97 | } 98 | if ec != nil { 99 | err = ec 100 | return 101 | } 102 | return 103 | } 104 | 105 | func TestAuthOK(t *testing.T) { 106 | addr := "127.0.0.1:8088" 107 | token := "token" 108 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 109 | if err != nil { 110 | t.Errorf("Error: %v", err) 111 | } 112 | if servConn != nil { 113 | servConn.Close() 114 | } 115 | if cliConn != nil { 116 | cliConn.Close() 117 | } 118 | } 119 | 120 | func TestAuthFail(t *testing.T) { 121 | addr := "127.0.0.1:8088" 122 | token := "wrong token" 123 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 124 | if err == nil { 125 | t.Errorf("Error: Should be failed") 126 | } 127 | if servConn != nil { 128 | servConn.Close() 129 | } 130 | if cliConn != nil { 131 | cliConn.Close() 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /proto/server/conn.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "fmt" 22 | "github.com/uniqush/uniqush-conn/msgcache" 23 | "github.com/uniqush/uniqush-conn/proto" 24 | "github.com/uniqush/uniqush-conn/rpc" 25 | "io" 26 | "strings" 27 | 28 | "net" 29 | "sync" 30 | "sync/atomic" 31 | ) 32 | 33 | // SendMessage() and ForwardMessage() are goroutine-safe. 34 | // SendMessage() and ForwardMessage() will send a message ditest, 35 | // instead of the message itself, if the message is too large. 36 | // ReceiveMessage() should nevery be called concurrently. 37 | type Conn interface { 38 | RemoteAddr() net.Addr 39 | Service() string 40 | Username() string 41 | UniqId() string 42 | Close() error 43 | 44 | // If the message is generated from the server, then use SendMessage() 45 | // to send it to the client. 46 | SendMessage(msg *rpc.Message, id string, extra map[string]string, tryDigest bool) error 47 | 48 | // If the message is generated from another client, then 49 | // use ForwardMessage() to send it to the client. 50 | ForwardMessage(sender, senderService string, msg *rpc.Message, id string, tryDigest bool) error 51 | 52 | // ReceiveMessage() will keep receiving Commands from the client 53 | // until it receives a Command with type CMD_DATA. 54 | ReceiveMessage() (msg *rpc.Message, err error) 55 | 56 | // Ask the client to connect to other servers. 57 | // Redirect() will not close the connection. The user should call Close() 58 | // seprately to close the connection. 59 | Redirect(addrs ...string) error 60 | 61 | SetMessageCache(cache msgcache.Cache) 62 | SetForwardRequestChannel(fwdChan chan<- *rpc.ForwardRequest) 63 | SetSubscribeRequestChan(subChan chan<- *rpc.SubscribeRequest) 64 | Visible() bool 65 | } 66 | 67 | type serverConn struct { 68 | cmdio *proto.CommandIO 69 | conn net.Conn 70 | compressThreshold int32 71 | digestThreshold int32 72 | service string 73 | username string 74 | connId string 75 | digestFielsLock sync.Mutex 76 | digestFields []string 77 | cmdProcs []CommandProcessor 78 | visible int32 79 | } 80 | 81 | type CommandProcessor interface { 82 | ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) 83 | } 84 | 85 | func (self *serverConn) Visible() bool { 86 | v := atomic.LoadInt32(&self.visible) 87 | return v > 0 88 | } 89 | 90 | func (self *serverConn) RemoteAddr() net.Addr { 91 | return self.conn.RemoteAddr() 92 | } 93 | 94 | func (self *serverConn) Close() error { 95 | if self == nil { 96 | return nil 97 | } 98 | cmd := &proto.Command{ 99 | Type: proto.CMD_BYE, 100 | } 101 | cmd.Randomize() 102 | self.cmdio.WriteCommand(cmd, false) 103 | return self.conn.Close() 104 | } 105 | 106 | func (self *serverConn) Service() string { 107 | if self == nil { 108 | return "" 109 | } 110 | return self.service 111 | } 112 | 113 | func (self *serverConn) Username() string { 114 | if self == nil { 115 | return "" 116 | } 117 | return self.username 118 | } 119 | 120 | func (self *serverConn) UniqId() string { 121 | if self == nil { 122 | return "" 123 | } 124 | return self.connId 125 | } 126 | 127 | func (self *serverConn) shouldCompress(size int) bool { 128 | t := int(atomic.LoadInt32(&self.compressThreshold)) 129 | if t > 0 && t < size { 130 | return true 131 | } 132 | return false 133 | } 134 | 135 | func (self *serverConn) shouldDigest(sz int) bool { 136 | d := atomic.LoadInt32(&self.digestThreshold) 137 | if d >= 0 && d < int32(sz) { 138 | return true 139 | } 140 | return false 141 | } 142 | 143 | func (self *serverConn) writeDigest(mc *rpc.MessageContainer, extra map[string]string, sz int) error { 144 | digest := &proto.Command{ 145 | Type: proto.CMD_DIGEST, 146 | } 147 | params := [4]string{fmt.Sprintf("%v", sz), mc.Id} 148 | 149 | if mc.FromUser() { 150 | params[2] = mc.Sender 151 | params[3] = mc.SenderService 152 | digest.Params = params[:4] 153 | } else { 154 | digest.Params = params[:2] 155 | } 156 | 157 | msg := mc.Message 158 | header := make(map[string]string, len(extra)+len(msg.Header)) 159 | self.digestFielsLock.Lock() 160 | defer self.digestFielsLock.Unlock() 161 | 162 | for _, f := range self.digestFields { 163 | if len(msg.Header) > 0 { 164 | if v, ok := msg.Header[f]; ok { 165 | header[f] = v 166 | } 167 | } 168 | if len(extra) > 0 { 169 | if v, ok := extra[f]; ok { 170 | header[f] = v 171 | } 172 | } 173 | } 174 | if len(header) > 0 { 175 | digest.Message = &rpc.Message{ 176 | Header: header, 177 | } 178 | } 179 | 180 | compress := self.shouldCompress(digest.Message.Size()) 181 | return self.cmdio.WriteCommand(digest, compress) 182 | } 183 | 184 | func (self *serverConn) Redirect(addrs ...string) error { 185 | if len(addrs) == 0 { 186 | return nil 187 | } 188 | cmd := &proto.Command{ 189 | Type: proto.CMD_REDIRECT, 190 | Params: addrs, 191 | } 192 | return self.cmdio.WriteCommand(cmd, false) 193 | } 194 | 195 | func (self *serverConn) SendMessage(msg *rpc.Message, id string, extra map[string]string, tryDigest bool) error { 196 | if msg == nil { 197 | cmd := &proto.Command{ 198 | Type: proto.CMD_EMPTY, 199 | } 200 | if len(id) > 0 { 201 | cmd.Params = []string{id} 202 | } 203 | return self.cmdio.WriteCommand(cmd, false) 204 | } 205 | sz := msg.Size() 206 | if tryDigest && self.shouldDigest(sz) { 207 | container := &rpc.MessageContainer{ 208 | Id: id, 209 | Message: msg, 210 | } 211 | return self.writeDigest(container, extra, sz) 212 | } 213 | cmd := &proto.Command{ 214 | Type: proto.CMD_DATA, 215 | Message: msg, 216 | } 217 | cmd.Params = []string{id} 218 | return self.cmdio.WriteCommand(cmd, self.shouldCompress(sz)) 219 | } 220 | 221 | func (self *serverConn) ForwardMessage(sender, senderService string, msg *rpc.Message, id string, tryDigest bool) error { 222 | sz := msg.Size() 223 | if sz == 0 { 224 | return nil 225 | } 226 | if tryDigest && self.shouldDigest(sz) { 227 | container := &rpc.MessageContainer{ 228 | Id: id, 229 | Sender: sender, 230 | SenderService: senderService, 231 | Message: msg, 232 | } 233 | return self.writeDigest(container, nil, sz) 234 | } 235 | cmd := &proto.Command{ 236 | Type: proto.CMD_FWD, 237 | Message: msg, 238 | } 239 | cmd.Params = []string{sender, senderService, id} 240 | return self.cmdio.WriteCommand(cmd, self.shouldCompress(sz)) 241 | } 242 | 243 | func (self *serverConn) processCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 244 | if cmd == nil { 245 | return 246 | } 247 | 248 | t := int(cmd.Type) 249 | if t > len(self.cmdProcs) { 250 | return 251 | } 252 | proc := self.cmdProcs[t] 253 | if proc != nil { 254 | msg, err = proc.ProcessCommand(cmd) 255 | } 256 | return 257 | } 258 | 259 | func isEOFlikeError(err error) bool { 260 | if err == io.ErrUnexpectedEOF || err == io.EOF { 261 | return true 262 | } 263 | 264 | // XXX any better idea? 265 | if strings.HasSuffix(err.Error(), "use of closed network connection") { 266 | return true 267 | } 268 | return false 269 | } 270 | 271 | func (self *serverConn) ReceiveMessage() (msg *rpc.Message, err error) { 272 | var cmd *proto.Command 273 | for { 274 | cmd, err = self.cmdio.ReadCommand() 275 | if err != nil { 276 | if isEOFlikeError(err) { 277 | err = io.EOF 278 | } 279 | return 280 | } 281 | switch cmd.Type { 282 | case proto.CMD_DATA: 283 | msg = cmd.Message 284 | return 285 | case proto.CMD_BYE: 286 | err = io.EOF 287 | return 288 | default: 289 | msg, err = self.processCommand(cmd) 290 | if err != nil || msg != nil { 291 | return 292 | } 293 | } 294 | } 295 | } 296 | 297 | func (self *serverConn) SetMessageCache(cache msgcache.Cache) { 298 | if cache == nil { 299 | return 300 | } 301 | proc := new(messageRetriever) 302 | proc.cache = cache 303 | proc.conn = self 304 | self.setCommandProcessor(proto.CMD_MSG_RETRIEVE, proc) 305 | 306 | p2 := new(retriaveAllMessages) 307 | p2.cache = cache 308 | p2.conn = self 309 | self.setCommandProcessor(proto.CMD_REQ_ALL_CACHED, p2) 310 | } 311 | 312 | func (self *serverConn) SetForwardRequestChannel(fwdChan chan<- *rpc.ForwardRequest) { 313 | if fwdChan == nil { 314 | return 315 | } 316 | proc := new(forwardProcessor) 317 | proc.conn = self 318 | proc.fwdChan = fwdChan 319 | self.setCommandProcessor(proto.CMD_FWD_REQ, proc) 320 | } 321 | 322 | func (self *serverConn) SetSubscribeRequestChan(subChan chan<- *rpc.SubscribeRequest) { 323 | if subChan == nil { 324 | return 325 | } 326 | proc := new(subscribeProcessor) 327 | proc.conn = self 328 | proc.subChan = subChan 329 | self.setCommandProcessor(proto.CMD_SUBSCRIPTION, proc) 330 | } 331 | 332 | func (self *serverConn) setCommandProcessor(cmdType uint8, proc CommandProcessor) { 333 | if cmdType >= proto.CMD_NR_CMDS { 334 | return 335 | } 336 | if len(self.cmdProcs) <= int(cmdType) { 337 | self.cmdProcs = make([]CommandProcessor, proto.CMD_NR_CMDS) 338 | } 339 | self.cmdProcs[cmdType] = proc 340 | } 341 | 342 | func NewConn(cmdio *proto.CommandIO, service, username, connId string, conn net.Conn) Conn { 343 | ret := new(serverConn) 344 | ret.conn = conn 345 | ret.cmdio = cmdio 346 | ret.service = service 347 | ret.username = username 348 | ret.connId = connId //fmt.Sprintf("%x-%x", time.Now().UnixNano(), rand.Int63()) 349 | ret.digestThreshold = 1024 350 | ret.compressThreshold = 1024 351 | 352 | settingproc := new(settingProcessor) 353 | settingproc.conn = ret 354 | ret.setCommandProcessor(proto.CMD_SETTING, settingproc) 355 | 356 | visproc := new(visibilityProcessor) 357 | visproc.conn = ret 358 | ret.setCommandProcessor(proto.CMD_SET_VISIBILITY, visproc) 359 | 360 | ret.visible = 1 361 | return ret 362 | } 363 | -------------------------------------------------------------------------------- /proto/server/conn_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "crypto/rand" 22 | "errors" 23 | "fmt" 24 | "github.com/garyburd/redigo/redis" 25 | "github.com/uniqush/uniqush-conn/msgcache" 26 | 27 | "github.com/uniqush/uniqush-conn/proto/client" 28 | "github.com/uniqush/uniqush-conn/rpc" 29 | "io" 30 | "sync" 31 | "testing" 32 | "time" 33 | ) 34 | 35 | func clearCache() { 36 | db := 1 37 | c, _ := redis.Dial("tcp", "localhost:6379") 38 | c.Do("SELECT", db) 39 | c.Do("FLUSHDB") 40 | c.Close() 41 | } 42 | 43 | func getCache() msgcache.Cache { 44 | clearCache() 45 | cache, _ := msgcache.GetCache("redis", "", "", "", "1", 0) 46 | return cache 47 | } 48 | 49 | type messageContainerProcessor interface { 50 | ProcessMessageContainer(mc *rpc.MessageContainer) error 51 | } 52 | 53 | func iterateOverContainers(srcProc, dstProc messageContainerProcessor, mcs ...*rpc.MessageContainer) error { 54 | wg := new(sync.WaitGroup) 55 | wg.Add(2) 56 | 57 | var es error 58 | var ed error 59 | 60 | go func() { 61 | defer wg.Done() 62 | for _, mc := range mcs { 63 | es = srcProc.ProcessMessageContainer(mc) 64 | } 65 | }() 66 | 67 | go func() { 68 | defer wg.Done() 69 | for _, mc := range mcs { 70 | ed = dstProc.ProcessMessageContainer(mc) 71 | } 72 | }() 73 | wg.Wait() 74 | if es != nil { 75 | return es 76 | } 77 | if ed != nil { 78 | return ed 79 | } 80 | return nil 81 | } 82 | 83 | func randomMessage() *rpc.Message { 84 | msg := new(rpc.Message) 85 | msg.Body = make([]byte, 10) 86 | io.ReadFull(rand.Reader, msg.Body) 87 | msg.Header = make(map[string]string, 2) 88 | msg.Header["aaa"] = "hello" 89 | msg.Header["aa"] = "hell" 90 | return msg 91 | } 92 | 93 | type serverSender struct { 94 | conn Conn 95 | extra map[string]string 96 | } 97 | 98 | func (self *serverSender) ProcessMessageContainer(mc *rpc.MessageContainer) error { 99 | if mc.FromUser() { 100 | return self.conn.ForwardMessage(mc.Sender, mc.SenderService, mc.Message, mc.Id, true) 101 | } 102 | return self.conn.SendMessage(mc.Message, mc.Id, self.extra, true) 103 | } 104 | 105 | type serverReceiver struct { 106 | conn Conn 107 | } 108 | 109 | func (self *serverReceiver) ProcessMessageContainer(mc *rpc.MessageContainer) error { 110 | msg, err := self.conn.ReceiveMessage() 111 | if err != nil { 112 | return err 113 | } 114 | if !msg.Eq(mc.Message) { 115 | return errors.New("corrupted data") 116 | } 117 | return nil 118 | } 119 | 120 | type clientReceiver struct { 121 | conn client.Conn 122 | } 123 | 124 | func (self *clientReceiver) ProcessMessageContainer(mc *rpc.MessageContainer) error { 125 | rmc, err := self.conn.ReceiveMessage() 126 | if err != nil { 127 | return err 128 | } 129 | if !rmc.Eq(mc) { 130 | return errors.New("corrupted data") 131 | } 132 | return nil 133 | } 134 | 135 | type clientSender struct { 136 | conn client.Conn 137 | } 138 | 139 | func (self *clientSender) ProcessMessageContainer(mc *rpc.MessageContainer) error { 140 | return self.conn.SendMessageToServer(mc.Message) 141 | } 142 | 143 | func TestSendMessageFromServerToClient(t *testing.T) { 144 | addr := "127.0.0.1:8088" 145 | token := "token" 146 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 147 | if err != nil { 148 | t.Errorf("Error: %v", err) 149 | } 150 | defer servConn.Close() 151 | defer cliConn.Close() 152 | N := 100 153 | mcs := make([]*rpc.MessageContainer, N) 154 | 155 | for i := 0; i < N; i++ { 156 | mcs[i] = &rpc.MessageContainer{ 157 | Message: randomMessage(), 158 | Id: fmt.Sprintf("%v", i), 159 | } 160 | } 161 | 162 | src := &serverSender{ 163 | conn: servConn, 164 | } 165 | 166 | dst := &clientReceiver{ 167 | conn: cliConn, 168 | } 169 | err = iterateOverContainers(src, dst, mcs...) 170 | if err != nil { 171 | t.Errorf("Error: %v", err) 172 | } 173 | } 174 | 175 | func TestSendMessageFromClientToServer(t *testing.T) { 176 | addr := "127.0.0.1:8088" 177 | token := "token" 178 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 179 | if err != nil { 180 | t.Errorf("Error: %v", err) 181 | } 182 | defer servConn.Close() 183 | defer cliConn.Close() 184 | N := 100 185 | mcs := make([]*rpc.MessageContainer, N) 186 | 187 | for i := 0; i < N; i++ { 188 | mcs[i] = &rpc.MessageContainer{ 189 | Message: randomMessage(), 190 | Id: fmt.Sprintf("%v", i), 191 | } 192 | } 193 | 194 | src := &clientSender{ 195 | conn: cliConn, 196 | } 197 | 198 | dst := &serverReceiver{ 199 | conn: servConn, 200 | } 201 | err = iterateOverContainers(src, dst, mcs...) 202 | if err != nil { 203 | t.Errorf("Error: %v", err) 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /proto/server/digest_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "fmt" 22 | "github.com/uniqush/uniqush-conn/proto/client" 23 | "github.com/uniqush/uniqush-conn/rpc" 24 | "testing" 25 | "time" 26 | ) 27 | 28 | func TestSendMessageDigestFromServerToClient(t *testing.T) { 29 | addr := "127.0.0.1:8088" 30 | token := "token" 31 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 32 | if err != nil { 33 | t.Errorf("Error: %v", err) 34 | } 35 | defer servConn.Close() 36 | defer cliConn.Close() 37 | N := 100 38 | mcs := make([]*rpc.MessageContainer, N) 39 | 40 | cache := getCache() 41 | defer clearCache() 42 | difields := make(map[string]string, 2) 43 | difields["df1"] = "df1value" 44 | difields["df2"] = "df2value" 45 | 46 | difieldNames := []string{"df1", "df2"} 47 | ttl := 1 * time.Hour 48 | 49 | for i := 0; i < N; i++ { 50 | mcs[i] = &rpc.MessageContainer{ 51 | Message: randomMessage(), 52 | Id: fmt.Sprintf("%v", i), 53 | } 54 | for k, v := range difields { 55 | mcs[i].Message.Header[k] = v 56 | } 57 | id, err := cache.CacheMessage(servConn.Service(), servConn.Username(), mcs[i], ttl) 58 | if err != nil { 59 | t.Errorf("dberror: %v", err) 60 | } 61 | mcs[i].Id = id 62 | } 63 | 64 | servConn.SetMessageCache(cache) 65 | src := &serverSender{ 66 | conn: servConn, 67 | } 68 | dst := &clientReceiver{ 69 | conn: cliConn, 70 | } 71 | 72 | err = cliConn.Config(0, 2048, difieldNames...) 73 | if err != nil { 74 | t.Errorf("Error: %v\n") 75 | } 76 | go func() { 77 | servConn.ReceiveMessage() 78 | }() 79 | 80 | digestChan := make(chan *client.Digest) 81 | cliConn.SetDigestChannel(digestChan) 82 | 83 | go func() { 84 | i := 0 85 | for digest := range digestChan { 86 | mc := mcs[i] 87 | i++ 88 | if len(difieldNames) != len(digest.Info) { 89 | t.Errorf("Error: wrong digest") 90 | } 91 | for k, v := range difields { 92 | if df, ok := digest.Info[k]; ok { 93 | if df != v { 94 | t.Errorf("Error: wrong digest value on field %v", k) 95 | } 96 | } else { 97 | t.Errorf("cannot find field %v in the digest", k) 98 | } 99 | } 100 | if mc.Id != digest.MsgId { 101 | t.Errorf("wrong id: %v != %v", mc.Id, digest.MsgId) 102 | } 103 | cliConn.RequestMessage(digest.MsgId) 104 | } 105 | if i != N { 106 | t.Errorf("received %v digest", i) 107 | } 108 | }() 109 | err = iterateOverContainers(src, dst, mcs...) 110 | if err != nil { 111 | t.Errorf("Error: %v", err) 112 | } 113 | close(digestChan) 114 | cliConn.Close() 115 | } 116 | -------------------------------------------------------------------------------- /proto/server/fwd_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "fmt" 22 | 23 | "github.com/uniqush/uniqush-conn/proto/client" 24 | "github.com/uniqush/uniqush-conn/rpc" 25 | 26 | "testing" 27 | "time" 28 | ) 29 | 30 | func TestForwardMessageFromServerToClient(t *testing.T) { 31 | addr := "127.0.0.1:8088" 32 | token := "token" 33 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 34 | if err != nil { 35 | t.Errorf("Error: %v", err) 36 | } 37 | defer servConn.Close() 38 | defer cliConn.Close() 39 | N := 100 40 | mcs := make([]*rpc.MessageContainer, N) 41 | 42 | for i := 0; i < N; i++ { 43 | mcs[i] = &rpc.MessageContainer{ 44 | Message: randomMessage(), 45 | Id: fmt.Sprintf("%v", i), 46 | Sender: "sender", 47 | SenderService: "someservice", 48 | } 49 | } 50 | 51 | src := &serverSender{ 52 | conn: servConn, 53 | } 54 | 55 | dst := &clientReceiver{ 56 | conn: cliConn, 57 | } 58 | err = iterateOverContainers(src, dst, mcs...) 59 | if err != nil { 60 | t.Errorf("Error: %v", err) 61 | } 62 | } 63 | 64 | type clientForwarder struct { 65 | conn client.Conn 66 | } 67 | 68 | func (self *clientForwarder) ProcessMessageContainer(mc *rpc.MessageContainer) error { 69 | err := self.conn.SendMessageToUsers(mc.Message, 1*time.Hour, mc.SenderService, mc.Sender) 70 | if err != nil { 71 | return err 72 | } 73 | return self.conn.SendMessageToServer(mc.Message) 74 | } 75 | 76 | func TestForwardRequestFromClientToServer(t *testing.T) { 77 | addr := "127.0.0.1:8088" 78 | token := "token" 79 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 80 | if err != nil { 81 | t.Errorf("Error: %v", err) 82 | } 83 | defer servConn.Close() 84 | defer cliConn.Close() 85 | N := 100 86 | mcs := make([]*rpc.MessageContainer, N) 87 | 88 | receiver := "receiver" 89 | receiverService := "someservice" 90 | 91 | for i := 0; i < N; i++ { 92 | mcs[i] = &rpc.MessageContainer{ 93 | Message: randomMessage(), 94 | Id: fmt.Sprintf("%v", i), 95 | Sender: receiver, // This is confusing. We hacked the struct. 96 | SenderService: receiverService, 97 | } 98 | } 99 | 100 | fwdChan := make(chan *rpc.ForwardRequest) 101 | 102 | servConn.SetForwardRequestChannel(fwdChan) 103 | src := &clientForwarder{ 104 | conn: cliConn, 105 | } 106 | 107 | dst := &serverReceiver{ 108 | conn: servConn, 109 | } 110 | 111 | go func() { 112 | i := 0 113 | for fwdreq := range fwdChan { 114 | mc := mcs[i] 115 | i++ 116 | if !mc.Message.Eq(fwdreq.Message) { 117 | t.Errorf("corrupted data") 118 | } 119 | if len(fwdreq.Receivers) != 1 { 120 | t.Errorf("receivers: %v", fwdreq.Receivers) 121 | } 122 | if fwdreq.Receivers[0] != receiver { 123 | t.Errorf("receiver is %v, not %v", fwdreq.Receivers, receiver) 124 | } 125 | if fwdreq.ReceiverService != receiverService { 126 | t.Errorf("receiver's service is %v, not %v", fwdreq.ReceiverService, receiverService) 127 | } 128 | } 129 | if i != N { 130 | t.Errorf("received only %v fwdreq", i) 131 | } 132 | }() 133 | err = iterateOverContainers(src, dst, mcs...) 134 | if err != nil { 135 | t.Errorf("Error: %v", err) 136 | } 137 | 138 | close(fwdChan) 139 | cliConn.Close() 140 | } 141 | -------------------------------------------------------------------------------- /proto/server/fwdproc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/proto" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | "strings" 24 | "time" 25 | ) 26 | 27 | type forwardProcessor struct { 28 | conn *serverConn 29 | fwdChan chan<- *rpc.ForwardRequest 30 | } 31 | 32 | func (self *forwardProcessor) ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 33 | if cmd == nil || cmd.Type != proto.CMD_FWD_REQ || self.conn == nil || self.fwdChan == nil { 34 | return 35 | } 36 | 37 | if len(cmd.Params) < 2 { 38 | err = proto.ErrBadPeerImpl 39 | return 40 | } 41 | if self.fwdChan == nil { 42 | return 43 | } 44 | fwdreq := new(rpc.ForwardRequest) 45 | fwdreq.Sender = self.conn.Username() 46 | fwdreq.SenderService = self.conn.Service() 47 | fwdreq.Message = cmd.Message 48 | fwdreq.TTL, err = time.ParseDuration(cmd.Params[0]) 49 | 50 | if err != nil { 51 | err = nil 52 | fwdreq.TTL = 72 * time.Hour 53 | } 54 | fwdreq.Receivers = strings.Split(cmd.Params[1], ",") 55 | if len(cmd.Params) > 2 { 56 | fwdreq.ReceiverService = cmd.Params[2] 57 | } else { 58 | fwdreq.ReceiverService = self.conn.Service() 59 | } 60 | self.fwdChan <- fwdreq 61 | return 62 | } 63 | -------------------------------------------------------------------------------- /proto/server/msgretriever.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/msgcache" 22 | "github.com/uniqush/uniqush-conn/proto" 23 | "github.com/uniqush/uniqush-conn/rpc" 24 | ) 25 | 26 | type messageRetriever struct { 27 | conn *serverConn 28 | cache msgcache.Cache 29 | } 30 | 31 | func (self *messageRetriever) ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 32 | if cmd == nil || cmd.Type != proto.CMD_MSG_RETRIEVE || self.conn == nil || self.cache == nil { 33 | return 34 | } 35 | if len(cmd.Params) < 1 { 36 | err = proto.ErrBadPeerImpl 37 | return 38 | } 39 | id := cmd.Params[0] 40 | mc, err := self.cache.Get(self.conn.Service(), self.conn.Username(), id) 41 | if err != nil { 42 | return 43 | } 44 | if mc == nil || mc.Message == nil { 45 | err = self.conn.SendMessage(nil, id, nil, false) 46 | return 47 | } 48 | if mc.FromServer() { 49 | err = self.conn.SendMessage(mc.Message, id, nil, false) 50 | } else { 51 | err = self.conn.ForwardMessage(mc.Sender, mc.SenderService, mc.Message, id, false) 52 | } 53 | return 54 | } 55 | -------------------------------------------------------------------------------- /proto/server/redir_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "crypto/rand" 22 | "crypto/rsa" 23 | "io" 24 | 25 | "github.com/uniqush/uniqush-conn/proto/client" 26 | "net" 27 | 28 | "testing" 29 | "time" 30 | ) 31 | 32 | func TestRedirectCommand(t *testing.T) { 33 | addr := "127.0.0.1:8088" 34 | token := "token" 35 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 36 | if err != nil { 37 | t.Errorf("Error: %v", err) 38 | } 39 | defer servConn.Close() 40 | defer cliConn.Close() 41 | 42 | go func() { 43 | servConn.ReceiveMessage() 44 | }() 45 | 46 | redirChan := make(chan *client.RedirectRequest) 47 | cliConn.SetRedirectChannel(redirChan) 48 | 49 | addresses := []string{"other-server.mydomain.com:8964", "others.com:8964"} 50 | 51 | go func() { 52 | cliConn.ReceiveMessage() 53 | }() 54 | 55 | servConn.Redirect(addresses...) 56 | 57 | go func() { 58 | for redir := range redirChan { 59 | if len(redir.Addresses) != len(addresses) { 60 | t.Errorf("Address length is not same: %v", len(redir.Addresses)) 61 | } 62 | 63 | for i, a := range redir.Addresses { 64 | if addresses[i] != a { 65 | t.Errorf("I got a weird address: %v", a) 66 | } 67 | } 68 | } 69 | 70 | }() 71 | close(redirChan) 72 | cliConn.Close() 73 | } 74 | 75 | type redirAuth struct { 76 | servers []string 77 | } 78 | 79 | func (self *redirAuth) Authenticate(srv, usr, connId, token, addr string) (bool, []string, error) { 80 | return false, self.servers, nil 81 | } 82 | 83 | func TestRedirectOnAuth(t *testing.T) { 84 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 85 | if err != nil { 86 | t.Errorf("Error: %v", err) 87 | return 88 | } 89 | pub := &priv.PublicKey 90 | 91 | servers := []string{"server1:1234", "server2:1234"} 92 | addr := "127.0.0.1:8088" 93 | 94 | ready := make(chan bool) 95 | 96 | go func() { 97 | ln, err := net.Listen("tcp", addr) 98 | if err != nil { 99 | t.Errorf("Error: %v", err) 100 | return 101 | } 102 | ready <- true 103 | c, err := ln.Accept() 104 | if err != nil { 105 | t.Errorf("Error: %v", err) 106 | return 107 | } 108 | ln.Close() 109 | 110 | auth := new(redirAuth) 111 | auth.servers = servers 112 | _, err = AuthConn(c, priv, auth, 3*time.Second) 113 | if err != io.EOF { 114 | t.Errorf("Error: should be EOF %v", err) 115 | return 116 | } 117 | }() 118 | 119 | <-ready 120 | c, err := net.Dial("tcp", addr) 121 | if err != nil { 122 | t.Errorf("Error: %v", err) 123 | return 124 | } 125 | _, err = client.Dial(c, pub, "service", "user", "token", 3*time.Second) 126 | if e, ok := err.(*client.RedirectRequest); ok { 127 | for i, a := range e.Addresses { 128 | if a != servers[i] { 129 | t.Errorf("%v is not %v", a, servers[i]) 130 | } 131 | } 132 | } else { 133 | t.Errorf("Should be a redirect request") 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /proto/server/retrieveall.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/msgcache" 22 | "github.com/uniqush/uniqush-conn/proto" 23 | "github.com/uniqush/uniqush-conn/rpc" 24 | "strconv" 25 | "time" 26 | ) 27 | 28 | type retriaveAllMessages struct { 29 | conn *serverConn 30 | cache msgcache.Cache 31 | } 32 | 33 | func cutString(data []byte) (str, rest []byte, err error) { 34 | var idx int 35 | var d byte 36 | idx = -1 37 | for idx, d = range data { 38 | if d == 0 { 39 | break 40 | } 41 | } 42 | if idx < 0 { 43 | err = proto.ErrMalformedCommand 44 | return 45 | } 46 | str = data[:idx] 47 | rest = data[idx+1:] 48 | return 49 | } 50 | 51 | func (self *retriaveAllMessages) sendAllCachedMessage(since time.Time) error { 52 | mcs, err := self.cache.RetrieveAllSince(self.conn.Service(), self.conn.Username(), since) 53 | if err != nil { 54 | return err 55 | } 56 | if len(mcs) == 0 { 57 | return nil 58 | } 59 | for _, mc := range mcs { 60 | if mc == nil { 61 | continue 62 | } 63 | if mc.FromServer() { 64 | err = self.conn.SendMessage(mc.Message, mc.Id, nil, true) 65 | } else { 66 | err = self.conn.ForwardMessage(mc.Sender, mc.SenderService, mc.Message, mc.Id, true) 67 | } 68 | } 69 | return nil 70 | } 71 | 72 | func (self *retriaveAllMessages) ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 73 | if cmd == nil || cmd.Type != proto.CMD_REQ_ALL_CACHED || self.conn == nil || self.cache == nil { 74 | return 75 | } 76 | since := time.Time{} 77 | if len(cmd.Params) > 0 { 78 | unix, err := strconv.ParseInt(cmd.Params[0], 10, 64) 79 | if err != nil { 80 | return nil, proto.ErrBadPeerImpl 81 | } 82 | since = time.Unix(unix, 0) 83 | } 84 | err = self.sendAllCachedMessage(since) 85 | return 86 | } 87 | -------------------------------------------------------------------------------- /proto/server/retrieveall_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "fmt" 22 | "sync" 23 | 24 | "github.com/uniqush/uniqush-conn/msgcache" 25 | "github.com/uniqush/uniqush-conn/rpc" 26 | 27 | "testing" 28 | "time" 29 | ) 30 | 31 | type serverCache struct { 32 | cache msgcache.Cache 33 | conn Conn 34 | } 35 | 36 | func (self *serverCache) ProcessMessageContainer(mc *rpc.MessageContainer) error { 37 | _, err := self.cache.CacheMessage(self.conn.Service(), self.conn.Username(), mc, 1*time.Hour) 38 | return err 39 | } 40 | 41 | func TestRequestAllCachedMessages(t *testing.T) { 42 | addr := "127.0.0.1:8088" 43 | token := "token" 44 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 45 | if err != nil { 46 | t.Errorf("Error: %v", err) 47 | } 48 | defer servConn.Close() 49 | defer cliConn.Close() 50 | 51 | cache := getCache() 52 | servConn.SetMessageCache(cache) 53 | 54 | N := 100 55 | mcs := make([]*rpc.MessageContainer, N) 56 | 57 | for i := 0; i < N; i++ { 58 | mcs[i] = &rpc.MessageContainer{ 59 | Message: randomMessage(), 60 | Id: fmt.Sprintf("%v", i), 61 | } 62 | _, err := cache.CacheMessage(servConn.Service(), servConn.Username(), mcs[i], 1*time.Hour) 63 | if err != nil { 64 | t.Errorf("Error: %v", err) 65 | } 66 | } 67 | 68 | wg := &sync.WaitGroup{} 69 | wg.Add(1) 70 | 71 | go func() { 72 | cliConn.RequestAllCachedMessages(time.Time{}) 73 | for _, _ = range mcs { 74 | rmc, err := cliConn.ReceiveMessage() 75 | if err != nil { 76 | t.Errorf("Error: %v", err) 77 | } 78 | 79 | found := false 80 | for _, m := range mcs { 81 | if rmc.Id == m.Id { 82 | if !rmc.Eq(m) { 83 | t.Errorf("corrupted data: %+v != %+v", rmc, m) 84 | } 85 | found = true 86 | } 87 | } 88 | if !found { 89 | t.Errorf("not found") 90 | } 91 | } 92 | wg.Done() 93 | }() 94 | 95 | go func() { 96 | servConn.ReceiveMessage() 97 | }() 98 | wg.Wait() 99 | } 100 | -------------------------------------------------------------------------------- /proto/server/settingsproc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/proto" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | "strconv" 24 | "sync/atomic" 25 | ) 26 | 27 | type settingProcessor struct { 28 | conn *serverConn 29 | } 30 | 31 | func (self *settingProcessor) ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 32 | if cmd.Type != proto.CMD_SETTING || self.conn == nil { 33 | return 34 | } 35 | if len(cmd.Params) < 2 { 36 | err = proto.ErrBadPeerImpl 37 | return 38 | } 39 | if len(cmd.Params[0]) > 0 { 40 | var d int 41 | d, err = strconv.Atoi(cmd.Params[0]) 42 | if err != nil { 43 | err = proto.ErrBadPeerImpl 44 | return 45 | } 46 | atomic.StoreInt32(&self.conn.digestThreshold, int32(d)) 47 | 48 | } 49 | if len(cmd.Params[1]) > 0 { 50 | var c int 51 | c, err = strconv.Atoi(cmd.Params[1]) 52 | if err != nil { 53 | err = proto.ErrBadPeerImpl 54 | return 55 | } 56 | atomic.StoreInt32(&self.conn.compressThreshold, int32(c)) 57 | } 58 | nrPreDigestFields := 2 59 | if len(cmd.Params) > nrPreDigestFields { 60 | self.conn.digestFielsLock.Lock() 61 | defer self.conn.digestFielsLock.Unlock() 62 | self.conn.digestFields = make([]string, len(cmd.Params)-nrPreDigestFields) 63 | for i, f := range cmd.Params[nrPreDigestFields:] { 64 | self.conn.digestFields[i] = f 65 | } 66 | } 67 | return 68 | } 69 | -------------------------------------------------------------------------------- /proto/server/sub_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/rpc" 22 | "sync" 23 | "testing" 24 | "time" 25 | ) 26 | 27 | func subeq(a *rpc.SubscribeRequest, b *rpc.SubscribeRequest) bool { 28 | if a.Subscribe != b.Subscribe { 29 | return false 30 | } 31 | if a.Service != b.Service { 32 | return false 33 | } 34 | if a.Username != b.Username { 35 | return false 36 | } 37 | if len(a.Params) != len(b.Params) { 38 | return false 39 | } 40 | for k, v := range a.Params { 41 | if v1, ok := b.Params[k]; ok { 42 | if v1 != v { 43 | return false 44 | } 45 | } else { 46 | return false 47 | } 48 | } 49 | return true 50 | } 51 | 52 | func TestSubscription(t *testing.T) { 53 | addr := "127.0.0.1:8088" 54 | token := "token" 55 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 56 | if err != nil { 57 | t.Errorf("Error: %v", err) 58 | } 59 | defer servConn.Close() 60 | defer cliConn.Close() 61 | subChan := make(chan *rpc.SubscribeRequest) 62 | servConn.SetSubscribeRequestChan(subChan) 63 | 64 | params := map[string]string{ 65 | "pushservicetype": "gcm", 66 | "regid": "someregid", 67 | } 68 | 69 | wg := &sync.WaitGroup{} 70 | go func() { 71 | err := cliConn.Subscribe(params) 72 | if err != nil { 73 | t.Errorf("sub error: %v", err) 74 | } 75 | err = cliConn.Unsubscribe(params) 76 | if err != nil { 77 | t.Errorf("unsub error: %v", err) 78 | } 79 | }() 80 | 81 | go func() { 82 | servConn.ReceiveMessage() 83 | }() 84 | 85 | wg.Add(1) 86 | go func() { 87 | subreq := <-subChan 88 | 89 | req := &rpc.SubscribeRequest{ 90 | Subscribe: true, 91 | Service: servConn.Service(), 92 | Username: servConn.Username(), 93 | Params: params, 94 | } 95 | 96 | if !subeq(req, subreq) { 97 | t.Errorf("%+v is wrong", subreq) 98 | } 99 | subreq = <-subChan 100 | req.Subscribe = false 101 | if !subeq(req, subreq) { 102 | t.Errorf("%+v is wrong", subreq) 103 | } 104 | wg.Done() 105 | }() 106 | 107 | wg.Wait() 108 | } 109 | -------------------------------------------------------------------------------- /proto/server/subproc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/proto" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | ) 24 | 25 | type subscribeProcessor struct { 26 | conn *serverConn 27 | subChan chan<- *rpc.SubscribeRequest 28 | } 29 | 30 | func (self *subscribeProcessor) ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 31 | if cmd == nil || cmd.Type != proto.CMD_SUBSCRIPTION || self.conn == nil || self.subChan == nil { 32 | return 33 | } 34 | if len(cmd.Params) < 1 { 35 | err = proto.ErrBadPeerImpl 36 | return 37 | } 38 | if cmd.Message == nil { 39 | err = proto.ErrBadPeerImpl 40 | return 41 | } 42 | if len(cmd.Message.Header) == 0 { 43 | err = proto.ErrBadPeerImpl 44 | return 45 | } 46 | sub := true 47 | if cmd.Params[0] == "0" { 48 | sub = false 49 | } else if cmd.Params[0] == "1" { 50 | sub = true 51 | } else { 52 | return 53 | } 54 | req := new(rpc.SubscribeRequest) 55 | req.Params = cmd.Message.Header 56 | req.Service = self.conn.Service() 57 | req.Username = self.conn.Username() 58 | req.Subscribe = sub 59 | self.subChan <- req 60 | return 61 | } 62 | -------------------------------------------------------------------------------- /proto/server/vis_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "fmt" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | "testing" 24 | "time" 25 | ) 26 | 27 | func TestSetInvisible(t *testing.T) { 28 | addr := "127.0.0.1:8088" 29 | token := "token" 30 | servConn, cliConn, err := buildServerClientConns(addr, token, 3*time.Second) 31 | if err != nil { 32 | t.Errorf("Error: %v", err) 33 | } 34 | defer servConn.Close() 35 | defer cliConn.Close() 36 | N := 1 37 | mcs := make([]*rpc.MessageContainer, N) 38 | 39 | for i := 0; i < N; i++ { 40 | mcs[i] = &rpc.MessageContainer{ 41 | Message: randomMessage(), 42 | Id: fmt.Sprintf("%v", i), 43 | } 44 | } 45 | 46 | src := &clientSender{ 47 | conn: cliConn, 48 | } 49 | 50 | dst := &serverReceiver{ 51 | conn: servConn, 52 | } 53 | 54 | cliConn.SetVisibility(false) 55 | err = iterateOverContainers(src, dst, mcs...) 56 | if err != nil { 57 | t.Errorf("Error: %v", err) 58 | } 59 | 60 | if servConn.Visible() { 61 | t.Errorf("Error: should be invisible") 62 | } 63 | cliConn.SetVisibility(true) 64 | err = iterateOverContainers(src, dst, mcs...) 65 | if err != nil { 66 | t.Errorf("Error: %v", err) 67 | } 68 | if !servConn.Visible() { 69 | t.Errorf("Error: should be visible") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /proto/server/visproc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package server 19 | 20 | import ( 21 | "github.com/uniqush/uniqush-conn/proto" 22 | "github.com/uniqush/uniqush-conn/rpc" 23 | "sync/atomic" 24 | ) 25 | 26 | type visibilityProcessor struct { 27 | conn *serverConn 28 | } 29 | 30 | func (self *visibilityProcessor) ProcessCommand(cmd *proto.Command) (msg *rpc.Message, err error) { 31 | if cmd == nil || cmd.Type != proto.CMD_SET_VISIBILITY { 32 | return 33 | } 34 | if len(cmd.Params) < 1 { 35 | err = proto.ErrBadPeerImpl 36 | return 37 | } 38 | if cmd.Params[0] == "0" { 39 | atomic.StoreInt32(&self.conn.visible, 0) 40 | } else if cmd.Params[0] == "1" { 41 | atomic.StoreInt32(&self.conn.visible, 1) 42 | } 43 | return 44 | } 45 | -------------------------------------------------------------------------------- /proto/utils.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import ( 21 | "errors" 22 | "hash" 23 | "io" 24 | 25 | "net" 26 | ) 27 | 28 | var ErrZeroEntropy = errors.New("Need more random number") 29 | var ErrImcompatibleProtocol = errors.New("incompatible protocol") 30 | var ErrBadServer = errors.New("Unkown Server") 31 | var ErrCorruptedData = errors.New("corrupted data") 32 | var ErrBadKeyExchangePacket = errors.New("Bad Key-exchange Packet") 33 | var ErrBadPeerImpl = errors.New("bad protocol implementation on peer") 34 | 35 | // incCounter increments a four byte, big-endian counter. 36 | func incCounter(c *[4]byte) { 37 | if c[3]++; c[3] != 0 { 38 | return 39 | } 40 | if c[2]++; c[2] != 0 { 41 | return 42 | } 43 | if c[1]++; c[1] != 0 { 44 | return 45 | } 46 | c[0]++ 47 | } 48 | 49 | // mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function 50 | // specified in PKCS#1 v2.1. 51 | // out = out xor MGF1(seed, hash) 52 | func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { 53 | var counter [4]byte 54 | var digest []byte 55 | 56 | done := 0 57 | for done < len(out) { 58 | hash.Write(seed) 59 | hash.Write(counter[0:4]) 60 | digest = hash.Sum(digest[:0]) 61 | hash.Reset() 62 | 63 | for i := 0; i < len(digest) && done < len(out); i++ { 64 | out[done] ^= digest[i] 65 | done++ 66 | } 67 | incCounter(&counter) 68 | } 69 | } 70 | 71 | func clearBytes(data []byte) { 72 | for i, _ := range data { 73 | data[i] = byte(0) 74 | } 75 | } 76 | 77 | func leftPaddingZero(data []byte, l int) []byte { 78 | if len(data) >= l { 79 | return data 80 | } 81 | ret := make([]byte, l-len(data), l) 82 | ret = append(ret, data...) 83 | return ret 84 | } 85 | 86 | // copyWithLeftPad copies src to the end of dest, padding with zero bytes as 87 | // needed. 88 | func copyWithLeftPad(dest, src []byte) { 89 | numPaddingBytes := len(dest) - len(src) 90 | for i := 0; i < numPaddingBytes; i++ { 91 | dest[i] = 0 92 | } 93 | copy(dest[numPaddingBytes:], src) 94 | } 95 | 96 | func xorBytes(longer, shorter []byte) []byte { 97 | if len(longer) == 0 { 98 | return nil 99 | } 100 | ret := make([]byte, len(longer)) 101 | 102 | for i, c := range shorter { 103 | ret[i] = c ^ longer[i] 104 | } 105 | 106 | for i := len(shorter); i < len(longer); i++ { 107 | ret[i] = longer[i] 108 | } 109 | 110 | return ret 111 | } 112 | 113 | func xorBytesEq(a, b []byte) bool { 114 | if len(b) > len(a) { 115 | // to prevent timing attack 116 | xorBytes(b, a) 117 | return false 118 | } else if len(a) > len(b) { 119 | xorBytes(a, b) 120 | return false 121 | } 122 | 123 | x := xorBytes(a, b) 124 | if x == nil { 125 | return true 126 | } 127 | 128 | ret := true 129 | for _, c := range x { 130 | if c != 0 { 131 | ret = false 132 | } 133 | } 134 | return ret 135 | } 136 | 137 | func writen(w io.Writer, buf []byte) error { 138 | n := len(buf) 139 | for n >= 0 { 140 | l, err := w.Write(buf) 141 | if err != nil { 142 | if ne, ok := err.(net.Error); ok { 143 | if ne.Temporary() { 144 | continue 145 | } 146 | } 147 | return err 148 | } 149 | if l >= n { 150 | return nil 151 | } 152 | n -= l 153 | buf = buf[l:] 154 | } 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /proto/utils_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package proto 19 | 20 | import "testing" 21 | 22 | func TestXorBytesEq(t *testing.T) { 23 | a := []byte{0, 2, 1, 3} 24 | b := []byte{0, 2, 1, 3} 25 | 26 | if !xorBytesEq(a, b) { 27 | t.Errorf("should be eq") 28 | } 29 | if !xorBytesEq(b, a) { 30 | t.Errorf("should be eq") 31 | } 32 | if !xorBytesEq(nil, nil) { 33 | t.Errorf("should be eq") 34 | } 35 | 36 | b = []byte{1, 2, 3, 4} 37 | if xorBytesEq(a, b) { 38 | t.Errorf("should not be eq") 39 | } 40 | if xorBytesEq(b, a) { 41 | t.Errorf("should not be eq") 42 | } 43 | 44 | b = []byte{1, 2, 3} 45 | if xorBytesEq(a, b) { 46 | t.Errorf("should not be eq") 47 | } 48 | if xorBytesEq(b, a) { 49 | t.Errorf("should not be eq") 50 | } 51 | 52 | if xorBytesEq(a, nil) { 53 | t.Errorf("should not be eq") 54 | } 55 | if xorBytesEq(b, nil) { 56 | t.Errorf("should not be eq") 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /push/push.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package push 19 | 20 | import ( 21 | "bufio" 22 | "fmt" 23 | "net" 24 | "net/http" 25 | "net/url" 26 | "strconv" 27 | "strings" 28 | "time" 29 | ) 30 | 31 | // TODO: Use decorator pattern to implement an aggregate Push interface 32 | 33 | type Push interface { 34 | Subscribe(service, username string, info map[string]string) error 35 | Unsubscribe(service, username string, info map[string]string) error 36 | Push(service, username, senderService, senderUsername string, info map[string]string, id string, size int) error 37 | } 38 | 39 | type uniqushPush struct { 40 | addr string 41 | timeout time.Duration 42 | } 43 | 44 | func NewUniqushPushClient(addr string, timeout time.Duration) Push { 45 | ret := new(uniqushPush) 46 | ret.addr = addr 47 | ret.timeout = timeout 48 | return ret 49 | } 50 | 51 | func timeoutDialler(ns time.Duration) func(net, addr string) (c net.Conn, err error) { 52 | return func(netw, addr string) (net.Conn, error) { 53 | c, err := net.Dial(netw, addr) 54 | if err != nil { 55 | return nil, err 56 | } 57 | if ns.Seconds() > 0.0 { 58 | c.SetDeadline(time.Now().Add(ns)) 59 | } 60 | return c, nil 61 | } 62 | } 63 | 64 | func (self *uniqushPush) postReadLines(path string, data url.Values, nrLines int) (value string, err error) { 65 | if len(path) == 0 { 66 | return 67 | } 68 | 69 | url := fmt.Sprintf("http://%v/%v", self.addr, path) 70 | 71 | c := http.Client{ 72 | Transport: &http.Transport{ 73 | Dial: timeoutDialler(self.timeout), 74 | }, 75 | } 76 | resp, err := c.PostForm(url, data) 77 | if err != nil { 78 | return 79 | } 80 | defer resp.Body.Close() 81 | if nrLines > 0 { 82 | respBuf := bufio.NewReader(resp.Body) 83 | line := make([]byte, 0, nrLines*512) 84 | for i := 0; i < nrLines; i++ { 85 | l, _, e := respBuf.ReadLine() 86 | if e != nil { 87 | err = e 88 | return 89 | } 90 | line = append(line, l...) 91 | } 92 | value = string(line) 93 | } 94 | return 95 | } 96 | 97 | func (self *uniqushPush) post(path string, data url.Values) error { 98 | _, err := self.postReadLines(path, data, 0) 99 | return err 100 | } 101 | 102 | func (self *uniqushPush) subscribe(service, username string, info map[string]string, sub bool) error { 103 | data := url.Values{} 104 | data.Add("service", service) 105 | data.Add("subscriber", username) 106 | 107 | for k, v := range info { 108 | switch k { 109 | case "pushservicetype": 110 | fallthrough 111 | case "regid": 112 | fallthrough 113 | case "devtoken": 114 | fallthrough 115 | case "account": 116 | data.Add(k, v) 117 | } 118 | } 119 | path := "unsubscribe" 120 | if sub { 121 | path = "subscribe" 122 | } 123 | err := self.post(path, data) 124 | return err 125 | } 126 | 127 | func (self *uniqushPush) NrDeliveryPoints(service, username string) int { 128 | data := url.Values{} 129 | data.Add("service", service) 130 | data.Add("subscriber", username) 131 | v, err := self.postReadLines("nrdp", data, 1) 132 | if err != nil { 133 | return 0 134 | } 135 | n, err := strconv.Atoi(strings.TrimSpace(v)) 136 | if err != nil { 137 | return 0 138 | } 139 | return n 140 | } 141 | 142 | func (self *uniqushPush) Subscribe(service, username string, info map[string]string) error { 143 | return self.subscribe(service, username, info, true) 144 | } 145 | 146 | func (self *uniqushPush) Unsubscribe(service, username string, info map[string]string) error { 147 | return self.subscribe(service, username, info, false) 148 | } 149 | 150 | func (self *uniqushPush) Push(service, username, senderService, senderUsername string, info map[string]string, id string, size int) error { 151 | if len(service) == 0 { 152 | return fmt.Errorf("NoService") 153 | } 154 | if len(username) == 0 { 155 | return fmt.Errorf("NoReceiver") 156 | } 157 | data := url.Values{} 158 | for k, v := range info { 159 | if strings.HasPrefix(strings.ToLower(k), "uq.") || 160 | strings.HasPrefix(strings.ToLower(k), "uniqush.") { 161 | // reserved prefixes. 162 | continue 163 | } 164 | data.Set(k, v) 165 | } 166 | 167 | // The format of the parameter string is: 168 | // id,size,service,username,senderService,senderUsername 169 | // The last two are optional if the message is sent from the server (not forwarded by another user) 170 | // This id part is hex number --- so that we can save space. 171 | param := make([]rune, 0, len(id)+len(service)+len(username)+len(senderService)+len(senderUsername)+32) 172 | 173 | param = append(param, []rune(id)...) 174 | param = append(param, rune(',')) 175 | param = append(param, []rune(fmt.Sprintf("%x,", size))...) 176 | param = append(param, []rune(service)...) 177 | param = append(param, rune(',')) 178 | param = append(param, []rune(username)...) 179 | if len(senderUsername) > 0 { 180 | param = append(param, rune(',')) 181 | if len(senderService) > 0 { 182 | param = append(param, []rune(senderService)...) 183 | param = append(param, rune(',')) 184 | } else { 185 | param = append(param, []rune(service)...) 186 | param = append(param, rune(',')) 187 | } 188 | param = append(param, []rune(senderUsername)...) 189 | } 190 | data.Set("uq.", string(param)) 191 | data.Set("service", service) 192 | data.Set("subscriber", username) 193 | 194 | err := self.post("push", data) 195 | return err 196 | } 197 | -------------------------------------------------------------------------------- /rpc/fwdreq.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import "time" 21 | 22 | type ForwardRequest struct { 23 | NeverDigest bool `json:"never-digest,omitempty"` 24 | DontPropagate bool `json:"dont-propagate,omitempty"` 25 | DontPush bool `json:"dont-push,omitempty"` 26 | DontCache bool `json:"dont-cache,omitempty"` 27 | 28 | DontAsk bool `json:"dont-ask-permission,omitempty"` 29 | 30 | Receivers []string `json:"receivers"` 31 | ReceiverService string `json:"receiver-service"` 32 | TTL time.Duration `json:"ttl"` 33 | MessageContainer 34 | } 35 | -------------------------------------------------------------------------------- /rpc/mc.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import "time" 21 | 22 | // MessageContainer is used to represent a message inside 23 | // the program. It has meta-data about a message like: 24 | // the message id, the sender and the service of the sender. 25 | type MessageContainer struct { 26 | Message *Message `json:"msg"` 27 | Id string `json:"id,omitempty"` 28 | Sender string `json:"s,omitempty"` 29 | SenderService string `json:"ss,omitempty"` 30 | Birthday time.Time `json:"b,omitempty"` 31 | } 32 | 33 | func (self *MessageContainer) FromServer() bool { 34 | return len(self.Sender) == 0 35 | } 36 | 37 | func (self *MessageContainer) FromUser() bool { 38 | return !self.FromServer() 39 | } 40 | 41 | func (a *MessageContainer) Eq(b *MessageContainer) bool { 42 | if a.Id != b.Id { 43 | return false 44 | } 45 | if a.Sender != b.Sender { 46 | return false 47 | } 48 | if a.SenderService != b.SenderService { 49 | return false 50 | } 51 | return a.Message.Eq(b.Message) 52 | } 53 | -------------------------------------------------------------------------------- /rpc/msg.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import "bytes" 21 | 22 | type Message struct { 23 | Header map[string]string `json:"header,omitempty"` 24 | Body []byte `json:"body,omitempty"` 25 | } 26 | 27 | func (self *Message) IsEmpty() bool { 28 | if self == nil { 29 | return true 30 | } 31 | return len(self.Header) == 0 && len(self.Body) == 0 32 | } 33 | 34 | func (self *Message) Size() int { 35 | if self == nil { 36 | return 0 37 | } 38 | ret := len(self.Body) 39 | for k, v := range self.Header { 40 | ret += len(k) + 1 41 | ret += len(v) + 1 42 | } 43 | ret += 8 44 | return ret 45 | } 46 | 47 | func (a *Message) Eq(b *Message) bool { 48 | if a == nil { 49 | if b == nil { 50 | return true 51 | } else { 52 | return false 53 | } 54 | } 55 | if len(a.Header) != len(b.Header) { 56 | return false 57 | } 58 | for k, v := range a.Header { 59 | if bv, ok := b.Header[k]; ok { 60 | if bv != v { 61 | return false 62 | } 63 | } else { 64 | return false 65 | } 66 | } 67 | return bytes.Equal(a.Body, b.Body) 68 | } 69 | 70 | /* 71 | func (a *Message) Eq(b *Message) bool { 72 | if !a.EqContent(b) { 73 | return false 74 | } 75 | if a.Id != b.Id { 76 | return false 77 | } 78 | if a.Sender != b.Sender { 79 | return false 80 | } 81 | if a.SenderService != b.SenderService { 82 | return false 83 | } 84 | return true 85 | } 86 | */ 87 | -------------------------------------------------------------------------------- /rpc/multipeer.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import ( 21 | "fmt" 22 | "math/rand" 23 | "sync" 24 | "time" 25 | ) 26 | 27 | type MultiPeer struct { 28 | peers []UniqushConnPeer 29 | lock sync.RWMutex 30 | id string 31 | } 32 | 33 | func NewMultiPeer(peers ...UniqushConnPeer) *MultiPeer { 34 | ret := new(MultiPeer) 35 | ret.peers = peers 36 | ret.id = fmt.Sprintf("%v-%v", time.Now().UnixNano(), rand.Int63()) 37 | return ret 38 | } 39 | 40 | func (self *MultiPeer) Id() string { 41 | return self.id 42 | } 43 | 44 | func (self *MultiPeer) AddPeer(p UniqushConnPeer) { 45 | if self == nil { 46 | return 47 | } 48 | self.lock.Lock() 49 | defer self.lock.Unlock() 50 | 51 | peerId := p.Id() 52 | 53 | // We don't want to add same peer twice 54 | for _, i := range self.peers { 55 | if i.Id() == peerId { 56 | return 57 | } 58 | } 59 | self.peers = append(self.peers, p) 60 | } 61 | 62 | func (self *MultiPeer) do(f func(p UniqushConnPeer) *Result) *Result { 63 | if self == nil { 64 | return nil 65 | } 66 | ret := new(Result) 67 | self.lock.RLock() 68 | defer self.lock.RUnlock() 69 | 70 | for _, p := range self.peers { 71 | r := f(p) 72 | if r == nil { 73 | continue 74 | } 75 | if r.Error != "" { 76 | ret.Error = r.Error 77 | return ret 78 | } 79 | 80 | ret.Results = append(ret.Results, r.Results...) 81 | } 82 | return ret 83 | } 84 | 85 | func (self *MultiPeer) Send(req *SendRequest) *Result { 86 | return self.do(func(p UniqushConnPeer) *Result { 87 | return p.Send(req) 88 | }) 89 | } 90 | 91 | func (self *MultiPeer) Forward(req *ForwardRequest) *Result { 92 | return self.do(func(p UniqushConnPeer) *Result { 93 | return p.Forward(req) 94 | }) 95 | } 96 | 97 | func (self *MultiPeer) Redirect(req *RedirectRequest) *Result { 98 | return self.do(func(p UniqushConnPeer) *Result { 99 | return p.Redirect(req) 100 | }) 101 | } 102 | 103 | func (self *MultiPeer) CheckUserStatus(req *UserStatusQuery) *Result { 104 | return self.do(func(p UniqushConnPeer) *Result { 105 | return p.CheckUserStatus(req) 106 | }) 107 | } 108 | -------------------------------------------------------------------------------- /rpc/peer.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import ( 21 | "bytes" 22 | "encoding/json" 23 | "fmt" 24 | "net" 25 | "net/http" 26 | "net/url" 27 | "time" 28 | ) 29 | 30 | const ( 31 | SEND_MESSAGE_PATH = "/send.json" 32 | FORWARD_MESSAGE_PATH = "/fwd.json" 33 | REDIRECT_CLIENT_PATH = "/redir.json" 34 | USER_STATUS_QUERY_PATH = "/user-status.json" 35 | ) 36 | 37 | type UniqushConnPeer interface { 38 | Send(req *SendRequest) *Result 39 | Forward(req *ForwardRequest) *Result 40 | Redirect(req *RedirectRequest) *Result 41 | CheckUserStatus(req *UserStatusQuery) *Result 42 | Id() string 43 | } 44 | 45 | type UniqushConnInstance struct { 46 | Addr string `json:"addr"` 47 | Timeout time.Duration `json:"timeout,omitempty"` 48 | } 49 | 50 | func NewUniqushConnInstance(u *url.URL, timeout time.Duration) (instance *UniqushConnInstance, err error) { 51 | if u.Scheme != "http" && u.Scheme != "https" { 52 | return nil, fmt.Errorf("%v is not supported", u.Scheme) 53 | } 54 | instance = new(UniqushConnInstance) 55 | if timeout < 3*time.Second { 56 | timeout = 3 * time.Second 57 | } 58 | instance.Timeout = timeout 59 | instance.Addr = u.String() 60 | return 61 | } 62 | 63 | func (self *UniqushConnInstance) Id() string { 64 | return self.Addr 65 | } 66 | 67 | func timeoutDialler(ns time.Duration) func(net, addr string) (c net.Conn, err error) { 68 | return func(netw, addr string) (net.Conn, error) { 69 | c, err := net.Dial(netw, addr) 70 | if err != nil { 71 | return nil, err 72 | } 73 | if ns.Seconds() > 0.0 { 74 | c.SetDeadline(time.Now().Add(ns)) 75 | } 76 | return c, nil 77 | } 78 | } 79 | 80 | func (self *UniqushConnInstance) post(url string, data interface{}, out interface{}) int { 81 | if len(url) == 0 || url == "none" { 82 | return 400 83 | } 84 | jdata, err := json.Marshal(data) 85 | if err != nil { 86 | return 400 87 | } 88 | c := http.Client{ 89 | Transport: &http.Transport{ 90 | Dial: timeoutDialler(self.Timeout), 91 | }, 92 | } 93 | resp, err := c.Post(url, "application/json", bytes.NewReader(jdata)) 94 | if err != nil { 95 | return 400 96 | } 97 | defer resp.Body.Close() 98 | 99 | if out != nil { 100 | e := json.NewDecoder(resp.Body) 101 | err = e.Decode(out) 102 | if err != nil { 103 | return 400 104 | } 105 | } 106 | return resp.StatusCode 107 | } 108 | 109 | func (self *UniqushConnInstance) requestThenResult(path string, req interface{}) *Result { 110 | result := new(Result) 111 | status := self.post(self.Addr+path, req, result) 112 | if status != 200 { 113 | return nil 114 | } 115 | return result 116 | } 117 | 118 | func (self *UniqushConnInstance) Send(req *SendRequest) *Result { 119 | return self.requestThenResult(SEND_MESSAGE_PATH, req) 120 | } 121 | 122 | func (self *UniqushConnInstance) Forward(req *ForwardRequest) *Result { 123 | return self.requestThenResult(FORWARD_MESSAGE_PATH, req) 124 | } 125 | 126 | func (self *UniqushConnInstance) Redirect(req *RedirectRequest) *Result { 127 | return self.requestThenResult(REDIRECT_CLIENT_PATH, req) 128 | } 129 | 130 | func (self *UniqushConnInstance) CheckUserStatus(req *UserStatusQuery) *Result { 131 | return self.requestThenResult(USER_STATUS_QUERY_PATH, req) 132 | } 133 | -------------------------------------------------------------------------------- /rpc/redirreq.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | type RedirectRequest struct { 21 | DontPropagate bool `json:"dont-propagate,omitempty` 22 | Receiver string `json:"receivers"` 23 | ReceiverService string `json:"receiver-service"` 24 | ConnId string `json:"conn-id"` 25 | CandidateSersers []string `json:"candidate-servers"` 26 | } 27 | -------------------------------------------------------------------------------- /rpc/result.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import "net" 21 | 22 | type Result struct { 23 | Error string `json:"error,omitempty"` 24 | Results []*ConnResult `json:"results,omitempty"` 25 | } 26 | 27 | func (self *Result) SetError(err error) { 28 | if self == nil || err == nil { 29 | return 30 | } 31 | self.Error = err.Error() 32 | } 33 | 34 | func (self *Result) NrResults() int { 35 | if self == nil { 36 | return 0 37 | } 38 | return len(self.Results) 39 | } 40 | 41 | func (self *Result) NrSuccess() int { 42 | if self == nil { 43 | return 0 44 | } 45 | ret := 0 46 | for _, r := range self.Results { 47 | if r.Error == "" { 48 | ret++ 49 | } 50 | } 51 | return ret 52 | } 53 | 54 | func (self *Result) NrSuccessForUser(service, user string) int { 55 | if self == nil { 56 | return 0 57 | } 58 | ret := 0 59 | for _, r := range self.Results { 60 | if r.Service == service && r.Username == user && r.Error == "" { 61 | ret += 1 62 | } 63 | } 64 | return ret 65 | } 66 | 67 | func (self *Result) Join(r *Result) { 68 | if self == nil { 69 | return 70 | } 71 | if r == nil { 72 | return 73 | } 74 | if self.Error != "" { 75 | return 76 | } 77 | if r.Error != "" { 78 | self.Error = r.Error 79 | return 80 | } 81 | self.Results = append(self.Results, r.Results...) 82 | } 83 | 84 | type connDescriptor interface { 85 | RemoteAddr() net.Addr 86 | Service() string 87 | Username() string 88 | UniqId() string 89 | Visible() bool 90 | } 91 | 92 | func (self *Result) Append(c connDescriptor, err error) { 93 | if self == nil { 94 | return 95 | } 96 | if self.Results == nil { 97 | self.Results = make([]*ConnResult, 0, 10) 98 | } 99 | r := new(ConnResult) 100 | r.ConnId = c.UniqId() 101 | if err != nil { 102 | r.Error = err.Error() 103 | } 104 | r.Visible = c.Visible() 105 | r.Username = c.Username() 106 | r.Service = c.Service() 107 | r.Address = c.RemoteAddr().String() 108 | self.Results = append(self.Results, r) 109 | } 110 | 111 | type ConnResult struct { 112 | Address string `json:"address"` 113 | ConnId string `json:"conn-id"` 114 | Error string `json:"error,omitempty"` 115 | Visible bool `json:"visible"` 116 | Username string `josn:"username"` 117 | Service string `json:"service"` 118 | } 119 | -------------------------------------------------------------------------------- /rpc/sendreq.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | import "time" 21 | 22 | type SendRequest struct { 23 | NeverDigest bool `json:"never-digest,omitempty"` 24 | DontPropagate bool `json:"dont-propagate,omitempty"` 25 | DontCache bool `json:"dont-cache,omitempty"` 26 | DontPush bool `json:"dont-push,omitempty"` 27 | 28 | Receivers []string `json:"receivers"` 29 | ReceiverService string `json:"receiver-service"` 30 | TTL time.Duration `json:"ttl"` 31 | Id string `json:"id,omitempty"` 32 | 33 | PushInfo map[string]string `json:"extra-push-info,omitempty"` 34 | Message *Message `json:"msg"` 35 | } 36 | -------------------------------------------------------------------------------- /rpc/subreq.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2012 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | type SubscribeRequest struct { 21 | Subscribe bool // false: unsubscribe; true: subscribe 22 | Service string 23 | Username string 24 | Params map[string]string 25 | } 26 | -------------------------------------------------------------------------------- /rpc/usrstatus.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2013 Nan Deng 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package rpc 19 | 20 | type UserStatusQuery struct { 21 | DontPropagate bool `json:"dont-propagate,omitempty"` 22 | Service string `json:"service"` 23 | Username string `json:"username"` 24 | } 25 | -------------------------------------------------------------------------------- /tools/connect-cluster/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "encoding/json" 7 | "flag" 8 | "fmt" 9 | "github.com/uniqush/uniqush-conn/rpc" 10 | "io" 11 | "io/ioutil" 12 | "net/http" 13 | "os" 14 | "time" 15 | ) 16 | 17 | func readInstanceList(r io.Reader) (list []string, err error) { 18 | scanner := bufio.NewScanner(r) 19 | list = make([]string, 0, 30) 20 | 21 | for scanner.Scan() { 22 | list = append(list, scanner.Text()) 23 | } 24 | err = scanner.Err() 25 | return 26 | } 27 | 28 | var flagInputFile = flag.String("f", "", "input file (stdin by default)") 29 | var flagTimeout = flag.Duration("timeout", 3*time.Second, "timeout") 30 | 31 | func main() { 32 | flag.Parse() 33 | var r io.ReadCloser 34 | r = os.Stdin 35 | if *flagInputFile != "" { 36 | var err error 37 | r, err = os.Open(*flagInputFile) 38 | if err != nil { 39 | fmt.Fprintf(os.Stderr, "Error: %v\n", err) 40 | return 41 | } 42 | } 43 | 44 | list, err := readInstanceList(r) 45 | if err != nil { 46 | fmt.Fprintf(os.Stderr, "Error: %v\n", err) 47 | return 48 | } 49 | req := &rpc.UniqushConnInstance{} 50 | 51 | for i, target := range list { 52 | fmt.Printf("Adding peers for %v...\n", target) 53 | if target == "" { 54 | continue 55 | } 56 | for j, peer := range list { 57 | if peer == "" { 58 | continue 59 | } 60 | if i == j { 61 | // uniqush-conn can perfectly handle this situation. 62 | // But why should we bother it? We can skip this condition easily 63 | continue 64 | } 65 | req.Addr = peer 66 | req.Timeout = *flagTimeout 67 | data, err := json.Marshal(req) 68 | if err != nil { 69 | fmt.Fprintf(os.Stderr, "Error: %v\n", err) 70 | continue 71 | } 72 | resp, err := http.Post(target+"/join.json", "application/json", bytes.NewReader(data)) 73 | if err != nil { 74 | fmt.Fprintf(os.Stderr, "Error: %v\n", err) 75 | continue 76 | } 77 | body, err := ioutil.ReadAll(resp.Body) 78 | if err != nil { 79 | fmt.Fprintf(os.Stderr, "Error: %v\n", err) 80 | continue 81 | } 82 | 83 | fmt.Printf("\t%v: %v", peer, string(body)) 84 | } 85 | } 86 | } 87 | --------------------------------------------------------------------------------