├── __init__.py ├── proto ├── __init__.py ├── pbproto │ ├── .gitignore │ ├── Makefile │ ├── proto_dict.py │ ├── gen_proto.lua │ ├── __init__.py │ └── protobuf_to_dict.py ├── proto.py ├── rc4.py └── rpcservice.py ├── test ├── __init__.py ├── Makefile ├── test.sproto ├── config.py ├── rc4.py ├── server.py └── sprotoparser.lua ├── .gitignore ├── .gitmodules ├── Makefile ├── simulator.py ├── main.py ├── cfg.py ├── cmdclient.py ├── game.py ├── README.md └── command.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /proto/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.spb 3 | -------------------------------------------------------------------------------- /proto/pbproto/.gitignore: -------------------------------------------------------------------------------- 1 | pb 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "proto/sproto"] 2 | path = proto/sproto 3 | url = git@github.com:hqwrong/py-sproto.git 4 | -------------------------------------------------------------------------------- /test/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | @ lua -e 'p = require"sprotoparser"; io.write(p.parse(io.open("./test.sproto"):read("*a")))' > test.spb 3 | 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | cd proto/sproto/ && make 3 | cd test && make 4 | 5 | clean: 6 | find . -type d -exec rm -f *.pyc \; 7 | rm test/*.spb 8 | -------------------------------------------------------------------------------- /proto/proto.py: -------------------------------------------------------------------------------- 1 | # /usr/bin/python2 2 | # -*- coding: utf-8 -*- 3 | 4 | from cfg import * 5 | 6 | proto = None 7 | 8 | 9 | if Config["proto"] == "protobuf": 10 | from .pbproto import PbRpc 11 | proto = PbRpc(Config["proto_path"]) 12 | else: 13 | from .sproto.sproto import SprotoRpc 14 | 15 | with open(Config["proto_path"][0], "rb") as f: 16 | _c2s_chunk = f.read() 17 | with open(Config["proto_path"][1], "rb") as f: 18 | _s2c_chunk = f.read() 19 | 20 | proto = SprotoRpc(_c2s_chunk, _s2c_chunk, Config["proto_header"]) 21 | 22 | -------------------------------------------------------------------------------- /simulator.py: -------------------------------------------------------------------------------- 1 | import gevent 2 | import command 3 | from game import Game 4 | 5 | def _print(): 6 | print "hello" 7 | 8 | class Simulator(object): 9 | def __init__(self, srv_addr): 10 | self.srv_addr = srv_addr 11 | self.workers = [] 12 | 13 | def run(self, actorlist): 14 | for actor in actorlist: 15 | assert command.has_cmd(actor["cmd"]),actor["cmd"] 16 | for _ in xrange(actor.get("count", 1)): 17 | self.workers.append(gevent.spawn(command.do_cmd, Game(self.srv_addr), actor["cmd"], actor["args"])) 18 | 19 | gevent.wait(self.workers) 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import gevent 3 | from gevent import monkey 4 | monkey.patch_all() 5 | monkey.patch_sys() 6 | import sys,signal,logging,argparse 7 | from cfg import * 8 | parse_args() 9 | 10 | from cmdclient import Client 11 | from simulator import Simulator 12 | 13 | def main(): 14 | srv_addr = (Config["server"], Config["port"]) 15 | print "connect to", srv_addr 16 | if Config["mode"] == "simulator": 17 | sim = Simulator(srv_addr) 18 | sim.run(Config["run"]) 19 | else: 20 | client = Client(srv_addr) 21 | client.interact() 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /proto/pbproto/Makefile: -------------------------------------------------------------------------------- 1 | GEN_DIR=pb 2 | PROTO_DIR=../../common/proto 3 | 4 | NET_RPC_CS=$(PROTO_DIR)/client.protolist 5 | 6 | NET_RPC_CS_HRL=$(GEN_DIR)/client_rpc.py 7 | 8 | RPCC=lua gen_proto.lua 9 | 10 | .PHONY: all proto protolist clean 11 | 12 | all: proto protolist 13 | 14 | clean: 15 | -rm -r $(GEN_DIR) 16 | find . -name "*.pyc" -exec rm {} \; 17 | 18 | $(NET_RPC_CS_HRL):$(NET_RPC_CS) 19 | $(RPCC) $< $@ client 20 | 21 | proto: 22 | -mkdir -p $(GEN_DIR) 23 | protoc --python_out=$(GEN_DIR) -I$(PROTO_DIR) `find $(PROTO_DIR) -name "*.proto"` 24 | find $(GEN_DIR) -type d -exec touch {}/__init__.py \; 25 | 26 | protolist: $(NET_RPC_CS_HRL) 27 | 28 | -------------------------------------------------------------------------------- /test/test.sproto: -------------------------------------------------------------------------------- 1 | .header { 2 | type 0 : integer 3 | session 1 : integer 4 | } 5 | 6 | echo 1 { 7 | request { 8 | msg 0 : string 9 | } 10 | response { 11 | msg 0 : string 12 | } 13 | } 14 | 15 | addone 2 { 16 | request { 17 | i 0 : integer 18 | } 19 | } 20 | 21 | notify_addone 3 { 22 | request { 23 | i 0 :integer 24 | } 25 | } 26 | 27 | login 4 { 28 | request { 29 | account 0 : string 30 | } 31 | response { 32 | prompt 0 : string 33 | } 34 | } 35 | 36 | addlist 5 { 37 | request { 38 | l 0 : *integer 39 | } 40 | response { 41 | answer 0 : integer 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import argparse,sys 2 | 3 | Config = {} 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description = "New game robot") 7 | parser.add_argument('-s', '--server', dest = 'host', 8 | help = 'game server ip address') 9 | parser.add_argument('-p', '--port', type=int, dest = 'port', 10 | help = 'game server tcp port') 11 | parser.add_argument('-u', '--uid', dest = 'uid', help="login as UID") 12 | parser.add_argument("-m", '--mode', dest = "mode", choices = ["simulator", "client"]) 13 | parser.add_argument(dest="config", help = "config file") 14 | args = parser.parse_args() 15 | if args.config: 16 | execfile(args.config, Config) 17 | del args.config 18 | 19 | for k,v in vars(args).iteritems(): 20 | if v: 21 | Config[k] = v 22 | 23 | sys.path.insert(0, Config["proto_path"]) 24 | 25 | -------------------------------------------------------------------------------- /test/config.py: -------------------------------------------------------------------------------- 1 | ############# game server address 2 | server = "127.0.0.1" 3 | port = 8251 4 | 5 | ############# stream encrypt 6 | encrypt = "rc4" 7 | c2s_key = "C2S_RC4" 8 | s2c_key = "S2C_RC4" 9 | 10 | ############# config proto 11 | 12 | ## protobuf 13 | # proto = "protobuf" # protobuf/sproto 14 | # proto_path = "./data/pb" # ./data/pb or ./data/sp 15 | 16 | ## sproto 17 | proto = "sproto" 18 | proto_path = ["./test/test.spb", "./test/test.spb"] 19 | proto_header = "header" 20 | 21 | ############ 22 | 23 | ########### config mode 24 | 25 | ## simulator mode 26 | 27 | ## client mode 28 | mode = "client" 29 | client_prompt = "> " 30 | 31 | # mode = "simulator" # client / simulator 32 | # run = [ 33 | # { 34 | # "cmd" : "echo", 35 | # "args" : ["hello world!"], 36 | # "count" : 5, # how many clients to launch to run this command 37 | # }, 38 | # { 39 | # "cmd": "addlist", 40 | # "args" : [[1,2,3,4,5]], 41 | # "count" : 3, 42 | # } 43 | # ] 44 | 45 | ############ 46 | -------------------------------------------------------------------------------- /test/rc4.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | class RC4(object): 4 | ''' 5 | key is a list or tuple 6 | ''' 7 | def __init__(self, key): 8 | self.index1 = 0 9 | self.index2 = 0 10 | 11 | perm = range(256) 12 | j = 0 13 | klen = len(key) 14 | for i in range(256): 15 | j = (j + perm[i] + ord(key[i % klen])) % 256 16 | perm[i], perm[j] = perm[j], perm[i] # swap 17 | self.perm = perm 18 | 19 | ''' 20 | * Encrypt some data using the supplied RC4 state buffer. 21 | * The input and output buffers may be the same buffer. 22 | * Since RC4 is a stream cypher, this function is used 23 | * for both encryption and decryption. 24 | ''' 25 | def crypt(self, data): 26 | dlen = len(data) 27 | out = "" 28 | perm = self.perm 29 | for i in range(dlen): 30 | self.index1 = (self.index1 + 1) % 256 31 | self.index2 = (self.index2 + perm[self.index1]) % 256 32 | perm[self.index1], perm[self.index2] = perm[self.index2], perm[self.index1] 33 | 34 | j = (perm[self.index1] + perm[self.index2]) % 256 35 | out = out + chr((ord(data[i]) ^ perm[j])) 36 | return out 37 | 38 | -------------------------------------------------------------------------------- /proto/rc4.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | class RC4(object): 4 | ''' 5 | key is a list or tuple 6 | ''' 7 | def __init__(self, key): 8 | self.index1 = 0 9 | self.index2 = 0 10 | 11 | perm = range(256) 12 | j = 0 13 | klen = len(key) 14 | for i in range(256): 15 | j = (j + perm[i] + ord(key[i % klen])) % 256 16 | perm[i], perm[j] = perm[j], perm[i] # swap 17 | self.perm = perm 18 | 19 | ''' 20 | * Encrypt some data using the supplied RC4 state buffer. 21 | * The input and output buffers may be the same buffer. 22 | * Since RC4 is a stream cypher, this function is used 23 | * for both encryption and decryption. 24 | ''' 25 | def crypt(self, data): 26 | dlen = len(data) 27 | out = "" 28 | perm = self.perm 29 | for i in range(dlen): 30 | self.index1 = (self.index1 + 1) % 256 31 | self.index2 = (self.index2 + perm[self.index1]) % 256 32 | perm[self.index1], perm[self.index2] = perm[self.index2], perm[self.index1] 33 | 34 | j = (perm[self.index1] + perm[self.index2]) % 256 35 | out = out + chr((ord(data[i]) ^ perm[j])) 36 | return out 37 | 38 | -------------------------------------------------------------------------------- /proto/pbproto/proto_dict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | 4 | _PROTO_DICT = {} 5 | 6 | def load_dir(topdir): 7 | def visit(arg, dirname, names): 8 | for i in names: 9 | filepath = os.path.join(dirname, i) 10 | if os.path.isfile(filepath) and filepath.endswith('.py'): 11 | try: 12 | load_file(filepath[len(topdir) + 1:]) 13 | except AttributeError: 14 | pass 15 | os.path.walk(topdir, visit, None) 16 | 17 | def load_message(desc, cls): 18 | descs = desc.nested_types_by_name 19 | for k in descs: 20 | desc = descs[k] 21 | if hasattr(cls, k): 22 | nested_cls = getattr(cls, k) 23 | _PROTO_DICT[desc.full_name] = nested_cls 24 | load_message(desc, nested_cls) 25 | 26 | def load_file(filename): 27 | root, ext = os.path.splitext(filename) 28 | module_name = root.replace(os.path.sep, '.') 29 | module = __import__(module_name, fromlist=module_name.split('.'), level=0) 30 | descs = module.DESCRIPTOR.message_types_by_name 31 | for k in descs: 32 | desc = descs[k] 33 | if hasattr(module, k): 34 | cls = getattr(module, k) 35 | _PROTO_DICT[desc.full_name] = cls 36 | load_message(desc, cls) 37 | 38 | def get(key): 39 | if key: 40 | return _PROTO_DICT[key] 41 | return None 42 | -------------------------------------------------------------------------------- /proto/pbproto/gen_proto.lua: -------------------------------------------------------------------------------- 1 | local plist = require "protolist" 2 | local filename = assert(arg[1]) 3 | 4 | local function readfile(fullname) 5 | local f = assert(io.open(fullname , "r")) 6 | local buffer = f:read "*a" 7 | f:close() 8 | return buffer 9 | end 10 | 11 | local cs 12 | if arg[3] == "client" then 13 | cs = 1 14 | end 15 | 16 | local proto_list = {} 17 | plist.parser(readfile(filename), proto_list, cs) 18 | 19 | local csfile, cserror = io.open( arg[ 2 ], "wb" ) 20 | assert( csfile, cserror ) 21 | 22 | csfile:write( [[# Generated By gen_proto.lua Do not Edit 23 | Descriptor = { 24 | ]]) 25 | 26 | local classes = {} 27 | for id, tab in pairs( proto_list ) do 28 | if( tonumber(id) ) then 29 | if not classes[tab.class] then 30 | classes[tab.class] = {} 31 | end 32 | table.insert(classes[tab.class], {id, tab}) 33 | end 34 | end 35 | 36 | for name, class in pairs(classes) do 37 | csfile:write(string.format(' "%s": [\n', name)) 38 | for _, v in ipairs(class) do 39 | id, tab = v[1], v[2] 40 | csfile:write( 41 | string.format( 42 | ' {"id": %s, "name": "%s", "input": "%s", "output": %s},\n', 43 | id, tab.normal_name, tab.input, tab.output and '"' .. tab.output .. '"' or "None") ) 44 | end 45 | csfile:write(" ],\n") 46 | end 47 | 48 | csfile:write("}") 49 | csfile:close() 50 | 51 | -------------------------------------------------------------------------------- /cmdclient.py: -------------------------------------------------------------------------------- 1 | import sys,traceback 2 | from game import Game 3 | from command import do_cmdstr,find_cmd 4 | 5 | import pprint 6 | 7 | from cfg import Config 8 | 9 | 10 | class Client(object): 11 | def __init__(self, srv_addr): 12 | self.game = Game(srv_addr) 13 | 14 | def docmd(self, tokens): 15 | cmdname = tokens[0] 16 | ok,similars = find_cmd(cmdname) 17 | if not ok: 18 | print "cmd not found:", cmdname 19 | print "Did you mean this?" 20 | print "\t", ", ".join(similars) 21 | return 22 | result = do_cmdstr(self.game, cmdname, tokens[1] if len(tokens) > 1 else "") 23 | if result != None: 24 | pprint.pprint(result) 25 | 26 | def interact(self): 27 | if "uid" in Config: 28 | self.docmd(["login", Config["uid"]]) 29 | while True: 30 | try: 31 | sys.stdout.write(Config["client_prompt"]) 32 | sys.stdout.flush() 33 | l = sys.stdin.read(1) 34 | if not l: # eof 35 | print "exit on EOF" 36 | exit(0) 37 | if l == '\n': 38 | continue 39 | l += sys.stdin.readline() 40 | l.strip() 41 | tokens = l.split(None, 1) 42 | if not tokens: 43 | continue 44 | self.docmd(tokens) 45 | 46 | except Exception as e: 47 | print "error occured", e, traceback.format_exc() 48 | continue 49 | -------------------------------------------------------------------------------- /game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python2 2 | # -*- coding: utf-8 -*- 3 | 4 | from proto.rpcservice import RpcService 5 | from google.protobuf import text_format 6 | import md5, sys, traceback 7 | import gevent 8 | 9 | from command import * 10 | from cfg import Config 11 | 12 | def _get_secret(uid, token, salt): 13 | m = md5.new() 14 | m.update(str(uid)) 15 | m.update(token) 16 | m.update(salt) 17 | return m.hexdigest() 18 | 19 | class Game(object): 20 | def __init__(self, addr): 21 | self.srv = RpcService(addr, self) 22 | self.is_login = False 23 | self.uid = None 24 | self.token = None 25 | 26 | self.srv._start() 27 | 28 | def _dft_handle(self, *args): 29 | print args 30 | 31 | def call(self, protoname, msg = {}): 32 | return self.srv.call(protoname, msg) 33 | 34 | def send(self, protoname, msg = {}): 35 | return self.srv.invoke(protoname, msg) 36 | 37 | ################################ Test ######################################### 38 | @addcmd() 39 | def help(self, cmdname = ""): 40 | ok,similars = find_cmd(cmdname) 41 | cmds = [cmdname] if ok else similars 42 | for cmdname in cmds: 43 | cmdname,args = inspect_cmd(cmdname) 44 | print cmdname,args 45 | 46 | @addcmd() 47 | def login(self, user): 48 | '''for sproto test login ''' 49 | resp = self.call("login", {"account": user}) 50 | Config["client_prompt"] = "[%s] > " % user 51 | print(resp["prompt"]) 52 | 53 | @addcmd() 54 | def echo(self, msg): 55 | '''for sproto echo test ''' 56 | resp = self.call("echo", {"msg" : msg}) 57 | return resp 58 | 59 | @addcmd() 60 | def addone(self, i): 61 | self.send("addone", {"i": i}) 62 | 63 | @addcmd() 64 | def addlist(self, l): 65 | resp = self.call("addlist", {"l": l}) 66 | print "answer:", resp["answer"] 67 | 68 | @addcmd() 69 | @addhandle("notify_addone") 70 | def notify_addone(self, i): 71 | print "addone result:", i 72 | 73 | ################################ Test End ######################################### 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # game-robot 2 | 3 | Simple and Easy-to-Use robot client for game server. 4 | 5 | ## Features 6 | 7 | 1. Provide high level synchronous api based on [gevent](http://gevent.org/) 8 | 9 | 2. Multiple protos support. [sproto](https://github.com/cloudwu/sproto) and [protobuf](https://github.com/google/protobuf) so far. 10 | 11 | 3. Two modes support. `client` mode for interaction on command line, `simulator` mode for async running which is often used in pressure test. 12 | 13 | 3. Simple and Easy-to-use api. Mostly only one line needed to add a command or request handler. 14 | 15 | 4. Simple but powerful command client, features passing in python types(string, int, list and dict, e.t.c), and Auto-Completions builtin. 16 | 17 | 5. Stream encrypt support built in. 18 | 19 | ## Setup && Launch 20 | 21 | Run 22 | 23 | git submodule update --init --recursive 24 | 25 | after first clone. 26 | 27 | Run 28 | 29 | make 30 | 31 | to build. 32 | 33 | You need a config file to launch. see `test/config.py` for detail. 34 | 35 | ## Game logic 36 | Please put your logic functions into `game.py`. 37 | 38 | use `@addcmd(ARGS, COMMAND_NAME)` decorator to add client command. 39 | 40 | use `@addhandle(protoname)` decorator to add handle function for game server's request. 41 | 42 | Check `test/config.py` to see how to config client. 43 | 44 | ## Test 45 | 46 | ### test cmdclient 47 | To launch test server: 48 | 49 | python -m test/server sproto 50 | 51 | To launch client: 52 | 53 | python main.py test/config.py 54 | 55 | Now you've got a client to play with: 56 | 57 | $ python2 main.py test/config.py 58 | connect to ('127.0.0.1', 8251) 59 | > help 60 | help ['cmdname'] 61 | notify_addone ['i'] 62 | addone ['i'] 63 | echo ['msg'] 64 | login ['user'] 65 | addlist ['l'] 66 | > login "foobar" 67 | hello, foobar 68 | [foobar] > addlist [1,3,5,7] 69 | answer: 16 70 | [foobar] > echo "hello" 71 | {'msg': 'hello'} 72 | 73 | ### test simulator 74 | To launch test server: 75 | 76 | python -m test/server sproto 77 | 78 | Uncomment following block in test/config.py: 79 | 80 | mode = "simulator" # client / simulator 81 | run = [ 82 | { 83 | "cmd" : "echo", 84 | "args" : ["hello world!"], 85 | "count" : 5, # how many clients to launch to run this command 86 | }, 87 | { 88 | "cmd": "addlist", 89 | "args" : [ [1,2,3,4,5] ], 90 | "count" : 3, 91 | } 92 | ] 93 | 94 | Run client: 95 | 96 | python main.py test/config.py 97 | -------------------------------------------------------------------------------- /command.py: -------------------------------------------------------------------------------- 1 | import sys, os, importlib 2 | import inspect 3 | 4 | cmddir = "cmd" 5 | 6 | commands = {} 7 | handles = {} 8 | 9 | def _splitargstr(argstr): 10 | argstr += " " 11 | in_string = None 12 | j = 0 13 | delimiter = [] 14 | args = [] 15 | for i in xrange(0, len(argstr)): 16 | if not delimiter and not in_string and argstr[i] in "\t\n ": 17 | tok = argstr[j:i+1].strip() 18 | if tok: 19 | args.append(tok) 20 | j = i 21 | elif in_string: 22 | if argstr[i] == in_string: 23 | in_string = None 24 | elif argstr[i] == "'" or argstr[i] == '"': 25 | in_string = argstr[i] 26 | elif argstr[i] == "[": 27 | delimiter.append("]") 28 | elif argstr[i] == "{": 29 | delimiter.append("}") 30 | elif argstr[i] == "]" or argstr[i] == "}": 31 | if not delimiter or delimiter[-1] != argstr[i]: 32 | raise NameError("syntax error at %dth char of [%s]"%(i, argstr)) 33 | delimiter.pop() 34 | 35 | if delimiter or in_string: 36 | raise NameError("syntax error at end of [%s]" % argstr) 37 | 38 | return args 39 | 40 | def _parseargs(argstr): 41 | args = _splitargstr(argstr) 42 | for i in xrange(len(args)): 43 | args[i] = eval(args[i], {}, {}) 44 | return args 45 | 46 | def find_cmd(cmdname): 47 | if cmdname in commands: 48 | return True,[] 49 | 50 | similars = [] 51 | for n in commands: 52 | if n.startswith(cmdname): 53 | similars.append(n) 54 | 55 | if not similars: 56 | mindiff = 100000 57 | for n in commands: 58 | d = abs(len(n) - len(cmdname)) 59 | for i in range(min(len(n), len(cmdname))): 60 | if n[i] != cmdname[i]: 61 | d += 1 62 | if d < mindiff: 63 | similars = [n] 64 | mindiff = d 65 | elif d == mindiff: 66 | similars.append(n) 67 | 68 | return False, similars 69 | 70 | def inspect_cmd(cmdname): 71 | cmd = commands[cmdname] 72 | args = inspect.getargspec(cmd["handle"]).args 73 | del args[0] # remove `self' arg 74 | return cmdname, args 75 | 76 | def do_cmd(player, cmdname, args): 77 | cmd = commands[cmdname] 78 | if callable(cmd["handle"]): 79 | return cmd["handle"](player, *args) 80 | else: 81 | return getattr(player, cmd["handle"], player._dft_handle)(*args) 82 | 83 | def do_cmdstr(player, cmdname, argstr): 84 | cmd = commands[cmdname] 85 | args = _parseargs(argstr) 86 | return do_cmd(player, cmdname, args) 87 | 88 | def get_handle(protoname): 89 | return handles.get(protoname, None) 90 | 91 | 92 | 93 | ######################## decorators ============================= 94 | 95 | def addcmd(name = ""): 96 | def _decorator(f): 97 | realname = name or f.__name__ 98 | if realname in commands: 99 | raise NameError(realname) 100 | commands[realname] = {"name":realname, "handle":f} 101 | return f 102 | 103 | return _decorator 104 | 105 | def addhandle(protoname): 106 | def _decorator(f): 107 | if protoname in handles: 108 | raise NameError(protoname) 109 | handles[protoname] = f 110 | return f 111 | 112 | return _decorator 113 | 114 | def listcmd(): 115 | return [cmdname for cmdname in commands] 116 | 117 | -------------------------------------------------------------------------------- /test/server.py: -------------------------------------------------------------------------------- 1 | import gevent 2 | from gevent import monkey 3 | monkey.patch_all() 4 | import socket, struct, sys 5 | from proto.sproto.sproto import SprotoRpc 6 | import test.config as config 7 | from test.rc4 import RC4 8 | 9 | class Handler(object): 10 | @staticmethod 11 | def echo(server, msg): 12 | return msg 13 | 14 | @staticmethod 15 | def addone(server, msg): 16 | server.send("notify_addone", {"i":msg["i"]+1}) 17 | 18 | @staticmethod 19 | def login(server, msg): 20 | return {"prompt": "hello, %s"%msg["account"]} 21 | 22 | @staticmethod 23 | def addlist(server, msg): 24 | accum = 0 25 | for i in msg["l"]: 26 | accum += i 27 | return {"answer": accum} 28 | 29 | class GameServer(object): 30 | def __init__(self, conn, addr): 31 | self.conn = conn 32 | self.client_addr = addr 33 | self.c2s_encrypt = None 34 | self.s2c_encrypt = None 35 | 36 | with open(config.proto_path[0]) as f: 37 | c2s = f.read() 38 | with open(config.proto_path[1]) as f: 39 | s2c = f.read() 40 | 41 | self.proto = SprotoRpc(c2s, s2c, "header") 42 | 43 | self._init_encrypt() 44 | 45 | def _init_encrypt(self): 46 | entype = getattr(config, "encrypt", None) 47 | if entype == "rc4": 48 | self.c2s_encrypt = RC4(config.c2s_key).crypt 49 | self.s2c_encrypt = RC4(config.s2c_key).crypt 50 | elif entype == None: 51 | self.c2s_encrypt = None 52 | self.s2c_encrypt = None 53 | else: 54 | raise ValueError("not support %s encrypt"%entype) 55 | 56 | def _send(self, data): 57 | data = struct.pack("!H", len(data)) + data 58 | if self.s2c_encrypt: 59 | data = self.s2c_encrypt(data) 60 | self.conn.sendall(data) 61 | 62 | def _recv(self, sz): 63 | data = self.conn.recv(sz, socket.MSG_WAITALL) 64 | if self.c2s_encrypt: 65 | data = self.c2s_encrypt(data) 66 | return data 67 | 68 | def run(self): 69 | while True: 70 | header = self._recv(2) 71 | if not header: 72 | print "disconnected", self.client_addr 73 | break 74 | sz, = struct.unpack("!H", header) 75 | content = self._recv(sz) 76 | p = self.proto.dispatch(content) 77 | session = p.get("session", 0) 78 | msg = p["msg"] 79 | protoname = p["proto"] 80 | assert p["type"] == "REQUEST" 81 | print "request:", protoname, msg 82 | resp = getattr(Handler, protoname)(self, msg) 83 | if session: 84 | print "response", resp 85 | pack = self.proto.response(protoname, resp, session) 86 | self._send(pack) 87 | 88 | def send(self, protoname, msg): 89 | pack = self.proto.request(protoname, msg) 90 | self._send(pack) 91 | 92 | 93 | class Server(object): 94 | def __init__(self, addr, prototype): 95 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 96 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 97 | sock.bind(addr) 98 | sock.listen(3) 99 | self.sock = sock 100 | self.conn = None 101 | 102 | def run(self): 103 | while True: 104 | conn, addr = self.sock.accept() 105 | gs = GameServer(conn, addr) 106 | gevent.spawn(gs.run) 107 | 108 | def main(): 109 | port = config.port 110 | if len(sys.argv) < 2: 111 | print "Usage %s [sproto/protobuf]"%sys.argv[0] 112 | exit() 113 | server = Server(("0.0.0.0", port), sys.argv[1]) 114 | print "listen on", port 115 | server.run() 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /proto/rpcservice.py: -------------------------------------------------------------------------------- 1 | # /usr/bin/python2 2 | # -*- coding: utf-8 -*- 3 | import logging, time, struct 4 | from proto import proto 5 | import gevent 6 | from gevent import socket 7 | from gevent.queue import Queue 8 | from gevent.event import AsyncResult 9 | import socket 10 | 11 | from google.protobuf import text_format 12 | 13 | import command 14 | from .rc4 import RC4 15 | from cfg import * 16 | 17 | class RpcService(object): 18 | SESSION_ID = 1 19 | def __init__(self, addr, game): 20 | self.hub = gevent.get_hub() 21 | self.addr = addr 22 | self.sock = None 23 | self.game = game 24 | 25 | self.time_diff = 0 26 | 27 | self.write_queue = Queue() 28 | self.write_tr = None 29 | 30 | self.read_queue = Queue() 31 | self.read_tr = None 32 | self.dispatch_tr = None 33 | 34 | 35 | entype = Config.get("encrypt", None) 36 | if entype == "rc4": 37 | self.c2s_encrypt = RC4(Config["c2s_key"]).crypt 38 | self.s2c_encrypt = RC4(Config["s2c_key"]).crypt 39 | elif entype == None: 40 | self.c2s_encrypt = None 41 | self.s2c_encrypt = None 42 | else: 43 | raise ValueError("not support %s encrypt"%entype) 44 | 45 | self._sessions = {} 46 | 47 | def _start(self): 48 | if self.sock: 49 | return 50 | 51 | # sock = util.RC4Conn(self.addr) 52 | # sock.connect() 53 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 54 | sock.connect(self.addr) 55 | 56 | self.sock = sock 57 | self.read_tr = gevent.spawn(self._read) 58 | self.write_tr = gevent.spawn(self._write) 59 | self.dispatch_tr = gevent.spawn(self._dispatch) 60 | return True 61 | 62 | def set_timestamp(self, timestamp): 63 | self.time_diff = timestamp - int(time.time()) 64 | 65 | def timestamp(self): 66 | return int(time.time()) + self.time_diff 67 | 68 | def stop(self): 69 | gevent.spawn(self._stop) 70 | 71 | def _stop(self): 72 | while True: 73 | gevent.sleep(1) 74 | if not self.write_queue.empty(): 75 | continue 76 | 77 | if not self.read_queue.empty(): 78 | continue 79 | 80 | gevent.kill(self.write_tr) 81 | gevent.kill(self.read_tr) 82 | gevent.kill(self.dispatch_tr) 83 | self.sock.close() 84 | break 85 | 86 | def _write(self): 87 | while True: 88 | data = self.write_queue.get() 89 | if self.c2s_encrypt: 90 | data = self.c2s_encrypt(data) 91 | try: 92 | self.sock.sendall(data) 93 | except socket.error, e: 94 | logging.info("write socket failed:%s" % str(e)) 95 | break 96 | 97 | def _read(self): 98 | left = "" 99 | while True: 100 | try: 101 | buf = self.sock.recv(4*1024) 102 | if not buf: 103 | logging.info("client disconnected, %s:%s" % self.addr) 104 | break 105 | if self.s2c_encrypt: 106 | buf = self.s2c_encrypt(buf) 107 | 108 | except socket.error, e: 109 | logging.info("read socket failed:%s" % str(e)) 110 | break 111 | 112 | left = left + buf 113 | while True: 114 | if len(left) < 2: 115 | break 116 | 117 | plen, = struct.unpack('!H', left[:2]) 118 | if len(left) < plen + 2: 119 | break 120 | 121 | data = left[2:plen+2] 122 | left = left[plen+2:] 123 | self.read_queue.put(data) 124 | 125 | def _dispatch(self): 126 | while True: 127 | data = self.read_queue.get() 128 | p = proto.dispatch(data) 129 | session = p["session"] 130 | msg = p["msg"] 131 | 132 | if p["type"] == "REQUEST": 133 | protoname = p["proto"] 134 | cb = command.get_handle(protoname) 135 | if not cb: 136 | print "no handler for proto:", protoname 137 | continue 138 | resp = cb(self.game, msg) 139 | if session: 140 | # rpc call 141 | pack = proto.response(protoname, resp, session) 142 | self._send(pack) 143 | else: 144 | # response 145 | ev = self._sessions[session] 146 | del self._sessions[session] 147 | ev.set(msg) 148 | 149 | def _get_session(self): 150 | cls = type(self) 151 | if cls.SESSION_ID > 100000000: 152 | cls.SESSION_ID = 1 153 | cls.SESSION_ID += 1 154 | return cls.SESSION_ID 155 | 156 | def _send(self, data): 157 | self.write_queue.put(struct.pack("!H", len(data)) + data) 158 | 159 | def invoke(self, protoname, msg): 160 | pack = proto.request(protoname, msg) 161 | self._send(pack) 162 | 163 | def call(self, protoname, msg): 164 | session = self._get_session() 165 | pack = proto.request(protoname, msg, session) 166 | ev = AsyncResult() 167 | self._sessions[session] = ev 168 | self._send(pack) 169 | return ev.get() 170 | 171 | -------------------------------------------------------------------------------- /proto/pbproto/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | from . import proto_dict 3 | import client_rpc 4 | from .protobuf_to_dict import protobuf_to_dict, dict_to_protobuf 5 | 6 | def _pb2dict(msg): 7 | return protobuf_to_dict(msg, use_enum_labels=True) 8 | 9 | def _dict2pb(pb, values): 10 | return dict_to_protobuf(pb, values) 11 | 12 | class Message(object): 13 | def __init__(self, pack, id): 14 | self.__dict__['_Rpc_Pack'] = pack 15 | self.__dict__['_Rpc_Type_Id'] = id 16 | 17 | def __getattr__(self, key): 18 | if key not in ('_Rpc_Pack', '_Rpc_Type_Id'): 19 | return getattr(self._Rpc_Pack, key) 20 | if key == '_Rpc_Type_Id': 21 | return self.__dict__['_Rpc_Type_Id'] 22 | else: 23 | raise AttributeError 24 | 25 | def __setattr__(self, key, value): 26 | if key not in ('_Rpc_Pack', '_Rpc_Type_Id'): 27 | setattr(self._Rpc_Pack, key, value) 28 | else: 29 | raise AttributeError 30 | 31 | def type_id(self): 32 | return self.__dict__['_Rpc_Type_Id'] 33 | 34 | def pb_pack(self): 35 | return self.__dict__['_Rpc_Pack'] 36 | 37 | def encode(self): 38 | return self.__dict__['_Rpc_Pack'].SerializeToString() 39 | 40 | def decode(self, data): 41 | return self.__dict__['_Rpc_Pack'].ParseFromString(data) 42 | 43 | 44 | class PbRpc(object): 45 | def __init__(self, pb_path): 46 | self._handlers = {} 47 | self._name2id = {} 48 | self._module2name = {} 49 | self._sessions = {} 50 | self.time_diff = 0 51 | 52 | proto_dict.load_dir(pb_path) 53 | self.__register_message(client_rpc.Descriptor) 54 | 55 | def __register_message(self, descriptor): 56 | for module_name in descriptor: 57 | name_list = [] 58 | for desc in descriptor[module_name]: 59 | name = desc['name'] 60 | type_name = module_name + '.' + desc['name'] 61 | handler = { 62 | 'input' : proto_dict.get(desc['input']), 63 | 'output': proto_dict.get(desc['output']), 64 | 'id': desc['id'], 65 | 'name': type_name, 66 | 'module_name': module_name, 67 | } 68 | self._handlers[desc['id']] = handler 69 | self._name2id[type_name] = desc['id'] 70 | name_list.append(name) 71 | self._module2name[module_name] = name_list 72 | 73 | def parse_pack(self, pack): 74 | p = proto_dict.get('proto.Pack')() 75 | p.ParseFromString(pack) 76 | return p 77 | 78 | def pack(self, session = None, type = None, body = None, timestamp = None): 79 | p = proto_dict.get('proto.Pack')() 80 | if session != None: 81 | p.session = session 82 | if type != None: 83 | p.type = type 84 | if body != None: 85 | p.data = body 86 | if timestamp != None: 87 | p.timestamp = timestamp 88 | else: 89 | p.timestamp = int(time.time() - 2.5) 90 | 91 | return p.SerializeToString() 92 | 93 | def lookup(self, type_name, input = True): 94 | handler = self._handlers[self._name2id[type_name]] 95 | if input: 96 | message = Message(handler['input'](), handler['id']) 97 | else: 98 | message = Message(handler['output'](), handler['id']) 99 | 100 | return message 101 | 102 | def type_name(self, type_id): 103 | return self._handlers[type_id]["name"] 104 | 105 | def type_names(self, module_name): 106 | return self._module2name[module_name] 107 | 108 | def type_id(self, type_name): 109 | return self._name2id[type_name] 110 | 111 | def module_name(self, type_id): 112 | return self._handlers[type_id]["module_name"] 113 | 114 | def has_response(self, type_id): 115 | return self._handlers[type_id]["output"] and True or False 116 | 117 | def parse_pack(self, pack): 118 | p = proto_dict.get('proto.Pack')() 119 | p.ParseFromString(pack) 120 | return p 121 | 122 | def set_timestamp(self, timestamp): 123 | self.time_diff = timestamp - int(time.time()) 124 | 125 | def timestamp(self): 126 | return int(time.time()) + self.time_diff 127 | 128 | def make_pack(self, msg, pack): 129 | def _fill(msg, pack, depth = 0): 130 | if depth > 100: 131 | raise OverflowError("too deep") 132 | 133 | if type(msg) is dict: 134 | for k,v in msg.iteritems(): 135 | if type(v) != dict and type(v) != list: 136 | setattr(pack, k, v) 137 | else: 138 | _fill(v, getattr(pack,k), depth+1) 139 | elif type(msg) is list: 140 | for v in msg: 141 | if type(v) != dict and type(v) != list: 142 | pack.add(v) 143 | else: 144 | _fill(v, pack.add(), depth+1) 145 | else: 146 | raise TypeError(type(msg)) 147 | 148 | return _fill(msg, pack) 149 | 150 | def dispatch(self, data): 151 | p = self.parse_pack(data) 152 | if p.type != 0: 153 | # request 154 | type_id = p.type 155 | type_name = self.type_name(type_id) 156 | message = self.lookup(type_name) 157 | message.decode(p.data) 158 | return { 159 | "type":"REQUEST", 160 | "proto":type_name, 161 | "msg": _pb2dict(message), 162 | "session": p.session if p.session != 0 else None, 163 | } 164 | else: 165 | # response 166 | session = p.session 167 | type_name = self._sessions[session] 168 | message = self.lookup(type_name, False) 169 | message.decode(p.data) 170 | del self._sessions[session] 171 | return {"type":"RESPONSE", "session":session, "msg":_pb2dict(message)} 172 | 173 | def request(self, protoname, msg, session = 0): 174 | if session: 175 | self._sessions[session] = protoname 176 | pack = self.lookup(protoname) 177 | if msg: 178 | _dict2pb(pack.pb_pack(), msg) 179 | 180 | return self.pack(session, pack.type_id(), pack.encode(), self.timestamp()) 181 | 182 | def response(self, protoname, msg, session): 183 | pack = self.lookup(protoname) 184 | if msg: 185 | _dict2pb(pack.pb_pack(), msg) 186 | return self.pack(session, pack.type_id(), pack.encode(), self.timestamp()) 187 | -------------------------------------------------------------------------------- /proto/pbproto/protobuf_to_dict.py: -------------------------------------------------------------------------------- 1 | # copy from https://github.com/benhodgson/protobuf-to-dict/blob/master/src/protobuf_to_dict.py 2 | 3 | 4 | from google.protobuf.message import Message 5 | from google.protobuf.descriptor import FieldDescriptor 6 | 7 | 8 | __all__ = ["protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf", "REVERSE_TYPE_CALLABLE_MAP"] 9 | 10 | 11 | EXTENSION_CONTAINER = '___X' 12 | 13 | 14 | TYPE_CALLABLE_MAP = { 15 | FieldDescriptor.TYPE_DOUBLE: float, 16 | FieldDescriptor.TYPE_FLOAT: float, 17 | FieldDescriptor.TYPE_INT32: int, 18 | FieldDescriptor.TYPE_INT64: long, 19 | FieldDescriptor.TYPE_UINT32: int, 20 | FieldDescriptor.TYPE_UINT64: long, 21 | FieldDescriptor.TYPE_SINT32: int, 22 | FieldDescriptor.TYPE_SINT64: long, 23 | FieldDescriptor.TYPE_FIXED32: int, 24 | FieldDescriptor.TYPE_FIXED64: long, 25 | FieldDescriptor.TYPE_SFIXED32: int, 26 | FieldDescriptor.TYPE_SFIXED64: long, 27 | FieldDescriptor.TYPE_BOOL: bool, 28 | FieldDescriptor.TYPE_STRING: unicode, 29 | FieldDescriptor.TYPE_BYTES: lambda b: b.encode("base64"), 30 | FieldDescriptor.TYPE_ENUM: int, 31 | } 32 | 33 | 34 | def repeated(type_callable): 35 | return lambda value_list: [type_callable(value) for value in value_list] 36 | 37 | 38 | def enum_label_name(field, value): 39 | return field.enum_type.values_by_number[int(value)].name 40 | 41 | 42 | def protobuf_to_dict(pb, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False): 43 | result_dict = {} 44 | extensions = {} 45 | for field, value in pb.ListFields(): 46 | type_callable = _get_field_value_adaptor(pb, field, type_callable_map, use_enum_labels) 47 | if field.label == FieldDescriptor.LABEL_REPEATED: 48 | type_callable = repeated(type_callable) 49 | 50 | if field.is_extension: 51 | extensions[str(field.number)] = type_callable(value) 52 | continue 53 | 54 | result_dict[field.name] = type_callable(value) 55 | 56 | if extensions: 57 | result_dict[EXTENSION_CONTAINER] = extensions 58 | return result_dict 59 | 60 | 61 | def _get_field_value_adaptor(pb, field, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False): 62 | if field.type == FieldDescriptor.TYPE_MESSAGE: 63 | # recursively encode protobuf sub-message 64 | return lambda pb: protobuf_to_dict(pb, 65 | type_callable_map=type_callable_map, 66 | use_enum_labels=use_enum_labels) 67 | 68 | if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM: 69 | return lambda value: enum_label_name(field, value) 70 | 71 | if field.type in type_callable_map: 72 | return type_callable_map[field.type] 73 | 74 | raise TypeError("Field %s.%s has unrecognised type id %d" % ( 75 | pb.__class__.__name__, field.name, field.type)) 76 | 77 | 78 | def get_bytes(value): 79 | return value.decode('base64') 80 | 81 | 82 | REVERSE_TYPE_CALLABLE_MAP = { 83 | FieldDescriptor.TYPE_BYTES: get_bytes, 84 | } 85 | 86 | 87 | def dict_to_protobuf(pb_klass_or_instance, values, type_callable_map=REVERSE_TYPE_CALLABLE_MAP, strict=True): 88 | """Populates a protobuf model from a dictionary. 89 | :param pb_klass_or_instance: a protobuf message class, or an protobuf instance 90 | :type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message 91 | :param dict values: a dictionary of values. Repeated and nested values are 92 | fully supported. 93 | :param dict type_callable_map: a mapping of protobuf types to callables for setting 94 | values on the target instance. 95 | :param bool strict: complain if keys in the map are not fields on the message. 96 | """ 97 | if isinstance(pb_klass_or_instance, Message): 98 | instance = pb_klass_or_instance 99 | else: 100 | instance = pb_klass_or_instance() 101 | return _dict_to_protobuf(instance, values, type_callable_map, strict) 102 | 103 | 104 | def _get_field_mapping(pb, dict_value, strict): 105 | field_mapping = [] 106 | for key, value in dict_value.items(): 107 | if key == EXTENSION_CONTAINER: 108 | continue 109 | if key not in pb.DESCRIPTOR.fields_by_name: 110 | if strict: 111 | raise KeyError("%s does not have a field called %s" % (pb, key)) 112 | continue 113 | field_mapping.append((pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None))) 114 | 115 | for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items(): 116 | try: 117 | ext_num = int(ext_num) 118 | except ValueError: 119 | raise ValueError("Extension keys must be integers.") 120 | if ext_num not in pb._extensions_by_number: 121 | if strict: 122 | raise KeyError("%s does not have a extension with number %s. Perhaps you forgot to import it?" % (pb, key)) 123 | continue 124 | ext_field = pb._extensions_by_number[ext_num] 125 | pb_val = None 126 | pb_val = pb.Extensions[ext_field] 127 | field_mapping.append((ext_field, ext_val, pb_val)) 128 | 129 | return field_mapping 130 | 131 | 132 | def _dict_to_protobuf(pb, value, type_callable_map, strict): 133 | fields = _get_field_mapping(pb, value, strict) 134 | 135 | for field, input_value, pb_value in fields: 136 | if field.label == FieldDescriptor.LABEL_REPEATED: 137 | for item in input_value: 138 | if field.type == FieldDescriptor.TYPE_MESSAGE: 139 | m = pb_value.add() 140 | _dict_to_protobuf(m, item, type_callable_map, strict) 141 | elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(item, basestring): 142 | pb_value.append(_string_to_enum(field, item)) 143 | else: 144 | pb_value.append(item) 145 | continue 146 | if field.type == FieldDescriptor.TYPE_MESSAGE: 147 | _dict_to_protobuf(pb_value, input_value, type_callable_map, strict) 148 | continue 149 | 150 | if field.type in type_callable_map: 151 | input_value = type_callable_map[field.type](input_value) 152 | 153 | if field.is_extension: 154 | pb.Extensions[field] = input_value 155 | continue 156 | 157 | if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value, basestring): 158 | input_value = _string_to_enum(field, input_value) 159 | 160 | setattr(pb, field.name, input_value) 161 | 162 | return pb 163 | 164 | def _string_to_enum(field, input_value): 165 | enum_dict = field.enum_type.values_by_name 166 | try: 167 | input_value = enum_dict[input_value].number 168 | except KeyError: 169 | raise KeyError("`%s` is not a valid value for field `%s`" % (input_value, field.name)) 170 | return input_value 171 | -------------------------------------------------------------------------------- /test/sprotoparser.lua: -------------------------------------------------------------------------------- 1 | local lpeg = require "lpeg" 2 | local table = require "table" 3 | 4 | local packbytes 5 | local packvalue 6 | 7 | if _VERSION == "Lua 5.3" then 8 | function packbytes(str) 9 | return string.pack("=0 and id < 65536) 32 | local a = id % 256 33 | local b = math.floor(id / 256) 34 | return string.char(a) .. string.char(b) 35 | end 36 | end 37 | 38 | local P = lpeg.P 39 | local S = lpeg.S 40 | local R = lpeg.R 41 | local C = lpeg.C 42 | local Ct = lpeg.Ct 43 | local Cg = lpeg.Cg 44 | local Cc = lpeg.Cc 45 | local V = lpeg.V 46 | 47 | local function count_lines(_,pos, parser_state) 48 | if parser_state.pos < pos then 49 | parser_state.line = parser_state.line + 1 50 | parser_state.pos = pos 51 | end 52 | return pos 53 | end 54 | 55 | local exception = lpeg.Cmt( lpeg.Carg(1) , function ( _ , pos, parser_state) 56 | error(string.format("syntax error at [%s] line (%d)", parser_state.file or "", parser_state.line)) 57 | return pos 58 | end) 59 | 60 | local eof = P(-1) 61 | local newline = lpeg.Cmt((P"\n" + "\r\n") * lpeg.Carg(1) ,count_lines) 62 | local line_comment = "#" * (1 - newline) ^0 * (newline + eof) 63 | local blank = S" \t" + newline + line_comment 64 | local blank0 = blank ^ 0 65 | local blanks = blank ^ 1 66 | local alpha = R"az" + R"AZ" + "_" 67 | local alnum = alpha + R"09" 68 | local word = alpha * alnum ^ 0 69 | local name = C(word) 70 | local typename = C(word * ("." * word) ^ 0) 71 | local tag = R"09" ^ 1 / tonumber 72 | local mainkey = "(" * blank0 * name * blank0 * ")" 73 | 74 | local function multipat(pat) 75 | return Ct(blank0 * (pat * blanks) ^ 0 * pat^0 * blank0) 76 | end 77 | 78 | local function namedpat(name, pat) 79 | return Ct(Cg(Cc(name), "type") * Cg(pat)) 80 | end 81 | 82 | local typedef = P { 83 | "ALL", 84 | FIELD = namedpat("field", (name * blanks * tag * blank0 * ":" * blank0 * (C"*")^0 * typename * mainkey^0)), 85 | STRUCT = P"{" * multipat(V"FIELD" + V"TYPE") * P"}", 86 | TYPE = namedpat("type", P"." * name * blank0 * V"STRUCT" ), 87 | SUBPROTO = Ct((C"request" + C"response") * blanks * (typename + V"STRUCT")), 88 | PROTOCOL = namedpat("protocol", name * blanks * tag * blank0 * P"{" * multipat(V"SUBPROTO") * P"}"), 89 | ALL = multipat(V"TYPE" + V"PROTOCOL"), 90 | } 91 | 92 | local proto = blank0 * typedef * blank0 93 | 94 | local convert = {} 95 | 96 | function convert.protocol(all, obj) 97 | local result = { tag = obj[2] } 98 | for _, p in ipairs(obj[3]) do 99 | assert(result[p[1]] == nil) 100 | local typename = p[2] 101 | if type(typename) == "table" then 102 | local struct = typename 103 | typename = obj[1] .. "." .. p[1] 104 | all.type[typename] = convert.type(all, { typename, struct }) 105 | end 106 | result[p[1]] = typename 107 | end 108 | return result 109 | end 110 | 111 | function convert.type(all, obj) 112 | local result = {} 113 | local typename = obj[1] 114 | local tags = {} 115 | local names = {} 116 | for _, f in ipairs(obj[2]) do 117 | if f.type == "field" then 118 | local name = f[1] 119 | if names[name] then 120 | error(string.format("redefine %s in type %s", name, typename)) 121 | end 122 | names[name] = true 123 | local tag = f[2] 124 | if tags[tag] then 125 | error(string.format("redefine tag %d in type %s", tag, typename)) 126 | end 127 | tags[tag] = true 128 | local field = { name = name, tag = tag } 129 | table.insert(result, field) 130 | local fieldtype = f[3] 131 | if fieldtype == "*" then 132 | field.array = true 133 | fieldtype = f[4] 134 | end 135 | local mainkey = f[5] 136 | if mainkey then 137 | assert(field.array) 138 | field.key = mainkey 139 | end 140 | field.typename = fieldtype 141 | else 142 | assert(f.type == "type") -- nest type 143 | local nesttypename = typename .. "." .. f[1] 144 | f[1] = nesttypename 145 | assert(all.type[nesttypename] == nil, "redefined " .. nesttypename) 146 | all.type[nesttypename] = convert.type(all, f) 147 | end 148 | end 149 | table.sort(result, function(a,b) return a.tag < b.tag end) 150 | return result 151 | end 152 | 153 | local function adjust(r) 154 | local result = { type = {} , protocol = {} } 155 | 156 | for _, obj in ipairs(r) do 157 | local set = result[obj.type] 158 | local name = obj[1] 159 | assert(set[name] == nil , "redefined " .. name) 160 | set[name] = convert[obj.type](result,obj) 161 | end 162 | 163 | return result 164 | end 165 | 166 | local buildin_types = { 167 | integer = 0, 168 | boolean = 1, 169 | string = 2, 170 | } 171 | 172 | local function checktype(types, ptype, t) 173 | if buildin_types[t] then 174 | return t 175 | end 176 | local fullname = ptype .. "." .. t 177 | if types[fullname] then 178 | return fullname 179 | else 180 | ptype = ptype:match "(.+)%..+$" 181 | if ptype then 182 | return checktype(types, ptype, t) 183 | elseif types[t] then 184 | return t 185 | end 186 | end 187 | end 188 | 189 | local function check_protocol(r) 190 | local map = {} 191 | local type = r.type 192 | for name, v in pairs(r.protocol) do 193 | local tag = v.tag 194 | local request = v.request 195 | local response = v.response 196 | local p = map[tag] 197 | 198 | if p then 199 | error(string.format("redefined protocol tag %d at %s", tag, name)) 200 | end 201 | 202 | if request and not type[request] then 203 | error(string.format("Undefined request type %s in protocol %s", request, name)) 204 | end 205 | 206 | if response and not type[response] then 207 | error(string.format("Undefined response type %s in protocol %s", response, name)) 208 | end 209 | 210 | map[tag] = v 211 | end 212 | return r 213 | end 214 | 215 | local function flattypename(r) 216 | for typename, t in pairs(r.type) do 217 | for _, f in pairs(t) do 218 | local ftype = f.typename 219 | local fullname = checktype(r.type, typename, ftype) 220 | if fullname == nil then 221 | error(string.format("Undefined type %s in type %s", ftype, typename)) 222 | end 223 | f.typename = fullname 224 | end 225 | end 226 | 227 | return r 228 | end 229 | 230 | local function parser(text,filename) 231 | local state = { file = filename, pos = 0, line = 1 } 232 | local r = lpeg.match(proto * -1 + exception , text , 1, state ) 233 | return flattypename(check_protocol(adjust(r))) 234 | end 235 | 236 | --[[ 237 | -- The protocol of sproto 238 | .type { 239 | .field { 240 | name 0 : string 241 | buildin 1 : integer 242 | type 2 : integer 243 | tag 3 : integer 244 | array 4 : boolean 245 | key 5 : integer # If key exists, array must be true, and it's a map. 246 | } 247 | name 0 : string 248 | fields 1 : *field 249 | } 250 | 251 | .protocol { 252 | name 0 : string 253 | tag 1 : integer 254 | request 2 : integer # index 255 | response 3 : integer # index 256 | } 257 | 258 | .group { 259 | type 0 : *type 260 | protocol 1 : *protocol 261 | } 262 | ]] 263 | 264 | local function packfield(f) 265 | local strtbl = {} 266 | if f.array then 267 | if f.key then 268 | table.insert(strtbl, "\6\0") -- 6 fields 269 | else 270 | table.insert(strtbl, "\5\0") -- 5 fields 271 | end 272 | else 273 | table.insert(strtbl, "\4\0") -- 4 fields 274 | end 275 | table.insert(strtbl, "\0\0") -- name (tag = 0, ref an object) 276 | if f.buildin then 277 | table.insert(strtbl, packvalue(f.buildin)) -- buildin (tag = 1) 278 | table.insert(strtbl, "\1\0") -- skip (tag = 2) 279 | table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3) 280 | else 281 | table.insert(strtbl, "\1\0") -- skip (tag = 1) 282 | table.insert(strtbl, packvalue(f.type)) -- type (tag = 2) 283 | table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3) 284 | end 285 | if f.array then 286 | table.insert(strtbl, packvalue(1)) -- array = true (tag = 4) 287 | end 288 | if f.key then 289 | table.insert(strtbl, packvalue(f.key)) -- key tag (tag = 5) 290 | end 291 | table.insert(strtbl, packbytes(f.name)) -- external object (name) 292 | return packbytes(table.concat(strtbl)) 293 | end 294 | 295 | local function packtype(name, t, alltypes) 296 | local fields = {} 297 | local tmp = {} 298 | for _, f in ipairs(t) do 299 | tmp.array = f.array 300 | tmp.name = f.name 301 | tmp.tag = f.tag 302 | 303 | tmp.buildin = buildin_types[f.typename] 304 | local subtype 305 | if not tmp.buildin then 306 | subtype = assert(alltypes[f.typename]) 307 | tmp.type = subtype.id 308 | else 309 | tmp.type = nil 310 | end 311 | if f.key then 312 | tmp.key = subtype.fields[f.key] 313 | if not tmp.key then 314 | error("Invalid map index :" .. f.key) 315 | end 316 | else 317 | tmp.key = nil 318 | end 319 | 320 | table.insert(fields, packfield(tmp)) 321 | end 322 | local data 323 | if #fields == 0 then 324 | data = { 325 | "\1\0", -- 1 fields 326 | "\0\0", -- name (id = 0, ref = 0) 327 | packbytes(name), 328 | } 329 | else 330 | data = { 331 | "\2\0", -- 2 fields 332 | "\0\0", -- name (tag = 0, ref = 0) 333 | "\0\0", -- field[] (tag = 1, ref = 1) 334 | packbytes(name), 335 | packbytes(table.concat(fields)), 336 | } 337 | end 338 | 339 | return packbytes(table.concat(data)) 340 | end 341 | 342 | local function packproto(name, p, alltypes) 343 | -- if p.request == nil then 344 | -- error(string.format("Protocol %s need request", name)) 345 | -- end 346 | if p.request then 347 | local request = alltypes[p.request] 348 | if request == nil then 349 | error(string.format("Protocol %s request type %s not found", name, p.request)) 350 | end 351 | request = request.id 352 | end 353 | local tmp = { 354 | "\4\0", -- 4 fields 355 | "\0\0", -- name (id=0, ref=0) 356 | packvalue(p.tag), -- tag (tag=1) 357 | } 358 | if p.request == nil and p.response == nil then 359 | tmp[1] = "\2\0" 360 | else 361 | if p.request then 362 | table.insert(tmp, packvalue(alltypes[p.request].id)) -- request typename (tag=2) 363 | else 364 | table.insert(tmp, "\1\0") 365 | end 366 | if p.response then 367 | table.insert(tmp, packvalue(alltypes[p.response].id)) -- request typename (tag=3) 368 | else 369 | tmp[1] = "\3\0" 370 | end 371 | end 372 | 373 | table.insert(tmp, packbytes(name)) 374 | 375 | return packbytes(table.concat(tmp)) 376 | end 377 | 378 | local function packgroup(t,p) 379 | if next(t) == nil then 380 | assert(next(p) == nil) 381 | return "\0\0" 382 | end 383 | local tt, tp 384 | local alltypes = {} 385 | for name in pairs(t) do 386 | table.insert(alltypes, name) 387 | end 388 | table.sort(alltypes) -- make result stable 389 | for idx, name in ipairs(alltypes) do 390 | local fields = {} 391 | for _, type_fields in ipairs(t[name]) do 392 | if buildin_types[type_fields.typename] then 393 | fields[type_fields.name] = type_fields.tag 394 | end 395 | end 396 | alltypes[name] = { id = idx - 1, fields = fields } 397 | end 398 | tt = {} 399 | for _,name in ipairs(alltypes) do 400 | table.insert(tt, packtype(name, t[name], alltypes)) 401 | end 402 | tt = packbytes(table.concat(tt)) 403 | if next(p) then 404 | local tmp = {} 405 | for name, tbl in pairs(p) do 406 | table.insert(tmp, tbl) 407 | tbl.name = name 408 | end 409 | table.sort(tmp, function(a,b) return a.tag < b.tag end) 410 | 411 | tp = {} 412 | for _, tbl in ipairs(tmp) do 413 | table.insert(tp, packproto(tbl.name, tbl, alltypes)) 414 | end 415 | tp = packbytes(table.concat(tp)) 416 | end 417 | local result 418 | if tp == nil then 419 | result = { 420 | "\1\0", -- 1 field 421 | "\0\0", -- type[] (id = 0, ref = 0) 422 | tt, 423 | } 424 | else 425 | result = { 426 | "\2\0", -- 2fields 427 | "\0\0", -- type array (id = 0, ref = 0) 428 | "\0\0", -- protocol array (id = 1, ref =1) 429 | 430 | tt, 431 | tp, 432 | } 433 | end 434 | 435 | return table.concat(result) 436 | end 437 | 438 | local function encodeall(r) 439 | return packgroup(r.type, r.protocol) 440 | end 441 | 442 | local sparser = {} 443 | 444 | function sparser.dump(str) 445 | local tmp = "" 446 | for i=1,#str do 447 | tmp = tmp .. string.format("%02X ", string.byte(str,i)) 448 | if i % 8 == 0 then 449 | if i % 16 == 0 then 450 | print(tmp) 451 | tmp = "" 452 | else 453 | tmp = tmp .. "- " 454 | end 455 | end 456 | end 457 | print(tmp) 458 | end 459 | 460 | function sparser.parse(text, name) 461 | local r = parser(text, name or "=text") 462 | local data = encodeall(r) 463 | return data 464 | end 465 | 466 | return sparser 467 | --------------------------------------------------------------------------------