├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── LICENSE ├── balancer └── balancer.lua ├── channel.lua ├── example ├── gateway.lua └── session.lua ├── http ├── handlers.lua ├── initParam.lua ├── mockauth.lua └── ssoProcessors.lua ├── ldap.lua ├── ldap ├── ldapPackets.lua └── parser.lua ├── parser.lua ├── readme.md ├── session ├── session.lua └── sessionManager.lua ├── ssh2.lua ├── ssh2 ├── commandCollector.lua ├── parser.lua ├── shellCommand.lua ├── ssh2CipherConf.lua └── ssh2Packets.lua ├── suproxy-v0.6.0-1.rockspec ├── tds.lua ├── tds ├── datetime.lua ├── parser.lua ├── tdsPackets.lua ├── token.lua └── version.lua ├── test.lua ├── tns.lua ├── tns ├── crypt.lua ├── parser.lua └── tnsPackets.lua └── utils ├── asn1.lua ├── compatibleLog.lua ├── datetime.lua ├── event.lua ├── ffi-zlib.lua ├── json.lua ├── pureluapack.lua ├── stringUtils.lua ├── tableUtils.lua ├── unicode.lua └── utils.lua /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, yizhu2000 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /balancer/balancer.lua: -------------------------------------------------------------------------------- 1 | --Default random balancer. this balancer randomly select upstreams from given 2 | --list. if one upstream is blamed, this upstream will be unselectable for given 3 | --suspendSpan time. 4 | local tableUtils=require "suproxy.utils.tableUtils" 5 | local utils=require "suproxy.utils.utils" 6 | local OrderedTable=tableUtils.OrderedTable 7 | local _M={} 8 | local function getKey(ip,port) return string.format("ip%sport%s",ip,port) end 9 | 10 | function _M:new(upstreams,suspendSpan) 11 | math.randomseed(utils.getTime()) 12 | local o=setmetatable({},{__index=self}) 13 | assert(upstreams,"upstreams can not be nil") 14 | o.upstreams=OrderedTable:new() 15 | o.blameList={} 16 | o.suspendSpan=suspendSpan or 30 17 | for i,v in ipairs(upstreams) do 18 | assert(v.ip,"upstream ip address cannot be null") 19 | assert(v.port,"upstream ip address cannot be null") 20 | o.upstreams[getKey(v.ip,v.port)]=v 21 | end 22 | return o 23 | end 24 | 25 | function _M:getBest() 26 | for k,v in pairs (self.blameList) do 27 | if v.addTime+self.suspendSpan<=utils.getTime() then 28 | self.upstreams[k]=v.value 29 | end 30 | end 31 | if #self.upstreams ==0 then return nil end 32 | local i=math.ceil(math.random(1,#self.upstreams)) 33 | if self.upstreams[i] then 34 | return self.upstreams[i].value 35 | end 36 | end 37 | 38 | function _M:blame(upstream) 39 | assert(upstream.ip,"upstream ip address cannot be null") 40 | assert(upstream.port,"upstream ip address cannot be null") 41 | local key=getKey(upstream.ip,upstream.port) 42 | if self.upstreams[key] then 43 | self.blameList[key]={addTime=utils.getTime(),value=self.upstreams[key]} 44 | self.upstreams:remove(key) 45 | end 46 | end 47 | 48 | _M.unitTest={} 49 | function _M.test() 50 | print("------------running balancer test") 51 | local suspendSpan=5 52 | local a=_M:new({{ip=1,port=1},{ip=2,port=2},{ip=3,port=3}},suspendSpan) 53 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 54 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 55 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 56 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 57 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 58 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 59 | print(tableUtils.printTableF(a:getBest(),{inline=true})) 60 | assert(#(a.upstreams)==3) 61 | a:blame({ip=1,port=1}) 62 | assert(#(a.upstreams)==2) 63 | assert(a.upstreams[getKey(1,1)]==nil) 64 | assert(a.upstreams[getKey(2,2)]) 65 | assert(a.upstreams[getKey(3,3)]) 66 | assert(a.blameList[getKey(1,1)]) 67 | assert(a.blameList[getKey(1,1)].value.ip==1) 68 | print("------------wait ",suspendSpan," seconds") 69 | local t0 = os.clock() 70 | while os.clock() - t0 <= 5 do end 71 | a:getBest() 72 | print(tableUtils.printTableF(a.upstreams)) 73 | assert(#(a.upstreams)==3) 74 | assert(#(a.blameList)==0) 75 | print("------------balancer test finished") 76 | end 77 | return _M -------------------------------------------------------------------------------- /channel.lua: -------------------------------------------------------------------------------- 1 | local sub = string.sub local byte = string.byte local format = string.format local tcp = ngx.socket.tcp local setmetatable = setmetatable local spawn = ngx.thread.spawn local wait = ngx.thread.wait local logger = require "suproxy.utils.compatibleLog" local ses= require "suproxy.session.session" local cjson=require "cjson" 2 | local event=require "suproxy.utils.event" local balancer=require "suproxy.balancer.balancer" local _M={} 3 | 4 | _M._VERSION = '0.01' 5 | 6 | 7 | function _M:new(upstreams,processor,options) local o={} options =options or {} options.c2pConnTimeout=options.c2pConnTimeout or 10000 options.c2pSendTimeout=options.c2pSendTimeout or 10000 options.c2pReadTimeout=options.c2pReadTimeout or 3600000 options.p2sConnTimeout=options.p2sConnTimeout or 10000 options.p2sSendTimeout=options.p2sSendTimeout or 10000 options.p2sReadTimeout=options.p2sReadTimeout or 3600000 8 | local c2pSock, err = ngx.req.socket() 9 | if not c2pSock then 10 | return nil, err 11 | end 12 | c2pSock:settimeouts(options.c2pConnTimeout , options.c2pSendTimeout , options.c2pReadTimeout) 13 | local standalone=false 14 | if(not upstreams) then 15 | logger.log(logger.ERR, format("[SuProxy] no upstream specified, Proxy will run in standalone mode")) 16 | standalone=true 17 | end 18 | local p2sSock=nil 19 | if(not standalone) then 20 | p2sSock, err = tcp() 21 | if not p2sSock then 22 | return nil, err 23 | end 24 | p2sSock:settimeouts(options.p2sConnTimeout , options.p2sSendTimeout , options.p2sReadTimeout ) 25 | end 26 | --add default receive-then-forward processor 27 | if(not processor and not standalone) then 28 | processor={} 29 | processor.processUpRequest=function(self) 30 | local data, err, partial =self.channel:c2pRead(1024*10) --real error happend or timeout if not data and not partial and err then return nil,err end 31 | if(data and not err) then 32 | return data 33 | else 34 | return partial 35 | end 36 | end 37 | processor.processDownRequest=function(self) 38 | local data, err, partial = self.channel:p2sRead(1024*10) --real error happend or timeout if not data and not partial and err then return nil,err end 39 | if(data and not err) then 40 | return data 41 | else 42 | return partial 43 | end 44 | end 45 | end 46 | --add default echo processor if proxy in standalone mode 47 | if(not processor and standalone) then 48 | processor={} 49 | processor.processUpRequest=function(self) 50 | local data, err, partial =self.channel:c2pRead(1024*10) 51 | --real error happend or timeout if not data and not partial and err then return nil,err end 52 | local echodata="" 53 | if(data and not err) then 54 | echodata=data 55 | else 56 | echodata=partial 57 | end 58 | logger.log(logger.INFO,echodata) 59 | local _,err=self.channel:c2pSend(echodata) 60 | logger.log(logger.ERR,partial) 61 | end 62 | end 63 | local upForwarder=function(self,data) 64 | if data then return self.channel:p2sSend(data) end 65 | end 66 | local downForwarder=function(self,data) 67 | if data then return self.channel:c2pSend(data) end 68 | end 69 | --add default upforwarder 70 | processor.sendUp=processor.sendUp or upForwarder 71 | --add default downforwarder 72 | processor.sendDown=processor.sendDown or downForwarder 73 | processor.ctx=processor.ctx or {} local sessionInvalidHandler=function (self,session) logger.log(logger.DEBUG,"session closed") self:shutdown() end --set default session invalid handler processor.sessionInvalid=processor.sessionInvalid or sessionInvalidHandler --set AuthSuccessEvent handler if processor.AuthSuccessEvent then processor.AuthSuccessEvent:addHandler(o,function(self,source,username) if self.session and username then self.session.uid=username end end) end --update ctx info to session if processor.ContextUpdateEvent then processor.ContextUpdateEvent:addHandler(o,function(self,source,ctx) if ctx and self.session then self.session.ctx=ctx end end) end o.p2sSock=p2sSock o.c2pSock=c2pSock o.processor=processor o.balancer=upstreams.getBest and upstreams or balancer:new(upstreams) o.standalone=standalone o.OnConnectEvent=event:new(o,"OnConnectEvent") o.sessionMan=options.sessionMan or ses:newDoNothing() setmetatable(o, { __index = self }) processor.channel=o return o 74 | end 75 | local function _cleanup(self) 76 | logger.log(logger.DEBUG, format("[SuProxy] clean up executed")) 77 | -- make sure buffers are clean 78 | ngx.flush(true) 79 | local p2sSock = self.p2sSock 80 | local c2pSock = self.c2pSock 81 | if p2sSock ~= nil then 82 | if p2sSock.shutdown then 83 | p2sSock:shutdown("send") 84 | end 85 | if p2sSock.close ~= nil then 86 | local ok, err = p2sSock:setkeepalive() 87 | if not ok then 88 | -- 89 | end 90 | end 91 | end 92 | 93 | if c2pSock ~= nil then 94 | if c2pSock.shutdown then 95 | c2pSock:shutdown("send") 96 | end 97 | if c2pSock.close ~= nil then 98 | local ok, err = c2pSock:close() 99 | if not ok then 100 | -- 101 | end 102 | end 103 | end 104 | 105 | end 106 | local function _upl(self) 107 | -- proxy client request to server local upstream=self.upstream 108 | local buf, err, partial local session,err=ses:new(self.processor._PROTOCAL,self.sessionMan) if err then logger.log(logger.ERR, format("[SuProxy] start session fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) return end self.processor.ctx.clientIP=ngx.var.remote_addr self.processor.ctx.clientPort=ngx.var.remote_port self.processor.ctx.srvIP=upstream.ip self.processor.ctx.srvPort=upstream.port self.processor.ctx.srvID=upstream.id self.processor.ctx.srvGID=upstream.gid self.processor.ctx.connTime=ngx.time() session.ctx=self.processor.ctx self.session=session self.OnConnectEvent:trigger({clientIP=session.ctx.clientIP,clientPort=session.ctx.clientPort,srvIP=session.ctx.srvIP,srvPort=session.ctx.srvPort}) 109 | while true do --todo: sessionMan should notify session change if not self.session:valid(self.session) then self.processor:sessionInvalid(self.session) else self.session.uptime=ngx.time() end logger.log(logger.DEBUG,"client --> proxy start process") 110 | buf, err, partial = self.processor:processUpRequest(self.standalone) 111 | if err then 112 | logger.log(logger.ERR, format("[SuProxy] processUpRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) 113 | break 114 | end 115 | --if in standalone mode, don't forward 116 | if not self.standalone and buf then 117 | local _, err = self.processor:sendUp(buf) 118 | if err then 119 | logger.log(logger.ERR, format("[SuProxy] forward to upstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) 120 | break 121 | end 122 | end 123 | end self:shutdown(upstream) 124 | end 125 | local function _dwn(self) local upstream=self.upstream 126 | -- proxy response to client 127 | local buf, err, partial 128 | while true do logger.log(logger.DEBUG,"server --> proxy start process") 129 | buf, err, partial = self.processor:processDownRequest(self.standalone) 130 | if err then 131 | logger.log(logger.ERR, format("[SuProxy] processDownRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) 132 | break 133 | end 134 | if buf then 135 | local _, err = self.processor:sendDown(buf) 136 | if err then 137 | logger.log(logger.ERR, format("[SuProxy] forward to downstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err)) 138 | break 139 | end 140 | end 141 | end self:shutdown(upstream) 142 | end function _M:c2pRead(length) local bytes,err,partial= self.c2pSock:receive(length) logger.logWithTitle(logger.DEBUG,"c2pRead",(bytes and bytes:hex16F() or "")) return bytes,err,partial end function _M:p2sRead(length) local bytes,err,partial= self.p2sSock:receive(length) logger.logWithTitle(logger.DEBUG,"p2sRead",(bytes and bytes:hex16F() or "")) return bytes,err,partial end function _M:c2pSend(bytes) logger.logWithTitle(logger.DEBUG,"c2pSend",(bytes and bytes:hex16F() or "")) return self.c2pSock:send(bytes) end function _M:p2sSend(bytes) logger.logWithTitle(logger.DEBUG,"p2sSend",(bytes and bytes:hex16F() or "")) return self.p2sSock:send(bytes) end 143 | function _M:run() --this while is to ensure _cleanup will always be executed 144 | while true do local upstream 145 | if(not self.standalone) then while true do upstream=self.balancer:getBest() if not upstream then logger.log(logger.ERR, format("[SuProxy] failed to get avaliable upstream")) break end 146 | local ok, err = self.p2sSock:connect(upstream.ip, upstream.port) 147 | if not ok then 148 | logger.log(logger.ERR, format("[SuProxy] failed to connect to proxy upstream: %s:%s, err:%s", upstream.ip, upstream.port, err)) 149 | self.balancer:blame(upstream) 150 | else logger.log(logger.INFO, format("[SuProxy] connect to proxy upstream: %s:%s", upstream.ip, upstream.port)) self.upstream=upstream break end end 151 | end if not self.standalone and not upstream then break end --_singThreadRun(self) 152 | local co_upl = spawn(_upl,self) 153 | if(not self.standalone) then 154 | local co_dwn = spawn(_dwn,self) 155 | wait(co_dwn) 156 | end 157 | wait(co_upl) 158 | break 159 | end 160 | _cleanup(self) 161 | end function _M:shutdown() if self.session then --self.processor:sessionInvalid(self.session) local err=self.session:kill(self.session) if err then logger.log(logger.ERR, format("[SuProxy] kill session fail: %s:%s, err:%s", self.upstream.ip, self.upstream.port, err)) end end _cleanup(self) end 162 | 163 | return _M 164 | -------------------------------------------------------------------------------- /example/gateway.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This demo implements a simple gateway for TNS,TDS,SSH2,LDAP protocols. 3 | To test this demo, modify nginx config, add following section to your 4 | config file. Config server credential in getCredential method, then 5 | use test/test as username/password to login. 6 | make sure the commands.log file path is valid. 7 | stream { 8 | lua_code_cache off; 9 | #mock logserver if you do not have one 10 | server { 11 | listen 12080; 12 | content_by_lua_block { 13 | ngx.log(ngx.DEBUG,"logserver Triggerred") 14 | local reqsock, err = ngx.req.socket(true) 15 | reqsock:settimeout(100) 16 | while(not err) do 17 | local command,err=reqsock:receive() 18 | if(err) then ngx.exit(0) end 19 | local f = assert(io.open("/data/logs/commands.log", "a")) 20 | if(command) then 21 | f:write(command .. "\n") 22 | f:close() 23 | end 24 | end 25 | } 26 | } 27 | #listen on ports 28 | server { 29 | listen 389; 30 | listen 1521; 31 | listen 22; 32 | listen 1433; 33 | content_by_lua_file lualib/suproxy/example/gateway.lua; 34 | } 35 | } 36 | #Session manager interfaces. if you want to view and manage your session 37 | #over http, this should be set. 38 | http { 39 | include mime.types; 40 | lua_code_cache off; 41 | server { 42 | listen 80; 43 | server_name localhost; 44 | default_type text/html; 45 | location /suproxy/manage{ 46 | content_by_lua_file lualib/suproxy/example/session.lua; 47 | } 48 | } 49 | ]] 50 | -------------------------socket logger init----------------------- 51 | local logger = require "resty.logger.socket" 52 | if not logger.initted() then 53 | local ok, err = logger.init{ 54 | -- logger server address 55 | host = '127.0.0.1', 56 | port = 12080, 57 | flush_limit = 10, 58 | drop_limit = 567800, 59 | } 60 | if not ok then 61 | ngx.log(ngx.ERR, "failed to initialize the logger: ",err) 62 | return 63 | end 64 | end 65 | ----------------------sessiom Man init---------------------------- 66 | local sessionManager= require ("suproxy.session.sessionManager"):new{ 67 | --redis server address 68 | ip="127.0.0.1", 69 | port=6379, 70 | expire=-1, 71 | extend=false 72 | } 73 | ----------------------------handlers ----------------------------- 74 | --Demo for swap credentials, oauth or other network auth method should be 75 | --used in real world instead of hardcoded credential 76 | local function getCredential(context,source,credential,session) 77 | local username=credential.username 78 | --if session.srvGID=="linuxServer" and session.srvID=="remote" then 79 | return {username="root",password="xxxxxx"} 80 | --end 81 | end 82 | 83 | --Demo for oauth password exchange 84 | local function oauth(context,source,credential,session) 85 | --show how to get password with oauth protocal,using username as code, an 86 | --app should be add and a password attributes should be add 87 | local param={ 88 | ssoProtocol="OAUTH", 89 | validate_code_url="http://changeToYourOwnTokenLink", 90 | profile_url="http://changeToYourOwnProfile", 91 | client_secret="changeToYourOwnSecret", 92 | appcode=session.srvGID 93 | } 94 | local authenticator=ssoProcessors.getProcessor(param) 95 | local result=authenticator:valiate(credential.username) 96 | local cred 97 | if result.status==ssoProcessors.CHECK_STATUS.SUCCESS then 98 | --confirm oauth password attributes correctly configged 99 | cred.username=result.accountData.user 100 | cred.password=result.accountData.attributes.password 101 | end 102 | if not cred then return cred,"can not get cred from remote server" end 103 | return cred 104 | end 105 | 106 | --Demo for command filter 107 | local function commandFilter(context,source,command,session) 108 | --change this to implement your own filter strategy 109 | if command:match("forbidden")then 110 | return nil,{message=command.." is a forbidden command",code=1234} 111 | end 112 | return command 113 | end 114 | 115 | local function log(username,content,session) 116 | local rs={ 117 | os.date("%Y.%m.%d %H:%M:%S", ngx.time()).."\t" , 118 | session.clientIP..":"..session.clientPort.."\t", 119 | username and username.."\t" or "UNKNOWN ", 120 | content.."\r\n" 121 | } 122 | local bytes, err = logger.log(table.concat(rs)) 123 | if err then 124 | ngx.log(ngx.ERR, "failed to log command: ", err) 125 | end 126 | end 127 | 128 | --Demo for login logging 129 | local function logAuth(context,source,username,session) 130 | local rs={ 131 | "login with ", 132 | (session and session.client) and session.client or "unknown client", 133 | (session and session.clientVersion) and session.clientVersion or "" 134 | } 135 | log(username,table.concat(rs),session) 136 | end 137 | 138 | --Demo for connect logging 139 | local function logConnect(context,source,connInfo) 140 | local rs={"connect to ",connInfo.srvIP..":"..connInfo.srvPort} 141 | log("UNKNOWN",table.concat(rs),connInfo) 142 | end 143 | 144 | --Demo for command logging 145 | local function logCmd(context,source,command,reply,session) 146 | local username=session.username 147 | log(username,command,session) 148 | if not reply or reply=="" then return end 149 | local bytes, err = logger.log("------------------reply--------------------\r\n" 150 | ..reply:sub(1,4000) 151 | .."\r\n----------------reply end------------------\r\n\r\n") 152 | if err then 153 | ngx.log(ngx.ERR, "failed to log reply: ", err) 154 | end 155 | end 156 | 157 | --Demo for login fail logging 158 | local function logAuthFail(context,source,failInfo,session) 159 | log(failInfo.username, 160 | "login fail, fail message: "..(failInfo.message or ""), 161 | session) 162 | end 163 | 164 | --Demo for self-defined authenticator 165 | local function authenticator(context,source,credential,session) 166 | local result=credential.username=="test" and credential.password=="test" 167 | local message=(not result) and "login with "..credential.username.." failed" 168 | return result,message 169 | end 170 | 171 | --Demo for auto response ldap search command, this Demo shows how to handle parser events 172 | local function ldap_SearchRequestHandler(context,src,p) 173 | if context.command:match("pleasechangeme") then 174 | local packets=require("suproxy.ldap.ldapPackets") 175 | local response=packets.SearchResultEntry:new() 176 | local done=packets.SearchResultDone:new() 177 | response.objectName="cn=admin,dc=www,dc=test,dc=com" 178 | response.messageId=p.messageId 179 | response.attributes={ 180 | {attrType="objectClass",values={"posixGroup","top"}}, 181 | {attrType="cn",values={"group"}}, 182 | {attrType="memberUid",values={"haha","test","test"}}, 183 | {attrType="gidNumber",values={"44789"}}, 184 | {attrType="description",values={"group"}} 185 | } 186 | done.resultCode=packets.ResultCode.success 187 | done.messageId=p.messageId 188 | response:pack() done:pack() 189 | context.channel:c2pSend(response.allBytes..done.allBytes) 190 | --stop forwarding 191 | p.allBytes="" 192 | end 193 | end 194 | 195 | --Demo for change the welcome info of ssh2 server 196 | local function myWelcome(context,source) 197 | local digger={"\r\n", 198 | [[ .-. ]].."\r\n", 199 | [[ / \ ]].."\r\n", 200 | [[ _____.....-----|(o) | ]].."\r\n", 201 | [[ _..--' _..--| .'' ]].."\r\n", 202 | [[ .' o _..--'' | | | ]].."\r\n", 203 | [[ / _/_..--'' | | | ]].."\r\n", 204 | [[ ________/ / / | | | ]].."\r\n", 205 | [[ | _ ____\ / / | | | ]].."\r\n", 206 | [[ _.-----._________|| || \\ / | | | ]].."\r\n", 207 | [[|=================||=||_____\\ |__|-' ]].."\r\n", 208 | [[| suproxy ||_||_____// (o\ | ]].."\r\n", 209 | [[|_________________|_________/ |-\| ]].."\r\n", 210 | [[ `-------------._______.----' / `. ]].."\r\n", 211 | [[ .,.,.,.,.,.,.,.,.,.,.,.,., / \]].."\r\n", 212 | [[ ((O) o o o o ======= o o(O)) ._.' /]].."\r\n", 213 | [[ `-.,.,.,.,.,.,.,.,.,.,.,-' `.......' ]].."\r\n", 214 | [[ scan me to login ]].."\r\n", 215 | "\r\n", 216 | } 217 | return table.concat(digger),false 218 | end 219 | 220 | local switch={} 221 | --dispatch different port to different channel 222 | --Demo for SSH2 processor 223 | switch[22]= function() 224 | local ssh=require("suproxy.ssh2"):new() 225 | ssh.AuthSuccessEvent:addHandler(ssh,logAuth) 226 | ssh.BeforeAuthEvent:addHandler(ssh,getCredential) 227 | ssh.OnAuthEvent:addHandler(ssh,authenticator) 228 | ssh.AuthFailEvent:addHandler(ssh,logAuthFail) 229 | local cmd=require("suproxy.ssh2.commandCollector"):new() 230 | cmd.CommandEnteredEvent:addHandler(ssh,commandFilter) 231 | cmd.CommandFinishedEvent:addHandler(ssh,logCmd) 232 | cmd.BeforeWelcomeEvent:addHandler(ssh,myWelcome) 233 | ssh.C2PDataEvent:addHandler(cmd,cmd.handleDataUp) 234 | ssh.S2PDataEvent:addHandler(cmd,cmd.handleDataDown) 235 | package.loaded.my_SSHB=package.loaded.my_SSHB or 236 | --change to your own upstreams 237 | require ("suproxy.balancer.balancer"):new{ 238 | --{ip="127.0.0.1",port=2222,id="local",gid="linuxServer"}, 239 | --{ip="192.168.46.128",port=22,id="remote",gid="linuxServer"}, 240 | --{ip="192.168.1.121",port=22,id="UBUNTU14",gid="testServer"}, 241 | {ip="192.168.1.152",port=22,id="UBUNTU20",gid="testServer"}, 242 | --{ip="192.168.1.103",port=22,id="SUSE11",gid="testServer"}, 243 | --{ip="192.168.1.186",port=22,id="OPENBSD",gid="testServer"}, 244 | --{ip="192.168.1.187",port=22,id="FreeBSD",gid="testServer"}, 245 | } 246 | local channel=require("suproxy.channel"):new(package.loaded.my_SSHB,ssh,{sessionMan=sessionManager}) 247 | channel.OnConnectEvent:addHandler(channel,logConnect) 248 | channel:run() 249 | end 250 | --Demo for TNS processor 251 | switch[1521]=function() 252 | --server version is required for password substitution 253 | local tns=require("suproxy.tns"):new{oracleVersion=11,swapPass=false} 254 | tns.AuthSuccessEvent:addHandler(tns,logAuth) 255 | tns.CommandEnteredEvent:addHandler(tns,commandFilter) 256 | tns.CommandFinishedEvent:addHandler(tns,logCmd) 257 | tns.AuthFailEvent:addHandler(tns,logAuthFail) 258 | tns.BeforeAuthEvent:addHandler(tns,getCredential) 259 | --tns.OnAuthEvent:addHandler(tns,authenticator) 260 | package.loaded.my_OracleB=package.loaded.my_OracleB or 261 | --change to your own upstreams 262 | require ("suproxy.balancer.balancer"):new{ 263 | {ip="192.168.1.96",port=1521,id="remote",gid="oracleServer"}, 264 | --{ip="192.168.46.157",port=1522,id="local",gid="oracleServer"}, 265 | --{ip="192.168.1.182",port=1521,id="182",gid="oracleServer"}, 266 | --{ip="192.168.1.190",port=1521,id="oracle10",gid="oracleServer"}, 267 | } 268 | local channel=require("suproxy.channel"):new(package.loaded.my_OracleB,tns,{sessionMan=sessionManager}) 269 | channel.OnConnectEvent:addHandler(channel,logConnect) 270 | channel:run() 271 | end 272 | --Demo for LDAP processor 273 | switch[389]=function() 274 | local ldap=require("suproxy.ldap"):new() 275 | ldap.AuthSuccessEvent:addHandler(ldap,logAuth) 276 | ldap.AuthFailEvent:addHandler(ldap,logAuthFail) 277 | ldap.CommandEnteredEvent:addHandler(ldap,commandFilter) 278 | ldap.CommandFinishedEvent:addHandler(ldap,logCmd) 279 | ldap.BeforeAuthEvent:addHandler(ldap,getCredential) 280 | ldap.OnAuthEvent:addHandler(ldap,authenticator) 281 | ldap.c2pParser.events.SearchRequest:addHandler(ldap,ldap_SearchRequestHandler) 282 | --change to your own upstreams 283 | local channel=require("suproxy.channel"):new({{ip="192.168.46.128",port=389,id="ldap1",gid="ldapServer"}},ldap,{sessionMan=sessionManager}) 284 | channel.OnConnectEvent:addHandler(channel,logConnect) 285 | channel:run() 286 | end 287 | --Demo for TDS processor 288 | switch[1433]=function() 289 | local tds=require("suproxy.tds"):new({disableSSL=false,catchReply=true}) 290 | tds.AuthSuccessEvent:addHandler(tds,logAuth) 291 | tds.CommandEnteredEvent:addHandler(tds,commandFilter) 292 | tds.CommandFinishedEvent:addHandler(tds,logCmd) 293 | tds.BeforeAuthEvent:addHandler(tds,getCredential) 294 | tds.OnAuthEvent:addHandler(tds,authenticator) 295 | tds.AuthFailEvent:addHandler(tds,logAuthFail) 296 | package.loaded.my_SQLServerB=package.loaded.my_SQLServerB or 297 | --change to your own upstreams 298 | require ("suproxy.balancer.balancer"):new{ 299 | {ip="192.168.1.135",port=1433,id="srv12",gid="sqlServer"}, 300 | --{ip="192.168.1.120",port=1433,id="srv14",gid="sqlServer"} 301 | } 302 | local channel=require("suproxy.channel"):new(package.loaded.my_SQLServerB,tds,{sessionMan=sessionManager}) 303 | channel.OnConnectEvent:addHandler(channel,logConnect) 304 | channel:run() 305 | end 306 | 307 | local fSwitch = switch[tonumber(ngx.var.server_port)] 308 | if fSwitch then 309 | fSwitch() 310 | end 311 | 312 | -------------------------------------------------------------------------------- /example/session.lua: -------------------------------------------------------------------------------- 1 | --[[ This Demo shows how to manage session on http modify nginx config, To test it change the redis ip and port setting and add following section to your config file http { include mime.types; lua_code_cache off; server { listen 80; server_name localhost; default_type text/html; location /suproxy/manage{ content_by_lua_file lualib/suproxy/example/session.lua; } } ]] local utils= require "suproxy.utils.utils" local cjson=require "cjson" local redisIP="127.0.0.1" local port=6379 local sessionMan=require ("suproxy.session.sessionManager"):new{ip=redisIP,port=port,expire=30} if not sessionMan then ngx.say("connect to Redis "..redisIP..":"..port.." failed" ) return end 2 | 3 | if ngx.var.request_uri:match("/suproxy/manage/session/kill") then 4 | local sid=utils.getArgsFromRcequest("sid") 5 | local uid=utils.getArgsFromRequest("uid") local result 6 | if not sid and not uid then result=sessionMan:clear() 7 | elseif sid then 8 | result=sessionMan:kill(sid) 9 | else 10 | result=sessionMan:killSessionOfUser(uid) 11 | end 12 | ngx.say(result.." items is removed") 13 | elseif ngx.var.request_uri:match("/suproxy/manage/session/get") then 14 | local sid=utils.getArgsFromRequest("sid") 15 | local uid=utils.getArgsFromRequest("uid") 16 | if not sid and not uid then ngx.say("valid sid or uid should be provided") return end 17 | local result 18 | if sid then 19 | result=sessionMan:get(sid) 20 | else 21 | result=sessionMan:getSessionOfUser(uid) 22 | result=cjson.encode(result) 23 | end 24 | ngx.say(result) 25 | elseif ngx.var.request_uri:match("/suproxy/manage/session/all") then 26 | local result,count=sessionMan:getAll() ngx.say("count:"..count) 27 | ngx.say(cjson.encode(result)) elseif ngx.var.request_uri:match("/suproxy/manage/session/clear") then local result=sessionMan:clear() ngx.say(result.." items is removed") 28 | end 29 | 30 | 31 | -------------------------------------------------------------------------------- /http/handlers.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (C) Joey Zhu 2 | local _M = {_VERSION="0.1.11"} 3 | local cjson=require("cjson") 4 | local utils=require("suproxy.utils.utils") 5 | local ssoProcessors=require("suproxy.http.ssoProcessors") 6 | local zlib = require('suproxy.utils.ffi-zlib') 7 | 8 | --加载内容 9 | function _M.loadUrl(url) 10 | res = ngx.location.capture(url,{method=ngx.HTTP_GET}) 11 | return res.body 12 | end 13 | 14 | function _M.loadToolBar(localParams) 15 | local request_uri = ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port..ngx.var.request_uri 16 | local loginPath=utils.addParamToUrl(localParams.loginUrl,"appCode",localParams.appCode) 17 | if request_uri~=nil then 18 | --告诉登录处理器需要返回的地址 19 | local ctx={method=ngx.var.request_method,targetURL=request_uri} 20 | --构建其他想要登录处理器回传的参数 21 | loginPath=utils.addParamToUrl(loginPath,"context",ngx.encode_base64(cjson.encode(ctx))) 22 | end 23 | --status,body,err=utils.jget(ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port.."/auth/toolbar.html") 24 | body=_M.loadUrl("/auth/toolbar.html") 25 | local username=cjson.decode(ngx.var.userdata).user 26 | return string.gsub(string.gsub(body, "{{username}}", username),"{{loginUrl}}",loginPath) 27 | end 28 | 29 | 30 | 31 | --替换返回值内容 32 | function _M.replaceResponse(regex,replacement) 33 | _M.replaceResponseMutiple({{regex=regex,replacement=replacement}}) 34 | end 35 | 36 | --替换返回值内容,同时替换多值 37 | function _M.replaceResponseMutiple(subs) 38 | local chunk, eof = ngx.arg[1], ngx.arg[2] 39 | local buffered = ngx.ctx.buffered 40 | if not buffered then 41 | buffered = {} -- XXX we can use table.new here 42 | ngx.ctx.buffered = buffered 43 | end 44 | if chunk ~= "" then 45 | buffered[#buffered + 1] = chunk 46 | ngx.arg[1] = nil 47 | end 48 | if eof then 49 | local whole = table.concat(buffered) 50 | ngx.ctx.buffered = nil 51 | -- try to unzip 52 | if ngx.var.upstreamEncoding=="gzip" then 53 | local debody = utils.unzip(whole) 54 | if debody then 55 | whole = debody 56 | end 57 | end 58 | -- try to add or replace response body 59 | -- local js_code = ... 60 | -- whole = whole .. js_code 61 | for i,v in ipairs(subs) do 62 | whole = string.gsub(whole, v.regex, v.replacement) 63 | end 64 | ngx.arg[1] = whole 65 | end 66 | end 67 | 68 | --检查权限 69 | function _M.accessCheck(localParams) 70 | local request_uri = ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port..ngx.var.request_uri 71 | ngx.var.callbackmethod=ngx.var.request_method 72 | if ngx.var.http_referer~=nil then 73 | ngx.var.referer=string.gsub(ngx.var.http_referer,ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port, ngx.var.backaddress) 74 | end 75 | 76 | local session = require "resty.session".open() 77 | local processor=ssoProcessors.getProcessor(localParams.ssoParam) 78 | if processor==nil then 79 | utils.error("can't find processor for sso protocal "..localParams.ssoParam.ssoProtocol,nil,500) 80 | return 81 | end 82 | --检查请求是否单点登录验证 83 | local result=processor:checkRequest() 84 | -- 不是单点验证请求 85 | if result.status==ssoProcessors.CHECK_STATUS.NOT_SSO_CHECK then 86 | --如果开启登录验证,则验证session 87 | ngx.log(ngx.DEBUG,"sessionid:"..ngx.encode_base64(session.id)) 88 | if localParams.checkLogin then 89 | ngx.log(ngx.DEBUG,"checkLogin:true"..request_uri) 90 | --session 已经存在 91 | if session.present then 92 | 93 | ngx.var.userdata=cjson.encode(session.data.user) 94 | return 95 | --session 不存在,直接跳转 96 | else 97 | local loginPath=utils.addParamToUrl(localParams.loginUrl,"appCode",localParams.appCode) 98 | if request_uri~=nil then 99 | --告诉登录处理器需要返回的地址 100 | local ctx={method=ngx.var.request_method,targetURL=request_uri} 101 | --构建其他想要登录处理器回传的参数 102 | loginPath=utils.addParamToUrl(loginPath,"context",ngx.encode_base64(cjson.encode(ctx))) 103 | end 104 | ---[[ 105 | if processor.formatLoginPath then 106 | loginPath=processor:formatLoginPath(loginPath) 107 | end 108 | --]] 109 | ngx.log(ngx.DEBUG,"loginPath:"..loginPath) 110 | ngx.redirect(loginPath) 111 | return 112 | end 113 | --未开启登录验证,直接报错 114 | else 115 | ngx.log(ngx.DEBUG,"checkLogin:false") 116 | utils.error(result.message,nil,ngx.HTTP_UNAUTHORIZED) 117 | return 118 | end 119 | end 120 | result=processor:valiate() 121 | if result.status==ssoProcessors.CHECK_STATUS.SUCCESS then 122 | -- 验证成功,获取user并建立session 123 | ngx.var.userdata=cjson.encode(result.accountData) 124 | local context=utils.getArgsFromRequest("context") 125 | local contextJson=nil 126 | if context~=nil then 127 | contextJson=cjson.decode(ngx.decode_base64(context)) 128 | else 129 | contextJson=localParams.defaultContext 130 | end 131 | 132 | if contextJson.method~=nil then 133 | ngx.var.callbackmethod=contextJson.method 134 | end 135 | session:start() 136 | session.data.user=result.accountData 137 | session:save() 138 | ngx.log(ngx.INFO, "Session Started -- " .. ngx.encode_base64(session.id)) 139 | --若context中返回的最初访问的页面和当前页面不同,则直接跳转到context中指定的页面 140 | if contextJson.targetURL~=nil then 141 | --[[ --删除url中由登录处理程序附加的参数如context和xtoken参数再进行url比较 142 | if(ngx.req.get_uri_args()["context"]) then 143 | request_uri=utils.removeParamFromUrl(request_uri,"context") 144 | end 145 | if(ngx.req.get_uri_args()["x_token"]) then 146 | request_uri=utils.removeParamFromUrl(request_uri,"x_token") 147 | end 148 | if(ngx.req.get_uri_args()["ticket"]) then 149 | request_uri=utils.removeParamFromUrl(request_uri,"ticket") 150 | end 151 | if(ngx.req.get_uri_args()["params"]) then 152 | request_uri=utils.removeParamFromUrl(request_uri,"params") 153 | end 154 | --]] 155 | if contextJson.targetURL~=request_uri then 156 | return ngx.redirect(contextJson.targetURL) 157 | end 158 | end 159 | ngx.log(ngx.INFO, "Jump to targetURL -- " .. contextJson.targetURL) 160 | return 161 | elseif result.status==ssoProcessors.CHECK_STATUS.AUTH_FAIL then 162 | -- 验证失败 163 | utils.error(result.message,"Status:"..result.status.." Message:"..result.message,500) 164 | return 165 | else 166 | utils.error(result.message,"Status:"..result.status.." Message:"..result.message,500) 167 | return 168 | end 169 | end 170 | 171 | return _M 172 | 173 | -------------------------------------------------------------------------------- /http/initParam.lua: -------------------------------------------------------------------------------- 1 | local cjson=require("cjson") local Global_Params = { 2 | appCode ="gate", 3 | --loginUrl="/auth/ssologin.html", 4 | loginUrl="/auth/mocklogin.html", 5 | defaultContext={ 6 | targetURL="/", 7 | method="GET" 8 | }, 9 | ---[[ 10 | ssoParam={ 11 | ssoProtocol="JWT", 12 | secret="lua-resty-jwt" 13 | } 14 | --]] 15 | --[[ 16 | ssoParam={ 17 | ssoProtocol="CAS", 18 | validate_url="xxxxxxxxxxxxxx", 19 | service="xxxxxxxxxxxx" 20 | } 21 | --]] 22 | --[[ 23 | ssoParam={ 24 | ssoProtocol="OAUTH", 25 | validate_code_url="xxxxxxxxxxxxx", 26 | profile_url="xxxxxxxxxxxxxxx", 27 | callbackurl="xxxxxxxxxxx", 28 | client_secret="xxxxxxxxxxxxx" 29 | } 30 | --]] 31 | } 32 | Global_Params.ssoParam.appCode=Global_Params.appCode 33 | Global_Params.ssoParam.loginUrl=Global_Params.loginUrl 34 | return cjson.encode(Global_Params) 35 | -------------------------------------------------------------------------------- /http/mockauth.lua: -------------------------------------------------------------------------------- 1 | local jwt = require "resty.jwt" local utils= require "suproxy.utils.utils" local cjson=require("cjson") 2 | 3 | --获取表单里的用户名密码,这里可以进行身份验证 local args = nil 4 | if "GET" == ngx.var.request_method then 5 | args = ngx.req.get_uri_args() 6 | elseif "POST" == ngx.var.request_method then 7 | ngx.req.read_body() 8 | args = ngx.req.get_post_args() 9 | end local username=args["name"] local password=args["pass"] 10 | if username == nil or password==nil then 11 | ngx.log(ngx.WARN, "Username and Password can not be empty ") 12 | ngx.exit(ngx.HTTP_UNAUTHORIZED) 13 | end 14 | ----模拟身份验证begin---- 15 | 16 | if username~="admin1" or password~="aA123." then 17 | ngx.log(ngx.WARN, "Wrong Username or Password ") 18 | ngx.exit(ngx.HTTP_UNAUTHORIZED) 19 | end 20 | 21 | ----模拟身份验证end------ 22 | 23 | --模拟签发jwt token local jwt_token = jwt:sign( 24 | "lua-resty-jwt", 25 | { 26 | header={typ="JWT", alg="HS256"}, 27 | payload={accountName="admin1"} 28 | } 29 | ) 30 | --获取跳转变量 local redirectUrl=nil local context=ngx.req.get_uri_args()["context"] local jsonContext=cjson.decode(ngx.decode_base64(context)) 31 | if jsonContext~=nil and jsonContext["targetURL"]~=nil then redirectUrl=jsonContext["targetURL"] else redirectUrl="/gateway/callback" end 32 | 33 | redirectUrl="/gateway/callback" 34 | 35 | --参数使用Post方式传递 local template=[[ 36 |
37 | 38 | 39 |
40 | 41 | ]] 42 | ngx.say(string.format(template,redirectUrl,jwt_token,context)) 43 | 44 | 45 | --[[ 46 | --参数使用get redirect方式传递 47 | redirectUrl=utils.addParamToUrl(redirectUrl,"context",ngx_decode_base64(context)) 48 | redirectUrl=utils.addParamToUrl(redirectUrl,"x_token",jwt_token) 49 | ngx.redirect(redirectUrl) 50 | --]] -------------------------------------------------------------------------------- /http/ssoProcessors.lua: -------------------------------------------------------------------------------- 1 | local utils = require "suproxy.utils.utils" local jwt = require "resty.jwt" local cjson=require("cjson") local _M = {_VERSION="0.1.11"} 2 | 3 | _M.CHECK_STATUS = { 4 | SUCCESS=0, 5 | NOT_SSO_CHECK=1, 6 | AUTH_FAIL=2, 7 | UNKNOWN_SSO_PROTOCAL=3 8 | } 9 | 10 | ------JWT Processor------------ local JWTProcessor={} 11 | 12 | function JWTProcessor.new(ssoParam) 13 | JWTProcessor.ssoParam=ssoParam 14 | return JWTProcessor 15 | end 16 | 17 | function JWTProcessor:checkRequest() 18 | local x_token=utils.getArgsFromRequest("x_token") 19 | --判断token是否传递 若没有传递,则不是JWT验证请求 20 | if x_token == nil then 21 | return {status=_M.CHECK_STATUS.NOT_SSO_CHECK,message="Not valid sso auth request"} 22 | end 23 | return {status=_M.CHECK_STATUS.SUCCESS} 24 | end 25 | 26 | function JWTProcessor:valiate() 27 | -- 验证token签名,如果验证错误,直接提示错误信息 28 | local x_token=utils.getArgsFromRequest("x_token") 29 | local jwt_obj = jwt:verify(self.ssoParam.secret,x_token) 30 | if jwt_obj.verified == false then 31 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message="Invalid token: ".. jwt_obj.reason} 32 | end 33 | return {status=_M.CHECK_STATUS.SUCCESS,accountData={user=jwt_obj.payload.accountName,attributes=jwt_obj.payload}} 34 | end 35 | ------JWT Processor end------------ 36 | ------CAS Processor------------ local CASProcessor={} 37 | 38 | function CASProcessor.new(ssoParam) 39 | CASProcessor.ssoParam=ssoParam 40 | return CASProcessor 41 | end 42 | 43 | function CASProcessor:checkRequest() 44 | local ticket=utils.getArgsFromRequest("ticket"); 45 | --判断ticket是否传递 若没有传递,则不是CAS验证请求 46 | if ticket == nil then 47 | return {status=_M.CHECK_STATUS.NOT_SSO_CHECK} 48 | end 49 | return {status=_M.CHECK_STATUS.SUCCESS} 50 | end 51 | local function cas_ticket_verify(validate_url,service,ticket) 52 | -----网络请求验证begin------ 53 | local payload = { 54 | service =service, 55 | ticket = ticket 56 | } 57 | local status, body, err = utils.jget(validate_url, ngx.encode_args(payload)) 58 | 59 | if not status or status ~= 200 then 60 | return {success=false,message=err} 61 | end 62 | 63 | local decodeResponse=cjson.decode(body) 64 | 65 | if decodeResponse==nil or decodeResponse.serviceResponse==nil then 66 | return {success=false,"response format wrong, can't be parse to json"} 67 | end 68 | 69 | if decodeResponse.serviceResponse.authenticationFailure then 70 | return {success=false, 71 | string.format("ticket validation failure, code:%s description:%s", 72 | decodeResponse.serviceResponse.authenticationFailure.code, 73 | decodeResponse.serviceResponse.authenticationFailure.description) 74 | } 75 | end 76 | 77 | if decodeResponse.serviceResponse.authenticationSuccess then 78 | return {success=true,data=decodeResponse.serviceResponse.authenticationSuccess} 79 | end 80 | 81 | ------网络请求验证身份end---- 82 | end 83 | 84 | function CASProcessor:valiate() 85 | local ticket=utils.getArgsFromRequest("ticket"); 86 | -- 验证token签名,如果验证错误,直接提示错误信息 87 | local cas_result = cas_ticket_verify(self.ssoParam.validate_url,self.ssoParam.service,ticket)--注意如果使用ngiam sso,则需要修改秘钥与应用中的一致 88 | if not cas_result.success then 89 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message="Invalid ticket: ".. cas_result.message} 90 | end 91 | 92 | return {status=_M.CHECK_STATUS.SUCCESS,accountData={user=cas_result.data.user,attributes=cas_result.data}} 93 | end 94 | ------CAS Processor end------------ 95 | 96 | ------OAUTH2.0 Processor------------ local OAUTHProcessor={} 97 | 98 | function OAUTHProcessor.new(ssoParam) 99 | OAUTHProcessor.ssoParam=ssoParam 100 | return OAUTHProcessor 101 | end 102 | 103 | function OAUTHProcessor:checkRequest(self) 104 | local code=utils.getArgsFromRequest("code") 105 | --判断code是否传递 若没有传递,则不是OAUTH验证请求 106 | if code == nil then 107 | return {status=_M.CHECK_STATUS.NOT_SSO_CHECK,message="Not valid sso auth request"} 108 | end 109 | return {status=_M.CHECK_STATUS.SUCCESS} 110 | end 111 | 112 | function OAUTHProcessor:get_token(code) 113 | local payload = { 114 | appcode = self.ssoParam.appCode, 115 | secret = self.ssoParam.client_secret, 116 | code = code 117 | } 118 | local status, body, err = utils.jpost(self.ssoParam.validate_code_url,ngx.encode_args(payload)) 119 | if not status or status ~= 200 then 120 | return {success=false,message=err} 121 | else 122 | local result=cjson.decode(body) 123 | if result.errorCode then 124 | return {success=false,message=body} 125 | end 126 | 127 | return {success=true,token=result.accessToken} 128 | end 129 | end 130 | 131 | function OAUTHProcessor:get_profile(token) 132 | local status, body, err = utils.jget( 133 | self.ssoParam.profile_url, 134 | ngx.encode_args({ 135 | appcode = self.ssoParam.appCode, 136 | secret = self.ssoParam.client_secret, 137 | token= token 138 | })) 139 | if not status or status ~= 200 then 140 | return {success=false,message=err} 141 | else 142 | ngx.log(ngx.DEBUG,body) 143 | local result=cjson.decode(body) 144 | if result.errorCode then 145 | return {success=false,message=body} 146 | end 147 | return {success=true,profile=result} 148 | end 149 | end 150 | 151 | function OAUTHProcessor:valiate(code) 152 | -- 验证token签名,如果验证错误,直接提示错误信息 153 | if not code then 154 | code=utils.getArgsFromRequest("code") 155 | end 156 | 157 | local result = self:get_token(code) 158 | if result.success==false then 159 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message=result.message} 160 | end 161 | local token=result.token 162 | result = self:get_profile(token) 163 | if result.success==false then 164 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message=result.message} 165 | end 166 | return {status=_M.CHECK_STATUS.SUCCESS,accountData={user=result.profile.accountName,attributes=result.profile}} 167 | 168 | end 169 | 170 | function OAUTHProcessor:formatLoginPath(loginUrl) 171 | return loginUrl 172 | end 173 | ------JWT Processor end------------ 174 | 175 | local processorList = { 176 | ["JWT"] = function(param) 177 | return JWTProcessor.new(param) 178 | end, 179 | ["CAS"] = function(param) 180 | return CASProcessor.new(param) 181 | end, 182 | ["OAUTH"] =function(param) 183 | return OAUTHProcessor.new(param) 184 | end 185 | } 186 | 187 | function _M.getProcessor(ssoParam) 188 | local p= processorList[ssoParam.ssoProtocol] 189 | if p~=nil then 190 | return p(ssoParam) 191 | else 192 | return nil 193 | end 194 | end 195 | 196 | return _M -------------------------------------------------------------------------------- /ldap.lua: -------------------------------------------------------------------------------- 1 | local asn1 = require("suproxy.utils.asn1") local format = string.format local ok,cjson=pcall(require,"cjson") 2 | if not ok then cjson = require("suproxy.utils.json") end 3 | local logger=require "suproxy.utils.compatibleLog" local bunpack = asn1.bunpack local fmt = string.format 4 | require "suproxy.utils.stringUtils" 5 | local tableUtils=require "suproxy.utils.tableUtils" local pureluapack=require"suproxy.utils.pureluapack" 6 | local event=require "suproxy.utils.event" 7 | local ldapPackets=require "suproxy.ldap.ldapPackets" local ResultCode=ldapPackets.ResultCode 8 | local _M = {} 9 | _M._PROTOCAL ='ldapv3' 10 | local encoder,decoder=asn1.ASN1Encoder:new(),asn1.ASN1Decoder:new() 11 | function _M:new() 12 | local o= setmetatable({},{__index=self}) 13 | o.ctx={} 14 | o.AuthSuccessEvent=event:new(o,"AuthSuccessEvent") 15 | o.AuthFailEvent=event:new(o,"AuthFailEvent") 16 | o.BeforeAuthEvent=event:newReturnEvent(o,"BeforeAuthEvent") o.OnAuthEvent=event:newReturnEvent(o,"OnAuthEvent") 17 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent") 18 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent") 19 | o.ContextUpdateEvent=event:new(o,"ContextUpdateEvent") local ldapParser=require ("suproxy.ldap.parser"):new() 20 | o.c2pParser=ldapParser.C2PParser 21 | o.s2pParser=ldapParser.S2PParser 22 | o.c2pParser.events.SearchRequest:setHandler(o,_M.SearchRequestHandler) 23 | o.c2pParser.events.BindRequest:setHandler(o,_M.BindRequestHandler) 24 | o.c2pParser.events.UnbindRequest:setHandler(o,_M.UnbindRequestHandler) 25 | o.s2pParser.events.BindResponse:setHandler(o,_M.BindResponseHandler) 26 | o.s2pParser.events.SearchResultEntry:setHandler(o,_M.SearchResultEntryHandler) 27 | o.s2pParser.events.SearchResultDone:setHandler(o,_M.SearchResultDoneHandler) 28 | return o 29 | end 30 | ----------------parser event handlers---------------------- 31 | function _M:SearchRequestHandler(src,p) 32 | local cstr=cjson.encode{baseObject=p.baseObject,scope=p.scope,filter=p.filter} 33 | if self.CommandEnteredEvent:hasHandler() then 34 | local cmd,err=self.CommandEnteredEvent:trigger(cstr,self.ctx) 35 | if err then p.allBytes=nil return end 36 | if cmd.command~=p.filter then 37 | --todo: modify the filter 38 | end 39 | end 40 | self.command=cstr 41 | end 42 | 43 | function _M:BindRequestHandler(src,p) 44 | local cred 45 | if self.BeforeAuthEvent:hasHandler() then 46 | cred=self.BeforeAuthEvent:trigger({username=p.username,password=p.password},self.ctx) 47 | end if self.OnAuthEvent:hasHandler() then local ok,message,cred=self.OnAuthEvent:trigger({username=p.username,password=p.password},self.ctx) if not ok then local resp=ldapPackets.BindResponse:new({ messageId=p.messageId, resultCode=ResultCode.invalidCredentials }):pack() self.channel:c2pSend(resp.allBytes) p.allBytes=nil return end end 48 | if cred and (p.username~=cred.username or p.password~=cred.password) then 49 | p.username=cred.username 50 | p.password=cred.password 51 | p:pack() 52 | end 53 | self.ctx.username=p.username 54 | if self.ContextUpdateEvent:hasHandler() then 55 | self.ContextUpdateEvent:trigger(self.ctx) 56 | end 57 | end 58 | 59 | function _M:BindResponseHandler(src,p) 60 | if p.resultCode==ResultCode.success then 61 | if self.AuthSuccessEvent:hasHandler() then 62 | self.AuthSuccessEvent:trigger(self.ctx.username,self.ctx) 63 | end 64 | else 65 | if self.AuthFailEvent:hasHandler() then 66 | self.AuthFailEvent:trigger({username=self.ctx.username,message="fail code: "..tostring(p.resultCode)},self.ctx) 67 | end 68 | end 69 | end 70 | 71 | function _M:SearchResultEntryHandler(src,p) 72 | self.reply=(self.reply or "")..cjson.encode({p.objectName,p.attributes}).."\r\n" 73 | end 74 | 75 | function _M:SearchResultDoneHandler(src,p) 76 | if self.CommandFinishedEvent:hasHandler() then 77 | self.CommandFinishedEvent:trigger(self.command,self.reply,self.ctx) 78 | end 79 | self.reply="" 80 | end 81 | 82 | function _M:UnbindRequestHandler(src,p) 83 | ngx.exit(0) 84 | end 85 | 86 | function _M:recv(readMethod) 87 | logger.log(logger.DEBUG,"start processRequest") 88 | local lengthdata,err = readMethod(self.channel,2) 89 | if(err) then 90 | logger.log(logger.ERR,"err when reading length") 91 | return nil,err 92 | end 93 | local length=("B"):unpack(lengthdata,2) 94 | local len_len=length-128 95 | local realLengthData="" 96 | if len_len>0 then 97 | realLengthData,err=readMethod(self.channel,len_len) 98 | if(err) then 99 | logger.log(logger.ERR,"err when reading real length") 100 | return nil,err 101 | end 102 | length=(">I"..len_len):unpack(realLengthData) 103 | end 104 | local payloadBytes,err = readMethod(self.channel,length) 105 | local allBytes=lengthdata..realLengthData..payloadBytes 106 | if(err) then 107 | logger.log(logger.ERR,"err when reading packet") 108 | return nil,err 109 | end 110 | return allBytes 111 | end ----------------implement processor methods--------------------- 112 | function _M.processUpRequest(self) 113 | local allBytes,err=self:recv(self.channel.c2pRead) if err then return nil,err end 114 | local p=self.c2pParser:parse(allBytes) 115 | return p.allBytes 116 | end 117 | 118 | function _M.processDownRequest(self) 119 | local allBytes,err=self:recv(self.channel.p2sRead) if err then return nil,err end 120 | local p=self.s2pParser:parse(allBytes) 121 | return p.allBytes 122 | end 123 | 124 | function _M:sessionInvalid(session) 125 | ngx.exit(0) 126 | end 127 | 128 | return _M; 129 | -------------------------------------------------------------------------------- /ldap/ldapPackets.lua: -------------------------------------------------------------------------------- 1 | local asn1 = require("suproxy.utils.asn1") 2 | local format = string.format 3 | local ok,cjson=pcall(require,"cjson") 4 | if not ok then cjson = require("suproxy.utils.json") end 5 | local logger=require "suproxy.utils.compatibleLog" 6 | local bunpack = asn1.bunpack 7 | local fmt = string.format 8 | require "suproxy.utils.stringUtils" 9 | require "suproxy.utils.pureluapack" 10 | local event=require "suproxy.utils.event" 11 | local tableUtils=require "suproxy.utils.tableUtils" 12 | local extends=tableUtils.extends 13 | local encoder,decoder=asn1.ASN1Encoder:new(),asn1.ASN1Decoder:new() 14 | local orderTable=tableUtils.OrderedTable 15 | local _M={} 16 | 17 | local APPNO = { 18 | BindRequest=0, BindResponse=1, UnbindRequest=2, SearchRequest = 3, 19 | SearchResultEntry=4,SearchResultDone=5, ModifyRequest=6, ModifyResponse = 7 , 20 | AddRequest=8, AddResponse=9, DelRequest=10, DelResponse = 11 , 21 | ModifyDNRequest=12, ModifyDNResponse=13,CompareRequest=14, CompareResponse =15, 22 | AbandonRequest =16, ExtendedRequest=23, ExtendedResponse=24,IntermediateResponse =25 23 | } 24 | _M.APPNO=APPNO 25 | 26 | local ResultCode = { 27 | success = 0, operationsError =1, protocolError =2, timeLimitExceeded =3, 28 | sizeLimitExceeded =4, compareFalse =5, compareTrue =6, authMethodNotSupported =7, 29 | strongerAuthRequired =8, --[[ 9 reserved --]] referral =10, adminLimitExceeded =11, 30 | unavailableCriticalExtension =12, confidentialityRequired =13, saslBindInProgress =14, noSuchAttribute =16, 31 | undefinedAttributeType =17, inappropriateMatching =18, constraintViolation =19, attributeOrValueExists =20, 32 | invalidAttributeSyntax =21, noSuchObject =32, aliasProblem =33, invalidDNSyntax =34, 33 | --[[ 35 reserved for isLeaf --]] aliasDereferencingProblem =36, --[[ 37-47 unused --]] inappropriateAuthentication =48, 34 | invalidCredentials =49, insufficientAccessRights =50, busy =51, unavailable =52, 35 | unwillingToPerform =53, loopDetect =54, --[[ 55-63 unused --]] namingViolation =64, 36 | objectClassViolation =65, notAllowedOnNonLeaf =66, notAllowedOnRDN =67, entryAlreadyExists =68, 37 | objectClassModsProhibited =69, --[[ 70 reserved for CLDAP --]] affectsMultipleDSAs =71, --[[ 72-79 unused --]] 38 | other= 80 39 | } 40 | _M.ResultCode=ResultCode 41 | local function encodeLDAPOp(encoder, appno, isConstructed, data) 42 | local asn1_type = asn1.BERtoInt(asn1.BERCLASS.Application, isConstructed, appno) 43 | return encoder:encode( data,asn1_type) 44 | end 45 | --TODO???: filter type 4 and 9 not processed, switch must be used to replace if 46 | local function decodeFilter(packet,pos) 47 | 48 | local newpos=pos 49 | 50 | local filter="(" 51 | 52 | --get filter type 53 | local newpos, tmp = bunpack(packet, "B", newpos) 54 | 55 | local field,condition 56 | 57 | local ftype = asn1.intToBER(tmp).number 58 | 59 | local newpos,flen=decoder.decodeLength(packet,newpos); 60 | 61 | local elenlen=newpos-pos-1 62 | 63 | logger.log(logger.DEBUG,"element:-----------------\n"..string.hex(packet,pos,pos+flen+elenlen,4,8,nil,nil,1,1)) 64 | 65 | logger.log(logger.DEBUG,"filter type:"..ftype) 66 | 67 | logger.log(logger.DEBUG,"filter length:"..flen) 68 | 69 | logger.log(logger.DEBUG,"data:-----------------\n"..string.hex(packet,newpos,newpos+flen-1,4,8,nil,nil,1,1)) 70 | 71 | --0 and 1 or 2 not 72 | if ftype<3 then 73 | if ftype==0 then 74 | filter=filter.."&" 75 | end 76 | if ftype==1 then 77 | filter=filter.."||" 78 | end 79 | if ftype==2 then 80 | filter=filter.."!" 81 | end 82 | local lp=newpos 83 | while(newpos-lp=" 123 | 124 | newpos, condition=decoder:decode(packet,newpos) 125 | 126 | logger.log(logger.DEBUG,"condition:"..condition) 127 | 128 | filter=filter..condition 129 | 130 | logger.log(logger.DEBUG,"filter5:"..filter) 131 | end 132 | 133 | --less or equal 134 | if ftype==6 then 135 | 136 | newpos, field=decoder:decode(packet,newpos); 137 | 138 | logger.log(logger.DEBUG,"field:"..field) 139 | 140 | filter=filter..field 141 | 142 | filter=filter.."<=" 143 | 144 | newpos, condition=decoder:decode(packet,newpos) 145 | 146 | logger.log(logger.DEBUG,"condition:"..condition) 147 | 148 | filter=filter..condition 149 | 150 | logger.log(logger.DEBUG,"filter6:"..filter) 151 | end 152 | 153 | 154 | --present 155 | if ftype==7 then 156 | newpos,tmp=bunpack(packet, "c" .. flen, newpos) 157 | filter=filter..tmp.."=*" 158 | logger.log(logger.DEBUG,"filter7:"..filter) 159 | end 160 | 161 | filter=filter..")" 162 | return newpos,filter 163 | end 164 | 165 | 166 | --parse and pack common header from or to bytes 167 | --common headers include :length messageId,opCode 168 | _M.Packet={ 169 | desc="base", 170 | parseHeader=function(self,allBytes,pos) 171 | local _,pos = ("B"):unpack(allBytes,pos) 172 | pos,self.length= decoder.decodeLength(allBytes, pos) 173 | pos,self.messageId = decoder:decode(allBytes, pos) 174 | local pos,tmp = bunpack(allBytes, "B", pos) 175 | local pos,l= decoder.decodeLength(allBytes, pos) 176 | self.opCode = asn1.intToBER(tmp).number 177 | return pos 178 | end, 179 | 180 | parse=function(self,allBytes,pos) 181 | local pos=self.parseHeader(self,allBytes,pos) 182 | self.parsePayload(self,allBytes,pos) 183 | self.allBytes=allBytes 184 | return self 185 | end, 186 | 187 | parsePayload=function(self,allBytes,pos) 188 | end, 189 | 190 | pack=function(self) 191 | local payloadBytes=self:packPayload() 192 | local allBytes=encoder:encodeSeq(encoder:encode(self.messageId) .. encodeLDAPOp(encoder, self.opCode,true,payloadBytes)) 193 | logger.logWithTitle(logger.DEBUG,"packing",allBytes:hex16F()) 194 | self.allBytes=allBytes 195 | return self 196 | end, 197 | 198 | new=function(self,o) 199 | local o=o or {} 200 | return orderTable.new(self,o) 201 | end 202 | } 203 | 204 | _M.BindRequest={ 205 | opCode= APPNO.BindRequest, 206 | desc="BindRequest", 207 | parsePayload=function(self,payload,pos) 208 | pos,self.version = decoder:decode(payload,pos) 209 | logger.log(logger.DEBUG,"version:"..self.version ) 210 | pos,self.username = decoder:decode(payload,pos) 211 | logger.log(logger.DEBUG,"username:"..self.username ) 212 | pos,self.password = decoder:decode(payload,pos) 213 | if self.username=="" then 214 | logger.log(logger.DEBUG,"anonymous login") 215 | elseif self.password=="" then 216 | logger.log(logger.DEBUG,"unauthorized login") 217 | end 218 | return self 219 | end, 220 | 221 | packPayload=function(self) 222 | local payloadBytes=encoder:encode(self.version)..encoder:encode(self.username)..encoder:encode(self.password,"simplePass") 223 | return payloadBytes 224 | end 225 | } 226 | extends(_M.BindRequest,_M.Packet) 227 | 228 | _M.BindResponse={ 229 | opCode= APPNO.BindResponse, 230 | desc="BindResponse", 231 | parsePayload=function(self,payload,pos) 232 | pos,self.resultCode=decoder:decode(payload,pos) 233 | return self 234 | end, 235 | packPayload=function(self) 236 | local payloadBytes=encoder:encode(self.resultCode,"enumerated") .. encoder:encode('') .. encoder:encode('') 237 | return payloadBytes 238 | end 239 | } 240 | extends(_M.BindResponse,_M.Packet) 241 | 242 | --UnbindRequest 243 | _M.UnbindRequest=extends({opCode=APPNO.UnbindRequest,desc="UnbindRequest"},_M.Packet) 244 | 245 | _M.SearchRequest={ 246 | opCode=APPNO.SearchRequest, 247 | desc="SearchRequest", 248 | parsePayload=function(self,payload,pos) 249 | logger.log(logger.DEBUG,"searchRequest:") 250 | 251 | pos,self.baseObject = decoder:decode(payload,pos) 252 | 253 | logger.log(logger.DEBUG,"baseObject:"..self.baseObject ) 254 | 255 | pos,self.scope = decoder:decode(payload,pos) 256 | 257 | logger.log(logger.DEBUG,"scope:"..self.scope ) 258 | 259 | pos,self.derefAlias = decoder:decode(payload,pos) 260 | 261 | logger.log(logger.DEBUG,"derefAlias:"..self.derefAlias ) 262 | 263 | pos,self.sizeLimit = decoder:decode(payload,pos) 264 | 265 | logger.log(logger.DEBUG,"sizeLimit:"..self.sizeLimit ) 266 | 267 | pos,self.timeLimit = decoder:decode(payload,pos) 268 | 269 | logger.log(logger.DEBUG,"timeLimit:"..self.timeLimit ) 270 | 271 | pos,self.typesOnly = decoder:decode(payload,pos) 272 | 273 | logger.log(logger.DEBUG,"typesOnly:"..(self.typesOnly and "true" or "false")) 274 | 275 | pos,self.filter=decodeFilter(payload,pos) 276 | 277 | logger.log(logger.DEBUG,"self.filter:"..self.filter) 278 | 279 | pos,self.attributes=decoder:decode(payload,pos) 280 | 281 | logger.log(logger.DEBUG,"self.attributes:\n"..cjson.encode(self.attributes)) 282 | 283 | logger.log(logger.DEBUG,"searchRequest finish") 284 | return self 285 | end 286 | } 287 | extends(_M.SearchRequest,_M.Packet) 288 | 289 | _M.SearchResultEntry={ 290 | opCode= APPNO.SearchResultEntry, 291 | desc="SearchResponseEntry", 292 | parsePayload=function(self,payload,pos) 293 | pos,self.objectName=decoder:decode(payload,pos) 294 | local pos,attr=decoder:decode(payload,pos) 295 | print(tableUtils.printTableF(attr)) 296 | self.attributes={} 297 | for i,v in ipairs(attr) do 298 | local t=v[1] 299 | table.remove(v,1) 300 | table.insert(self.attributes,{attrType=t,values=v}) 301 | end 302 | return self 303 | end, 304 | packPayload=function(self) 305 | local resultObjectName = encoder:encode(self.objectName) 306 | local tmp="" 307 | for i,v in ipairs(self.attributes) do 308 | local attrValues="" 309 | for _,val in ipairs(v.values) do 310 | attrValues=attrValues..encoder:encode(val) 311 | end 312 | tmp=tmp..encoder:encodeSeq(encoder:encode(v.attrType)..encoder:encodeSet(attrValues)) 313 | end 314 | local resultAttributes = encoder:encodeSeq(tmp) 315 | local payloadBytes = resultObjectName..resultAttributes 316 | return payloadBytes 317 | end, 318 | } 319 | extends(_M.SearchResultEntry,_M.Packet) 320 | 321 | _M.SearchResultDone={ 322 | opCode= APPNO.SearchResultDone, 323 | desc="SearchResponseDone", 324 | parsePayload=function(self,payload,pos) 325 | pos,self.resultCode=decoder:decode(payload,pos) 326 | return self 327 | end, 328 | packPayload=function(self) 329 | local payloadBytes=encoder:encode(self.resultCode,"enumerated") .. encoder:encode('') .. encoder:encode('') 330 | return payloadBytes 331 | end 332 | } 333 | extends(_M.SearchResultDone,_M.Packet) 334 | 335 | ---------------------test starts here ----------------------- 336 | _M.unitTest={} 337 | function _M.unitTest.testBindRequest() 338 | local bytes=string.fromhex("30840000003402010a60840000002b020103041e636e3d61646d696e2c64633d7777772c64633d746573742c64633d636f6d800661413132332e") 339 | local p=require("suproxy.ldap.parser"):new().C2PParser:parse(bytes) 340 | assert(p.version==3,p.version) 341 | assert(p.username=="cn=admin,dc=www,dc=test,dc=com",p.username) 342 | assert(p.password=="aA123.",p.password) 343 | p.username="cn=admin,dc=www,dc=test,dc=cn" 344 | p.password="Aa123." 345 | p:pack(encoder) 346 | bytes=p.allBytes 347 | local p=require("suproxy.ldap.parser"):new().C2PParser:parse(bytes) 348 | assert(p.username=="cn=admin,dc=www,dc=test,dc=cn",p.username) 349 | assert(p.password=="Aa123.",p.password) 350 | end 351 | 352 | function _M.unitTest.testSearchRequest() 353 | local bytes=string.fromhex("30840000008302010c638400000046041564633d7777772c64633d746573742c64633d636f6d0a01010a010002010002013c010100870b6f626a656374636c61737330840000000d040b6f626a656374636c617373a0840000002e3084000000280416312e322e3834302e3131333535362e312e342e3331390101ff040b3084000000050201640400") 354 | local p=require("suproxy.ldap.parser"):new().C2PParser:parse(bytes) 355 | assert(p.baseObject=="dc=www,dc=test,dc=com",p.baseObject) 356 | assert(p.scope==1,p.scope) 357 | assert(p.derefAlias==0,p.derefAlias) 358 | assert(p.timeLimit==60,p.timeLimit) 359 | assert(p.sizeLimit==0,p.sizeLimit) 360 | assert(p.typesOnly==false,p.typesOnly) 361 | assert(p.filter=="(objectclass=*)",p.filter) 362 | end 363 | 364 | 365 | function _M.unitTest.testSearchResultEntry() 366 | local bytes=string.fromhex("306502016a6460041564633d7777772c64633d746573742c64633d636f6d3047302c040b6f626a656374436c617373311d0403746f70040864634f626a656374040c6f7267616e697a6174696f6e300a04016f31050403646576300b0402646331050403777777") 367 | local p=require("suproxy.ldap.parser"):new().S2PParser:parse(bytes) 368 | assert(p.objectName=="dc=www,dc=test,dc=com",p.objectName) 369 | assert(#(p.attributes)==3,#(p.attributes)) 370 | p:pack() 371 | bytes=p.allBytes 372 | local p=require("suproxy.ldap.parser"):new().S2PParser:parse(bytes) 373 | assert(p.objectName=="dc=www,dc=test,dc=com",p.objectName) 374 | assert(#(p.attributes)==3,#(p.attributes)) 375 | end 376 | 377 | function _M.test() 378 | for k,v in pairs(_M.unitTest) do 379 | print("------------running "..k) 380 | v() 381 | print("------------"..k.." finished") 382 | end 383 | end 384 | 385 | return _M -------------------------------------------------------------------------------- /ldap/parser.lua: -------------------------------------------------------------------------------- 1 | local P=require "suproxy.ldap.ldapPackets" 2 | local parser=require("suproxy.parser") 3 | local _M={} 4 | 5 | local conf={ 6 | {key=P.APPNO.BindRequest, parser=P.BindRequest, eventName="BindRequest"}, 7 | {key=P.APPNO.UnbindRequest, parser=P.UnbindRequest, eventName="UnbindRequest"}, 8 | {key=P.APPNO.SearchRequest, parser=P.SearchRequest, eventName="SearchRequest"}, 9 | {key=P.APPNO.BindResponse, parser=P.BindResponse, eventName="BindResponse"}, 10 | {key=P.APPNO.SearchResultEntry, parser=P.SearchResultEntry,eventName="SearchResultEntry"}, 11 | {key=P.APPNO.SearchResultDone, parser=P.SearchResultDone, eventName="SearchResultDone"}, 12 | } 13 | 14 | local function keyG(allBytes,pos) 15 | local p=P.Packet:new() p:parseHeader(allBytes,pos) return p.opCode 16 | end 17 | 18 | function _M:new() 19 | local o= setmetatable({},{__index=self}) 20 | local C2PParser=parser:new() 21 | C2PParser.keyGenerator=keyG 22 | C2PParser:registerMulti(conf) 23 | C2PParser:registerDefaultParser(P.Packet) 24 | o.C2PParser=C2PParser 25 | 26 | local S2PParser=parser:new() 27 | S2PParser.keyGenerator=keyG 28 | S2PParser:registerMulti(conf) 29 | S2PParser:registerDefaultParser(P.Packet) 30 | o.S2PParser=S2PParser 31 | return o 32 | end 33 | return _M 34 | 35 | -------------------------------------------------------------------------------- /parser.lua: -------------------------------------------------------------------------------- 1 | require "suproxy.utils.stringUtils" 2 | local event=require "suproxy.utils.event" 3 | local logger=require "suproxy.utils.compatibleLog" 4 | local tableUtils=require "suproxy.utils.tableUtils" 5 | local _M={} 6 | function _M:new() 7 | local o=setmetatable({},{__index=self}) 8 | o.events={} 9 | o.parserList={} 10 | o.defaultParseEvent=event:newReturnEvent(nil,"defaultParseEvent") 11 | return o 12 | end 13 | 14 | function _M:register(key,parserName,parser,eventName,e) 15 | assert(not self.parserList[key],string.format("unable to register parser %s, key already registered",key)) 16 | if eventName then 17 | assert(not self.events[eventName],string.format("unable to register event %s, event already registered",eventName)) 18 | e=e or event:new(nil,eventName) 19 | self.events[eventName]=e 20 | end 21 | self.parserList[key]={parser=parser,parserName=parserName,event=e} 22 | end 23 | 24 | function _M:unregister(key,eventName) 25 | self.parserList[key]=nil 26 | if eventName then 27 | self.events[eventName]=nil 28 | end 29 | end 30 | 31 | function _M:getParser(key) 32 | return self.parserList[key] 33 | end 34 | 35 | function _M:registerMulti(t) 36 | for i,v in ipairs(t) do 37 | local parserName 38 | if v.parser then parserName=v.parserName or v.parser.desc or tostring(v.parser) end 39 | self:register(v.key,parserName,v.parser,v.eventName,v.e) 40 | end 41 | end 42 | 43 | function _M:registerDefaultParser(parser) 44 | assert(parser,"default parser can not be null") 45 | self.defaultParser=parser 46 | end 47 | 48 | function _M.printPacket(packet,allBytes,key,parserName,...) 49 | local args={...} 50 | if not parserName then 51 | logger.logWithTitle(logger.DEBUG,string.format("packet with key %s doesn't have parser",key),(allBytes and allBytes:hex16F() or "")) 52 | else 53 | logger.logWithTitle(logger.DEBUG,string.format("packet with key %s will be parsed by parser %s ",key,parserName or "Unknown"),(allBytes and allBytes:hex16F() or "")) 54 | end 55 | for i,v in ipairs(args) do 56 | logger.log(logger.DEBUG,"\r\noptions"..i..":"..tableUtils.printTableF(v,{inline=true,printIndex=true})) 57 | end 58 | logger.log(logger.DEBUG,"\r\npacket:"..tableUtils.printTableF(packet,{ascii=true,excepts={"allBytes"}})) 59 | end 60 | 61 | --static method to parse all kinds of packets 62 | function _M:parse(allBytes,pos,key,...) 63 | pos=pos or 1 64 | assert(allBytes,"bytes stream can not be null") 65 | if not key then key=self.keyGenerator end 66 | if type(key)=="function" then 67 | key=key(allBytes,pos,...) 68 | end 69 | assert(key,"key can not be null") 70 | local packet={} 71 | packet.allBytes=allBytes 72 | local parser,event,newBytes,parserName 73 | if self.parserList[key] then 74 | parser=self.parserList[key].parser 75 | parserName=self.parserList[key].parserName 76 | if self.parserList[key].event then 77 | event=self.parserList[key].event 78 | end 79 | end 80 | if not parser and self.defaultParser then 81 | parser=self.defaultParser 82 | parserName="Default Parser" 83 | end 84 | event=event or self.defaultParseEvent 85 | local args={...} 86 | local ok=true 87 | local ret 88 | if parser then 89 | ok,ret=xpcall(function() return parser:new(nil,unpack(args)):parse(allBytes,pos,unpack(args))end,function(err) logger.log(logger.ERR,err) logger.log(logger.ERR,debug.traceback()) end,"error when parsing ") 90 | if ok then packet= ret end 91 | end 92 | if logger.getLogLevel().code>=logger.DEBUG.code then 93 | _M.printPacket(packet,allBytes,key,parserName,...) 94 | end 95 | packet.__key=packet.__key or key 96 | if ok and event and event:hasHandler() then 97 | xpcall(function() return event:trigger(packet,allBytes,key,unpack(args)) end,function(err) logger.log(logger.ERR,err) logger.log(logger.ERR,debug.traceback()) end,"error when exe parser handler " ) 98 | end 99 | return packet 100 | end 101 | _M.doParse=doParse 102 | return _M -------------------------------------------------------------------------------- /session/session.lua: -------------------------------------------------------------------------------- 1 | require "suproxy.utils.stringUtils" 2 | require "suproxy.utils.pureluapack" 3 | local _M={} 4 | function _M:new(stype,manager) 5 | assert(manager,"manager can not be null") 6 | local now=ngx.time() 7 | local session={sid=string.random(4):hex(),uptime=now,ctime=now,stype=stype,uid="_SUPROXY_UNKNOWN"} 8 | setmetatable(session,{__index=self}) 9 | local sessionMeta={ 10 | __index=session, 11 | __newindex=function(t,k,v) 12 | local now=ngx.time() 13 | manager:setProperty(session.sid,k,v,"uptime",now) 14 | t.__data[k]=v 15 | t.__data.uptime=now 16 | end 17 | } 18 | manager:create(session) 19 | local proxy={__data=session,__manager=manager} 20 | return setmetatable(proxy,sessionMeta),nil 21 | end 22 | 23 | function _M:kill() 24 | return self.__manager:kill(self.sid) 25 | end 26 | 27 | function _M:valid() 28 | return self.__manager:valid(self.sid) 29 | end 30 | 31 | function _M.newDoNothing(self) 32 | return { 33 | create=function() return end, 34 | setProperty=function() return 1 end, 35 | valid=function() return true end, 36 | kill=function() return true end 37 | } 38 | end 39 | 40 | return _M -------------------------------------------------------------------------------- /session/sessionManager.lua: -------------------------------------------------------------------------------- 1 | local cjson=require("cjson") 2 | local logger=require("suproxy.utils.compatibleLog") local redis = require "resty.redis" local _M = {} 3 | ---------------required method for implements---------------------- 4 | -- options {ip=ip,port=port,sock=sock,timeout=timeout,expire=expire,extend=true} 5 | function _M.new(self,options) 6 | local o=setmetatable({},{__index=self}) 7 | local red = redis:new() 8 | local ip=options.ip or "127.0.0.1" 9 | local port=options.port or 6379 10 | local timeout=options.timeout or 5000 11 | o.expire=options.expire or 3600 12 | o.extend=(options.extend==nil) and true or options.extend 13 | local sock=options.sock 14 | red:set_timeout(timeout) 15 | local ok, err 16 | if sock then 17 | ok,err=red:connect("unix:/path/to/redis.sock") 18 | else 19 | ok, err = red:connect(ip, port) 20 | end 21 | if not ok then 22 | logger.log(logger.ERR,"failed to connect: ", err) 23 | return 24 | end 25 | o.redis=red 26 | return o 27 | end local function getKey(sid) return "gateway_session_"..sid end function _M:create(session) assert(session,"session can not be null") assert(session.sid,"session.sid can not be null") local sessions=self.redis--ngx.shared.sessions 28 | local k=getKey(session.sid) 29 | local ok,err=sessions:set(k,cjson.encode(session)) 30 | if ok and self.expire>=0 then ok,err=sessions:expire(k,self.expire) end if not ok then return false,err end return true 31 | end function _M:setProperty(sid,...) 32 | local args={...} 33 | assert(#args%2==0,"key value count should be even") 34 | local s=self:get(sid) 35 | local result=0 if not s then return result end 36 | for i=1,#args,2 do s[args[i]]=args[i+1] local ok,err=self:update(sid,s) 37 | if ok then 38 | result=result+1 39 | else 40 | logger.log(logger.ERR,"failed to update ", args[i]," to" ,args[i+1]," with error message: ",err) 41 | end end 42 | return result,err end function _M:update(sid,session) assert(session,"session can not be null") assert(sid,"sid can not be null") local sessions=self.redis--ngx.shared.sessions 43 | local k=getKey(sid) local ok,err=sessions:set(k,cjson.encode(session)) 44 | if ok then 45 | if self.extend then ok,err=sessions:expire(k,self.expire) 46 | elseif(self.expire>=0) then 47 | local currentTime=ngx.time() 48 | local createTime=session.ctime 49 | local elapsedTime=currentTime-createTime 50 | ok,err=sessions:expire(k,self.expire-elapsedTime) 51 | if not ok then sessions:expire(k,0) end 52 | end 53 | end return ok,err end 54 | 55 | function _M:get(sid) assert(sid,"sid can not be null") 56 | local sessions=self.redis--ngx.shared.sessions 57 | local result,err=sessions:get(getKey(sid)) 58 | if err or not result or result==ngx.null then 59 | return nil,err 60 | end 61 | return cjson.decode(result),nil 62 | end function _M:valid(sid) assert(sid,"sid can not be null") local s,err=self:get(sid) return s and true or false end 63 | 64 | function _M:kill(sid) assert(sid,"sid can not be null") 65 | local sessions=self.redis--ngx.shared.sessions 66 | return sessions:del(getKey(sid)) 67 | end 68 | ---------------------manage method----------------------- 69 | function _M:getSessionOfUser(uid) 70 | assert(uid,"uid can not be null") 71 | local sessions=self.redis--ngx.shared.sessions 72 | local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys() 73 | if err then return nil,err end 74 | local result={} 75 | for i=1,#keys,1 do 76 | if cjson.decode(sessions:get(keys[i])).ctx.uid==uid then 77 | result[keys[i]]=sessions:get(keys[i]) 78 | end 79 | end 80 | return result,nil 81 | end 82 | 83 | function _M:killSessionOfUser(uid) assert(uid,"uid can not be null") local sessions=self.redis--ngx.shared.sessions 84 | local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys() if err then return 0,err end 85 | local result=0 86 | for i=1,#keys,1 do 87 | if cjson.decode(sessions:get(keys[i])).ctx.uid==uid then 88 | sessions:del(keys[i]) 89 | result=result+1 90 | end 91 | end 92 | return result,nil 93 | end 94 | 95 | function _M:getAll() 96 | local sessions=self.redis--ngx.shared.sessions 97 | local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys() local result={} if err then return result,0,err end local count=0 98 | for i=1,#keys,1 do 99 | result[keys[i]]=cjson.decode(sessions:get(keys[i])) count=count+1 100 | end 101 | return result,count,nil 102 | end function _M:clear() local sessions=self.redis--ngx.shared.sessions local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys() local result=0 if err then return result,err end for i=1,#keys,1 do sessions:del(keys[i]) result=result+1 end return result end 103 | 104 | return _M -------------------------------------------------------------------------------- /ssh2/commandCollector.lua: -------------------------------------------------------------------------------- 1 | local ssh2Packet=require "suproxy.ssh2.ssh2Packets" local sc=require"suproxy.ssh2.shellCommand" local event=require "suproxy.utils.event" 2 | local logger=require "suproxy.utils.compatibleLog" 3 | local _M={} 4 | 5 | function _M:new() 6 | local o=setmetatable({}, {__index=self}) o.reply="" 7 | o.commandReply="" o.welcome="" 8 | o.command=sc:new() o.firstReply=true o.BeforeWelcomeEvent=event:newReturnEvent(o,"BeforeWelcomeEvent") 9 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent") 10 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent") 11 | return o 12 | end 13 | local function removeANSIEscape(str) 14 | return str:gsub(string.char(0x1b).."[%[%]%(][0-9%:%;%<%=%>%?]*".."[@A-Z%[%]%^_`a-z%{%|%}%~]","") 15 | end 16 | local function removeUnprintableAscii(str) 17 | return str:gsub(".", 18 | function(x) 19 | if (string.byte(x)<=31 or string.byte(x)==127) and (string.byte(x)~=0x0d) then return "" end 20 | end 21 | ) 22 | end 23 | function _M:handleDataUp(processor,packet,ctx) 24 | self.waitForWelcome=false if self.waitForReply then self.waitForReply=false 25 | if self.commandReply and self.commandReply ~="" then 26 | logger.log(logger.DEBUG,"--------------\r\n",self.commandReply) 27 | local reply=removeANSIEscape(self.commandReply) 28 | self.CommandFinishedEvent:trigger(self.lastCommand,reply,ctx) 29 | end end 30 | self.reply="" self.commandReply="" 31 | local channel=packet.channel 32 | local letter=packet.data 33 | logger.log(logger.DEBUG,"-------------letter---------------",letter:hex()) 34 | --up down arrow 35 | if letter==string.char(0x1b,0x5b,0x41) or letter==string.char(0x1b,0x5b,0x42) then 36 | self.upArrowClicked=true 37 | self.command:clear() 38 | --ctrl+u 39 | elseif letter==string.char(0x15) then 40 | self.command:removeBefore(nil,all) 41 | --left arrow or ctrl+b 42 | elseif letter==string.char(0x1b,0x5b,0x44) or letter==string.char(2) then 43 | self.command:moveCursor(-1) 44 | --right arrow or ctrl+f 45 | elseif letter==string.char(0x1b,0x5b,0x43) or letter==string.char(6) then 46 | self.command:moveCursor(1) 47 | --home or ctrl+a 48 | elseif letter==string.char(0x1b,0x5b,0x31,0x7e) or letter==string.char(1) then 49 | self.command:home() 50 | --end or ctrl+e 51 | elseif letter==string.char(0x1b,0x5b,0x34,0x7e) or letter==string.char(5) then 52 | self.command:toEnd() 53 | --delete or control+d 54 | elseif letter==string.char(0x1b,0x5b,0x33,0x7e) or letter==string.char(4) then 55 | self.command:removeAfter() 56 | --tab 57 | elseif letter==string.char(0x09) then 58 | self.tabClicked=true 59 | --backspace 60 | elseif letter==string.char(0x7f) or letter==string.char(8) then 61 | self.command:removeBefore() 62 | --ctrl+c 63 | elseif letter==string.char(0x03) then 64 | self.command:clear() 65 | --ctrl+? still needs further process 66 | elseif letter==string.char(0x1f) then 67 | self.tabClicked=true 68 | --enter 69 | elseif letter==string.char(0x0d) then 70 | if(self.command:getLength()>0) then 71 | local cstr=self.command:toString() self.command:clear() self.lastCommand=cstr self.waitForReply=true 72 | local newcmd,err=self.CommandEnteredEvent:trigger(cstr,ctx) 73 | if err then local toSend=ssh2Packet.ChannelData:new{ channel=256, data=table.concat{"\r\n",err.message,"\r\n"} }:pack().allBytes processor:sendDown(toSend) --0x05 0x15 move cursor to the end and delete all 74 | packet.data=string.char(5,0x15,0x0d) 75 | packet:pack() 76 | elseif newcmd~=cstr then 77 | --0x05 0x15 for move cursor to the end and delete all 78 | packet.data=string.char(5,0x15)..newcmd.."\n" 79 | packet:pack() 80 | end 81 | end 82 | elseif ((string.byte(letter,1)>31 and string.byte(letter,1)<127)) or string.byte(letter,1)>=128 83 | then 84 | self.command:append(letter) 85 | end 86 | return packet 87 | end 88 | local function processReply(self,reply) if not reply then return end --found OSC command ESC]0; means new prompt should be display local endPos=reply:find(string.char(0x1b,0x5d,0x30,0x3b)) if not endPos then endPos=reply:find("[.*@.*:.*]?[%s]?%[?.*@.*%]?[\\$#][%s]?") end if not endPos then endPos=reply:find("mysql>[%s]?") end if endPos then self.lastPrompt=reply:sub(endPos) self.commandReply=self.commandReply..reply:sub(1,endPos-1) return self.lastPrompt,self.commandReply else self.commandReply=self.commandReply..reply end end 89 | function _M:handleDataDown(processor,packet,ctx) 90 | local reply=packet.data 91 | --up arrow 92 | if self.upArrowClicked then 93 | --command may have leading 0x08 bytes, trim it 94 | self.command:append(removeUnprintableAscii(removeANSIEscape(reply))) 95 | self.upArrowClicked=false 96 | --tab 97 | elseif self.tabClicked then 98 | self.command:append(removeUnprintableAscii(reply),self.commandPtr) 99 | self.tabClicked=false 100 | --prompt received 101 | elseif self.waitForReply and reply then 102 | processReply(self,reply) --welcome screen 103 | elseif self.firstReply and self.BeforeWelcomeEvent:hasHandler() then self.firstReply=false local welcome,prepend=self.BeforeWelcomeEvent:trigger(ctx) if welcome then local prompt,orignalWelcome=processReply(self,reply) if not prompt then self.waitForWelcome=true end self.prepend=prepend local data={ welcome, prepend and self.commandReply or "", (not prepend) and (prompt or ">") or "" } packet.data=table.concat(data) packet:pack() end elseif self.waitForWelcome then local prompt,orignalWelcome=processReply(self,reply) if not self.prepend then if prompt then packet.data=prompt packet:pack() else packet.allBytes=nil end end end 104 | return packet 105 | end 106 | 107 | 108 | return _M -------------------------------------------------------------------------------- /ssh2/parser.lua: -------------------------------------------------------------------------------- 1 | --ssh2.0 protocol parser local P=require "suproxy.ssh2.ssh2Packets" local parser=require("suproxy.parser") local _M={} local conf={ {key=P.PktType.KeyXInit, parserName="KeyXInit", parser=P.KeyXInit, eventName="KeyXInitEvent"}, {key=P.PktType.DHKeyXInit, parserName="DHKeyXInit", parser=P.DHKeyXInit, eventName="DHKeyXInitEvent"}, {key=P.PktType.DHKeyXReply, parserName="DHKeyXReply", parser=P.DHKeyXReply, eventName="DHKeyXReplyEvent"}, {key=P.PktType.AuthReq, parserName="AuthReq", parser=P.AuthReq, eventName="AuthReqEvent"}, {key=P.PktType.AuthFail, parserName="AuthFail", parser=P.AuthFail, eventName="AuthFailEvent"}, {key=P.PktType.ChannelData, parserName="ChannelData", parser=P.ChannelData, eventName="ChannelDataEvent"}, {key=P.PktType.Disconnect, parserName="Disconnect", parser=P.Disconnect, eventName="DisconnectEvent"}, {key=P.PktType.NewKeys, eventName="NewKeysEvent"}, {key=P.PktType.AuthSuccess, eventName="AuthSuccessEvent"} } local keyG=function(allBytes) return allBytes:byte(6) end function _M:new() local o= setmetatable({},{__index=self}) local C2PParser=parser:new() C2PParser.keyGenerator=keyG C2PParser:registerMulti(conf) C2PParser:registerDefaultParser(P.Base) o.C2PParser=C2PParser local S2PParser=parser:new() S2PParser.keyGenerator=keyG S2PParser:registerMulti(conf) S2PParser:registerDefaultParser(P.Base) o.S2PParser=S2PParser return o end return _M -------------------------------------------------------------------------------- /ssh2/shellCommand.lua: -------------------------------------------------------------------------------- 1 | require "suproxy.utils.stringUtils" 2 | require "suproxy.utils.pureluapack" 3 | local _M = {} 4 | function _M.new(self) 5 | local o={} 6 | o.cursor=0 7 | o.chars={} 8 | return setmetatable(o, {__index=self}) 9 | end 10 | local function rshiftArray(tab,cursor,count) 11 | for i=#tab,cursor,-1 do 12 | tab[i+count]=tab[i] 13 | end 14 | return tab 15 | end 16 | 17 | function _M.append(self,str) 18 | local i=1 19 | while i<=#str do 20 | local c=string.byte(str,i) 21 | self.chars=rshiftArray(self.chars,self.cursor,1) 22 | if c>0xF0 then 23 | self.chars[self.cursor+1]=str:sub(i,i+3) 24 | i=i+4 25 | elseif c>0xE0 then 26 | --unicode 27 | self.chars[self.cursor+1]=str:sub(i,i+2) 28 | i=i+3 29 | elseif c>0xC0 then 30 | self.chars[self.cursor+1]=str:sub(i,i+1) 31 | i=i+2 32 | else 33 | self.chars[self.cursor+1]=str:sub(i,i) 34 | i=i+1 35 | end 36 | self.cursor=self.cursor+1 37 | end 38 | end 39 | 40 | function _M.removeBefore(self,count,all) 41 | if all then count=self.cursor end 42 | if not count then count=1 end 43 | if self.cursor<1 then return end 44 | if self.cursor-count<0 then count= self.cursor end 45 | for i=1, count, -1 do 46 | table.remove(self.chars,self.cursor) 47 | end 48 | self.cursor=self.cursor-count 49 | end 50 | 51 | function _M.clear(self) 52 | self.cursor=0 53 | self.chars={} 54 | end 55 | 56 | function _M.removeAfter(self,count,all) 57 | if all then count= #(self.chars)-self.cursor end 58 | if not count then count=1 end 59 | if self.cursor==#(self.chars) then return end 60 | if self.cursor+count>#(self.chars) then count= #(self.chars)-self.cursor end 61 | for i=1,count,1 do 62 | table.remove(self.chars,self.cursor) 63 | end 64 | end 65 | 66 | function _M.home(self) 67 | self.cursor=0 68 | end 69 | 70 | function _M.toEnd(self) 71 | self.cursor=#(self.chars) 72 | end 73 | 74 | function _M.moveCursor(self,step) 75 | if self.cursor+step>#(self.chars) then 76 | self.cursor=#(self.chars) return 77 | elseif self.cursor+step<0 then 78 | self.cursor=0 return 79 | end 80 | self.cursor=self.cursor+step 81 | end 82 | 83 | function _M.getLength(self) 84 | return #(self.chars) 85 | end 86 | 87 | function _M.get(self,pos) 88 | if pos>#(self.chars) or pos<1 then return nil end 89 | return self.chars[pos] 90 | end 91 | 92 | function _M.toString(self) 93 | -- local result="" 94 | -- for k, v in pairs(self.chars) do 95 | -- result=result..v 96 | -- end 97 | -- return result 98 | return table.concat(self.chars) 99 | end 100 | 101 | function _M.test() 102 | local str="中华A已经Bあまり哈哈哈1234567" 103 | print(str:hex()) 104 | local shellCommand=_M 105 | local command=shellCommand:new() 106 | command:append(str) 107 | assert(command:getLength()==19,command:getLength()) 108 | assert(command:get(1)=="中",command:get(1)) 109 | assert(command:get(19)=="7",command:get(19)) 110 | assert(command.cursor==19,command.cursor) 111 | command:moveCursor(-100) 112 | command:moveCursor(100) 113 | command:moveCursor(-2) 114 | assert(command.cursor==17,command.cursor) 115 | command:removeAfter() 116 | command:removeAfter() 117 | command:removeAfter() 118 | command:removeAfter() 119 | command:removeAfter() 120 | command:removeBefore() 121 | assert(command.cursor==16,command.cursor) 122 | assert(command:toString()=="中华A已经Bあまり哈哈哈1234",command:toString():hex()) 123 | command:clear() 124 | command:append("34") 125 | command:moveCursor(-2) 126 | command:append("12") 127 | assert(command:toString()=="1234",command:toString()) 128 | command:home() 129 | assert(command.cursor==0,command.cursor) 130 | command:toEnd() 131 | assert(command.cursor==command:getLength(),command.cursor) 132 | print(command:toString()) 133 | end 134 | 135 | return _M 136 | 137 | -------------------------------------------------------------------------------- /ssh2/ssh2CipherConf.lua: -------------------------------------------------------------------------------- 1 | local bn = require "resty.openssl.bn" local rand= require "resty.openssl.rand" local tableUtils = require "suproxy.utils.tableUtils" local _M={} _M.EncAlg={ {name="aes128-ctr",cipherStr="aes-128-ctr"}, {name="aes128-cbc",cipherStr="aes-128-cbc"} } function _M.EncAlg:getList() local rs={} for i,v in ipairs(self) do rs[#rs+1]=v.name end return table.concat(rs,",") end _M.DHAlg={} --http://ietf.org/rfc/rfc3526.txt _M.DHAlg[#(_M.DHAlg)+1]={ name="diffie-hellman-group14-sha256", p=bn.from_binary(string.fromhex([[ FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 15728E5A 8AACAA68 FFFFFFFF FFFFFFFF ]])), shaAlg="sha256"} _M.DHAlg[#(_M.DHAlg)+1]={ name="diffie-hellman-group14-sha1", p=bn.from_binary(string.fromhex([[ FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F 83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D 670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9 DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510 15728E5A 8AACAA68 FFFFFFFF FFFFFFFF ]])), shaAlg="sha1" } _M.DHAlg[#(_M.DHAlg)+1]={ name="diffie-hellman-group1-sha1", p=bn.from_binary(string.fromhex([[ FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE65381 FFFFFFFF FFFFFFFF ]])), shaAlg="sha1" } function _M.DHAlg:getList() local rs={} for i,v in ipairs(self) do rs[#rs+1]=v.name end return table.concat(rs,",") end 2 | --todo should be random bytes _M.y=bn.from_binary(rand.bytes(1024)) _M.x=bn.from_binary(rand.bytes(1024)) --key pair used for KEX _M.pubkey= 3 | [[ 4 | -----BEGIN RSA PUBLIC KEY----- 5 | MIIBCgKCAQEAyfPdItqWAL0kLjr4C9FJUm1nyRqNePUfAEHZqH+zQDnUmRUnJc/t 6 | YvViQwoBS4O21LbEJJJyA2UQ3LsiCj6l511uTJKjs43jS8uufLamnZkovfnj766V 7 | AQuGLb/LL28kbDNrjEBILG7Z1SjKOMcj8ltt5Jno3hy8QbufK+9nk1AyjvJy2xxg 8 | mAUYOXxI8hYOmIybdL06sKmnqn3CcBjHm5al426f91BgZk0uiaK+8Tq3fVi36fss 9 | o5ZGI3V64zRF+FCE80RvGW3S4ErUm95+SwLRjVav6keCQXYfVHiQ9sacLxjuVve4 10 | /UKjlFztG8+U/ZrIO5GgHEEc8px2s5mqMwIDAQAB 11 | -----END RSA PUBLIC KEY----- 12 | ]] 13 | _M.privkey= 14 | [[ 15 | -----BEGIN RSA PRIVATE KEY----- 16 | MIIEowIBAAKCAQEAyfPdItqWAL0kLjr4C9FJUm1nyRqNePUfAEHZqH+zQDnUmRUn 17 | Jc/tYvViQwoBS4O21LbEJJJyA2UQ3LsiCj6l511uTJKjs43jS8uufLamnZkovfnj 18 | 766VAQuGLb/LL28kbDNrjEBILG7Z1SjKOMcj8ltt5Jno3hy8QbufK+9nk1AyjvJy 19 | 2xxgmAUYOXxI8hYOmIybdL06sKmnqn3CcBjHm5al426f91BgZk0uiaK+8Tq3fVi3 20 | 6fsso5ZGI3V64zRF+FCE80RvGW3S4ErUm95+SwLRjVav6keCQXYfVHiQ9sacLxju 21 | Vve4/UKjlFztG8+U/ZrIO5GgHEEc8px2s5mqMwIDAQABAoIBAFsWEZxhyJxGsuXj 22 | FPOHjrGNxOzQfBSdQkFEch5sknWaX8g34TNNx/0FPi+MeK8Nlk30rRztrFzZnbRg 23 | 9uZ2ATAMVO5WiV031tfd4zI+04FrjhO5fNQjAvO4tek2gzc+wsfGnXBhoevgh4F7 24 | 51GaiB0MndEolf5wKXzgWddgIHAxQ3pgqTqBhCvr0h/U0VxGkntqqEDIzRKohB0D 25 | hd5MXP9hCdTTOud9Kfy/2DKl0a8UWC6N5oyT1EhmGI011Fpc3J+svIbd9fPnRo4B 26 | RoAaiKWezYOja6ruRYZo9+GBtjznV1IGlK9EttMv9W5MbCDyFGU4/MoNmal9pAUz 27 | +HrX/aECgYEA7fG5QWYY6YLCs5UEpwC6jxD5uTXR8LA5JP8VxNLHA4/HCY3nXThC 28 | 800iEWfdgLtdN3H95KslW+E7WLoT7hdONZKKcGq6zBywa4JawUXt8jMravhN17os 29 | 6DTEOHtUE6WETUslJhK0o3232h8wo2dZ99lPiA00Uk8nKUku0CMVULECgYEA2Ub4 30 | iUZKpCHGJ+HjnKINbua36TATBEVY3myr8XBwHDJIAB+LyK+DwJ9QiXJDKnlDGsa8 31 | XEdYaRkUYNIcZgVnTF4s3O1BuGFXBbc7Av+Z85zLtSo/1kDn9YI92Los9APPffdu 32 | UMtbIj9eXqtzVg1DjWcbhZQZAi8uONGEvh+CQiMCgYBIlayFnreKxDDQx2yb5UUD 33 | z5HeReS9H4TPHGFvoTzEgV+eMoOZlEgYIDd8R8ryMjXFbCifUPYciSCpeFoMD1/0 34 | R7ejg2toSHgo06MLwmFLuQBNqWFVpZ19WFtjP3vuYldxnLLAYoRoOzmSeGFF94ki 35 | alAwmJaVZT/1ADYfmBQwgQKBgFRokd0ihZTF2ilcRARxoC5ZS1E37+tU1XVzWkjt 36 | mWAa2IXTu4Y3SUPnoG4FCbrSaRNZ6Ysf3GTX7Wa/uXCY4Mx2OY+KTGHIzvnVeQNt 37 | MO3HGAxFYY9mn7Zs5oHvsc8KO+1/1kdk+P6RB6RXjvL7LCceyz5VjnGeyqIgIyWJ 38 | MB1pAoGBAN2ztGq/yOpNwiCrxeeDpoFe7iQHznaXIgK7yA+uFFUUT58/bl87kq9/ 39 | huKG4OgDnZ+iHeSk9CFNseGZHIdcliTDGnud7ipOCFJCzjtlCcM/oHl1hL7oCJUL 40 | GrUlhvtkRKRV96mCDIEGxKCB0xzqjXKAgzTGApCQ9RUKG5EbNSnC 41 | -----END RSA PRIVATE KEY----- 42 | ]] 43 | return _M -------------------------------------------------------------------------------- /ssh2/ssh2Packets.lua: -------------------------------------------------------------------------------- 1 | --ssh2.0 protocol parser and encoder --Packet parser follows rfc rfc4251,4252,4253 require "suproxy.utils.stringUtils" require "suproxy.utils.pureluapack" local tableUtils=require "suproxy.utils.tableUtils" local ok,cjson=pcall(require,"cjson") cjson = ok and cjson or require("suproxy.utils.json") local extends=tableUtils.extends local orderTable=tableUtils.OrderedTable local asn1 = require "suproxy.utils.asn1" local event=require "suproxy.utils.event" local logger=require "suproxy.utils.compatibleLog" local _M={} --Packet type defines, only the type that have been implemented are listed _M.PktType={ KeyXInit=0x14, DHKeyXInit=0x1e, DHKeyXReply=0x1f, AuthReq=0x32, AuthFail=0x33, ChannelData=0x5e, Disconnect=0x01, NewKeys=0x15, AuthSuccess=0x34, } --Tool for mpint format padding (rfc4251 section 5) local function paddingInt(n) if(n:byte(1)>=128)then return string.char(0)..n end return n end 2 | local function packSSHData(data,padding) 3 | local paddingLength=16-(#data+5)%16 4 | if paddingLength<4 then paddingLength=paddingLength+16 end 5 | local padding=padding or string.random(paddingLength) 6 | return string.pack(">I4B",#data+1+#padding,#padding)..data..padding 7 | end --Base Packet implements header parser and pack -- uint32 packet_length -- byte padding_length -- byte[n1] payload; n1 = packet_length - padding_length - 1 -- byte[n2] random padding; n2 = padding_length -- byte[m] mac (Message Authentication Code - MAC); m = mac_length _M.Base={ parse=function(self,allBytes) local pos self.dataLength,self.paddingLength,self.code,pos=string.unpack(">I4BB",allBytes) self.allBytes=allBytes self:parsePayload(allBytes,pos) return self end, parsePayload=function(self,allbytes,pos) return self end, pack=function(self) self.allBytes=packSSHData(string.char(self.code)..self:packPayload()) logger.logWithTitle(logger.DEBUG,"packing",self.allBytes:hex16F()) return self end, packPayload=function(self) return "" end, new=function(self,o) local o=o or {} return orderTable.new(self,o) end } --Key Exchange Init Packet -- byte SSH_MSG_KEXINIT 0x14 -- byte[16] cookie (random bytes) -- name-list kex_algorithms -- name-list server_host_key_algorithms -- name-list encryption_algorithms_client_to_server -- name-list encryption_algorithms_server_to_client -- name-list mac_algorithms_client_to_server -- name-list mac_algorithms_server_to_client -- name-list compression_algorithms_client_to_server -- name-list compression_algorithms_server_to_client -- name-list languages_client_to_server -- name-list languages_server_to_client -- boolean first_kex_packet_follows -- uint32 0 (reserved for future extension) _M.KeyXInit={ code=_M.PktType.KeyXInit, parsePayload=function(self,allBytes,pos) self.cookie,self.kex_alg, self.key_alg,self.enc_alg_c2s, self.enc_alg_s2c,self.mac_alg_c2s, self.mac_alg_s2c,self.comp_alg_c2s, self.comp_alg_s2c,self.lan_c2s, self.lan_s2c,self.kex_follows, self.reserved=string.unpack(">c16s4s4s4s4s4s4s4s4s4s4BI4",allBytes,pos) self.payloadBytes=allBytes:sub(6,5+self.dataLength-self.paddingLength-1) return self end, packPayload=function(self) local rs=string.pack(">c16s4s4s4s4s4s4s4s4s4s4BI4",self.cookie, self.kex_alg,self.key_alg, self.enc_alg_c2s,self.enc_alg_s2c, self.mac_alg_c2s,self.mac_alg_s2c, self.comp_alg_c2s,self.comp_alg_s2c, self.lan_c2s,self.lan_s2c,self.kex_follows,self.reserved ) self.payloadBytes=string.char(self.code)..rs return rs end } extends(_M.KeyXInit,_M.Base) 8 | 9 | 10 | -- byte SSH_MSG_KEXDH_INIT 0x1e 11 | -- mpint e 12 | _M.DHKeyXInit={ 13 | code=_M.PktType.DHKeyXInit, 14 | parsePayload=function(self,payload,pos) 15 | self.e=string.unpack(">s4",payload,pos) 16 | return self 17 | end, 18 | packPayload=function(self) 19 | return string.pack(">s4",paddingInt(self.e)) 20 | end 21 | } extends(_M.DHKeyXInit,_M.Base) 22 | 23 | -- byte SSH_MSG_KEXDH_REPLY 0x1f 24 | -- string server public host key and certificates (K_S) 25 | -- mpint f 26 | -- string signature of H 27 | _M.DHKeyXReply={ 28 | code=_M.PktType.DHKeyXReply, 29 | parsePayload=function(self,payload,pos) local hh 30 | self.K_S, 31 | self.f, 32 | hh=string.unpack(">s4s4s4",payload,pos) 33 | self.key_alg, 34 | self.signH=string.unpack(">s4s4",hh) 35 | return self 36 | end, 37 | packPayload=function(self) 38 | return string.pack(">s4s4s4",self.K_S,paddingInt(self.f),string.pack(">s4s4",self.key_alg,self.signH)) 39 | end 40 | } extends(_M.DHKeyXReply,_M.Base) 41 | --process user authenticate, request format in rfc4252 section 8 42 | -- byte SSH_MSG_USERAUTH_REQUEST 0x32 43 | -- string user name 44 | -- string service name 45 | -- string method 46 | -- below are optional, if method is "none" ,following field woundn't appear 47 | -- boolean FALSE 48 | -- string plaintext password in ISO-10646 UTF-8 encoding [RFC3629] 49 | _M.AuthReq={ 50 | code=_M.PktType.AuthReq, 51 | parsePayload=function(self,payload,pos) 52 | local passStartPos 53 | self.username, 54 | self.serviceName, 55 | self.method,passStartPos=string.unpack(">s4s4s4",payload,pos) 56 | if self.method=="password" then 57 | self.password=string.unpack(">s1",payload,passStartPos+4) 58 | end 59 | return self 60 | end, 61 | packPayload=function(self) 62 | local req=string.pack(">s4s4s4",self.username,self.serviceName,self.method) 63 | if self.method=="password" then 64 | req=req..string.pack(">s4s1","",self.password) 65 | end 66 | return req 67 | end 68 | } extends(_M.AuthReq,_M.Base) 69 | --SSH_MSG_USERAUTH_FAILURE 0x33 70 | _M.AuthFail={ code=_M.PktType.AuthFail, parsePayload=function(self,payload,pos) self.methods,pos=string.unpack(">s4",payload,pos) if pos < #payload then self.partialSuccess=string.unpack(">I4",payload,pos)>0 end return self end, packPayload=function(self) return string.pack(">s4I4",self.methods,self.partialSuccess and 0 or 1) end } extends(_M.AuthFail,_M.Base) 71 | -- byte SSH_MSG_CHANNEL_DATA 0x5e 72 | -- uint32 recipient channel 73 | -- string data 74 | _M.ChannelData={ 75 | code=_M.PktType.ChannelData, 76 | parsePayload=function(self,payload,pos) 77 | self.channel,self.data=string.unpack(">I4s4",payload,pos) 78 | return self 79 | end, 80 | packPayload=function(self) 81 | return string.pack(">I4s4",self.channel,self.data) 82 | end 83 | } extends(_M.ChannelData,_M.Base) 84 | -- byte SSH_MSG_DISCONNECT 0x01 85 | -- uint32 reason code 86 | -- string description in ISO-10646 UTF-8 encoding [RFC3629] 87 | -- string language tag [RFC3066] 88 | _M.Disconnect={ 89 | code=_M.PktType.Disconnect, --todo: lang not parsed yet 90 | parsePayload=function(self,payload,pos) 91 | self.reasonCode,self.message=string.unpack(">I4s4",payload,pos) 92 | return self 93 | end, 94 | packPayload=function(self) 95 | return string.pack(">I4s4",self.reasonCode,self.message) 96 | end 97 | } extends(_M.Disconnect,_M.Base) 98 | return _M -------------------------------------------------------------------------------- /suproxy-v0.6.0-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "suproxy" 2 | version = "v0.6.0-1" 3 | source = { 4 | url = "git+https://github.com/yizhu2000/suproxy.git", 5 | tag = "v0.6.0" 6 | } 7 | description = { 8 | detailed = "Lua SSH2,LDAP,TNS,TDS proxy and mim library for OpenResty", 9 | homepage = "https://github.com/yizhu2000/suproxy", 10 | license = "BSD" 11 | } 12 | dependencies = { 13 | "lua >= 5.1", 14 | "lua-resty-openssl >= 0.6", 15 | "lua-resty-logger-socket >= v0.1" 16 | } 17 | build = { 18 | type = "builtin", 19 | modules = { 20 | ["suproxy.channel"]="channel.lua", 21 | ["suproxy.ldap"]="ldap.lua", 22 | ["suproxy.parser"]="parser.lua", 23 | ["suproxy.ssh2"]="ssh2.lua", 24 | ["suproxy.tds"]="tds.lua", 25 | ["suproxy.test"]="test.lua", 26 | ["suproxy.tns"]="tns.lua", 27 | ["suproxy.balancer.balancer"]="balancer/balancer.lua", 28 | ["suproxy.example.gateway"]="example/gateway.lua", 29 | ["suproxy.example.session"]="example/session.lua", 30 | ["suproxy.http.handlers"]="http/handlers.lua", 31 | ["suproxy.http.initParam"]="http/initParam.lua", 32 | ["suproxy.http.mockauth"]="http/mockauth.lua", 33 | ["suproxy.http.ssoProcessors"]="http/ssoProcessors.lua", 34 | ["suproxy.ldap.ldapPackets"]="ldap/ldapPackets.lua", 35 | ["suproxy.ldap.parser"]="ldap/parser.lua", 36 | ["suproxy.session.session"]="session/session.lua", 37 | ["suproxy.session.sessionManager"]="session/sessionManager.lua", 38 | ["suproxy.ssh2.commandCollector"]="ssh2/commandCollector.lua", 39 | ["suproxy.ssh2.parser"]="ssh2/parser.lua", 40 | ["suproxy.ssh2.shellCommand"]="ssh2/shellCommand.lua", 41 | ["suproxy.ssh2.ssh2CipherConf"]="ssh2/ssh2CipherConf.lua", 42 | ["suproxy.ssh2.ssh2Packets"]="ssh2/ssh2Packets.lua", 43 | ["suproxy.tds.datetime"]="tds/datetime.lua", 44 | ["suproxy.tds.parser"]="tds/parser.lua", 45 | ["suproxy.tds.tdsPackets"]="tds/tdsPackets.lua", 46 | ["suproxy.tds.token"]="tds/token.lua", 47 | ["suproxy.tds.version"]="tds/version.lua", 48 | ["suproxy.tns.crypt"]="tns/crypt.lua", 49 | ["suproxy.tns.parser"]="tns/parser.lua", 50 | ["suproxy.tns.tnsPackets"]="tns/tnsPackets.lua", 51 | ["suproxy.utils.asn1"]="utils/asn1.lua", 52 | ["suproxy.utils.compatibleLog"]="utils/compatibleLog.lua", 53 | ["suproxy.utils.datetime"]="utils/datetime.lua", 54 | ["suproxy.utils.event"]="utils/event.lua", 55 | ["suproxy.utils.ffi-zlib"]="utils/ffi-zlib.lua", 56 | ["suproxy.utils.json"]="utils/json.lua", 57 | ["suproxy.utils.pureluapack"]="utils/pureluapack.lua", 58 | ["suproxy.utils.stringUtils"]="utils/stringUtils.lua", 59 | ["suproxy.utils.tableUtils"]="utils/tableUtils.lua", 60 | ["suproxy.utils.unicode"]="utils/unicode.lua", 61 | ["suproxy.utils.utils"]="utils/utils.lua" 62 | } 63 | } -------------------------------------------------------------------------------- /tds.lua: -------------------------------------------------------------------------------- 1 | require "suproxy.utils.stringUtils" 2 | require "suproxy.utils.pureluapack" 3 | local event=require "suproxy.utils.event" 4 | local ok,cjson=pcall(require,"cjson") 5 | if not ok then cjson = require("suproxy.utils.json") end 6 | local logger=require "suproxy.utils.compatibleLog" 7 | local tdsPacket=require "suproxy.tds.tdsPackets" 8 | local tableUtils=require "suproxy.utils.tableUtils" local _M = {} 9 | _M._PROTOCAL ='tds' 10 | 11 | function _M.new(self,options) 12 | local o= setmetatable({},{__index=self}) 13 | options=options or {} 14 | o.disableSSL=true 15 | if options.disableSSL~=nil then o.disableSSL=options.disableSSL end 16 | o.catchReply=false 17 | if options.catchReply~=nil then o.catchReply=options.catchReply end 18 | o.BeforeAuthEvent=event:newReturnEvent(o,"BeforeAuthEvent") 19 | o.OnAuthEvent=event:newReturnEvent(o,"OnAuthEvent") 20 | o.AuthSuccessEvent=event:new(o,"AuthSuccessEvent") 21 | o.AuthFailEvent=event:new(o,"AuthFailEvent") 22 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent") 23 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent") 24 | o.ContextUpdateEvent=event:new(o,"ContextUpdateEvent") 25 | o.ctx={} 26 | local tdsParser=require ("suproxy.tds.parser"):new(o.catchReply) 27 | o.C2PParser=tdsParser.C2PParser 28 | o.S2PParser=tdsParser.S2PParser 29 | o.C2PParser.events.SQLBatch:addHandler(o,_M.SQLBatchHandler) 30 | o.C2PParser.events.Prelogin:addHandler(o,_M.PreloginHandler) 31 | o.C2PParser.events.Login7:addHandler(o,_M.Login7Handler) 32 | o.S2PParser.events.LoginResponse:addHandler(o,_M.LoginResponseHandler) 33 | o.S2PParser.events.SSLLoginResponse:addHandler(o,_M.LoginResponseHandler) 34 | o.S2PParser.events.SQLResponse:addHandler(o,_M.SQLResponseHandler) 35 | return o 36 | end 37 | ----------------parser event handlers---------------------- 38 | function _M:SQLBatchHandler(src,p) 39 | if self.CommandEnteredEvent:hasHandler() then 40 | local cmd,err=self.CommandEnteredEvent:trigger(p.sql,self.ctx) 41 | if err then 42 | self.channel:c2pSend(tdsPacket.packErrorResponse(err.message,err.code)) 43 | p.allBytes=nil 44 | return 45 | end 46 | end 47 | ngx.ctx.sql=p.sql 48 | end 49 | 50 | function _M:PreloginHandler(src,p) 51 | if p.options.Encryption and self.disableSSL then 52 | p.options.Encryption=2 53 | --self.ctx.serverVer=p.options.Version.versionNumber 54 | p:pack() 55 | end 56 | end 57 | 58 | function _M:Login7Handler(src,p) 59 | local cred 60 | if self.BeforeAuthEvent:hasHandler() then 61 | cred=self.BeforeAuthEvent:trigger({username=p.username,password=p.password},self.ctx) 62 | end 63 | if self.OnAuthEvent:hasHandler() then 64 | local ok,message,cred=self.OnAuthEvent:trigger({username=p.username,password=p.password},self.ctx) 65 | if not ok then 66 | self.channel:c2pSend(tdsPacket.packErrorResponse(message or "login with "..p.username.." failed",18456)) 67 | p.allBytes=nil 68 | return 69 | end 70 | end 71 | if cred and (p.username~=cred.username or p.password~=cred.password) then 72 | print(p.username,cred.username,p.password,cred.password) 73 | p.username=cred.username 74 | p.password=cred.password 75 | p:pack() 76 | end 77 | self.ctx.username=p.username 78 | self.ctx.client=p.appName 79 | self.ctx.clientVer=p.ClientProgVer:hex() 80 | self.ctx.libName=p.libName 81 | self.ctx.tdsVer=p.TDSVersion:hex() 82 | if self.ContextUpdateEvent:hasHandler() then 83 | self.ContextUpdateEvent:trigger(self.ctx) 84 | end 85 | end 86 | 87 | function _M:LoginResponseHandler(src,p) 88 | if p.success then 89 | if self.AuthSuccessEvent:hasHandler() then 90 | self.AuthSuccessEvent:trigger(self.ctx.username,self.ctx) 91 | end 92 | self.ctx.serverVer=p.serverVersion.versionNumber 93 | self.ctx.tdsVer=p.TDSVersion:hex() 94 | if self.ContextUpdateEvent:hasHandler() then 95 | self.ContextUpdateEvent:trigger(self.ctx) 96 | end 97 | else 98 | if self.AuthFailEvent:hasHandler() then 99 | self.AuthFailEvent:trigger({username=self.ctx.username,message="["..p.errNo.."]"..p.message},self.ctx) 100 | end 101 | end 102 | end 103 | 104 | function _M:SQLResponseHandler(src,p) 105 | if self.CommandFinishedEvent:hasHandler() then 106 | local reply=p.tostring and p:tostring() or "" 107 | reply=reply 108 | self.CommandFinishedEvent:trigger(ngx.ctx.sql,reply,self.ctx) 109 | end 110 | end 111 | 112 | ----------------implement processor methods--------------------- 113 | local function recv(self,readMethod) 114 | local headerBytes,err,partial=readMethod(self.channel,8) 115 | if(err) then 116 | logger.log(logger.ERR,"err when reading header",err) 117 | return partial,err 118 | end 119 | local packet=tdsPacket.Packet:new() 120 | local pos=packet:parseHeader(headerBytes) 121 | local payloadBytes,err,allBytes 122 | if(packet.code==0x17) then 123 | local _,_,_,dataLength=string.unpack(">BBBI2",headerBytes) 124 | payloadBytes,err=readMethod(self.channel,dataLength-3) 125 | allBytes=headerBytes..payloadBytes 126 | else 127 | local dataLength=packet.dataLength 128 | payloadBytes,err=readMethod(self.channel,dataLength-8) 129 | allBytes=headerBytes..payloadBytes 130 | end 131 | return allBytes 132 | end 133 | 134 | function _M.processUpRequest(self) 135 | local readMethod=self.channel.c2pRead 136 | local allBytes,err=recv(self,readMethod) 137 | if err then return nil,err end 138 | local p=self.C2PParser:parse(allBytes) 139 | ngx.ctx.upPacket=p.code 140 | return p.allBytes 141 | end 142 | 143 | function _M.processDownRequest(self) 144 | local readMethod=self.channel.p2sRead 145 | local allBytes,err=recv(self,readMethod) 146 | if err then return nil,err end 147 | local p =self.S2PParser:parse(allBytes,nil,ngx.ctx.upPacket) 148 | return p.allBytes 149 | end 150 | 151 | function _M:sessionInvalid(session) 152 | self.channel:c2pSend(tdsPacket.packErrorResponse("you are not allowed to connect, please contact the admin")) 153 | ngx.exit(0) 154 | end 155 | 156 | return _M 157 | -------------------------------------------------------------------------------- /tds/datetime.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | 参数说明: 4 | srcDateTime 原始时间字符串,要求格式%Y%m%d%H%M%S,这个时间格式字符串表示4位年份、月份、day、小时、分钟、秒都是2位数字 5 | interval 对该时间进行加或减具体值,>0表示加 <0表示减 6 | dateUnit 时间单位,支持DAY、HOUR、SECOND、MINUTE 4种时间单位操作,根据interval具体值对原始时间按指定的单位进行加或减 7 | 例如, 8 | interval=10,unit='DAY',表示对原始时间加10天 9 | interval=-1,unit='HOUR',表示对原始时间减1小时 10 | 11 | 返回结果是一个os.date,他是一个table结构,里面包含了year,month,day,hour,minute,second 6个属性,跟据需要从结果里面取出需要的属性然后根据需要产生相应的新的日期格式即可。 12 | ]] 13 | 14 | _M={} 15 | 16 | function _M.getNewDate(srcDateTime,interval ,dateUnit) 17 | --从日期字符串中截取出年月日时分秒 18 | local Y = string.sub(srcDateTime,1,4) 19 | local M = string.sub(srcDateTime,5,6) 20 | local D = string.sub(srcDateTime,7,8) 21 | local H = string.sub(srcDateTime,9,10) 22 | local MM = string.sub(srcDateTime,11,12) 23 | local SS = string.sub(srcDateTime,13,14) 24 | 25 | --把日期时间字符串转换成对应的日期时间 26 | local dt1 = os.time{year=Y, month=M, day=D, hour=H,min=MM,sec=SS} 27 | 28 | --根据时间单位和偏移量得到具体的偏移数据 29 | local ofset=0 30 | 31 | if dateUnit =='DAY' then 32 | ofset = 60 *60 * 24 * interval 33 | 34 | elseif dateUnit == 'HOUR' then 35 | ofset = 60 *60 * interval 36 | 37 | elseif dateUnit == 'MINUTE' then 38 | ofset = 60 * interval 39 | 40 | elseif dateUnit == 'SECOND' then 41 | ofset = interval 42 | end 43 | 44 | --指定的时间+时间偏移量 45 | local newTime = os.date("*t", dt1 + tonumber(ofset)) 46 | return newTime 47 | end 48 | -------------------------------------------------------------------------------- /tds/parser.lua: -------------------------------------------------------------------------------- 1 | --tds protocol parser and encoder local P=require "suproxy.tds.tdsPackets" local parser=require("suproxy.parser") local _M={} ----------------------build parser--------------------------- --config c2p parsers local c2pConf={ {key=P.Login7.code, parser=P.Login7, eventName="Login7"}, {key=P.SQLBatch.code, parser=P.SQLBatch, eventName="SQLBatch"}, {key=P.Prelogin.code, parser=P.Prelogin, eventName="Prelogin"}, } --config s2p parsers local s2pConf={ {key=P.Prelogin.code, parser=P.PreloginResponse, eventName="PreloginResponse"}, {key=P.Login7.code, parser=P.LoginResponse, eventName="LoginResponse"}, --ssl login response {key=0x17, parser=P.LoginResponse, eventName="SSLLoginResponse"}, {key=P.SQLBatch.code, parser=P.SQLResponse, eventName="SQLResponse"} } function _M:new(catchReply) local o= setmetatable({},{__index=self}) local C2PParser=parser:new() C2PParser.keyGenerator=function(allBytes) return allBytes:byte(1) end C2PParser:registerMulti(c2pConf) C2PParser:registerDefaultParser(P.Packet) o.C2PParser=C2PParser local S2PParser=parser:new() S2PParser:registerMulti(s2pConf) if not catchReply then S2PParser:unregister(P.SQLBatch.code,"SQLResponse") S2PParser:register(P.SQLBatch.code,nil,nil,"SQLResponse") end S2PParser:registerDefaultParser(P.Packet) o.S2PParser=S2PParser return o end return _M -------------------------------------------------------------------------------- /tds/version.lua: -------------------------------------------------------------------------------- 1 | local _M={} 2 | --- SqlServerVersionInfo class 3 | 4 | _M={ 5 | versionNumber = "", -- The full version string (e.g. "9.00.2047.00") 6 | major = nil, -- The major version (e.g. 9) 7 | minor = nil, -- The minor version (e.g. 0) 8 | build = nil, -- The build number (e.g. 2047) 9 | subBuild = nil, -- The sub-build number (e.g. 0) 10 | productName = nil, -- The product name (e.g. "SQL Server 2005") 11 | brandedVersion = nil, -- The branded version of the product (e.g. "2005") 12 | servicePackLevel = nil, -- The service pack level (e.g. "SP1") 13 | patched = nil, -- Whether patches have been applied since SP installation (true/false/nil) 14 | source = nil, -- The source of the version info (e.g. "SSRP", "SSNetLib") 15 | 16 | new = function(self,o) 17 | o = o or {} 18 | setmetatable(o, self) 19 | self.__index = self 20 | return o 21 | end, 22 | 23 | --- Sets the version using a version number string. 24 | -- 25 | -- @param versionNumber a version number string (e.g. "9.00.1399.00") 26 | -- @param source a string indicating the source of the version info (e.g. "SSRP", "SSNetLib") 27 | SetVersionNumber = function(self, versionNumber, source) 28 | local major, minor, revision, subBuild 29 | if versionNumber:match( "^%d+%.%d+%.%d+.%d+" ) then 30 | major, minor, revision, subBuild = versionNumber:match( "^(%d+)%.(%d+)%.(%d+)" ) 31 | elseif versionNumber:match( "^%d+%.%d+%.%d+" ) then 32 | major, minor, revision = versionNumber:match( "^(%d+)%.(%d+)%.(%d+)" ) 33 | else 34 | print("%s: SetVersionNumber: versionNumber is not in correct format: %s", "MSSQL", versionNumber or "nil" ) 35 | end 36 | 37 | self:SetVersion( major, minor, revision, subBuild, source ) 38 | end, 39 | 40 | --- Sets the version using the individual numeric components of the version 41 | -- number. 42 | -- 43 | -- @param source a string indicating the source of the version info (e.g. "SSRP", "SSNetLib") 44 | SetVersion = function(self, major, minor, build, subBuild, source) 45 | self.source = source 46 | -- make sure our version numbers all end up as valid numbers 47 | self.major, self.minor, self.build, self.subBuild = 48 | tonumber( major or 0 ), tonumber( minor or 0 ), tonumber( build or 0 ), tonumber( subBuild or 0 ) 49 | 50 | self.versionNumber = string.format( "%u.%02u.%u.%02u", self.major, self.minor, self.build, self.subBuild ) 51 | 52 | self:_ParseVersionInfo() 53 | end, 54 | 55 | --- Using the version number, determines the product version 56 | _InferProductVersion = function(self) 57 | 58 | local VERSION_LOOKUP_TABLE = { 59 | ["^6%.0"] = "6.0", ["^6%.5"] = "6.5", ["^7%.0"] = "7.0", 60 | ["^8%.0"] = "2000", ["^9%.0"] = "2005", ["^10%.0"] = "2008", 61 | ["^10%.50"] = "2008 R2", ["^11%.0"] = "2012", ["^12%.0"] = "2014", 62 | ["^13%.0"] = "2016", ["^14%.0"] = "2017", ["^15%.0"] = "2019" 63 | } 64 | 65 | local product = "" 66 | 67 | for m, v in pairs(VERSION_LOOKUP_TABLE) do 68 | if ( self.versionNumber:match(m) ) then 69 | product = v 70 | self.brandedVersion = product 71 | break 72 | end 73 | end 74 | 75 | self.productName = ("Microsoft SQL Server %s"):format(product) 76 | 77 | end, 78 | 79 | 80 | --- Returns a lookup table that maps revision numbers to service pack levels for 81 | -- the applicable SQL Server version (e.g. { {1600, "RTM"}, {2531, "SP1"} }). 82 | _GetSpLookupTable = function(self) 83 | 84 | -- Service pack lookup tables: 85 | -- For instances where a revised service pack was released (e.g. 2000 SP3a), we will include the 86 | -- build number for the original SP and the build number for the revision. However, leaving it 87 | -- like this would make it appear that subsequent builds were a patched version of the revision 88 | -- (e.g. a patch applied to 2000 SP3 that increased the build number to 780 would get displayed 89 | -- as "SP3a+", when it was actually SP3+). To avoid this, we will include an additional fake build 90 | -- number that combines the two. 91 | local SP_LOOKUP_TABLE = { 92 | ["6.5"] = { 93 | {201, "RTM"}, 94 | {213, "SP1"}, 95 | {240, "SP2"}, 96 | {258, "SP3"}, 97 | {281, "SP4"}, 98 | {415, "SP5"}, 99 | {416, "SP5a"}, 100 | {417, "SP5/SP5a"}, 101 | }, 102 | 103 | ["7.0"] = { 104 | {623, "RTM"}, 105 | {699, "SP1"}, 106 | {842, "SP2"}, 107 | {961, "SP3"}, 108 | {1063, "SP4"}, 109 | }, 110 | 111 | ["2000"] = { 112 | {194, "RTM"}, 113 | {384, "SP1"}, 114 | {532, "SP2"}, 115 | {534, "SP2"}, 116 | {760, "SP3"}, 117 | {766, "SP3a"}, 118 | {767, "SP3/SP3a"}, 119 | {2039, "SP4"}, 120 | }, 121 | 122 | ["2005"] = { 123 | {1399, "RTM"}, 124 | {2047, "SP1"}, 125 | {3042, "SP2"}, 126 | {4035, "SP3"}, 127 | {5000, "SP4"}, 128 | }, 129 | 130 | ["2008"] = { 131 | {1600, "RTM"}, 132 | {2531, "SP1"}, 133 | {4000, "SP2"}, 134 | {5500, "SP3"}, 135 | {6000, "SP4"}, 136 | }, 137 | 138 | ["2008 R2"] = { 139 | {1600, "RTM"}, 140 | {2500, "SP1"}, 141 | {4000, "SP2"}, 142 | {6000, "SP3"}, 143 | }, 144 | 145 | ["2012"] = { 146 | {2100, "RTM"}, 147 | {3000, "SP1"}, 148 | {5058, "SP2"}, 149 | {6020, "SP3"}, 150 | {7001, "SP4"}, 151 | }, 152 | 153 | ["2014"] = { 154 | {2000, "RTM"}, 155 | {4100, "SP1"}, 156 | {5000, "SP2"}, 157 | {6024, "SP3"}, 158 | }, 159 | 160 | ["2016"] = { 161 | {1601, "RTM"}, 162 | {4001, "SP1"}, 163 | {5026, "SP2"}, 164 | }, 165 | 166 | ["2017"] = { 167 | {1000, "RTM"}, 168 | {3257, "CU18"}, 169 | }, 170 | 171 | ["2019"] = { 172 | {2000, "RTM"}, 173 | }, 174 | } 175 | 176 | 177 | if ( not self.brandedVersion ) then 178 | self:_InferProductVersion() 179 | end 180 | 181 | local spLookupTable = SP_LOOKUP_TABLE[self.brandedVersion] 182 | print("brandedVersion: %s, #lookup: %d", self.brandedVersion, spLookupTable and #spLookupTable or 0) 183 | 184 | return spLookupTable 185 | 186 | end, 187 | 188 | 189 | --- Processes version data to determine (if possible) the product version, 190 | -- service pack level and patch status. 191 | _ParseVersionInfo = function(self) 192 | 193 | local spLookupTable = self:_GetSpLookupTable() 194 | 195 | if spLookupTable then 196 | 197 | local spLookupItr = 0 198 | -- Loop through the service pack levels until we find one whose revision 199 | -- number is the same as or lower than our revision number. 200 | while spLookupItr < #spLookupTable do 201 | spLookupItr = spLookupItr + 1 202 | 203 | if (spLookupTable[ spLookupItr ][1] == self.build ) then 204 | spLookupItr = spLookupItr 205 | break 206 | elseif (spLookupTable[ spLookupItr ][1] > self.build ) then 207 | -- The target revision number is lower than the first release 208 | if spLookupItr == 1 then 209 | self.servicePackLevel = "Pre-RTM" 210 | else 211 | -- we went too far - it's the previous SP, but with patches applied 212 | spLookupItr = spLookupItr - 1 213 | end 214 | break 215 | end 216 | end 217 | 218 | -- Now that we've identified the proper service pack level: 219 | if self.servicePackLevel ~= "Pre-RTM" then 220 | self.servicePackLevel = spLookupTable[ spLookupItr ][2] 221 | 222 | if ( spLookupTable[ spLookupItr ][1] == self.build ) then 223 | self.patched = false 224 | else 225 | self.patched = true 226 | end 227 | end 228 | 229 | -- Clean up some of our inferences. If the source of our revision number 230 | -- was the SSRP (SQL Server Browser) response, we need to recognize its 231 | -- limitations: 232 | -- * Versions of SQL Server prior to 2005 are reported with the RTM build 233 | -- number, regardless of the actual version (e.g. SQL Server 2000 is 234 | -- always 8.00.194). 235 | -- * Versions of SQL Server starting with 2005 (and going through at least 236 | -- 2008) do better but are still only reported with the build number as 237 | -- of the last service pack (e.g. SQL Server 2005 SP3 with patches is 238 | -- still reported as 9.00.4035.00). 239 | if ( self.source == "SSRP" ) then 240 | self.patched = nil 241 | 242 | if ( self.major <= 8 ) then 243 | self.servicePackLevel = nil 244 | end 245 | end 246 | end 247 | 248 | return true 249 | end, 250 | 251 | --- 252 | ToString = function(self) 253 | local rs = {} 254 | if self.productName then 255 | rs[#rs+1]= self.productName 256 | if self.servicePackLevel then 257 | rs[#rs+1]= " " 258 | rs[#rs+1]= self.servicePackLevel 259 | end 260 | if self.patched then 261 | rs[#rs+1]= "+" 262 | end 263 | end 264 | 265 | return table.concat(rs) 266 | end, 267 | 268 | 269 | } 270 | 271 | return _M -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | --function scandir(directory) 2 | -- local i, t, popen = 0, {}, io.popen 3 | -- local pfile = popen('dir "'..directory..'*.lua" /b') 4 | -- for filename in pfile:lines() do 5 | -- i = i + 1 6 | -- t[i] = filename 7 | -- end 8 | -- pfile:close() 9 | -- return t 10 | --end 11 | --local currentDir=debug.getinfo(1).source:sub(1,#debug.getinfo(1).source-8) 12 | --print(currentDir) 13 | --local files=scandir("C:\\env\\openresty-1.15.8.3-win64\\lualib\\gateway\\") 14 | --for i,v in ipairs (files) do 15 | -- print(v) 16 | -- local mod=require("suproxy."..v:sub(1,#v-4)) 17 | -- if mod.test then 18 | -- mod.test() 19 | -- end 20 | --end 21 | print("start testing") 22 | local m=require "suproxy.utils.compatibleLog" 23 | m.test() 24 | m=require "suproxy.utils.datetime" 25 | m.test() 26 | m=require "suproxy.utils.event" 27 | m.test() 28 | m=require "suproxy.utils.pureluapack" 29 | m.test() 30 | m=require "suproxy.utils.tableUtils" 31 | m.test() 32 | m=require "suproxy.utils.unicode" 33 | m.test() 34 | m=require "suproxy.tns.tnsPackets" 35 | m.test() 36 | m=require "suproxy.tds.tdsPackets" 37 | m.test() 38 | m=require "suproxy.ssh2.shellCommand" 39 | m.test() 40 | m=require "suproxy.ldap.ldapPackets" 41 | m.test() 42 | m=require "suproxy.balancer.balancer" 43 | m.test() 44 | 45 | print("All test finished without error") -------------------------------------------------------------------------------- /tns.lua: -------------------------------------------------------------------------------- 1 | require "suproxy.utils.stringUtils" 2 | require "suproxy.utils.pureluapack" 3 | local event=require "suproxy.utils.event" 4 | local logger=require "suproxy.utils.compatibleLog" 5 | local tnsPackets=require "suproxy.tns.tnsPackets" 6 | local tableUtils=require "suproxy.utils.tableUtils" 7 | local crypt= require "suproxy.tns.crypt" local _M = {} 8 | _M._PROTOCAL ='tns' 9 | 10 | function _M.new(self,options) 11 | options=options or {} 12 | local o= setmetatable({},{__index=self}) 13 | o.AuthSuccessEvent=event:new(o,"AuthSuccessEvent") 14 | o.AuthFailEvent=event:new(o,"AuthFailEvent") 15 | o.BeforeAuthEvent=event:newReturnEvent(o,"BeforeAuthEvent") 16 | o.OnAuthEvent=event:newReturnEvent(o,"OnAuthEvent") 17 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent") 18 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent") 19 | o.ContextUpdateEvent=event:new(o,"ContextUpdateEvent") 20 | o.options=tnsPackets.Options:new() 21 | o.options.oracleVersion.major=options.oracleVersion or o.options.oracleVersion.major 22 | o.swapPass=options.swapPass or false 23 | o.ctx={} 24 | local tnsParser=require ("suproxy.tns.parser"):new() 25 | o.C2PParser=tnsParser.C2PParser 26 | o.C2PParser.events.ConnectEvent:setHandler(o,_M.ConnectHandler) 27 | o.C2PParser.events.AuthRequestEvent:setHandler(o,_M.AuthRequestHandler) 28 | o.C2PParser.events.SessionRequestEvent:setHandler(o,_M.SessionRequestHandler) 29 | o.C2PParser.events.SetProtocolEvent:setHandler(o,_M.SetProtocolRequestHandler) 30 | o.C2PParser.events.SQLRequestEvent:setHandler(o,_M.SQLRequestHandler) 31 | o.C2PParser.events.Piggyback1169:setHandler(o,_M.PiggbackHandler) 32 | o.C2PParser.events.Piggyback116b:setHandler(o,_M.PiggbackHandler) 33 | o.C2PParser.events.MarkerEvent:setHandler(o,_M.MarkerHandler) 34 | o.S2PParser=tnsParser.S2PParser 35 | o.S2PParser.events.SessionResponseEvent:setHandler(o,_M.SessionResponseHandler) 36 | o.S2PParser.events.VersionResponseEvent:setHandler(o,_M.VersionResponseHandler) 37 | o.S2PParser.events.SetProtocolEvent:setHandler(o,_M.SetProtocolResponseHandler) 38 | o.S2PParser.events.AcceptEvent:setHandler(o,_M.AcceptHandler) 39 | o.S2PParser.events.AuthErrorEvent:setHandler(o,_M.AuthErrorHandler) 40 | return o 41 | end 42 | 43 | ----------------parser event handlers---------------------- 44 | function _M:ConnectHandler(src,p) 45 | p:setTnsVersion(314) 46 | self.ctx.connStr=p:getConnStr() 47 | p:pack() 48 | end 49 | 50 | function _M:AcceptHandler(src,p) 51 | self.options.tnsVersion=p:getTnsVersion() 52 | self.options.headerCheckSum=p:checkHeader() 53 | self.options.packetCheckSum=p:checkPacket() 54 | self.ctx.tnsVer=p:getTnsVersion() 55 | end 56 | 57 | function _M:AuthRequestHandler(src,p) 58 | local ckey=p:getAuthKey() 59 | local skey=self.serverKey 60 | local tmpKey=self.tmpKey 61 | local salt=self.salt 62 | local pass=p:getPassword() 63 | local username=p:getUsername() 64 | if username ~= self.username then 65 | p:setUsername(self.username) 66 | end 67 | local passChanged=false 68 | if self.swapPass and self.tempPass ~= self.realPass then 69 | passChanged=true 70 | local ck 71 | local realpass=self.realPass 72 | if self.options.oracleVersion.major==11 then 73 | ck,pass=crypt:Decrypt11g(ckey,tmpKey,pass,self.tempPass,salt) 74 | elseif self.options.oracleVersion.major==10 then 75 | --ck,pass=crypt:Decrypt10g(ckey,tmpKey,pass,self.tempPass,salt) 76 | end 77 | if self.options.oracleVersion.major==11 or self.options.oracleVersion.major==10 then 78 | ckey,pass=crypt:Encrypt11g(realpass,ck,skey,salt) 79 | p:setAuthKey(ckey) 80 | p:setPassword(pass) 81 | end 82 | end 83 | if self.OnAuthEvent:hasHandler() then 84 | local ok,message=self.OnAuthEvent:trigger({username=username,password=self.tempPass},self.ctx) 85 | if not ok then 86 | --return marker1 87 | self.channel:c2pSend(tnsPackets.Marker1:pack()) 88 | --return marker2 89 | self.channel:c2pSend(tnsPackets.Marker2:pack()) 90 | self.responseError=true 91 | p.allBytes=nil 92 | return 93 | end 94 | end 95 | if username ~= self.username or passChanged then 96 | p:pack() 97 | end 98 | end 99 | 100 | function _M:SessionRequestHandler(src,p) 101 | self.options.program=p:getProgram() 102 | self.options.is64Bit=p:is64Bit() 103 | self.ctx.client=p:getProgram() 104 | self.username=p:getUsername() 105 | if self.BeforeAuthEvent:hasHandler() then 106 | local cred=self.BeforeAuthEvent:trigger({username=p:getUsername()},self.ctx) 107 | self.tempPass=cred.temppass 108 | self.realPass=cred.password 109 | self.username=cred.username 110 | self.tempUsername=p:getUsername() 111 | end 112 | self.ctx.username=self.username 113 | if self.ContextUpdateEvent:hasHandler() then 114 | self.ContextUpdateEvent:trigger(self.ctx) 115 | end 116 | p:setUsername(self.username) 117 | p:pack() 118 | end 119 | 120 | function _M:SetProtocolRequestHandler(src,p) 121 | self.options.platform=p:getClientPlatform() 122 | self.ctx.clientPlatform=p:getClientPlatform() 123 | end 124 | 125 | function _M:SetProtocolResponseHandler(src,p) 126 | self.options.srvPlatform=p:getClientPlatform() 127 | self.ctx.srvPlatform=p:getClientPlatform() 128 | end 129 | 130 | function _M:VersionResponseHandler(src,p) 131 | --todo find a better chance to trigger authSuccess 132 | if self.AuthSuccessEvent:hasHandler() then 133 | self.AuthSuccessEvent:trigger(self.ctx.username,self.ctx) 134 | end 135 | self.options.oracleVersion.major=p:getMajor() 136 | self.options.oracleVersion.minor=p:getMinor() 137 | self.options.oracleVersion.build=p:getBuild() 138 | self.options.oracleVersion.subbuild=p:getSub() 139 | self.options.oracleVersion.fix=p:getFix() 140 | self.ctx.serverVer=p:getVersion() 141 | if self.ContextUpdateEvent:hasHandler() then 142 | self.ContextUpdateEvent:trigger(self.ctx) 143 | end 144 | end 145 | 146 | function _M:SessionResponseHandler(src,p) 147 | --if temp pass equals real pass then do nothing 148 | if self.tempPass==self.realPass or not self.swapPass then return end 149 | if self.options.oracleVersion.major==11 or self.options.oracleVersion.major==10 then 150 | self.serverKey=p:getAuthKey() 151 | self.salt=p:getSalt() 152 | local tmpKey=crypt:getServerKey(self.tempPass,self.realPass,self.serverKey,self.salt) 153 | self.tmpKey=tmpKey 154 | p:setAuthKey(tmpKey) 155 | p:pack() 156 | end 157 | end 158 | 159 | function _M:PiggbackHandler(src,p) 160 | if not p.__key then return end 161 | local entry=self.C2PParser.parserList[p.__key] 162 | if entry and entry.event and entry.event:hasHandler() then 163 | entry.event:trigger(p) 164 | end 165 | end 166 | 167 | function _M:MarkerHandler(src,p) 168 | --process req marker, if flag true then return error 169 | if self.responseError then 170 | self.channel:c2pSend(tnsPackets.NoPermissionError:new(self.options):pack()) 171 | self.responseError=false 172 | end 173 | if self.sessionStop then 174 | self.channel:c2pSend(tnsPackets.NoPermissionError:new(self.options):pack()) 175 | ngx.exit(0) 176 | return 177 | end 178 | end 179 | 180 | function _M:AuthErrorHandler(src,p) 181 | if self.AuthFailEvent:hasHandler() then 182 | self.AuthFailEvent:trigger({username=self.username},self.ctx) 183 | end 184 | end 185 | 186 | function _M:SQLRequestHandler(src,p) 187 | local command=p:getCommand() 188 | --if command=="altersession" then replace end 189 | if command and command:len()>0 then 190 | local allBytes 191 | if self.CommandEnteredEvent:hasHandler() then 192 | local cmd,err=self.CommandEnteredEvent:trigger(command,self.ctx) 193 | if err then 194 | --set a flag indicate error happen 195 | self.responseError=true 196 | --return marker1 197 | self.channel:c2pSend(tnsPackets.Marker1:pack()) 198 | --return marker2 199 | self.channel:c2pSend(tnsPackets.Marker2:pack()) 200 | p.allBytes=nil 201 | return 202 | end 203 | if cmd and cmd~=command then 204 | command=cmd 205 | p:setCommand(cmd) 206 | p:pack() 207 | end 208 | end 209 | --if username was changed during login, alter session sql sent by client should be update to real username 210 | if self.tempUsername ~= self.username then 211 | if command:match("ALTER SESSION SET CURRENT_SCHEMA") then 212 | command=command:gsub("%= .*","%= "..self.username:literalize()) 213 | p:setCommand(command) 214 | p:pack() 215 | end 216 | end 217 | 218 | if self.CommandFinishedEvent:hasHandler() then 219 | self.CommandFinishedEvent:trigger(command,"",self.ctx) 220 | end 221 | end 222 | end 223 | 224 | -------------implement processor methods--------------- 225 | function _M.processUpRequest(self) 226 | local readMethod=self.channel.c2pRead 227 | local allBytes,err=self:recv(readMethod) 228 | if err then return nil,err end 229 | local p=self.C2PParser:parse(allBytes,nil,nil,self.options) 230 | self.request=p.__key 231 | return p.allBytes 232 | end 233 | 234 | function _M.processDownRequest(self) 235 | local readMethod=self.channel.p2sRead 236 | local allBytes,err= self:recv(readMethod) 237 | if err then return nil,err end 238 | --if request is oci function call , then set options and wait for respond 239 | local ociCall 240 | if self.request then 241 | ociCall= self.request:match("callId") and self.request or nil 242 | end 243 | local p=self.S2PParser:parse(allBytes,nil,nil,self.options,ociCall) 244 | return p.allBytes 245 | end 246 | 247 | function _M:recv(readMethod) 248 | local lengthdata,err=readMethod(self.channel,2) 249 | if(err) then 250 | logger.log(logger.ERR,"err when reading length") 251 | return nil,err 252 | end 253 | local pktLen=string.unpack(">I2",lengthdata) 254 | local data,err=readMethod(self.channel,pktLen-2) 255 | if(err) then 256 | logger.log(logger.ERR,"err when reading packet") 257 | return nil,err 258 | end 259 | local allBytes=lengthdata..data 260 | return allBytes 261 | end 262 | 263 | function _M:sessionInvalid(session) 264 | --return marker1 265 | self.channel:c2pSend(tnsPackets.Marker1:pack()) 266 | --return marker2 267 | self.channel:c2pSend(tnsPackets.Marker2:pack()) 268 | self.sessionStop=true 269 | end 270 | 271 | return _M 272 | -------------------------------------------------------------------------------- /tns/crypt.lua: -------------------------------------------------------------------------------- 1 | --- Class that handles all Oracle encryption 2 | local cipher=require ("resty.openssl.cipher") 3 | local rand=require("resty.openssl.rand") 4 | require "suproxy.utils.stringUtils" 5 | require("suproxy.utils.pureluapack") 6 | local aes = require "resty.aes" 7 | local _M = { 8 | 9 | getServerKey=function(self,pass,realpass,s_sesskey,auth_vrfy_data) 10 | local pw_hash = ngx.sha1_bin(realpass .. auth_vrfy_data) .. "\0\0\0\0" 11 | local srv_sesskey=cipher.new("aes-192-cbc"):decrypt(pw_hash,pw_hash:sub(1,16),s_sesskey,true) 12 | --srv_sesskey= rand.bytes(40) .. string.fromhex("0808080808080808") 13 | local pw_hash = ngx.sha1_bin(pass .. auth_vrfy_data) .. "\0\0\0\0" 14 | local result=cipher.new("aes-192-cbc"):encrypt(pw_hash,pw_hash:sub(1,16),srv_sesskey,true) 15 | return result,srv_sesskey 16 | end, 17 | 18 | Decrypt11g = function(self, c_sesskey, s_sesskey, auth_password, pass, salt ) 19 | local sha1 = ngx.sha1_bin(pass .. salt) .. "\0\0\0\0" 20 | local server_sesskey =cipher.new("aes-192-cbc"):decrypt(sha1,sha1:sub(1,16),s_sesskey,true) 21 | local client_sesskey = cipher.new("aes-192-cbc"):decrypt(sha1,sha1:sub(1,16),c_sesskey,true) 22 | local combined_sesskey = {} 23 | for i=17, 40 do 24 | combined_sesskey[#combined_sesskey+1] = string.char( bit.bxor(string.byte(server_sesskey, i) , string.byte(client_sesskey,i)) ) 25 | end 26 | combined_sesskey = table.concat(combined_sesskey) 27 | print("combined_sesskey",combined_sesskey:hex()) 28 | combined_sesskey = ( ngx.md5_bin( combined_sesskey:sub(1,16) ) .. ngx.md5_bin(combined_sesskey:sub(17) ) ):sub(1, 24) 29 | local p,err= cipher.new("aes-192-cbc"):decrypt(combined_sesskey,combined_sesskey:sub(1,16),auth_password) 30 | return client_sesskey,p:sub(17) 31 | end, 32 | 33 | -- -- - Creates an Oracle 10G password hash 34 | 35 | -- -- @param username containing the Oracle user name 36 | -- -- @param password containing the Oracle user password 37 | -- -- @return hash containing the Oracle hash 38 | -- HashPassword10g = function( self, username, password ) 39 | -- local uspw = (username .. password):upper():gsub(".", "\0%1") 40 | -- local key = stdnse.fromhex("0123456789abcdef") 41 | 42 | -- -- do padding 43 | -- uspw = uspw .. string.rep('\0', (8 - (#uspw % 8)) % 8) 44 | 45 | -- local iv2 = openssl.encrypt( "DES-CBC", key, nil, uspw, false ):sub(-8) 46 | -- local enc = openssl.encrypt( "DES-CBC", iv2, nil, uspw, false ):sub(-8) 47 | -- return enc 48 | -- end, 49 | 50 | -- -- Test function, not currently in use 51 | -- Decrypt10g = function(self, user, pass, srv_sesskey_enc ) 52 | -- local pwhash = self:HashPassword10g( user, pass ) .. "\0\0\0\0\0\0\0\0" 53 | -- local cli_sesskey_enc = stdnse.fromhex("7B244D7A1DB5ABE553FB9B7325110024911FCBE95EF99E7965A754BC41CF31C0") 54 | -- local srv_sesskey = openssl.decrypt( "AES-128-CBC", pwhash, nil, srv_sesskey_enc ) 55 | -- local cli_sesskey = openssl.decrypt( "AES-128-CBC", pwhash, nil, cli_sesskey_enc ) 56 | -- local auth_pass = stdnse.fromhex("4C5E28E66B6382117F9D41B08957A3B9E363B42760C33B44CA5D53EA90204ABE") 57 | -- local pass 58 | 59 | -- local combined_sesskey = {} 60 | -- for i=17, 32 do 61 | -- combined_sesskey[#combined_sesskey+1] = string.char( string.byte(srv_sesskey, i) ~ string.byte(cli_sesskey, i) ) 62 | -- end 63 | -- combined_sesskey = openssl.md5( table.concat(combined_sesskey) ) 64 | 65 | -- pass = openssl.decrypt( "AES-128-CBC", combined_sesskey, nil, auth_pass ):sub(17) 66 | 67 | -- print( stdnse.tohex( srv_sesskey )) 68 | -- print( stdnse.tohex( cli_sesskey )) 69 | -- print( stdnse.tohex( combined_sesskey )) 70 | -- print( "pass=" .. pass ) 71 | -- end, 72 | 73 | -- -- - Performs the relevant encryption needed for the Oracle 10g response 74 | 75 | -- -- @param user containing the Oracle user name 76 | -- -- @param pass containing the Oracle user password 77 | -- -- @param srv_sesskey_enc containing the encrypted server session key as 78 | -- -- received from the PreAuth packet 79 | -- -- @return cli_sesskey_enc the encrypted client session key 80 | -- -- -- @return auth_pass the encrypted Oracle password 81 | -- Encrypt10g = function( self, user, pass, srv_sesskey_enc ) 82 | 83 | -- local pwhash = self:HashPassword10g( user, pass ) .. "\0\0\0\0\0\0\0\0" 84 | -- -- We're currently using a static client session key, this should 85 | -- -- probably be changed to a random value in the future 86 | -- local cli_sesskey = stdnse.fromhex("FAF5034314546426F329B1DAB1CDC5B8FF94349E0875623160350B0E13A0DA36") 87 | -- local srv_sesskey = openssl.decrypt( "AES-128-CBC", pwhash, nil, srv_sesskey_enc ) 88 | -- local cli_sesskey_enc = openssl.encrypt( "AES-128-CBC", pwhash, nil, cli_sesskey ) 89 | -- -- This value should really be random, not this static cruft 90 | -- local rnd = stdnse.fromhex("4C31AFE05F3B012C0AE9AB0CDFF0C508") 91 | -- local auth_pass 92 | 93 | -- local combined_sesskey = {} 94 | -- for i=17, 32 do 95 | -- combined_sesskey[#combined_sesskey+1] = string.char( string.byte(srv_sesskey, i) ~ string.byte(cli_sesskey, i) ) 96 | -- end 97 | -- combined_sesskey = openssl.md5( table.concat(combined_sesskey) ) 98 | -- auth_pass = openssl.encrypt("AES-128-CBC", combined_sesskey, nil, rnd .. pass, true ) 99 | -- auth_pass = stdnse.tohex(auth_pass) 100 | -- cli_sesskey_enc = stdnse.tohex(cli_sesskey_enc) 101 | -- return cli_sesskey_enc, auth_pass 102 | -- end, 103 | 104 | -- - Performs the relevant encryption needed for the Oracle 11g response 105 | 106 | -- @param pass containing the Oracle user password 107 | -- @param cli_sesskey unencrypted client key 108 | -- @param srv_sesskey_enc containing the encrypted server session key as 109 | -- received from the PreAuth packet 110 | -- @param auth_vrfy_data containing the password salt as received from the 111 | -- PreAuth packet 112 | -- @return cli_sesskey_enc the encrypted client session key 113 | -- @return auth_pass the encrypted Oracle password 114 | Encrypt11g = function( self, pass,cli_sesskey, srv_sesskey_enc, auth_vrfy_data ) 115 | local rnd = rand.bytes(16) 116 | --local cli_sesskey = rand.bytes(40) .. string.fromhex("0808080808080808") 117 | local pw_hash = ngx.sha1_bin(pass .. auth_vrfy_data) .. "\0\0\0\0" 118 | local srv_sesskey=cipher.new("aes-192-cbc"):decrypt(pw_hash, pw_hash:sub(1,16),srv_sesskey_enc) 119 | local auth_password 120 | local cli_sesskey_enc 121 | local combined_sesskey = {} 122 | for i=17, 40 do 123 | combined_sesskey[#combined_sesskey+1] = string.char( bit.bxor(string.byte(srv_sesskey, i) ,string.byte(cli_sesskey, i) )) 124 | end 125 | combined_sesskey = table.concat(combined_sesskey) 126 | combined_sesskey = ( ngx.md5_bin( combined_sesskey:sub(1,16) ) .. ngx.md5_bin( combined_sesskey:sub(17) ) ):sub(1, 24) 127 | local cli_sesskey_enc=cipher.new("aes-192-cbc"):encrypt(pw_hash,pw_hash:sub(1,16),cli_sesskey,true) 128 | auth_password=cipher.new("aes-192-cbc"):encrypt(combined_sesskey,combined_sesskey:sub(1,16),rnd .. pass) 129 | return cli_sesskey_enc, auth_password 130 | end, 131 | 132 | } 133 | return _M -------------------------------------------------------------------------------- /tns/parser.lua: -------------------------------------------------------------------------------- 1 | --tns protocol parser 2 | local P=require "suproxy.tns.tnsPackets" 3 | local K=P.getKey 4 | local parser=require("suproxy.parser") 5 | local _M={} 6 | 7 | local conf={ 8 | {key=K({code=1}), parser=P.Connect, eventName="ConnectEvent"}, 9 | {key=K({callId=0x76}), parser=P.SessionRequest, eventName="SessionRequestEvent"}, 10 | {key=K({callId=0x73}), parser=P.AuthRequest, eventName="AuthRequestEvent"}, 11 | {key=K({dataId=1}), parser=P.SetProtocolRequest, eventName="SetProtocolEvent"}, 12 | {key=K({callId=0x69}), parser=P.Piggyback, eventName="Piggyback1169"}, 13 | {key=K({callId=0x6b}), parser=P.Piggyback, eventName="Piggyback116b"}, 14 | {key=K({callId=0x5e}), parser=P.SQLRequest, eventName="SQLRequestEvent"}, 15 | {key=K({code=2}), parser=P.Accept, eventName="AcceptEvent"}, 16 | {key=K({code=12}), eventName="MarkerEvent"}, 17 | {key=K({dataId=8,req=K({callId=0x76})}), parser=P.SessionResponse, eventName="SessionResponseEvent"}, 18 | {key=K({dataId=8,req=K({callId=0x3b})}), parser=P.VersionResponse, eventName="VersionResponseEvent"}, 19 | {key=K({code=12,req=K({callId=0x73})}), eventName="AuthErrorEvent"}, 20 | } 21 | 22 | local keyG=function(allBytes,pos,options,request) 23 | --if options is null use default value 24 | options=P.Options:new(options) 25 | local pktChkLen=(options:pktChk() and 2 or 0) 26 | local hdrChkLen=(options:hdrChk() and 2 or 0) 27 | local pktType=allBytes:byte(3+pktChkLen) 28 | local dataId,callId,key,keyStr 29 | if pktType==P.PacketType.DATA.code then dataId=allBytes:byte(7+pktChkLen+hdrChkLen) end 30 | if dataId==P.DataID.USER_OCI_FUNC.code or dataId==P.DataID.PIGGYBACK_FUNC.code then callId=allBytes:byte(8+pktChkLen+hdrChkLen) end 31 | return K({callId=callId,dataId=dataId,code=pktType,req=request}) 32 | end 33 | 34 | function _M:new() 35 | local o= setmetatable({},{__index=self}) 36 | local C2PParser=parser:new() 37 | C2PParser.keyGenerator=keyG 38 | C2PParser:registerMulti(conf) 39 | C2PParser:registerDefaultParser(P.Packet) 40 | o.C2PParser=C2PParser 41 | 42 | local S2PParser=parser:new() 43 | S2PParser.keyGenerator=keyG 44 | S2PParser:registerMulti(conf) 45 | S2PParser:registerDefaultParser(P.Packet) 46 | o.S2PParser=S2PParser 47 | return o 48 | end 49 | return _M 50 | 51 | 52 | -------------------------------------------------------------------------------- /utils/compatibleLog.lua: -------------------------------------------------------------------------------- 1 | local tableUtils=require "suproxy.utils.tableUtils" 2 | local _M={} 3 | 4 | _M.STDERR ={code=0x00,ngxCode=0x00,desc="STDERR"} 5 | _M.EMERG ={code=0x01,ngxCode=0x01,desc="EMERG" } 6 | _M.ALERT ={code=0x02,ngxCode=0x02,desc="ALERT" } 7 | _M.CRIT ={code=0x03,ngxCode=0x03,desc="CRIT" } 8 | _M.ERR ={code=0x04,ngxCode=0x04,desc="ERR" } 9 | _M.WARN ={code=0x05,ngxCode=0x05,desc="WARN" } 10 | _M.NOTICE ={code=0x06,ngxCode=0x06,desc="NOTICE"} 11 | _M.INFO ={code=0x07,ngxCode=0x07,desc="INFO" } 12 | _M.DEBUG ={code=0x08,ngxCode=0x08,desc="DEBUG" } 13 | 14 | local _ngxMapping={ 15 | [0x00]=_M.STDERR,[0x01]=_M.EMERG ,[0x02]=_M.ALERT , 16 | [0x03]=_M.CRIT ,[0x04]=_M.ERR ,[0x05]=_M.WARN , 17 | [0x06]=_M.NOTICE,[0x07]=_M.INFO ,[0x08]=_M.DEBUG , 18 | } 19 | 20 | local _logLevel=_M.DEBUG 21 | 22 | function _M.logInner(level,stackUpLevel,...) 23 | level=level or _M.NOTICE 24 | stackUpLevel=stackUpLevel or 2 25 | local args={...} 26 | local func=debug.getinfo(stackUpLevel).short_src ..":"..debug.getinfo(stackUpLevel).currentline 27 | local ok,ngxLog=pcall(require,"ngx.errlog") 28 | if ok and ngxLog then 29 | ngxLog.raw_log(level.ngxCode,func..": "..tableUtils.concat(args)) 30 | elseif level.code<=_logLevel.code then 31 | print(level.desc,func,":",tableUtils.concat(args)) 32 | end 33 | end 34 | 35 | function _M.log(level,...) 36 | _M.logInner(level,3,...) 37 | end 38 | 39 | function _M.logWithTitle(level,title,...) 40 | local l=(80-#title)/2 41 | local l=l >= 0 and l or 0 42 | _M.logInner(level,3,"\r\n"..string.rep("-",l)..title..string.rep("-",l).."\r\n",...) 43 | end 44 | 45 | function _M.getLogLevel() 46 | local ok,ret=pcall(require,"ngx.errlog") 47 | if ok then _logLevel= _ngxMapping[ret.get_sys_filter_level()] end 48 | return _logLevel 49 | end 50 | 51 | function _M.setLogLevel(level) 52 | _logLevel=level 53 | local ok,ret=pcall(require,"ngx.errlog") 54 | if ok then 55 | status,err=ret.set_filter_level(level.ngxCode) 56 | if not status then 57 | ngx.log(ngx.ERR, err) 58 | end 59 | end 60 | end 61 | 62 | _M.unitTest={} 63 | 64 | function _M.test() 65 | _M.setLogLevel(_M.ERR) 66 | _M.log(_M.DEBUG,"abc",1,nil,{}) 67 | _M.logWithTitle(_M.ERR,"abc",1,nil,{},6) 68 | _M.setLogLevel(_M.DEBUG) 69 | _M.log(_M.DEBUG,"abc",1,nil,{}) 70 | _M.logWithTitle(_M.ERR,"abc",1,nil,{},6) 71 | end 72 | 73 | return _M 74 | -------------------------------------------------------------------------------- /utils/datetime.lua: -------------------------------------------------------------------------------- 1 | --from https://www.cnblogs.com/wangzhitie/p/5209985.html 2 | --[[ 3 | 参数说明: 4 | srcDateTime 原始时间字符串,要求格式%Y%m%d%H%M%S,这个时间格式字符串表示4位年份、月份、day、小时、分钟、秒都是2位数字 5 | interval 对该时间进行加或减具体值,>0表示加 <0表示减 6 | dateUnit 时间单位,支持DAY、HOUR、SECOND、MINUTE 4种时间单位操作,根据interval具体值对原始时间按指定的单位进行加或减 7 | 例如, 8 | interval=10,unit='DAY',表示对原始时间加10天 9 | interval=-1,unit='HOUR',表示对原始时间减1小时 10 | 11 | 返回结果是一个os.date,他是一个table结构,里面包含了year,month,day,hour,minute,second 6个属性,跟据需要从结果里面取出需要的属性然后根据需要产生相应的新的日期格式即可。 12 | ]] 13 | 14 | local _M={} 15 | 16 | function _M.getNewDate(srcDateTime,interval ,dateUnit) 17 | --从日期字符串中截取出年月日时分秒 18 | local Y = string.sub(srcDateTime,1,4) 19 | local M = string.sub(srcDateTime,5,6) 20 | local D = string.sub(srcDateTime,7,8) 21 | local H = string.sub(srcDateTime,9,10) 22 | local MM = string.sub(srcDateTime,11,12) 23 | local SS = string.sub(srcDateTime,13,14) 24 | 25 | --把日期时间字符串转换成对应的日期时间 26 | local dt1 = os.time{year=Y, month=M, day=D, hour=H,min=MM,sec=SS} 27 | 28 | --根据时间单位和偏移量得到具体的偏移数据 29 | local ofset=0 30 | 31 | if dateUnit =='DAY' then 32 | ofset = 60 *60 * 24 * interval 33 | 34 | elseif dateUnit == 'HOUR' then 35 | ofset = 60 *60 * interval 36 | 37 | elseif dateUnit == 'MINUTE' then 38 | ofset = 60 * interval 39 | 40 | elseif dateUnit == 'SECOND' then 41 | ofset = interval 42 | end 43 | 44 | --指定的时间+时间偏移量 45 | local newTime = os.date("*t", dt1 + tonumber(ofset)) 46 | return newTime 47 | end 48 | 49 | function _M.test() 50 | local oldTime="20130908232828" 51 | --把指定的时间加3小时 52 | local newTime=_M.getNewDate(oldTime,3,'HOUR') 53 | local t1 = string.format('%d-%02d-%02d %02d:%02d:%02d',newTime.year,newTime.month,newTime.day,newTime.hour,newTime.min,newTime.sec) 54 | print('t1='..t1) 55 | 56 | --把指定的时间加1天 57 | local newTime=_M.getNewDate(oldTime,1,'DAY') 58 | 59 | local t2 = string.format('%d%02d%02d%02d%02d%02d',newTime.year,newTime.month,newTime.day,newTime.hour,newTime.min,newTime.sec) 60 | 61 | print('t2='..t2) 62 | end 63 | 64 | return _M 65 | -------------------------------------------------------------------------------- /utils/event.lua: -------------------------------------------------------------------------------- 1 | local _M={} 2 | local logger=require "suproxy.utils.compatibleLog" 3 | local function addHandler(self,context,...) 4 | local handlers={...} 5 | if #handlers ==0 then print("no handler was added to event") return error("no handler was added to event") end 6 | for k,v in ipairs(handlers) do 7 | table.insert(self.chain,{context=context,handler=v}) 8 | end 9 | end 10 | 11 | function _M:new(source,name) 12 | local o={ 13 | source=source, 14 | name=name, 15 | chain={}, 16 | } 17 | o.addHandler=addHandler 18 | o.trigger=function(self,...) 19 | logger.logWithTitle(logger.DEBUG,string.format("event %s triggered",self.name),"") 20 | local args={...} 21 | for k,v in ipairs(self.chain) do 22 | v.handler(v.context,self.source,unpack(args)) 23 | end 24 | end 25 | return setmetatable(o, {__index=self}) 26 | end 27 | 28 | function _M:newReturnEvent(source,name) 29 | local o=_M:new(source,name) 30 | o.addHandler=function(self,context,...) 31 | assert(#(self.chain)==0,"returnEvent cannot has more than one handler") 32 | addHandler(self,context,...) 33 | end 34 | o.trigger=function(self,...) 35 | logger.logWithTitle(logger.DEBUG,string.format("event %s triggered",self.name),"") 36 | local args={...} 37 | for k,v in ipairs(self.chain) do 38 | return unpack{v.handler(v.context,self.source,unpack(args))} 39 | end 40 | end 41 | return o 42 | end 43 | 44 | 45 | function _M:setHandler(context,...) 46 | self.chain={} 47 | self:addHandler(context,...) 48 | end 49 | 50 | function _M:hasHandler() 51 | return #(self.chain)>0 52 | end 53 | 54 | _M.unitTest={} 55 | function _M.test() 56 | local src={} 57 | src.eventA=_M:new(src) 58 | src.eventB=_M:newReturnEvent(src) 59 | local handlers={ 60 | handle1=function(self,source,params) 61 | print("1 executed"..tostring(params)) 62 | source.name=params 63 | end, 64 | handle2=function(self,source,params) 65 | print("2 executed"..tostring(params)) 66 | self.name=params 67 | end, 68 | handle3=function(self,source,params) 69 | print("3 executed"..tostring(params)) 70 | end, 71 | } 72 | src.eventA:addHandler(handlers,handlers.handle1,handlers.handle2,handlers.handle3) 73 | local result=src.eventA:trigger("0","lala") 74 | assert(src.name=="0",src.name) 75 | assert(handlers.name=="0",handlers.name) 76 | src.eventB:addHandler(src,function(self,source,params) source.name=params return "2",true end) 77 | local result1=src.eventB:trigger("0","la") 78 | assert(src.name=="0",src.name) 79 | assert(result1=="2",result) 80 | ok=pcall(src.eventB.addHandler,src.eventB,nil,function()end) 81 | assert(not ok) 82 | end 83 | 84 | return _M -------------------------------------------------------------------------------- /utils/ffi-zlib.lua: -------------------------------------------------------------------------------- 1 | local ffi = require "ffi" 2 | local ffi_new = ffi.new 3 | local ffi_str = ffi.string 4 | local ffi_sizeof = ffi.sizeof 5 | local ffi_copy = ffi.copy 6 | local tonumber = tonumber 7 | 8 | local _M = { 9 | _VERSION = '0.01', 10 | } 11 | 12 | local mt = { __index = _M } 13 | 14 | 15 | ffi.cdef([[ 16 | enum { 17 | Z_NO_FLUSH = 0, 18 | Z_PARTIAL_FLUSH = 1, 19 | Z_SYNC_FLUSH = 2, 20 | Z_FULL_FLUSH = 3, 21 | Z_FINISH = 4, 22 | Z_BLOCK = 5, 23 | Z_TREES = 6, 24 | /* Allowed flush values; see deflate() and inflate() below for details */ 25 | Z_OK = 0, 26 | Z_STREAM_END = 1, 27 | Z_NEED_DICT = 2, 28 | Z_ERRNO = -1, 29 | Z_STREAM_ERROR = -2, 30 | Z_DATA_ERROR = -3, 31 | Z_MEM_ERROR = -4, 32 | Z_BUF_ERROR = -5, 33 | Z_VERSION_ERROR = -6, 34 | /* Return codes for the compression/decompression functions. Negative values 35 | * are errors, positive values are used for special but normal events. 36 | */ 37 | Z_NO_COMPRESSION = 0, 38 | Z_BEST_SPEED = 1, 39 | Z_BEST_COMPRESSION = 9, 40 | Z_DEFAULT_COMPRESSION = -1, 41 | /* compression levels */ 42 | Z_FILTERED = 1, 43 | Z_HUFFMAN_ONLY = 2, 44 | Z_RLE = 3, 45 | Z_FIXED = 4, 46 | Z_DEFAULT_STRATEGY = 0, 47 | /* compression strategy; see deflateInit2() below for details */ 48 | Z_BINARY = 0, 49 | Z_TEXT = 1, 50 | Z_ASCII = Z_TEXT, /* for compatibility with 1.2.2 and earlier */ 51 | Z_UNKNOWN = 2, 52 | /* Possible values of the data_type field (though see inflate()) */ 53 | Z_DEFLATED = 8, 54 | /* The deflate compression method (the only one supported in this version) */ 55 | Z_NULL = 0, /* for initializing zalloc, zfree, opaque */ 56 | }; 57 | 58 | 59 | typedef void* (* z_alloc_func)( void* opaque, unsigned items, unsigned size ); 60 | typedef void (* z_free_func) ( void* opaque, void* address ); 61 | 62 | typedef struct z_stream_s { 63 | char* next_in; 64 | unsigned avail_in; 65 | unsigned long total_in; 66 | char* next_out; 67 | unsigned avail_out; 68 | unsigned long total_out; 69 | char* msg; 70 | void* state; 71 | z_alloc_func zalloc; 72 | z_free_func zfree; 73 | void* opaque; 74 | int data_type; 75 | unsigned long adler; 76 | unsigned long reserved; 77 | } z_stream; 78 | 79 | 80 | const char* zlibVersion(); 81 | const char* zError(int); 82 | 83 | int inflate(z_stream*, int flush); 84 | int inflateEnd(z_stream*); 85 | int inflateInit2_(z_stream*, int windowBits, const char* version, int stream_size); 86 | 87 | int deflate(z_stream*, int flush); 88 | int deflateEnd(z_stream* ); 89 | int deflateInit2_(z_stream*, int level, int method, int windowBits, int memLevel,int strategy, const char *version, int stream_size); 90 | 91 | unsigned long adler32(unsigned long adler, const char *buf, unsigned len); 92 | unsigned long crc32(unsigned long crc, const char *buf, unsigned len); 93 | unsigned long adler32_combine(unsigned long, unsigned long, long); 94 | unsigned long crc32_combine(unsigned long, unsigned long, long); 95 | 96 | ]]) 97 | 98 | local zlib = ffi.load(ffi.os == "Windows" and "zlib1" or "z") 99 | _M.zlib = zlib 100 | 101 | -- Default to 16k output buffer 102 | local DEFAULT_CHUNK = 16384 103 | 104 | local Z_OK = zlib.Z_OK 105 | local Z_NO_FLUSH = zlib.Z_NO_FLUSH 106 | local Z_STREAM_END = zlib.Z_STREAM_END 107 | local Z_FINISH = zlib.Z_FINISH 108 | 109 | local function zlib_err(err) 110 | return ffi_str(zlib.zError(err)) 111 | end 112 | _M.zlib_err = zlib_err 113 | 114 | local function createStream(bufsize) 115 | -- Setup Stream 116 | local stream = ffi_new("z_stream") 117 | 118 | -- Create input buffer var 119 | local inbuf = ffi_new('char[?]', bufsize+1) 120 | stream.next_in, stream.avail_in = inbuf, 0 121 | 122 | -- create the output buffer 123 | local outbuf = ffi_new('char[?]', bufsize) 124 | stream.next_out, stream.avail_out = outbuf, 0 125 | 126 | return stream, inbuf, outbuf 127 | end 128 | _M.createStream = createStream 129 | 130 | local function initInflate(stream, windowBits) 131 | -- Setup inflate process 132 | local windowBits = windowBits or (15 + 32) -- +32 sets automatic header detection 133 | local version = ffi_str(zlib.zlibVersion()) 134 | 135 | return zlib.inflateInit2_(stream, windowBits, version, ffi_sizeof(stream)) 136 | end 137 | _M.initInflate = initInflate 138 | 139 | local function initDeflate(stream, options) 140 | -- Setup deflate process 141 | local method = zlib.Z_DEFLATED 142 | local level = options.level or zlib.Z_DEFAULT_COMPRESSION 143 | local memLevel = options.memLevel or 8 144 | local strategy = options.strategy or zlib.Z_DEFAULT_STRATEGY 145 | local windowBits = options.windowBits or (15 + 16) -- +16 sets gzip wrapper not zlib 146 | local version = ffi_str(zlib.zlibVersion()) 147 | 148 | return zlib.deflateInit2_(stream, level, method, windowBits, memLevel, strategy, version, ffi_sizeof(stream)) 149 | end 150 | _M.initDeflate = initDeflate 151 | 152 | local function flushOutput(stream, bufsize, output, outbuf) 153 | -- Calculate available output bytes 154 | local out_sz = bufsize - stream.avail_out 155 | if out_sz == 0 then 156 | return 157 | end 158 | -- Read bytes from output buffer and pass to output function 159 | output(ffi_str(outbuf, out_sz)) 160 | end 161 | 162 | local function flate(zlib_flate, zlib_flateEnd, input, output, bufsize, stream, inbuf, outbuf) 163 | -- Inflate or Deflate a stream 164 | local err = 0 165 | local mode = Z_NO_FLUSH 166 | repeat 167 | -- Read some input 168 | local data = input(bufsize) 169 | if data ~= nil then 170 | ffi_copy(inbuf, data) 171 | stream.next_in, stream.avail_in = inbuf, #data 172 | else 173 | -- EOF, try and finish up 174 | mode = Z_FINISH 175 | stream.avail_in = 0 176 | end 177 | 178 | -- While the output buffer is being filled completely just keep going 179 | repeat 180 | stream.next_out = outbuf 181 | stream.avail_out = bufsize 182 | -- Process the stream 183 | err = zlib_flate(stream, mode) 184 | if err < Z_OK then 185 | -- Error, clean up and return 186 | zlib_flateEnd(stream) 187 | return false, "FLATE: "..zlib_err(err), stream 188 | end 189 | -- Write the data out 190 | flushOutput(stream, bufsize, output, outbuf) 191 | until stream.avail_out ~= 0 192 | 193 | until err == Z_STREAM_END 194 | 195 | -- Stream finished, clean up and return 196 | zlib_flateEnd(stream) 197 | return true, zlib_err(err) 198 | end 199 | _M.flate = flate 200 | 201 | local function adler(str, chksum) 202 | local chksum = chksum or 0 203 | local str = str or "" 204 | return zlib.adler32(chksum, str, #str) 205 | end 206 | _M.adler = adler 207 | 208 | local function crc(str, chksum) 209 | local chksum = chksum or 0 210 | local str = str or "" 211 | return zlib.crc32(chksum, str, #str) 212 | end 213 | _M.crc = crc 214 | 215 | function _M.inflateGzip(input, output, bufsize, windowBits) 216 | local bufsize = bufsize or DEFAULT_CHUNK 217 | 218 | -- Takes 2 functions that provide input data from a gzip stream and receives output data 219 | -- Returns uncompressed string 220 | local stream, inbuf, outbuf = createStream(bufsize) 221 | 222 | local init = initInflate(stream, windowBits) 223 | if init == Z_OK then 224 | local ok, err = flate(zlib.inflate, zlib.inflateEnd, input, output, bufsize, stream, inbuf, outbuf) 225 | return ok,err 226 | else 227 | -- Init error 228 | zlib.inflateEnd(stream) 229 | return false, "INIT: "..zlib_err(init) 230 | end 231 | end 232 | 233 | function _M.deflateGzip(input, output, bufsize, options) 234 | local bufsize = bufsize or DEFAULT_CHUNK 235 | options = options or {} 236 | 237 | -- Takes 2 functions that provide plain input data and receives output data 238 | -- Returns gzip compressed string 239 | local stream, inbuf, outbuf = createStream(bufsize) 240 | 241 | local init = initDeflate(stream, options) 242 | if init == Z_OK then 243 | local ok, err = flate(zlib.deflate, zlib.deflateEnd, input, output, bufsize, stream, inbuf, outbuf) 244 | return ok,err 245 | else 246 | -- Init error 247 | zlib.deflateEnd(stream) 248 | return false, "INIT: "..zlib_err(init) 249 | end 250 | end 251 | 252 | function _M.version() 253 | return ffi_str(zlib.zlibVersion()) 254 | end 255 | 256 | return _M -------------------------------------------------------------------------------- /utils/json.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- json.lua 3 | -- 4 | -- Copyright (c) 2020 rxi 5 | -- 6 | -- Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | -- this software and associated documentation files (the "Software"), to deal in 8 | -- the Software without restriction, including without limitation the rights to 9 | -- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | -- of the Software, and to permit persons to whom the Software is furnished to do 11 | -- so, subject to the following conditions: 12 | -- 13 | -- The above copyright notice and this permission notice shall be included in all 14 | -- copies or substantial portions of the Software. 15 | -- 16 | -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | -- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | -- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | -- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | -- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | -- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | -- SOFTWARE. 23 | -- 24 | 25 | local json = { _version = "0.1.2" } 26 | 27 | ------------------------------------------------------------------------------- 28 | -- Encode 29 | ------------------------------------------------------------------------------- 30 | 31 | local encode 32 | 33 | local escape_char_map = { 34 | [ "\\" ] = "\\", 35 | [ "\"" ] = "\"", 36 | [ "\b" ] = "b", 37 | [ "\f" ] = "f", 38 | [ "\n" ] = "n", 39 | [ "\r" ] = "r", 40 | [ "\t" ] = "t", 41 | } 42 | 43 | local escape_char_map_inv = { [ "/" ] = "/" } 44 | for k, v in pairs(escape_char_map) do 45 | escape_char_map_inv[v] = k 46 | end 47 | 48 | 49 | local function escape_char(c) 50 | return "\\" .. (escape_char_map[c] or string.format("u%04x", c:byte())) 51 | end 52 | 53 | 54 | local function encode_nil(val) 55 | return "null" 56 | end 57 | 58 | 59 | local function encode_table(val, stack) 60 | local res = {} 61 | stack = stack or {} 62 | 63 | -- Circular reference? 64 | if stack[val] then error("circular reference") end 65 | 66 | stack[val] = true 67 | 68 | if rawget(val, 1) ~= nil or next(val) == nil then 69 | -- Treat as array -- check keys are valid and it is not sparse 70 | local n = 0 71 | for k in pairs(val) do 72 | if type(k) ~= "number" then 73 | error("invalid table: mixed or invalid key types") 74 | end 75 | n = n + 1 76 | end 77 | if n ~= #val then 78 | error("invalid table: sparse array") 79 | end 80 | -- Encode 81 | for i, v in ipairs(val) do 82 | table.insert(res, encode(v, stack)) 83 | end 84 | stack[val] = nil 85 | return "[" .. table.concat(res, ",") .. "]" 86 | 87 | else 88 | -- Treat as an object 89 | for k, v in pairs(val) do 90 | if type(k) ~= "string" then 91 | error("invalid table: mixed or invalid key types") 92 | end 93 | table.insert(res, encode(k, stack) .. ":" .. encode(v, stack)) 94 | end 95 | stack[val] = nil 96 | return "{" .. table.concat(res, ",") .. "}" 97 | end 98 | end 99 | 100 | 101 | local function encode_string(val) 102 | return '"' .. val:gsub('[%z\1-\31\\"]', escape_char) .. '"' 103 | end 104 | 105 | 106 | local function encode_number(val) 107 | -- Check for NaN, -inf and inf 108 | if val ~= val or val <= -math.huge or val >= math.huge then 109 | error("unexpected number value '" .. tostring(val) .. "'") 110 | end 111 | return string.format("%.14g", val) 112 | end 113 | 114 | 115 | local type_func_map = { 116 | [ "nil" ] = encode_nil, 117 | [ "table" ] = encode_table, 118 | [ "string" ] = encode_string, 119 | [ "number" ] = encode_number, 120 | [ "boolean" ] = tostring, 121 | } 122 | 123 | 124 | encode = function(val, stack) 125 | local t = type(val) 126 | local f = type_func_map[t] 127 | if f then 128 | return f(val, stack) 129 | end 130 | error("unexpected type '" .. t .. "'") 131 | end 132 | 133 | 134 | function json.encode(val) 135 | return ( encode(val) ) 136 | end 137 | 138 | 139 | ------------------------------------------------------------------------------- 140 | -- Decode 141 | ------------------------------------------------------------------------------- 142 | 143 | local parse 144 | 145 | local function create_set(...) 146 | local res = {} 147 | for i = 1, select("#", ...) do 148 | res[ select(i, ...) ] = true 149 | end 150 | return res 151 | end 152 | 153 | local space_chars = create_set(" ", "\t", "\r", "\n") 154 | local delim_chars = create_set(" ", "\t", "\r", "\n", "]", "}", ",") 155 | local escape_chars = create_set("\\", "/", '"', "b", "f", "n", "r", "t", "u") 156 | local literals = create_set("true", "false", "null") 157 | 158 | local literal_map = { 159 | [ "true" ] = true, 160 | [ "false" ] = false, 161 | [ "null" ] = nil, 162 | } 163 | 164 | 165 | local function next_char(str, idx, set, negate) 166 | for i = idx, #str do 167 | if set[str:sub(i, i)] ~= negate then 168 | return i 169 | end 170 | end 171 | return #str + 1 172 | end 173 | 174 | 175 | local function decode_error(str, idx, msg) 176 | local line_count = 1 177 | local col_count = 1 178 | for i = 1, idx - 1 do 179 | col_count = col_count + 1 180 | if str:sub(i, i) == "\n" then 181 | line_count = line_count + 1 182 | col_count = 1 183 | end 184 | end 185 | error( string.format("%s at line %d col %d", msg, line_count, col_count) ) 186 | end 187 | 188 | 189 | local function codepoint_to_utf8(n) 190 | -- http://scripts.sil.org/cms/scripts/page.php?site_id=nrsi&id=iws-appendixa 191 | local f = math.floor 192 | if n <= 0x7f then 193 | return string.char(n) 194 | elseif n <= 0x7ff then 195 | return string.char(f(n / 64) + 192, n % 64 + 128) 196 | elseif n <= 0xffff then 197 | return string.char(f(n / 4096) + 224, f(n % 4096 / 64) + 128, n % 64 + 128) 198 | elseif n <= 0x10ffff then 199 | return string.char(f(n / 262144) + 240, f(n % 262144 / 4096) + 128, 200 | f(n % 4096 / 64) + 128, n % 64 + 128) 201 | end 202 | error( string.format("invalid unicode codepoint '%x'", n) ) 203 | end 204 | 205 | 206 | local function parse_unicode_escape(s) 207 | local n1 = tonumber( s:sub(1, 4), 16 ) 208 | local n2 = tonumber( s:sub(7, 10), 16 ) 209 | -- Surrogate pair? 210 | if n2 then 211 | return codepoint_to_utf8((n1 - 0xd800) * 0x400 + (n2 - 0xdc00) + 0x10000) 212 | else 213 | return codepoint_to_utf8(n1) 214 | end 215 | end 216 | 217 | 218 | local function parse_string(str, i) 219 | local res = "" 220 | local j = i + 1 221 | local k = j 222 | 223 | while j <= #str do 224 | local x = str:byte(j) 225 | 226 | if x < 32 then 227 | decode_error(str, j, "control character in string") 228 | 229 | elseif x == 92 then -- `\`: Escape 230 | res = res .. str:sub(k, j - 1) 231 | j = j + 1 232 | local c = str:sub(j, j) 233 | if c == "u" then 234 | local hex = str:match("^[dD][89aAbB]%x%x\\u%x%x%x%x", j + 1) 235 | or str:match("^%x%x%x%x", j + 1) 236 | or decode_error(str, j - 1, "invalid unicode escape in string") 237 | res = res .. parse_unicode_escape(hex) 238 | j = j + #hex 239 | else 240 | if not escape_chars[c] then 241 | decode_error(str, j - 1, "invalid escape char '" .. c .. "' in string") 242 | end 243 | res = res .. escape_char_map_inv[c] 244 | end 245 | k = j + 1 246 | 247 | elseif x == 34 then -- `"`: End of string 248 | res = res .. str:sub(k, j - 1) 249 | return res, j + 1 250 | end 251 | 252 | j = j + 1 253 | end 254 | 255 | decode_error(str, i, "expected closing quote for string") 256 | end 257 | 258 | 259 | local function parse_number(str, i) 260 | local x = next_char(str, i, delim_chars) 261 | local s = str:sub(i, x - 1) 262 | local n = tonumber(s) 263 | if not n then 264 | decode_error(str, i, "invalid number '" .. s .. "'") 265 | end 266 | return n, x 267 | end 268 | 269 | 270 | local function parse_literal(str, i) 271 | local x = next_char(str, i, delim_chars) 272 | local word = str:sub(i, x - 1) 273 | if not literals[word] then 274 | decode_error(str, i, "invalid literal '" .. word .. "'") 275 | end 276 | return literal_map[word], x 277 | end 278 | 279 | 280 | local function parse_array(str, i) 281 | local res = {} 282 | local n = 1 283 | i = i + 1 284 | while 1 do 285 | local x 286 | i = next_char(str, i, space_chars, true) 287 | -- Empty / end of array? 288 | if str:sub(i, i) == "]" then 289 | i = i + 1 290 | break 291 | end 292 | -- Read token 293 | x, i = parse(str, i) 294 | res[n] = x 295 | n = n + 1 296 | -- Next token 297 | i = next_char(str, i, space_chars, true) 298 | local chr = str:sub(i, i) 299 | i = i + 1 300 | if chr == "]" then break end 301 | if chr ~= "," then decode_error(str, i, "expected ']' or ','") end 302 | end 303 | return res, i 304 | end 305 | 306 | 307 | local function parse_object(str, i) 308 | local res = {} 309 | i = i + 1 310 | while 1 do 311 | local key, val 312 | i = next_char(str, i, space_chars, true) 313 | -- Empty / end of object? 314 | if str:sub(i, i) == "}" then 315 | i = i + 1 316 | break 317 | end 318 | -- Read key 319 | if str:sub(i, i) ~= '"' then 320 | decode_error(str, i, "expected string for key") 321 | end 322 | key, i = parse(str, i) 323 | -- Read ':' delimiter 324 | i = next_char(str, i, space_chars, true) 325 | if str:sub(i, i) ~= ":" then 326 | decode_error(str, i, "expected ':' after key") 327 | end 328 | i = next_char(str, i + 1, space_chars, true) 329 | -- Read value 330 | val, i = parse(str, i) 331 | -- Set 332 | res[key] = val 333 | -- Next token 334 | i = next_char(str, i, space_chars, true) 335 | local chr = str:sub(i, i) 336 | i = i + 1 337 | if chr == "}" then break end 338 | if chr ~= "," then decode_error(str, i, "expected '}' or ','") end 339 | end 340 | return res, i 341 | end 342 | 343 | 344 | local char_func_map = { 345 | [ '"' ] = parse_string, 346 | [ "0" ] = parse_number, 347 | [ "1" ] = parse_number, 348 | [ "2" ] = parse_number, 349 | [ "3" ] = parse_number, 350 | [ "4" ] = parse_number, 351 | [ "5" ] = parse_number, 352 | [ "6" ] = parse_number, 353 | [ "7" ] = parse_number, 354 | [ "8" ] = parse_number, 355 | [ "9" ] = parse_number, 356 | [ "-" ] = parse_number, 357 | [ "t" ] = parse_literal, 358 | [ "f" ] = parse_literal, 359 | [ "n" ] = parse_literal, 360 | [ "[" ] = parse_array, 361 | [ "{" ] = parse_object, 362 | } 363 | 364 | 365 | parse = function(str, idx) 366 | local chr = str:sub(idx, idx) 367 | local f = char_func_map[chr] 368 | if f then 369 | return f(str, idx) 370 | end 371 | decode_error(str, idx, "unexpected character '" .. chr .. "'") 372 | end 373 | 374 | 375 | function json.decode(str) 376 | if type(str) ~= "string" then 377 | error("expected argument of type string, got " .. type(str)) 378 | end 379 | local res, idx = parse(str, next_char(str, 1, space_chars, true)) 380 | idx = next_char(str, idx, space_chars, true) 381 | if idx <= #str then 382 | decode_error(str, idx, "trailing garbage") 383 | end 384 | return res 385 | end 386 | 387 | 388 | return json -------------------------------------------------------------------------------- /utils/pureluapack.lua: -------------------------------------------------------------------------------- 1 | local bit=require("bit") require "suproxy.utils.stringUtils" local _INTLEN=4 local _DEFAULT_ENDIAN="<" 2 | local function packDouble (e,l,n) local result local sign = 0 if n < 0.0 then sign = 0x80 n = -n end local mant, expo = math.frexp(n) if mant ~= mant then result = string.char( 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) elseif mant == math.huge then if sign == 0 then result = string.char( 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) else result = string.char(0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) end elseif mant == 0.0 and expo == 0 then result = string.char(sign, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) else expo = expo + 0x3FE mant = (mant * 2.0 - 1.0) * math.ldexp(0.5, 53) local x=(expo % 0x10) * 0x10 + math.floor(mant / 0x1000000000000) print(tostring(x)) result = string.char( sign + math.floor(expo / 0x10), (expo % 0x10) * 0x10 + math.floor(mant / 0x1000000000000), math.floor(mant / 0x10000000000) % 0x100, math.floor(mant / 0x100000000) % 0x100, math.floor(mant / 0x1000000) % 0x100, math.floor(mant / 0x10000) % 0x100, math.floor(mant / 0x100) % 0x100, mant % 0x100) end if e=="<" then result=result:reverse() end return result end local function unpackDouble (e,l,s) local rs=s:sub(1,l) if e=="<" then rs=rs:reverse() end local b1, b2, b3, b4, b5, b6, b7, b8 = rs:byte(1, 8) local sign = b1 > 0x7F local expo = (b1 % 0x80) * 0x10 + math.floor(b2 / 0x10) local mant = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8 if sign then sign = -1 else sign = 1 end local n if mant == 0 and expo == 0 then n = sign * 0.0 elseif expo == 0x7FF then if mant == 0 then n = sign * huge else n = 0.0/0.0 end else n = sign * math.ldexp(1.0 + mant / 0x10000000000000, expo - 0x3FF) end return n,9 end local function packNumber(e,length,k) 3 | if(k<0) then k=256^length+k end 4 | local rs="" 5 | local i=0 6 | while (k>=1) do 7 | local t=k%256 8 | if(e==">" ) then rs=string.char(t)..rs end 9 | if(e=="<" or e==nil or e=="=") then rs=rs..string.char(t) end 10 | k=bit.rshift(k,8) 11 | i=i+1 12 | end 13 | if i>length then return nil end 14 | while i" ) then rs=string.char(0)..rs end 16 | if(e=="<" or e==nil or e=="=") then rs=rs..string.char(0) end 17 | i=i+1 18 | end 19 | return rs 20 | end 21 | 22 | local function packLengthPreStr(e,l,value) 23 | 24 | local strLen=string.len(value) 25 | 26 | return packNumber(e,l,strLen)..value 27 | end 28 | local function packLengthStr(e,l,value) 29 | 30 | local strLen=string.len(value) 31 | 32 | for i=strLen,l-1,1 do 33 | 34 | value=value..string.char(0) 35 | end 36 | 37 | return value 38 | end 39 | local function packZeroEndStr(e,l,value) 40 | 41 | return value..string.char(0) 42 | end 43 | local function unpackNumber(endian,length,Str) 44 | local rs=Str:sub(1,length) 45 | if(endian==">")then 46 | rs=rs:reverse() 47 | end 48 | local i=1 49 | local result=string.byte(rs,1) 50 | while i+1<=length do 51 | result=result+string.byte(rs,i+1)*(256^i) 52 | i=i+1 53 | end 54 | 55 | return math.floor(result),i+1 56 | end 57 | local function unpackSignedNumber(endian,length,Str) 58 | local result,pos=unpackNumber(endian,length,Str) 59 | --minus value 60 | if result >= (256^length)/2 then 61 | result = result - 256^length 62 | end 63 | 64 | return result,pos 65 | end 66 | local function unpackLengthPreStr(e,l,value) 67 | 68 | local strLen=unpackNumber(e,l,value) 69 | 70 | return value:sub(l+1,l+strLen),l+strLen+1 71 | end 72 | local function unpackLengthStr(e,l,value) 73 | return value:sub(1,l),l+1 74 | end 75 | local function unpackZeroEndStr(e,l,value) 76 | local i=1; 77 | while(i<=#value and string.byte(value,i)~=0) do i=i+1 end 78 | return value:sub(1,i-1),i+1 79 | end 80 | local function getLen(t,l) local _t={ ["B"]=1, ["b"]=1, ["H"]=2, ["h"]=2, ["d"]=8, ["f"]=4, ["s"]=1, ["I"]=_INTLEN, ["i"]=_INTLEN } if(not l ) then l=_t[t] end return l end 81 | function string.pack(fmt,...) 82 | local arg={...} 83 | assert(type(fmt)=="string","bad argument #1 to 'pack' (string expected, got "..type(fmt)..")") 84 | local rs="" 85 | local i=1 86 | local nativeEndian=_DEFAULT_ENDIAN 87 | for w,e,t,l in fmt:gmatch("(([<>=]?)([bBhH1LjJTiIfdnczsxX])([%d]*))") do 88 | l=tonumber(l) 89 | if(e:len()~=0) then nativeEndian=e end 90 | if(t=="I" or t=="B" or t=="H") then 91 | l=getLen(t,l) 92 | assert(type(arg[i]) == "number", "bad argument #"..(i+1).." to 'pack' (number expected, got "..type(arg[i])..")") 93 | assert(arg[i]<=256^l-1,"bad argument #"..(i+1).." to 'pack' (unsign integer overflow)") 94 | rs=rs..packNumber(nativeEndian,l,arg[i]) 95 | elseif(t=="i" or t=="b" or t=="h") then 96 | l=getLen(t,l) 97 | assert(type(arg[i]) == "number", "bad argument #"..(i+1).." to 'pack' (number expected, got "..type(arg[i])..")") 98 | assert(arg[i]<256^l/2 and arg[i]>=-256^l/2,"bad argument #"..(i+1).." to 'pack' (signed interger overflow)") 99 | rs=rs..packNumber(nativeEndian,l,arg[i]) elseif(t=="d") then l=getLen(t,l) assert(type(arg[i]) == "number", "bad argument #"..(i+1).." to 'pack' (number expected, got "..type(arg[i])..")") rs=rs..packDouble(nativeEndian,l,arg[i]) 100 | elseif(t=="s") then 101 | assert(type(arg[i]) == "string", "bad argument #"..(i+1).." to 'pack' (string expected, got "..type(arg[i])..")") 102 | l=getLen(t,l) 103 | rs=rs..packLengthPreStr(nativeEndian,l,arg[i]) 104 | elseif(t=="c") then 105 | assert(type(arg[i]) == "string", "bad argument #"..(i+1).." to 'pack' (string expected, got "..type(arg[i])..")") 106 | assert(l,"missing size for format option 'c'") 107 | rs=rs..packLengthStr(nativeEndian,l,arg[i]) 108 | elseif(t=="z") then 109 | assert(type(arg[i]) == "string", "bad argument #"..(i+1).." to 'pack' (string expected, got "..type(arg[i])..")") 110 | rs=rs..packZeroEndStr(nativeEndian,l,arg[i]) 111 | else 112 | error("invalid format option '"..t) 113 | end 114 | i=i+1 115 | end 116 | 117 | return rs 118 | end 119 | 120 | 121 | function string.unpack(fmt,value,pos) 122 | 123 | assert(type(fmt)=="string","bad argument #1 to 'unpack' (string expected, got "..type(fmt)..")") 124 | assert(type(value)=="string","bad argument #2 to 'unpack' (string expected, got "..type(value)..")") 125 | if(pos) then assert(pos>=1 and pos<=value:len()+1,"pos invalid") end 126 | local rs={} 127 | local i=1 128 | if(pos)then i=pos end 129 | local nativeEndian=_DEFAULT_ENDIAN 130 | for w,e,t,l in fmt:gmatch("(([<>=]?)([bBhH1LjJTiIfdnczsxX])([%d]*))") do 131 | l=tonumber(l) 132 | if(e:len()~=0) then nativeEndian=e end 133 | local segment=value:sub(i) 134 | local ps,index 135 | if(t=="I" or t=="B" or t=="H") then 136 | l=getLen(t,l) assert(l>=1,"size out of limit") 137 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)") 138 | ps,index=unpackNumber(nativeEndian,l,segment) 139 | elseif(t=="i" or t=="b" or t=="h") then 140 | l=getLen(t,l) 141 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)") 142 | ps,index=unpackSignedNumber(nativeEndian,l,segment) elseif(t=="d") then l=getLen(t,l) assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)") ps,index=unpackDouble(nativeEndian,l,segment) 143 | elseif(t=="s") then 144 | l=getLen(t,l) 145 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)") 146 | ps,index=unpackLengthPreStr(nativeEndian,l,segment) 147 | elseif(t=="c") then 148 | ps,index=unpackLengthStr(nativeEndian,l,segment) 149 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)") 150 | elseif(t=="z") then 151 | ps,index=unpackZeroEndStr(nativeEndian,l,segment) 152 | else 153 | error("invalid format option '"..t) 154 | end 155 | 156 | table.insert(rs,ps) 157 | i=i+index-1 158 | end 159 | table.insert(rs,i) 160 | return unpack(rs) 161 | end 162 | local _M={} 163 | function _M.test() 164 | local k=string.pack(">I4BHs1I4BHs1s2",673845,2,-200,"123") 182 | print(string.hex(m)) 183 | --add some useless byte on start to test pos params 184 | m=string.char(0x2,0x3,0xff,0x00)..m 185 | local v1,v2,v3,v4,v5=string.unpack("I4Bh>s2",m,5) 186 | print(v1,v2,v3,v4,v5) 187 | assert(v1==673845) 188 | assert(v2==2) 189 | assert(v3==-200) 190 | assert(v4=="123") 191 | 192 | p=string.fromhex([[ 193 | FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1 194 | 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD 195 | EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245 196 | E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED 197 | EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE65381 198 | FFFFFFFF FFFFFFFF 199 | ]]) 200 | q="FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF" 201 | assert(p:hex()==q,"fromhex err") 202 | print(p:hex(1,nil,4,8," "," ",1,1,12)) 203 | p=string.fromhex("140000002900004823000018BE000067840000012E637572766532353531392D736861323536406C69627373682E6F72672C656364682D736861322D6E697374703235362C656364682D736861322D6E697374703338342C656364682D736861322D6E697374703532312C6469666669652D68656C6C6D616E2D67726F75702D65786368616E67652D7368613235362C6469666669652D68656C6C6D616E2D67726F75702D65786368616E67652D736861312C6469666669652D68656C6C6D616E2D67726F757031382D7368613531322C6469666669652D68656C6C6D616E2D67726F757031362D7368613531322C6469666669652D68656C6C6D616E2D67726F757031342D7368613235362C6469666669652D68656C6C6D616E2D67726F757031342D736861312C6469666669652D68656C6C6D616E2D67726F7570312D73686131000000717373682D7273612C7273612D736861322D3235362C7273612D736861322D3531322C7373682D6473732C65636473612D736861322D6E697374703235362C65636473612D736861322D6E697374703338342C65636473612D736861322D6E697374703532312C7373682D656432353531390000011963686163686132302D706F6C7931333035406F70656E7373682E636F6D2C6165733132382D6374722C6165733139322D6374722C6165733235362D6374722C6165733132382D67636D406F70656E7373682E636F6D2C6165733235362D67636D406F70656E7373682E636F6D2C6165733132382D6362632C6165733139322D6362632C6165733235362D6362632C336465732D6362632C626C6F77666973682D6362632C636173743132382D6362632C617263666F75722C72696A6E6461656C3132382D6362632C72696A6E6461656C3139322D6362632C72696A6E6461656C3235362D6362632C72696A6E6461656C2D636263406C797361746F722E6C69752E73652C617263666F75723132382C617263666F75723235360000011963686163686132302D706F6C7931333035406F70656E7373682E636F6D2C6165733132382D6374722C6165733139322D6374722C6165733235362D6374722C6165733132382D67636D406F70656E7373682E636F6D2C6165733235362D67636D406F70656E7373682E636F6D2C6165733132382D6362632C6165733139322D6362632C6165733235362D6362632C336465732D6362632C626C6F77666973682D6362632C636173743132382D6362632C617263666F75722C72696A6E6461656C3132382D6362632C72696A6E6461656C3139322D6362632C72696A6E6461656C3235362D6362632C72696A6E6461656C2D636263406C797361746F722E6C69752E73652C617263666F75723132382C617263666F757232353600000178686D61632D736861322D3235362D65746D406F70656E7373682E636F6D2C686D61632D736861322D3531322D65746D406F70656E7373682E636F6D2C686D61632D736861312D65746D406F70656E7373682E636F6D2C686D61632D736861322D3235362C686D61632D736861322D3531322C686D61632D736861312C686D61632D736861312D39362C686D61632D6D64352C686D61632D6D64352D39362C686D61632D726970656D643136302C686D61632D726970656D64313630406F70656E7373682E636F6D2C756D61632D3634406F70656E7373682E636F6D2C756D61632D313238406F70656E7373682E636F6D2C686D61632D736861312D39362D65746D406F70656E7373682E636F6D2C686D61632D6D64352D65746D406F70656E7373682E636F6D2C686D61632D6D64352D39362D65746D406F70656E7373682E636F6D2C756D61632D36342D65746D406F70656E7373682E636F6D2C756D61632D3132382D65746D406F70656E7373682E636F6D2C6E6F6E6500000178686D61632D736861322D3235362D65746D406F70656E7373682E636F6D2C686D61632D736861322D3531322D65746D406F70656E7373682E636F6D2C686D61632D736861312D65746D406F70656E7373682E636F6D2C686D61632D736861322D3235362C686D61632D736861322D3531322C686D61632D736861312C686D61632D736861312D39362C686D61632D6D64352C686D61632D6D64352D39362C686D61632D726970656D643136302C686D61632D726970656D64313630406F70656E7373682E636F6D2C756D61632D3634406F70656E7373682E636F6D2C756D61632D313238406F70656E7373682E636F6D2C686D61632D736861312D39362D65746D406F70656E7373682E636F6D2C686D61632D6D64352D65746D406F70656E7373682E636F6D2C686D61632D6D64352D39362D65746D406F70656E7373682E636F6D2C756D61632D36342D65746D406F70656E7373682E636F6D2C756D61632D3132382D65746D406F70656E7373682E636F6D2C6E6F6E65000000046E6F6E65000000046E6F6E65000000000000000000000000002CA9032100000014") 204 | 205 | print(p:hex(1,nil,8,16," "," ",1,1,12)) p=string.pack("31 and string.byte(x)<127 then 43 | asciiStr=asciiStr..string.char(string.byte(x)) 44 | else 45 | asciiStr=asciiStr.."." 46 | end 47 | if linewidth and i%linewidth==0 then 48 | elseif columnwidth and i%columnwidth==0 then 49 | asciiStr=asciiStr..columnspan 50 | end 51 | end 52 | if linewidth and (i%linewidth==0 or i==#self) then 53 | d=d..columnspan..asciiStr.."\n" 54 | asciiStr="" 55 | elseif columnwidth and i%columnwidth==0 then 56 | d=d..columnspan 57 | else 58 | d=d..bytespan 59 | end 60 | return d 61 | end) 62 | return s 63 | end 64 | 65 | function string.random(byteNum) 66 | local ok,rand=pcall(require,"resty.openssl.rand") 67 | if ok then return rand.bytes(byteNum) end 68 | local byteNum=byteNum or 4 69 | local result="" 70 | math.randomseed(ngx and ngx.time() or os.time()) 71 | for i=1,byteNum,1 do 72 | result=result..string.pack("I1",math.random(0x7f)) 73 | end 74 | return result 75 | end 76 | 77 | function string.split(self,splitter) 78 | local nFindStartIndex = 1 79 | local nSplitIndex = 1 80 | local nSplitArray = {} 81 | while true do 82 | local nFindLastIndex = string.find(self, splitter, nFindStartIndex) 83 | if not nFindLastIndex then 84 | nSplitArray[nSplitIndex] = string.sub(self, nFindStartIndex, string.len(self)) 85 | break 86 | end 87 | nSplitArray[nSplitIndex] = string.sub(self, nFindStartIndex, nFindLastIndex - 1) 88 | nFindStartIndex = nFindLastIndex + string.len(splitter) 89 | nSplitIndex = nSplitIndex + 1 90 | end 91 | return nSplitArray 92 | end 93 | 94 | function string.ascii(self,noFormat) 95 | return self:gsub("." ,function(x) 96 | if (string.byte(x)<31 or string.byte(x)>127) then 97 | if not noFormat and string.byte(x) ~= 0x09 and string.byte(x) ~= 0x0a and string.byte(x) ~= 0x0d then 98 | return x 99 | end 100 | return "." 101 | end 102 | end) 103 | end 104 | --same as hex(1,nil,4,8," "," ",1,1,...) 105 | function string.hexF(self,pos,endpos,...) 106 | pos=pos or 1 107 | return self:hex(pos,endpos,4,8," "," ",1,1,...) 108 | end 109 | --same as hex(1,nil,8,16," "," ",1,1,...) 110 | function string.hex16F(self,pos,endpos,...) 111 | pos=pos or 1 112 | return self:hex(pos,endpos,8,16," "," ",1,1,...) 113 | end 114 | --same as hex(1,nil,8,32," "," ",1,1,...) 115 | function string.hex32F(self,pos,endpos,...) 116 | pos=pos or 1 117 | return self:hex(pos,endpos,8,32," "," ",1,1,...) 118 | end 119 | --decimal number to hex string 120 | function string.dec2hex(input) 121 | assert(type(input)=="number","input must be a number") 122 | return string.format("0x%02X",input) 123 | end 124 | --decimal number to hex string and format as decimal[0x hex] 125 | function string.dec2hexF(input) 126 | return tostring(input).."["..string.dec2hex(input).."]" 127 | end 128 | 129 | function string.literalize(str) 130 | local result=str:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", function(c) return "%" .. c end) 131 | return result 132 | end 133 | 134 | function string.subPlain(str,plainPattern,repl) 135 | local p=plainPattern:literalize() 136 | local r=repl:literalize() 137 | return str:gsub(p,r) 138 | end 139 | 140 | function string.fromhex(value) 141 | local newValue=value:gsub("[^0-9a-fA-F]",function(x) return"" end) 142 | assert(#newValue%2 == 0,"value length % 2 must be 0") 143 | local rs=newValue:gsub("..",function(x) return string.char(tonumber(x,16)) end) 144 | return rs 145 | end 146 | 147 | function string.append(self,element,i) 148 | if not i then i=#self end 149 | if i<0 then i=0 end 150 | if i>#self then i=#self end 151 | if i==0 then return element..self,i+#element end 152 | if i==#self then return self..element,i+#element end 153 | return self:sub(1,i)..element..self:sub(i+1),i+#element 154 | end 155 | 156 | function string.trim(self,element) 157 | if #self<#element or not element then return self end 158 | local result=self; 159 | while result:sub(#result-#element+1)==element do 160 | result=result:sub(1,#result-#element) 161 | end 162 | while result:sub(1,#element)==element do 163 | result=result:sub(#element+1) 164 | end 165 | return result 166 | end 167 | 168 | function string.remove(self,i) 169 | if i<1 then i=1 end 170 | if i>#self then i=#self end 171 | return self:sub(1,i-1)..self:sub(i+1),i-1 172 | end 173 | 174 | function string.compare(value,value1) 175 | local i=1 176 | local result={} 177 | value:gsub("(.)",function(x) if x:byte()~=value1:byte(i) then table.insert(result,i) end i=i+1 end) 178 | return unpack(result) 179 | end 180 | 181 | function string.compareF(value,value1,pos,endpos) 182 | return "---------------------1-------------------------\r\n" 183 | ..value:hexF(pos,endpos,value:compare(value1)) 184 | .."\r\n----------------------2-----------------------\r\n" 185 | ..value1:hexF(pos,endpos,value1:compare(value)) 186 | end 187 | 188 | function string.compare16F(value,value1,pos,endpos) 189 | return "--------------------------------------1---------------------------------------\r\n" 190 | ..value:hex16F(pos,endpos,value:compare(value1)) 191 | .."\r\n--------------------------------------2---------------------------------------\r\n" 192 | ..value1:hex16F(pos,endpos,value1:compare(value)) 193 | end 194 | 195 | local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' -- You will need this for encoding/decoding 196 | -- encoding 197 | function string.base64Encode(data) 198 | return ((data:gsub('.', function(x) 199 | local r,b='',x:byte() 200 | for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end 201 | return r; 202 | end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x) 203 | if (#x < 6) then return '' end 204 | local c=0 205 | for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end 206 | return b:sub(c+1,c+1) 207 | end)..({ '', '==', '=' })[#data%3+1]) 208 | end 209 | 210 | -- decoding 211 | function string.base64Decode(data) 212 | data = string.gsub(data, '[^'..b..'=]', '') 213 | return (data:gsub('.', function(x) 214 | if (x == '=') then return '' end 215 | local r,f='',(b:find(x)-1) 216 | for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end 217 | return r; 218 | end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x) 219 | if (#x ~= 8) then return '' end 220 | local c=0 221 | for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end 222 | return string.char(c) 223 | end)) 224 | end 225 | -------------------------------------------------------------------------------- /utils/tableUtils.lua: -------------------------------------------------------------------------------- 1 | require "suproxy.utils.stringUtils" 2 | local _M={} 3 | local printFlag="___printted" 4 | --format and print table in json style 5 | function _M.printTableF(tab,options,layer,map) 6 | options=options or {} 7 | local printIndex=options.printIndex or false 8 | local layer=layer or 1 9 | local stopLayer=options.stopLayer or 0xffff 10 | local inline=options.inline or false 11 | local wrap= (not inline) and "\r\n" or "" 12 | local tabb= (not inline) and "\t" or "" 13 | local justLen=options.justLen or false 14 | local ascii=options.ascii or true 15 | local excepts=options.excepts or {} 16 | local map=map or {} 17 | local includes=options.includes 18 | local logStr="" 19 | if layer>stopLayer then return logStr end 20 | if type(tab)=="table" and (not map[tostring(tab)]) then 21 | map[tostring(tab)]="" 22 | local i=1 23 | local isList=true 24 | local shortList=true 25 | for k, v in pairs(tab) do 26 | if k~=i then isList=false break end 27 | if type(v)=="table" then shortList=false end 28 | i=i+1 29 | end 30 | --no item in table 31 | if i==1 then isList,shortList =false,false end 32 | if not tab.__orderred then 33 | for k, v in pairs(tab) do 34 | local skip=false 35 | k=tostring(k) 36 | for _,e in ipairs(excepts) do 37 | if k:match(e) then 38 | skip=true 39 | break 40 | end 41 | end 42 | if includes then 43 | local inWhiteList=false 44 | for _,e in ipairs(includes) do 45 | if k:match(e) then 46 | inWhiteList=true 47 | end 48 | end 49 | skip=not inWhiteList 50 | end 51 | if not skip then 52 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer) end 53 | if not isList then logStr=logStr.."\""..k.."\":" end 54 | --print(string.rep(" ",layer),k,":",tostring(v)) 55 | logStr=logStr.._M.printTableF(v,options,layer+1,map) 56 | logStr=logStr.."," 57 | end 58 | end 59 | else 60 | for k, v in ipairs(tab) do 61 | local skip=false 62 | for _,e in ipairs(excepts) do 63 | if v.key:match(e) then 64 | skip=true 65 | break 66 | end 67 | end 68 | if includes then 69 | local inWhiteList=false 70 | for _,e in ipairs(includes) do 71 | if v.key:match(e) then 72 | inWhiteList=true 73 | end 74 | end 75 | skip=not inWhiteList 76 | end 77 | if not skip then 78 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer) end 79 | if not isList then logStr=logStr.."\""..v.key.."\":" end 80 | logStr=logStr.._M.printTableF(v.value,options,layer+1,map) 81 | logStr=logStr.."," 82 | end 83 | end 84 | end 85 | if printIndex and getmetatable(tab) and getmetatable(tab).__index then 86 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer) end 87 | if not isList then logStr=logStr.."\"__index\":" end 88 | logStr=logStr.._M.printTableF(getmetatable(tab).__index,options,layer+1,map).."," 89 | end 90 | if #logStr >0 then logStr=logStr:sub(1,#logStr-1) end 91 | logStr=(isList and "[" or "{")..logStr 92 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer-1) end 93 | logStr=logStr..(isList and "]" or "}")..tostring(tab) 94 | elseif type(tab)=="string" then 95 | logStr="[Len:"..tab:len().."]"..logStr 96 | if not justLen then 97 | if #tab<40 then 98 | logStr=logStr.."\""..(ascii and tab:ascii(true) or tostring(tab)).."\"".."["..tab:hex().."]" 99 | else 100 | logStr=logStr..wrap..string.rep(tabb,layer-1).."\""..(ascii and tab:ascii(true) or tostring(tab)).."\""..wrap..string.rep(tabb,layer-1).."["..tab:hex().."]" 101 | end 102 | end 103 | elseif type(tab)=="number" then 104 | logStr=logStr..string.dec2hexF(tab) 105 | else 106 | logStr=logStr..tostring(tab) 107 | end 108 | return logStr 109 | end 110 | --nil and table compatible concat 111 | function _M.concat(tab,splitter) 112 | local rs={} 113 | for i=1,#tab do 114 | local v=tab[i] 115 | v = v or "nil" 116 | v=(type(v)=="table") and tostring(tab) or v 117 | rs[#rs +1]=v 118 | end 119 | return table.concat(rs,splitter) 120 | end 121 | 122 | --add index to table, then you can use the index to get item 123 | function _M.addIndex(tab,key) 124 | local lookUps={} 125 | for k,v in pairs(tab) do lookUps[v[key]]=v end 126 | local index 127 | local mt 128 | repeat 129 | mt=getmetatable(tab) 130 | if mt then index=mt.__index tab=index end 131 | until not index 132 | setmetatable(tab,{__index=lookUps}) 133 | end 134 | 135 | --imitate extends keyword in java 136 | --copy parents method into subclass if not exists in subclass 137 | --set __base as the base class 138 | --limitation : any method add after extends method can't call sub class's override method 139 | function _M.extends(o,parent) 140 | assert(o,"object can not be null") 141 | assert(parent,"parent can not be null") 142 | for k,v in pairs(parent) do 143 | if not o[k] then o[k]=v end 144 | end 145 | -- if not o.orderred then 146 | -- setmetatable(o,{__index=parent}) 147 | -- else 148 | -- local index=getmetatable(o).__index 149 | -- setmetatable(index,{__index=parent}) 150 | -- end 151 | o.__base=parent 152 | return o 153 | end 154 | 155 | --order table: item key can not be Number 156 | _M.OrderedTable={ 157 | new=function(self,o) 158 | local o=o or {} 159 | local k_i={} 160 | local meta={ 161 | __index=self, 162 | __newindex=function(t,k,v) 163 | assert(type(k)~="number") 164 | rawset(k_i,k,#t+1) 165 | rawset(t,k,v) 166 | rawset(t,#t+1,{key=k,value=v}) 167 | end, 168 | __k_i=k_i 169 | } 170 | o.__orderred=true 171 | return setmetatable(o,meta) 172 | end, 173 | getIndex=function(self,k) 174 | assert(type(k)~="number") 175 | local k_i=getmetatable(self).__k_i 176 | return k_i[k] 177 | end, 178 | getKVTable=function(self) 179 | local rs 180 | for i,v in ipairs(self) do 181 | rs[v.key]=v.value 182 | end 183 | return rs 184 | end, 185 | remove=function(self,k) 186 | assert(type(k)~="number") 187 | local k_i=getmetatable(self).__k_i 188 | local removeIndex=k_i[k] 189 | table.remove(self,removeIndex) 190 | rawset(k_i,k,nil) 191 | rawset(self,k,nil) 192 | for i=removeIndex,#self do 193 | k_i[self[i].key]=i 194 | end 195 | return 196 | end 197 | } 198 | 199 | 200 | 201 | _M.unitTest={} 202 | function _M.unitTest.OrderedTable() 203 | local t=_M.OrderedTable:new() 204 | t.A=1 205 | t.B=2 206 | t.C=3 207 | t.D=4 208 | t.E=5 209 | print(_M.printTableF(t)) 210 | assert(t[1].key=="A",t[1].key) 211 | assert(t[2].key=="B",t[2].key) 212 | assert(t[3].key=="C",t[3].key) 213 | assert(#t==5,#t) 214 | t:remove("B") 215 | print(_M.printTableF(t)) 216 | assert(t[1].key=="A",t[1].key) 217 | assert(t[2].key=="C",t[2].key) 218 | assert(t[3].key=="D",t[3].key) 219 | assert(#t==4) 220 | end 221 | 222 | function _M.unitTest.concat() 223 | local t={[1]=1,[2]="",[3]=nil,[4]="abc",[5]=4524354326} 224 | assert(_M.concat(t,";")=="1;;nil;abc;4524354326",_M.concat(t,";")) 225 | end 226 | 227 | function _M.test() 228 | for k,v in pairs(_M.unitTest) do 229 | print("------------running "..k) 230 | v() 231 | print("------------"..k.." finished") 232 | end 233 | end 234 | 235 | return _M -------------------------------------------------------------------------------- /utils/unicode.lua: -------------------------------------------------------------------------------- 1 | -- Localize a few functions for a tiny speed boost, since these will be looped 2 | -- over every char of a string 3 | require "suproxy.utils.stringUtils" 4 | require "suproxy.utils.pureluapack" 5 | local byte = string.byte 6 | local char = string.char 7 | local pack = string.pack 8 | local unpack = string.unpack 9 | local concat = table.concat 10 | 11 | local _M={} 12 | 13 | ---Decode a buffer containing Unicode data. 14 | --@param buf The string/buffer to be decoded 15 | --@param decoder A Unicode decoder function (such as utf8_dec) 16 | --@param bigendian For encodings that care about byte-order (such as UTF-16), 17 | -- set this to true to force big-endian byte order. Default: 18 | -- false (little-endian) 19 | --@return A list-table containing the code points as numbers 20 | function _M.decode(buf, decoder, bigendian) 21 | local cp = {} 22 | local pos = 1 23 | while pos <= #buf do 24 | pos, cp[#cp+1] = decoder(buf, pos, bigendian) 25 | end 26 | return cp 27 | end 28 | 29 | ---Encode a list of Unicode code points 30 | --@param list A list-table of code points as numbers 31 | --@param encoder A Unicode encoder function (such as utf8_enc) 32 | --@param bigendian For encodings that care about byte-order (such as UTF-16), 33 | -- set this to true to force big-endian byte order. Default: 34 | -- false (little-endian) 35 | --@return An encoded string 36 | function _M.encode(list, encoder, bigendian) 37 | local buf = {} 38 | for i, cp in ipairs(list) do 39 | buf[i] = encoder(cp, bigendian) 40 | end 41 | return table.concat(buf, "") 42 | end 43 | 44 | ---Transcode a string from one format to another 45 | -- 46 | --The string will be decoded and re-encoded in one pass. This saves some 47 | --overhead vs simply passing the output of unicode.encode to 48 | --unicode.decode. 49 | --@param buf The string/buffer to be transcoded 50 | --@param decoder A Unicode decoder function (such as utf16_dec) 51 | --@param encoder A Unicode encoder function (such as utf8_enc) 52 | --@param bigendian_dec Set this to true to force big-endian decoding. 53 | --@param bigendian_enc Set this to true to force big-endian encoding. 54 | --@return An encoded string 55 | function _M.transcode(buf, decoder, encoder, bigendian_dec, bigendian_enc) 56 | local out = {} 57 | local cp 58 | local pos = 1 59 | while pos <= #buf do 60 | pos, cp = decoder(buf, pos, bigendian_dec) 61 | out[#out+1] = encoder(cp, bigendian_enc) 62 | end 63 | return table.concat(out) 64 | end 65 | 66 | --- Determine (poorly) the character encoding of a string 67 | -- 68 | -- First, the string is checked for a Byte-order Mark (BOM). This can be 69 | -- examined to determine UTF-16 with endianness or UTF-8. If no BOM is found, 70 | -- the string is examined. 71 | -- 72 | -- If null bytes are encountered, UTF-16 is assumed. Endianness is determined 73 | -- by byte position, assuming the null is the high-order byte. Otherwise, if 74 | -- byte values over 127 are found, UTF-8 decoding is attempted. If this fails, 75 | -- the result is 'other', otherwise it is 'utf-8'. If no high bytes are found, 76 | -- the result is 'ascii'. 77 | -- 78 | --@param buf The string/buffer to be identified 79 | --@param len The number of bytes to inspect in order to identify the string. 80 | -- Default: 100 81 | --@return A string describing the encoding: 'ascii', 'utf-8', 'utf-16be', 82 | -- 'utf-16le', or 'other' meaning some unidentified 8-bit encoding 83 | function _M.chardet(buf, len) 84 | local limit = len or 100 85 | if limit > #buf then 86 | limit = #buf 87 | end 88 | -- Check BOM 89 | if limit >= 2 then 90 | local bom1, bom2 = byte(buf, 1, 2) 91 | if bom1 == 0xff and bom2 == 0xfe then 92 | return 'utf-16le' 93 | elseif bom1 == 0xfe and bom2 == 0xff then 94 | return 'utf-16be' 95 | elseif limit >= 3 then 96 | local bom3 = byte(buf, 3) 97 | if bom1 == 0xef and bom2 == 0xbb and bom3 == 0xbf then 98 | return 'utf-8' 99 | end 100 | end 101 | end 102 | -- Try bytes 103 | local pos = 1 104 | local high = false 105 | local utf8 = true 106 | while pos < limit do 107 | local c = byte(buf, pos) 108 | if c == 0 then 109 | if pos % 2 == 0 then 110 | return 'utf-16le' 111 | else 112 | return 'utf-16be' 113 | end 114 | utf8 = false 115 | pos = pos + 1 116 | elseif c > 127 then 117 | if not high then 118 | high = true 119 | end 120 | if utf8 then 121 | local p, cp = utf8_dec(buf, pos) 122 | if not p then 123 | utf8 = false 124 | else 125 | pos = p 126 | end 127 | end 128 | if not utf8 then 129 | pos = pos + 1 130 | end 131 | else 132 | pos = pos + 1 133 | end 134 | end 135 | if high then 136 | if utf8 then 137 | return 'utf-8' 138 | else 139 | return 'other' 140 | end 141 | else 142 | return 'ascii' 143 | end 144 | end 145 | 146 | ---Encode a Unicode code point to UTF-16. See RFC 2781. 147 | -- 148 | -- Windows OS prior to Windows 2000 only supports UCS-2, so beware using this 149 | -- function to encode code points above 0xFFFF. 150 | --@param cp The Unicode code point as a number 151 | --@param bigendian Set this to true to encode big-endian UTF-16. Default is 152 | -- false (little-endian) 153 | --@return A string containing the code point in UTF-16 encoding. 154 | function _M.utf16_enc(cp, bigendian) 155 | local fmt = "= 0xD800 and cp <= 0xDFFF then 193 | local high = bit.lshift((cp - 0xD800) ,10) 194 | cp, pos = unpack(fmt, buf, pos) 195 | cp = 0x10000 + high + cp - 0xDC00 196 | end 197 | return pos, cp 198 | end 199 | 200 | ---Encode a Unicode code point to UTF-8. See RFC 3629. 201 | -- 202 | -- Does not check that cp is a real character; that is, doesn't exclude the 203 | -- surrogate range U+D800 - U+DFFF and a handful of others. 204 | --@param cp The Unicode code point as a number 205 | --@return A string containing the code point in UTF-8 encoding. 206 | function _M.utf8_enc(cp) 207 | local bytes = {} 208 | local n, mask 209 | if cp % 1.0 ~= 0.0 or cp < 0 then 210 | -- Only defined for nonnegative integers. 211 | return nil 212 | elseif cp <= 0x7F then 213 | -- Special case of one-byte encoding. 214 | return char(cp) 215 | elseif cp <= 0x7FF then 216 | n = 2 217 | mask = 0xC0 218 | elseif cp <= 0xFFFF then 219 | n = 3 220 | mask = 0xE0 221 | elseif cp <= 0x10FFFF then 222 | n = 4 223 | mask = 0xF0 224 | else 225 | return nil 226 | end 227 | 228 | while n > 1 do 229 | bytes[n] = char(0x80 + bit.band(cp, 0x3F)) 230 | cp = bit.rshift(cp, 6) 231 | n = n - 1 232 | end 233 | bytes[1] = char(mask + cp) 234 | 235 | return table.concat(bytes) 236 | end 237 | 238 | ---Decodes a UTF-8 character. 239 | -- 240 | -- Does not check that the returned code point is a real character. 241 | --@param buf A string containing the character 242 | --@param pos The index in the string where the character begins 243 | --@return pos The index in the string where the character ended or nil on error 244 | --@return cp The code point of the character as a number, or an error string 245 | function _M.utf8_dec(buf, pos) 246 | pos = pos or 1 247 | local n, mask 248 | local bv = byte(buf, pos) 249 | if bv <= 0x7F then 250 | return pos+1, bv 251 | elseif bv <= 0xDF then 252 | --110xxxxx 10xxxxxx 253 | n = 1 254 | mask = 0xC0 255 | elseif bv <= 0xEF then 256 | --1110xxxx 10xxxxxx 10xxxxxx 257 | n = 2 258 | mask = 0xE0 259 | elseif bv <= 0xF7 then 260 | --11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 261 | n = 3 262 | mask = 0xF0 263 | else 264 | return nil, string.format("Invalid UTF-8 byte at %d", pos) 265 | end 266 | 267 | local cp = bv - mask 268 | 269 | if pos + n > #buf then 270 | return nil, string.format("Incomplete UTF-8 sequence at %d", pos) 271 | end 272 | for i = 1, n do 273 | bv = byte(buf, pos + i) 274 | if bv < 0x80 or bv > 0xBF then 275 | return nil, string.format("Invalid UTF-8 sequence at %d", pos + i) 276 | end 277 | cp = bit.lshift(cp ,6) + bit.band(bv , 0x3F) 278 | end 279 | 280 | return pos + 1 + n, cp 281 | end 282 | 283 | ---Helper function for the common case of UTF-16 to UTF-8 transcoding, such as 284 | --from a Windows/SMB unicode string to a printable ASCII (subset of UTF-8) 285 | --string. 286 | --@param from A string in UTF-16, little-endian 287 | --@return The string in UTF-8 288 | function _M.utf16to8(from) 289 | return _M.transcode(from, _M.utf16_dec, _M.utf8_enc, false, nil) 290 | end 291 | 292 | ---Helper function for the common case of UTF-8 to UTF-16 transcoding, such as 293 | --from a printable ASCII (subset of UTF-8) string to a Windows/SMB unicode 294 | --string. 295 | --@param from A string in UTF-8 296 | --@return The string in UTF-16, little-endian 297 | function _M.utf8to16(from) 298 | return _M.transcode(from, _M.utf8_dec, _M.utf16_enc, nil, false) 299 | end 300 | 301 | function _M.test() 302 | local str="中华A已经Bあまり哈哈哈1234567" 303 | local b=_M.decode(str,_M.utf8_dec,false) 304 | print(require("suproxy.utils.json").encode(b)) 305 | print(_M.encode(b,_M.utf8_enc,false):hexF()) 306 | print(str:hexF()) 307 | local b=_M.decode(str,_M.utf8_dec,false) 308 | print(require("suproxy.utils.json").encode(b)) 309 | 310 | local str="12345abcd" 311 | print(_M.utf8to16(str):hex()) 312 | 313 | end 314 | 315 | return _M -------------------------------------------------------------------------------- /utils/utils.lua: -------------------------------------------------------------------------------- 1 | local _M = {_VERSION="0.1.11"} local table_insert = table.insert local table_concat = table.concat 2 | 3 | function _M.getTime() return ngx and ngx.time() or os.time() end 4 | function _M.addParamToUrl(urlString, paramName,paramValue) 5 | 6 | if urlString==nil then urlString="" end 7 | 8 | if paramValue==nill then paramValue="" end 9 | 10 | if paramName==nil then return urlString end 11 | 12 | if string.find(urlString,"?") then urlString=urlString.."&" else urlString=urlString.."?" end 13 | 14 | urlString=urlString..paramName.."="..paramValue 15 | 16 | return urlString 17 | end 18 | 19 | function _M._86_64() 20 | return 0xfffffffff==0xffffffff and 32 or 64 21 | end 22 | 23 | function _M.removeParamFromUrl(urlString, paramName) 24 | 25 | if urlString==nil then urlString="" end 26 | 27 | if paramName==nil then return urlString end 28 | 29 | urlString=string.gsub (urlString,"[\\?\\&]"..paramName.."=?[^&$]*", "") 30 | 31 | ngx.log(ngx.DEBUG,"urlString:"+urlString) 32 | 33 | local qmarkindex=string.find(urlString,"\\?") 34 | local andmarkindex=string.find(urlString,"\\&") 35 | if qmarkindex==-1 and andmarkindex>0 then 36 | urlString=string.gsub (urlString,"\\&", "?") 37 | end 38 | 39 | 40 | return urlString 41 | end 42 | 43 | 44 | --get url or post arguments from request 45 | function _M.getArgsFromRequest(argName) 46 | local args=ngx.req.get_uri_args() 47 | local result=args[argName] 48 | if result==nil and "POST" == ngx.var.request_method then 49 | ngx.req.read_body() 50 | args = ngx.req.get_post_args() 51 | result=args[argName] 52 | end 53 | return result 54 | end 55 | 56 | 57 | function _M.error(msg, detail, status) local cjson=require("cjson") 58 | if status then ngx.status = status end 59 | ngx.say(cjson.encode({ msg = msg, detail = detail })) 60 | ngx.log(ngx.ERR,cjson.encode({ msg = msg, detail = detail })) 61 | if status then ngx.exit(status) end 62 | end 63 | 64 | local errors = { 65 | UNAVAILABLE = 'upstream-unavailable', 66 | QUERY_ERROR = 'query-failed' 67 | } 68 | 69 | _M.errors = errors 70 | local function request(method) 71 | return function(url, payload, headers) 72 | headers = headers or {} 73 | headers['Content-Type'] = 'application/x-www-form-urlencoded' 74 | local httpc = require( "resty.http" ).new() 75 | local params = {headers = headers, method = method } 76 | if string.sub(string.lower(url),1,5)=='https' then params.ssl_verify = true end 77 | if method == 'GET' then params.query = payload 78 | else params.body = payload end 79 | local res, err = httpc:request_uri(url, params) 80 | if err then 81 | ngx.log(ngx.ERR, table.concat( 82 | {method .. ' fail', url, payload}, '|' 83 | )) 84 | return nil, nil, errors.UNAVAILABLE 85 | else 86 | if res.status >= 400 then 87 | ngx.log(ngx.ERR, table.concat({ 88 | method .. ' fail code', url, res.status, res.body, 89 | }, '|')) 90 | return res.status, res.body, errors.QUERY_ERROR 91 | else 92 | return res.status, res.body, nil 93 | end 94 | end 95 | end 96 | end 97 | 98 | _M.jget = request('GET') 99 | _M.jput = request('PUT') 100 | _M.jpost = request('POST') 101 | 102 | 103 | function _M.unzip(inputString) local zlib=require('suproxy.utils.ffi-zlib') 104 | -- Reset vars 105 | local chunk = 16384 106 | local output_table = {} 107 | local count = 0 108 | local input = function(bufsize) 109 | ngx.log(ngx.DEBUG,"count:"..count) 110 | local start = count > 0 and bufsize*count or 1 111 | local data = inputString:sub(start, (bufsize*(count+1)-1) ) 112 | count = count + 1 113 | ngx.log(ngx.INFO,"--------data-----------") 114 | ngx.log(ngx.INFO,data) 115 | return data 116 | end 117 | 118 | local output = function(data) 119 | table_insert(output_table, data) 120 | end 121 | 122 | local ok, err = zlib.inflateGzip(input, output, chunk) 123 | if not ok then 124 | ngx.log(ngx.ERR,"unzip error") 125 | end 126 | local output_data = table_concat(output_table,'') 127 | 128 | 129 | return output_data 130 | end 131 | 132 | 133 | 134 | return _M --------------------------------------------------------------------------------