├── .gitignore ├── README.md ├── client ├── client.go └── client_option.go ├── codec └── codec.go ├── go.mod ├── go.sum ├── main └── main.go ├── protocol └── protocol.go ├── server ├── server.go └── server_option.go └── transport └── transport.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rpc-demo 2 | 参考 3 | 4 | [从零开始实现一个RPC框架(零)](https://juejin.im/post/5c7b9967518825470368d8d4) 5 | 6 | [从零开始实现一个RPC框架(一) 7 | ](https://juejin.im/post/5c7bcdb2e51d452f5a38e461) -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/megaredfan/rpc-demo/codec" 7 | "github.com/megaredfan/rpc-demo/protocol" 8 | "github.com/megaredfan/rpc-demo/transport" 9 | "io" 10 | "log" 11 | "strings" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | ) 16 | 17 | var ErrorShutdown = errors.New("client is shut down") 18 | 19 | type RPCClient interface { 20 | Go(ctx context.Context, serviceMethod string, arg interface{}, reply interface{}, done chan *Call) *Call 21 | Call(ctx context.Context, serviceMethod string, arg interface{}, reply interface{}) error 22 | Close() error 23 | } 24 | 25 | type Call struct { 26 | ServiceMethod string // 服务名.方法名 27 | Args interface{} // 参数 28 | Reply interface{} // 返回值(指针类型) 29 | Error error // 错误信息 30 | Done chan *Call // 在调用结束时激活 31 | } 32 | 33 | type simpleClient struct { 34 | codec codec.Codec 35 | rwc io.ReadWriteCloser 36 | pendingCalls sync.Map 37 | mutex sync.Mutex 38 | shutdown bool 39 | option Option 40 | seq uint64 41 | } 42 | 43 | func NewRPCClient(network string, addr string, option Option) (RPCClient, error) { 44 | client := new(simpleClient) 45 | client.option = option 46 | 47 | client.codec = codec.GetCodec(option.SerializeType) 48 | 49 | tr := transport.NewTransport(option.TransportType) 50 | err := tr.Dial(network, addr) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | client.rwc = tr 56 | 57 | go client.input() 58 | return client, nil 59 | } 60 | 61 | func (c *Call) done() { 62 | c.Done <- c 63 | } 64 | 65 | func (c *simpleClient) Go(ctx context.Context, serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { 66 | call := new(Call) 67 | call.ServiceMethod = serviceMethod 68 | call.Args = args 69 | call.Reply = reply 70 | 71 | if done == nil { 72 | done = make(chan *Call, 10) // buffered. 73 | } else { 74 | if cap(done) == 0 { 75 | log.Panic("rpc: done channel is unbuffered") 76 | } 77 | } 78 | call.Done = done 79 | 80 | c.send(ctx, call) 81 | 82 | return call 83 | } 84 | 85 | func (c *simpleClient) Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error { 86 | seq := atomic.AddUint64(&c.seq, 1) 87 | ctx = context.WithValue(ctx, protocol.RequestSeqKey, seq) 88 | 89 | canFn := func() {} 90 | if c.option.RequestTimeout != time.Duration(0) { 91 | ctx, canFn = context.WithTimeout(ctx, c.option.RequestTimeout) 92 | metaDataInterface := ctx.Value(protocol.MetaDataKey) 93 | var metaData map[string]string 94 | if metaDataInterface == nil { 95 | metaData = make(map[string]string) 96 | } else { 97 | metaData = metaDataInterface.(map[string]string) 98 | } 99 | metaData[protocol.RequestTimeoutKey] = c.option.RequestTimeout.String() 100 | ctx = context.WithValue(ctx, protocol.MetaDataKey, metaData) 101 | } 102 | 103 | done := make(chan *Call, 1) 104 | call := c.Go(ctx, serviceMethod, args, reply, done) 105 | select { 106 | case <-ctx.Done(): 107 | canFn() 108 | c.pendingCalls.Delete(seq) 109 | call.Error = errors.New("client request time out") 110 | case <-call.Done: 111 | } 112 | return call.Error 113 | } 114 | 115 | func (c *simpleClient) Close() error { 116 | c.mutex.Lock() 117 | defer c.mutex.Unlock() 118 | c.shutdown = true 119 | 120 | c.pendingCalls.Range(func(key, value interface{}) bool { 121 | call, ok := value.(*Call) 122 | if ok { 123 | call.Error = ErrorShutdown 124 | call.done() 125 | } 126 | 127 | c.pendingCalls.Delete(key) 128 | return true 129 | }) 130 | return nil 131 | } 132 | 133 | func (c *simpleClient) send(ctx context.Context, call *Call) { 134 | seq := ctx.Value(protocol.RequestSeqKey).(uint64) 135 | 136 | c.pendingCalls.Store(seq, call) 137 | 138 | request := protocol.NewMessage(c.option.ProtocolType) 139 | request.Seq = seq 140 | request.MessageType = protocol.MessageTypeRequest 141 | serviceMethod := strings.SplitN(call.ServiceMethod, ".", 2) 142 | request.ServiceName = serviceMethod[0] 143 | request.MethodName = serviceMethod[1] 144 | request.SerializeType = codec.MessagePack 145 | request.CompressType = protocol.CompressTypeNone 146 | if ctx.Value(protocol.MetaDataKey) != nil { 147 | request.MetaData = ctx.Value(protocol.MetaDataKey).(map[string]string) 148 | } 149 | 150 | requestData, err := c.codec.Encode(call.Args) 151 | if err != nil { 152 | log.Println(err) 153 | c.pendingCalls.Delete(seq) 154 | call.Error = err 155 | call.done() 156 | return 157 | } 158 | request.Data = requestData 159 | 160 | data := protocol.EncodeMessage(c.option.ProtocolType, request) 161 | 162 | _, err = c.rwc.Write(data) 163 | if err != nil { 164 | log.Println(err) 165 | c.pendingCalls.Delete(seq) 166 | call.Error = err 167 | call.done() 168 | return 169 | } 170 | } 171 | 172 | func (c *simpleClient) input() { 173 | var err error 174 | var response *protocol.Message 175 | for err == nil { 176 | response, err = protocol.DecodeMessage(c.option.ProtocolType, c.rwc) 177 | if err != nil { 178 | break 179 | } 180 | 181 | seq := response.Seq 182 | callInterface, _ := c.pendingCalls.Load(seq) 183 | call := callInterface.(*Call) 184 | c.pendingCalls.Delete(seq) 185 | 186 | switch { 187 | case call == nil: 188 | //请求已经被清理掉了,可能是已经超时了 189 | case response.Error != "": 190 | call.Error = errors.New(response.Error) 191 | call.done() 192 | default: 193 | err = c.codec.Decode(response.Data, call.Reply) 194 | if err != nil { 195 | call.Error = errors.New("reading body " + err.Error()) 196 | } 197 | call.done() 198 | } 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /client/client_option.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/megaredfan/rpc-demo/codec" 5 | "github.com/megaredfan/rpc-demo/protocol" 6 | "github.com/megaredfan/rpc-demo/transport" 7 | "time" 8 | ) 9 | 10 | type Option struct { 11 | ProtocolType protocol.ProtocolType 12 | SerializeType codec.SerializeType 13 | CompressType protocol.CompressType 14 | TransportType transport.TransportType 15 | 16 | RequestTimeout time.Duration 17 | } 18 | 19 | var DefaultOption = Option{ 20 | ProtocolType: protocol.Default, 21 | SerializeType: codec.MessagePack, 22 | CompressType: protocol.CompressTypeNone, 23 | TransportType: transport.TCPTransport, 24 | } 25 | -------------------------------------------------------------------------------- /codec/codec.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "github.com/vmihailenco/msgpack" 5 | ) 6 | 7 | type SerializeType byte 8 | 9 | const ( 10 | MessagePack SerializeType = iota 11 | ) 12 | 13 | var codecs = map[SerializeType]Codec{ 14 | MessagePack: &MessagePackCodec{}, 15 | } 16 | 17 | type Codec interface { 18 | Encode(value interface{}) ([]byte, error) 19 | Decode(data []byte, value interface{}) error 20 | } 21 | 22 | func GetCodec(t SerializeType) Codec { 23 | return codecs[t] 24 | } 25 | 26 | type MessagePackCodec struct{} 27 | 28 | func (c MessagePackCodec) Encode(v interface{}) ([]byte, error) { 29 | return msgpack.Marshal(v) 30 | } 31 | 32 | func (c MessagePackCodec) Decode(data []byte, v interface{}) error { 33 | return msgpack.Unmarshal(data, v) 34 | } 35 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/megaredfan/rpc-demo 2 | 3 | require github.com/vmihailenco/msgpack v4.0.2+incompatible 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/vmihailenco/msgpack v4.0.2+incompatible h1:6ujmmycMfB62Mwv2N4atpnf8CKLSzhgodqMenpELKIQ= 2 | github.com/vmihailenco/msgpack v4.0.2+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= 3 | -------------------------------------------------------------------------------- /main/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/megaredfan/rpc-demo/client" 8 | "github.com/megaredfan/rpc-demo/server" 9 | "log" 10 | "math/rand" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | func main() { 16 | 17 | s := server.NewSimpleServer(server.DefaultOption) 18 | err := s.Register(Arith{}, make(map[string]string)) 19 | if err != nil { 20 | panic(err) 21 | } 22 | go func() { 23 | err = s.Serve("tcp", ":8888") 24 | if err != nil { 25 | panic(err) 26 | } 27 | }() 28 | 29 | time.Sleep(1e9) 30 | 31 | wg := new(sync.WaitGroup) 32 | wg.Add(100) 33 | for i := 0; i < 100; i++ { 34 | go func() { 35 | c, err := client.NewRPCClient("tcp", ":8888", client.DefaultOption) 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | args := Args{A: rand.Intn(200), B: rand.Intn(100)} 41 | reply := &Reply{} 42 | err = c.Call(context.TODO(), "Arith.Add", args, reply) 43 | if err != nil { 44 | panic(err) 45 | } 46 | if reply.C != args.A+args.B { 47 | log.Fatal(reply.C) 48 | } else { 49 | fmt.Println(reply.C) 50 | } 51 | 52 | err = c.Call(context.TODO(), "Arith.Minus", args, reply) 53 | if err != nil { 54 | panic(err) 55 | } 56 | if reply.C != args.A-args.B { 57 | log.Fatal(reply.C) 58 | } else { 59 | fmt.Println(reply.C) 60 | } 61 | 62 | err = c.Call(context.TODO(), "Arith.Mul", args, reply) 63 | if err != nil { 64 | panic(err) 65 | } 66 | if reply.C != args.A*args.B { 67 | log.Fatal(reply.C) 68 | } else { 69 | fmt.Println(reply.C) 70 | } 71 | 72 | err = c.Call(context.TODO(), "Arith.Divide", args, reply) 73 | if err != nil { 74 | log.Println(err) 75 | 76 | } 77 | if err != nil && err.Error() == "divided by 0" { 78 | log.Println(err) 79 | } else if reply.C != args.A/args.B { 80 | log.Fatal(reply.C) 81 | } else { 82 | fmt.Println(reply.C) 83 | } 84 | wg.Done() 85 | }() 86 | } 87 | wg.Wait() 88 | } 89 | 90 | type Arith struct{} 91 | 92 | type Args struct { 93 | A int 94 | B int 95 | } 96 | 97 | type Reply struct { 98 | C int 99 | } 100 | 101 | //arg可以是指针类型,也可以是指针类型 102 | func (a Arith) Add(ctx context.Context, arg *Args, reply *Reply) error { 103 | reply.C = arg.A + arg.B 104 | return nil 105 | } 106 | 107 | func (a Arith) Minus(ctx context.Context, arg Args, reply *Reply) error { 108 | reply.C = arg.A - arg.B 109 | return nil 110 | } 111 | 112 | func (a Arith) Mul(ctx context.Context, arg Args, reply *Reply) error { 113 | reply.C = arg.A * arg.B 114 | return nil 115 | } 116 | 117 | func (a Arith) Divide(ctx context.Context, arg *Args, reply *Reply) error { 118 | if arg.B == 0 { 119 | return errors.New("divided by 0") 120 | } 121 | reply.C = arg.A / arg.B 122 | return nil 123 | } 124 | -------------------------------------------------------------------------------- /protocol/protocol.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "github.com/megaredfan/rpc-demo/codec" 7 | "github.com/vmihailenco/msgpack" 8 | "io" 9 | ) 10 | 11 | //------------------------------------------------------------------------------------------------- 12 | //|2byte|1byte |4byte |4byte | header length |(total length - header length - 4byte)| 13 | //------------------------------------------------------------------------------------------------- 14 | //|magic|version|total length|header length| header | body | 15 | //------------------------------------------------------------------------------------------------- 16 | 17 | type MessageType byte 18 | 19 | //请求类型 20 | const ( 21 | MessageTypeRequest MessageType = iota 22 | MessageTypeResponse 23 | ) 24 | 25 | type CompressType byte 26 | 27 | const ( 28 | CompressTypeNone CompressType = iota 29 | ) 30 | 31 | type StatusCode byte 32 | 33 | const ( 34 | StatusOK StatusCode = iota 35 | StatusError 36 | ) 37 | 38 | type ProtocolType byte 39 | 40 | const ( 41 | Default ProtocolType = iota 42 | ) 43 | 44 | type Protocol interface { 45 | NewMessage() *Message 46 | DecodeMessage(r io.Reader) (*Message, error) 47 | EncodeMessage(message *Message) []byte 48 | } 49 | 50 | var protocols = map[ProtocolType]Protocol{ 51 | Default: &RPCProtocol{}, 52 | } 53 | 54 | const ( 55 | RequestSeqKey = "rpc_request_seq" 56 | RequestTimeoutKey = "rpc_request_timeout" 57 | MetaDataKey = "rpc_meta_data" 58 | ) 59 | 60 | type Header struct { 61 | Seq uint64 //序号, 用来唯一标识请求或响应 62 | MessageType MessageType //消息类型,用来标识一个消息是请求还是响应 63 | CompressType CompressType //压缩类型,用来标识一个消息的压缩方式 64 | SerializeType codec.SerializeType //序列化类型,用来标识消息体采用的编码方式 65 | StatusCode StatusCode //状态类型,用来标识一个请求是正常还是异常 66 | ServiceName string //服务名 67 | MethodName string //方法名 68 | Error string //方法调用发生的异常 69 | MetaData map[string]string //其他元数据 70 | } 71 | 72 | func NewMessage(t ProtocolType) *Message { 73 | return protocols[t].NewMessage() 74 | } 75 | 76 | func DecodeMessage(t ProtocolType, r io.Reader) (*Message, error) { 77 | return protocols[t].DecodeMessage(r) 78 | } 79 | 80 | func EncodeMessage(t ProtocolType, m *Message) []byte { 81 | return protocols[t].EncodeMessage(m) 82 | } 83 | 84 | type Message struct { 85 | *Header 86 | Data []byte 87 | } 88 | 89 | func (m Message) Clone() *Message { 90 | header := *m.Header 91 | c := new(Message) 92 | c.Header = &header 93 | c.Data = m.Data 94 | return c 95 | } 96 | 97 | type RPCProtocol struct { 98 | } 99 | 100 | func (RPCProtocol) NewMessage() *Message { 101 | return &Message{Header: &Header{}} 102 | } 103 | 104 | func (RPCProtocol) DecodeMessage(r io.Reader) (msg *Message, err error) { 105 | first3bytes := make([]byte, 3) 106 | _, err = io.ReadFull(r, first3bytes) 107 | if err != nil { 108 | return 109 | } 110 | if !checkMagic(first3bytes[:2]) { 111 | err = errors.New("wrong protocol") 112 | return 113 | } 114 | totalLenBytes := make([]byte, 4) 115 | _, err = io.ReadFull(r, totalLenBytes) 116 | if err != nil { 117 | return 118 | } 119 | totalLen := int(binary.BigEndian.Uint32(totalLenBytes)) 120 | if totalLen < 4 { 121 | err = errors.New("invalid total length") 122 | return 123 | } 124 | data := make([]byte, totalLen) 125 | _, err = io.ReadFull(r, data) 126 | headerLen := int(binary.BigEndian.Uint32(data[:4])) 127 | headerBytes := data[4 : headerLen+4] 128 | header := &Header{} 129 | err = msgpack.Unmarshal(headerBytes, header) 130 | if err != nil { 131 | return 132 | } 133 | msg = new(Message) 134 | msg.Header = header 135 | msg.Data = data[headerLen+4:] 136 | return 137 | } 138 | 139 | func (RPCProtocol) EncodeMessage(msg *Message) []byte { 140 | first3bytes := []byte{0xab, 0xba, 0x00} 141 | headerBytes, _ := msgpack.Marshal(msg.Header) 142 | 143 | totalLen := 4 + len(headerBytes) + len(msg.Data) 144 | totalLenBytes := make([]byte, 4) 145 | binary.BigEndian.PutUint32(totalLenBytes, uint32(totalLen)) 146 | 147 | data := make([]byte, totalLen+7) 148 | start := 0 149 | copyFullWithOffset(data, first3bytes, &start) 150 | copyFullWithOffset(data, totalLenBytes, &start) 151 | 152 | headerLenBytes := make([]byte, 4) 153 | binary.BigEndian.PutUint32(headerLenBytes, uint32(len(headerBytes))) 154 | copyFullWithOffset(data, headerLenBytes, &start) 155 | copyFullWithOffset(data, headerBytes, &start) 156 | copyFullWithOffset(data, msg.Data, &start) 157 | return data 158 | } 159 | 160 | func checkMagic(bytes []byte) bool { 161 | return bytes[0] == 0xab && bytes[1] == 0xba 162 | } 163 | 164 | func copyFullWithOffset(dst []byte, src []byte, start *int) { 165 | copy(dst[*start:*start+len(src)], src) 166 | *start = *start + len(src) 167 | } 168 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/megaredfan/rpc-demo/codec" 7 | "github.com/megaredfan/rpc-demo/protocol" 8 | "github.com/megaredfan/rpc-demo/transport" 9 | "io" 10 | "log" 11 | "reflect" 12 | "strings" 13 | "sync" 14 | "unicode" 15 | "unicode/utf8" 16 | ) 17 | 18 | type RPCServer interface { 19 | Register(rcvr interface{}, metaData map[string]string) error 20 | Serve(network string, addr string) error 21 | Close() error 22 | } 23 | 24 | type simpleServer struct { 25 | codec codec.Codec 26 | serviceMap sync.Map 27 | tr transport.ServerTransport 28 | mutex sync.Mutex 29 | shutdown bool 30 | 31 | option Option 32 | } 33 | 34 | type methodType struct { 35 | method reflect.Method 36 | ArgType reflect.Type 37 | ReplyType reflect.Type 38 | } 39 | 40 | type service struct { 41 | name string 42 | typ reflect.Type 43 | rcvr reflect.Value 44 | methods map[string]*methodType 45 | } 46 | 47 | func NewSimpleServer(option Option) RPCServer { 48 | s := new(simpleServer) 49 | s.option = option 50 | s.codec = codec.GetCodec(option.SerializeType) 51 | return s 52 | } 53 | 54 | func (s *simpleServer) Register(rcvr interface{}, metaData map[string]string) error { 55 | typ := reflect.TypeOf(rcvr) 56 | name := typ.Name() 57 | srv := new(service) 58 | srv.name = name 59 | srv.rcvr = reflect.ValueOf(rcvr) 60 | srv.typ = typ 61 | methods := suitableMethods(typ, true) 62 | srv.methods = methods 63 | 64 | if len(srv.methods) == 0 { 65 | var errorStr string 66 | 67 | // 如果对应的类型没有任何符合规则的方法,扫描对应的指针类型 68 | // 也是从net.rpc包里抄来的 69 | method := suitableMethods(reflect.PtrTo(srv.typ), false) 70 | if len(method) != 0 { 71 | errorStr = "rpcx.Register: type " + name + " has no exported methods of suitable type (hint: pass a pointer to value of that type)" 72 | } else { 73 | errorStr = "rpcx.Register: type " + name + " has no exported methods of suitable type" 74 | } 75 | log.Println(errorStr) 76 | return errors.New(errorStr) 77 | } 78 | if _, duplicate := s.serviceMap.LoadOrStore(name, srv); duplicate { 79 | return errors.New("rpc: service already defined: " + name) 80 | } 81 | return nil 82 | } 83 | 84 | // Precompute the reflect type for error. Can't use error directly 85 | // because Typeof takes an empty interface value. This is annoying. 86 | var typeOfError = reflect.TypeOf((*error)(nil)).Elem() 87 | var typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem() 88 | 89 | //过滤符合规则的方法,从net.rpc包抄的 90 | func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { 91 | methods := make(map[string]*methodType) 92 | for m := 0; m < typ.NumMethod(); m++ { 93 | method := typ.Method(m) 94 | mtype := method.Type 95 | mname := method.Name 96 | 97 | // 方法必须是可导出的 98 | if method.PkgPath != "" { 99 | continue 100 | } 101 | // 需要有四个参数: receiver, Context, args, *reply. 102 | if mtype.NumIn() != 4 { 103 | if reportErr { 104 | log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) 105 | } 106 | continue 107 | } 108 | // 第一个参数必须是context.Context 109 | ctxType := mtype.In(1) 110 | if !ctxType.Implements(typeOfContext) { 111 | if reportErr { 112 | log.Println("method", mname, " must use context.Context as the first parameter") 113 | } 114 | continue 115 | } 116 | 117 | // 第二个参数是arg 118 | argType := mtype.In(2) 119 | if !isExportedOrBuiltinType(argType) { 120 | if reportErr { 121 | log.Println(mname, "parameter type not exported:", argType) 122 | } 123 | continue 124 | } 125 | // 第三个参数是返回值,必须是指针类型的 126 | replyType := mtype.In(3) 127 | if replyType.Kind() != reflect.Ptr { 128 | if reportErr { 129 | log.Println("method", mname, "reply type not a pointer:", replyType) 130 | } 131 | continue 132 | } 133 | // 返回值的类型必须是可导出的 134 | if !isExportedOrBuiltinType(replyType) { 135 | if reportErr { 136 | log.Println("method", mname, "reply type not exported:", replyType) 137 | } 138 | continue 139 | } 140 | // 必须有一个返回值 141 | if mtype.NumOut() != 1 { 142 | if reportErr { 143 | log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) 144 | } 145 | continue 146 | } 147 | // 返回值类型必须是error 148 | if returnType := mtype.Out(0); returnType != typeOfError { 149 | if reportErr { 150 | log.Println("method", mname, "returns", returnType.String(), "not error") 151 | } 152 | continue 153 | } 154 | methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} 155 | } 156 | return methods 157 | } 158 | 159 | // Is this type exported or a builtin? 160 | func isExportedOrBuiltinType(t reflect.Type) bool { 161 | for t.Kind() == reflect.Ptr { 162 | t = t.Elem() 163 | } 164 | // PkgPath will be non-empty even for an exported type, 165 | // so we need to check the type name as well. 166 | return isExported(t.Name()) || t.PkgPath() == "" 167 | } 168 | 169 | // Is this an exported - upper case - name? 170 | func isExported(name string) bool { 171 | rune, _ := utf8.DecodeRuneInString(name) 172 | return unicode.IsUpper(rune) 173 | } 174 | 175 | func (s *simpleServer) Serve(network string, addr string) error { 176 | s.tr = transport.NewServerTransport(s.option.TransportType) 177 | err := s.tr.Listen(network, addr) 178 | if err != nil { 179 | log.Println(err) 180 | return err 181 | } 182 | for { 183 | conn, err := s.tr.Accept() 184 | if err != nil { 185 | log.Println(err) 186 | return err 187 | } 188 | go s.serveTransport(conn) 189 | } 190 | 191 | } 192 | 193 | func (s *simpleServer) Close() error { 194 | s.mutex.Lock() 195 | defer s.mutex.Unlock() 196 | s.shutdown = true 197 | 198 | err := s.tr.Close() 199 | 200 | s.serviceMap.Range(func(key, value interface{}) bool { 201 | s.serviceMap.Delete(key) 202 | return true 203 | }) 204 | return err 205 | } 206 | 207 | type Request struct { 208 | Seq uint32 209 | Reply interface{} 210 | Data []byte 211 | } 212 | 213 | func (s *simpleServer) serveTransport(tr transport.Transport) { 214 | for { 215 | request, err := protocol.DecodeMessage(s.option.ProtocolType, tr) 216 | 217 | if err != nil { 218 | if err == io.EOF { 219 | log.Printf("client has closed this connection: %s", tr.RemoteAddr().String()) 220 | } else if strings.Contains(err.Error(), "use of closed network connection") { 221 | log.Printf("rpcx: connection %s is closed", tr.RemoteAddr().String()) 222 | } else { 223 | log.Printf("rpcx: failed to read request: %v", err) 224 | } 225 | return 226 | } 227 | response := request.Clone() 228 | response.MessageType = protocol.MessageTypeResponse 229 | 230 | sname := request.ServiceName 231 | mname := request.MethodName 232 | srvInterface, ok := s.serviceMap.Load(sname) 233 | if !ok { 234 | s.writeErrorResponse(response, tr, "can not find service") 235 | return 236 | } 237 | srv, ok := srvInterface.(*service) 238 | if !ok { 239 | s.writeErrorResponse(response, tr, "not *service type") 240 | return 241 | 242 | } 243 | 244 | mtype, ok := srv.methods[mname] 245 | if !ok { 246 | s.writeErrorResponse(response, tr, "can not find method") 247 | return 248 | } 249 | argv := newValue(mtype.ArgType) 250 | replyv := newValue(mtype.ReplyType) 251 | 252 | ctx := context.Background() 253 | err = s.codec.Decode(request.Data, argv) 254 | 255 | var returns []reflect.Value 256 | if mtype.ArgType.Kind() != reflect.Ptr { 257 | returns = mtype.method.Func.Call([]reflect.Value{srv.rcvr, 258 | reflect.ValueOf(ctx), 259 | reflect.ValueOf(argv).Elem(), 260 | reflect.ValueOf(replyv)}) 261 | } else { 262 | returns = mtype.method.Func.Call([]reflect.Value{srv.rcvr, 263 | reflect.ValueOf(ctx), 264 | reflect.ValueOf(argv), 265 | reflect.ValueOf(replyv)}) 266 | } 267 | if len(returns) > 0 && returns[0].Interface() != nil { 268 | err = returns[0].Interface().(error) 269 | s.writeErrorResponse(response, tr, err.Error()) 270 | return 271 | } 272 | 273 | responseData, err := codec.GetCodec(request.SerializeType).Encode(replyv) 274 | if err != nil { 275 | s.writeErrorResponse(response, tr, err.Error()) 276 | return 277 | } 278 | 279 | response.StatusCode = protocol.StatusOK 280 | response.Data = responseData 281 | 282 | _, err = tr.Write(protocol.EncodeMessage(s.option.ProtocolType, response)) 283 | if err != nil { 284 | log.Println(err) 285 | return 286 | } 287 | } 288 | } 289 | 290 | func newValue(t reflect.Type) interface{} { 291 | if t.Kind() == reflect.Ptr { 292 | return reflect.New(t.Elem()).Interface() 293 | } else { 294 | return reflect.New(t).Interface() 295 | } 296 | } 297 | 298 | func (s *simpleServer) writeErrorResponse(response *protocol.Message, w io.Writer, err string) { 299 | response.Error = err 300 | log.Println(response.Error) 301 | response.StatusCode = protocol.StatusError 302 | response.Data = response.Data[:0] 303 | _, _ = w.Write(protocol.EncodeMessage(s.option.ProtocolType, response)) 304 | } 305 | -------------------------------------------------------------------------------- /server/server_option.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/megaredfan/rpc-demo/codec" 5 | "github.com/megaredfan/rpc-demo/protocol" 6 | "github.com/megaredfan/rpc-demo/transport" 7 | ) 8 | 9 | type Option struct { 10 | ProtocolType protocol.ProtocolType 11 | SerializeType codec.SerializeType 12 | CompressType protocol.CompressType 13 | TransportType transport.TransportType 14 | } 15 | 16 | var DefaultOption = Option{ 17 | ProtocolType: protocol.Default, 18 | SerializeType: codec.MessagePack, 19 | CompressType: protocol.CompressTypeNone, 20 | TransportType: transport.TCPTransport, 21 | } 22 | -------------------------------------------------------------------------------- /transport/transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "io" 5 | "net" 6 | ) 7 | 8 | type transportMaker func() Transport 9 | 10 | type serverTransportMaker func() ServerTransport 11 | 12 | type TransportType byte 13 | 14 | const TCPTransport TransportType = iota 15 | 16 | var makeTransport = map[TransportType]transportMaker{ 17 | TCPTransport: func() Transport { 18 | return new(Socket) 19 | }, 20 | } 21 | 22 | var makeServerTransport = map[TransportType]serverTransportMaker{ 23 | TCPTransport: func() ServerTransport { 24 | return new(ServerSocket) 25 | }, 26 | } 27 | 28 | type Transport interface { 29 | Dial(network, addr string) error 30 | io.ReadWriteCloser 31 | RemoteAddr() net.Addr 32 | LocalAddr() net.Addr 33 | } 34 | 35 | type Socket struct { 36 | conn net.Conn 37 | } 38 | 39 | func NewTransport(t TransportType) Transport { 40 | return makeTransport[t]() 41 | } 42 | 43 | func (s *Socket) Dial(network, addr string) error { 44 | conn, err := net.Dial(network, addr) 45 | s.conn = conn 46 | return err 47 | } 48 | 49 | func (s *Socket) Read(p []byte) (n int, err error) { 50 | return s.conn.Read(p) 51 | } 52 | 53 | func (s *Socket) Write(p []byte) (n int, err error) { 54 | return s.conn.Write(p) 55 | } 56 | 57 | func (s *Socket) Close() error { 58 | return s.conn.Close() 59 | } 60 | 61 | func (s *Socket) RemoteAddr() net.Addr { 62 | return s.conn.RemoteAddr() 63 | } 64 | 65 | func (s Socket) LocalAddr() net.Addr { 66 | return s.conn.LocalAddr() 67 | } 68 | 69 | type ServerTransport interface { 70 | Listen(network, addr string) error 71 | Accept() (Transport, error) 72 | io.Closer 73 | } 74 | 75 | type ServerSocket struct { 76 | ln net.Listener 77 | } 78 | 79 | func NewServerTransport(t TransportType) ServerTransport { 80 | return makeServerTransport[t]() 81 | } 82 | 83 | func (s *ServerSocket) Listen(network, addr string) error { 84 | ln, err := net.Listen(network, addr) 85 | s.ln = ln 86 | return err 87 | } 88 | 89 | func (s *ServerSocket) Accept() (Transport, error) { 90 | conn, err := s.ln.Accept() 91 | return &Socket{conn: conn}, err 92 | } 93 | 94 | func (s *ServerSocket) Close() error { 95 | return s.ln.Close() 96 | } 97 | --------------------------------------------------------------------------------