├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── LICENSE
├── balancer
└── balancer.lua
├── channel.lua
├── example
├── gateway.lua
└── session.lua
├── http
├── handlers.lua
├── initParam.lua
├── mockauth.lua
└── ssoProcessors.lua
├── ldap.lua
├── ldap
├── ldapPackets.lua
└── parser.lua
├── parser.lua
├── readme.md
├── session
├── session.lua
└── sessionManager.lua
├── ssh2.lua
├── ssh2
├── commandCollector.lua
├── parser.lua
├── shellCommand.lua
├── ssh2CipherConf.lua
└── ssh2Packets.lua
├── suproxy-v0.6.0-1.rockspec
├── tds.lua
├── tds
├── datetime.lua
├── parser.lua
├── tdsPackets.lua
├── token.lua
└── version.lua
├── test.lua
├── tns.lua
├── tns
├── crypt.lua
├── parser.lua
└── tnsPackets.lua
└── utils
├── asn1.lua
├── compatibleLog.lua
├── datetime.lua
├── event.lua
├── ffi-zlib.lua
├── json.lua
├── pureluapack.lua
├── stringUtils.lua
├── tableUtils.lua
├── unicode.lua
└── utils.lua
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, yizhu2000
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/balancer/balancer.lua:
--------------------------------------------------------------------------------
1 | --Default random balancer. this balancer randomly select upstreams from given
2 | --list. if one upstream is blamed, this upstream will be unselectable for given
3 | --suspendSpan time.
4 | local tableUtils=require "suproxy.utils.tableUtils"
5 | local utils=require "suproxy.utils.utils"
6 | local OrderedTable=tableUtils.OrderedTable
7 | local _M={}
8 | local function getKey(ip,port) return string.format("ip%sport%s",ip,port) end
9 |
10 | function _M:new(upstreams,suspendSpan)
11 | math.randomseed(utils.getTime())
12 | local o=setmetatable({},{__index=self})
13 | assert(upstreams,"upstreams can not be nil")
14 | o.upstreams=OrderedTable:new()
15 | o.blameList={}
16 | o.suspendSpan=suspendSpan or 30
17 | for i,v in ipairs(upstreams) do
18 | assert(v.ip,"upstream ip address cannot be null")
19 | assert(v.port,"upstream ip address cannot be null")
20 | o.upstreams[getKey(v.ip,v.port)]=v
21 | end
22 | return o
23 | end
24 |
25 | function _M:getBest()
26 | for k,v in pairs (self.blameList) do
27 | if v.addTime+self.suspendSpan<=utils.getTime() then
28 | self.upstreams[k]=v.value
29 | end
30 | end
31 | if #self.upstreams ==0 then return nil end
32 | local i=math.ceil(math.random(1,#self.upstreams))
33 | if self.upstreams[i] then
34 | return self.upstreams[i].value
35 | end
36 | end
37 |
38 | function _M:blame(upstream)
39 | assert(upstream.ip,"upstream ip address cannot be null")
40 | assert(upstream.port,"upstream ip address cannot be null")
41 | local key=getKey(upstream.ip,upstream.port)
42 | if self.upstreams[key] then
43 | self.blameList[key]={addTime=utils.getTime(),value=self.upstreams[key]}
44 | self.upstreams:remove(key)
45 | end
46 | end
47 |
48 | _M.unitTest={}
49 | function _M.test()
50 | print("------------running balancer test")
51 | local suspendSpan=5
52 | local a=_M:new({{ip=1,port=1},{ip=2,port=2},{ip=3,port=3}},suspendSpan)
53 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
54 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
55 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
56 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
57 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
58 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
59 | print(tableUtils.printTableF(a:getBest(),{inline=true}))
60 | assert(#(a.upstreams)==3)
61 | a:blame({ip=1,port=1})
62 | assert(#(a.upstreams)==2)
63 | assert(a.upstreams[getKey(1,1)]==nil)
64 | assert(a.upstreams[getKey(2,2)])
65 | assert(a.upstreams[getKey(3,3)])
66 | assert(a.blameList[getKey(1,1)])
67 | assert(a.blameList[getKey(1,1)].value.ip==1)
68 | print("------------wait ",suspendSpan," seconds")
69 | local t0 = os.clock()
70 | while os.clock() - t0 <= 5 do end
71 | a:getBest()
72 | print(tableUtils.printTableF(a.upstreams))
73 | assert(#(a.upstreams)==3)
74 | assert(#(a.blameList)==0)
75 | print("------------balancer test finished")
76 | end
77 | return _M
--------------------------------------------------------------------------------
/channel.lua:
--------------------------------------------------------------------------------
1 | local sub = string.sub
local byte = string.byte
local format = string.format
local tcp = ngx.socket.tcp
local setmetatable = setmetatable
local spawn = ngx.thread.spawn
local wait = ngx.thread.wait
local logger = require "suproxy.utils.compatibleLog"
local ses= require "suproxy.session.session"
local cjson=require "cjson"
2 | local event=require "suproxy.utils.event"
local balancer=require "suproxy.balancer.balancer"
local _M={}
3 |
4 | _M._VERSION = '0.01'
5 |
6 |
7 | function _M:new(upstreams,processor,options)
local o={}
options =options or {}
options.c2pConnTimeout=options.c2pConnTimeout or 10000
options.c2pSendTimeout=options.c2pSendTimeout or 10000
options.c2pReadTimeout=options.c2pReadTimeout or 3600000
options.p2sConnTimeout=options.p2sConnTimeout or 10000
options.p2sSendTimeout=options.p2sSendTimeout or 10000
options.p2sReadTimeout=options.p2sReadTimeout or 3600000
8 | local c2pSock, err = ngx.req.socket()
9 | if not c2pSock then
10 | return nil, err
11 | end
12 | c2pSock:settimeouts(options.c2pConnTimeout , options.c2pSendTimeout , options.c2pReadTimeout)
13 | local standalone=false
14 | if(not upstreams) then
15 | logger.log(logger.ERR, format("[SuProxy] no upstream specified, Proxy will run in standalone mode"))
16 | standalone=true
17 | end
18 | local p2sSock=nil
19 | if(not standalone) then
20 | p2sSock, err = tcp()
21 | if not p2sSock then
22 | return nil, err
23 | end
24 | p2sSock:settimeouts(options.p2sConnTimeout , options.p2sSendTimeout , options.p2sReadTimeout )
25 | end
26 | --add default receive-then-forward processor
27 | if(not processor and not standalone) then
28 | processor={}
29 | processor.processUpRequest=function(self)
30 | local data, err, partial =self.channel:c2pRead(1024*10)
--real error happend or timeout
if not data and not partial and err then return nil,err end
31 | if(data and not err) then
32 | return data
33 | else
34 | return partial
35 | end
36 | end
37 | processor.processDownRequest=function(self)
38 | local data, err, partial = self.channel:p2sRead(1024*10)
--real error happend or timeout
if not data and not partial and err then return nil,err end
39 | if(data and not err) then
40 | return data
41 | else
42 | return partial
43 | end
44 | end
45 | end
46 | --add default echo processor if proxy in standalone mode
47 | if(not processor and standalone) then
48 | processor={}
49 | processor.processUpRequest=function(self)
50 | local data, err, partial =self.channel:c2pRead(1024*10)
51 | --real error happend or timeout
if not data and not partial and err then return nil,err end
52 | local echodata=""
53 | if(data and not err) then
54 | echodata=data
55 | else
56 | echodata=partial
57 | end
58 | logger.log(logger.INFO,echodata)
59 | local _,err=self.channel:c2pSend(echodata)
60 | logger.log(logger.ERR,partial)
61 | end
62 | end
63 | local upForwarder=function(self,data)
64 | if data then return self.channel:p2sSend(data) end
65 | end
66 | local downForwarder=function(self,data)
67 | if data then return self.channel:c2pSend(data) end
68 | end
69 | --add default upforwarder
70 | processor.sendUp=processor.sendUp or upForwarder
71 | --add default downforwarder
72 | processor.sendDown=processor.sendDown or downForwarder
73 |
processor.ctx=processor.ctx or {}
local sessionInvalidHandler=function (self,session)
logger.log(logger.DEBUG,"session closed")
self:shutdown()
end
--set default session invalid handler
processor.sessionInvalid=processor.sessionInvalid or sessionInvalidHandler
--set AuthSuccessEvent handler
if processor.AuthSuccessEvent then
processor.AuthSuccessEvent:addHandler(o,function(self,source,username)
if self.session and username then self.session.uid=username end
end)
end
--update ctx info to session
if processor.ContextUpdateEvent then
processor.ContextUpdateEvent:addHandler(o,function(self,source,ctx)
if ctx and self.session then
self.session.ctx=ctx
end
end)
end
o.p2sSock=p2sSock
o.c2pSock=c2pSock
o.processor=processor
o.balancer=upstreams.getBest and upstreams or balancer:new(upstreams)
o.standalone=standalone
o.OnConnectEvent=event:new(o,"OnConnectEvent")
o.sessionMan=options.sessionMan or ses:newDoNothing()
setmetatable(o, { __index = self })
processor.channel=o
return o
74 | end
75 |
local function _cleanup(self)
76 | logger.log(logger.DEBUG, format("[SuProxy] clean up executed"))
77 | -- make sure buffers are clean
78 | ngx.flush(true)
79 | local p2sSock = self.p2sSock
80 | local c2pSock = self.c2pSock
81 | if p2sSock ~= nil then
82 | if p2sSock.shutdown then
83 | p2sSock:shutdown("send")
84 | end
85 | if p2sSock.close ~= nil then
86 | local ok, err = p2sSock:setkeepalive()
87 | if not ok then
88 | --
89 | end
90 | end
91 | end
92 |
93 | if c2pSock ~= nil then
94 | if c2pSock.shutdown then
95 | c2pSock:shutdown("send")
96 | end
97 | if c2pSock.close ~= nil then
98 | local ok, err = c2pSock:close()
99 | if not ok then
100 | --
101 | end
102 | end
103 | end
104 |
105 | end
106 |
local function _upl(self)
107 | -- proxy client request to server
local upstream=self.upstream
108 | local buf, err, partial
local session,err=ses:new(self.processor._PROTOCAL,self.sessionMan)
if err then
logger.log(logger.ERR, format("[SuProxy] start session fail: %s:%s, err:%s", upstream.ip, upstream.port, err))
return
end
self.processor.ctx.clientIP=ngx.var.remote_addr
self.processor.ctx.clientPort=ngx.var.remote_port
self.processor.ctx.srvIP=upstream.ip
self.processor.ctx.srvPort=upstream.port
self.processor.ctx.srvID=upstream.id
self.processor.ctx.srvGID=upstream.gid
self.processor.ctx.connTime=ngx.time()
session.ctx=self.processor.ctx
self.session=session
self.OnConnectEvent:trigger({clientIP=session.ctx.clientIP,clientPort=session.ctx.clientPort,srvIP=session.ctx.srvIP,srvPort=session.ctx.srvPort})
109 | while true do
--todo: sessionMan should notify session change
if not self.session:valid(self.session)
then self.processor:sessionInvalid(self.session)
else self.session.uptime=ngx.time() end
logger.log(logger.DEBUG,"client --> proxy start process")
110 | buf, err, partial = self.processor:processUpRequest(self.standalone)
111 | if err then
112 | logger.log(logger.ERR, format("[SuProxy] processUpRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err))
113 | break
114 | end
115 | --if in standalone mode, don't forward
116 | if not self.standalone and buf then
117 | local _, err = self.processor:sendUp(buf)
118 | if err then
119 | logger.log(logger.ERR, format("[SuProxy] forward to upstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err))
120 | break
121 | end
122 | end
123 | end
self:shutdown(upstream)
124 | end
125 |
local function _dwn(self)
local upstream=self.upstream
126 | -- proxy response to client
127 | local buf, err, partial
128 | while true do
logger.log(logger.DEBUG,"server --> proxy start process")
129 | buf, err, partial = self.processor:processDownRequest(self.standalone)
130 | if err then
131 | logger.log(logger.ERR, format("[SuProxy] processDownRequest fail: %s:%s, err:%s", upstream.ip, upstream.port, err))
132 | break
133 | end
134 | if buf then
135 | local _, err = self.processor:sendDown(buf)
136 | if err then
137 | logger.log(logger.ERR, format("[SuProxy] forward to downstream fail: %s:%s, err:%s", upstream.ip, upstream.port, err))
138 | break
139 | end
140 | end
141 | end
self:shutdown(upstream)
142 | end
function _M:c2pRead(length)
local bytes,err,partial= self.c2pSock:receive(length)
logger.logWithTitle(logger.DEBUG,"c2pRead",(bytes and bytes:hex16F() or ""))
return bytes,err,partial
end
function _M:p2sRead(length)
local bytes,err,partial= self.p2sSock:receive(length)
logger.logWithTitle(logger.DEBUG,"p2sRead",(bytes and bytes:hex16F() or ""))
return bytes,err,partial
end
function _M:c2pSend(bytes)
logger.logWithTitle(logger.DEBUG,"c2pSend",(bytes and bytes:hex16F() or ""))
return self.c2pSock:send(bytes)
end
function _M:p2sSend(bytes)
logger.logWithTitle(logger.DEBUG,"p2sSend",(bytes and bytes:hex16F() or ""))
return self.p2sSock:send(bytes)
end
143 | function _M:run()
--this while is to ensure _cleanup will always be executed
144 | while true do
local upstream
145 | if(not self.standalone) then
while true do
upstream=self.balancer:getBest()
if not upstream then
logger.log(logger.ERR, format("[SuProxy] failed to get avaliable upstream"))
break
end
146 | local ok, err = self.p2sSock:connect(upstream.ip, upstream.port)
147 | if not ok then
148 | logger.log(logger.ERR, format("[SuProxy] failed to connect to proxy upstream: %s:%s, err:%s", upstream.ip, upstream.port, err))
149 | self.balancer:blame(upstream)
150 | else
logger.log(logger.INFO, format("[SuProxy] connect to proxy upstream: %s:%s", upstream.ip, upstream.port))
self.upstream=upstream
break
end
end
151 | end
if not self.standalone and not upstream then
break
end
--_singThreadRun(self)
152 | local co_upl = spawn(_upl,self)
153 | if(not self.standalone) then
154 | local co_dwn = spawn(_dwn,self)
155 | wait(co_dwn)
156 | end
157 | wait(co_upl)
158 | break
159 | end
160 | _cleanup(self)
161 | end
function _M:shutdown()
if self.session then
--self.processor:sessionInvalid(self.session)
local err=self.session:kill(self.session)
if err then
logger.log(logger.ERR, format("[SuProxy] kill session fail: %s:%s, err:%s", self.upstream.ip, self.upstream.port, err))
end
end
_cleanup(self)
end
162 |
163 | return _M
164 |
--------------------------------------------------------------------------------
/example/gateway.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | This demo implements a simple gateway for TNS,TDS,SSH2,LDAP protocols.
3 | To test this demo, modify nginx config, add following section to your
4 | config file. Config server credential in getCredential method, then
5 | use test/test as username/password to login.
6 | make sure the commands.log file path is valid.
7 | stream {
8 | lua_code_cache off;
9 | #mock logserver if you do not have one
10 | server {
11 | listen 12080;
12 | content_by_lua_block {
13 | ngx.log(ngx.DEBUG,"logserver Triggerred")
14 | local reqsock, err = ngx.req.socket(true)
15 | reqsock:settimeout(100)
16 | while(not err) do
17 | local command,err=reqsock:receive()
18 | if(err) then ngx.exit(0) end
19 | local f = assert(io.open("/data/logs/commands.log", "a"))
20 | if(command) then
21 | f:write(command .. "\n")
22 | f:close()
23 | end
24 | end
25 | }
26 | }
27 | #listen on ports
28 | server {
29 | listen 389;
30 | listen 1521;
31 | listen 22;
32 | listen 1433;
33 | content_by_lua_file lualib/suproxy/example/gateway.lua;
34 | }
35 | }
36 | #Session manager interfaces. if you want to view and manage your session
37 | #over http, this should be set.
38 | http {
39 | include mime.types;
40 | lua_code_cache off;
41 | server {
42 | listen 80;
43 | server_name localhost;
44 | default_type text/html;
45 | location /suproxy/manage{
46 | content_by_lua_file lualib/suproxy/example/session.lua;
47 | }
48 | }
49 | ]]
50 | -------------------------socket logger init-----------------------
51 | local logger = require "resty.logger.socket"
52 | if not logger.initted() then
53 | local ok, err = logger.init{
54 | -- logger server address
55 | host = '127.0.0.1',
56 | port = 12080,
57 | flush_limit = 10,
58 | drop_limit = 567800,
59 | }
60 | if not ok then
61 | ngx.log(ngx.ERR, "failed to initialize the logger: ",err)
62 | return
63 | end
64 | end
65 | ----------------------sessiom Man init----------------------------
66 | local sessionManager= require ("suproxy.session.sessionManager"):new{
67 | --redis server address
68 | ip="127.0.0.1",
69 | port=6379,
70 | expire=-1,
71 | extend=false
72 | }
73 | ----------------------------handlers -----------------------------
74 | --Demo for swap credentials, oauth or other network auth method should be
75 | --used in real world instead of hardcoded credential
76 | local function getCredential(context,source,credential,session)
77 | local username=credential.username
78 | --if session.srvGID=="linuxServer" and session.srvID=="remote" then
79 | return {username="root",password="xxxxxx"}
80 | --end
81 | end
82 |
83 | --Demo for oauth password exchange
84 | local function oauth(context,source,credential,session)
85 | --show how to get password with oauth protocal,using username as code, an
86 | --app should be add and a password attributes should be add
87 | local param={
88 | ssoProtocol="OAUTH",
89 | validate_code_url="http://changeToYourOwnTokenLink",
90 | profile_url="http://changeToYourOwnProfile",
91 | client_secret="changeToYourOwnSecret",
92 | appcode=session.srvGID
93 | }
94 | local authenticator=ssoProcessors.getProcessor(param)
95 | local result=authenticator:valiate(credential.username)
96 | local cred
97 | if result.status==ssoProcessors.CHECK_STATUS.SUCCESS then
98 | --confirm oauth password attributes correctly configged
99 | cred.username=result.accountData.user
100 | cred.password=result.accountData.attributes.password
101 | end
102 | if not cred then return cred,"can not get cred from remote server" end
103 | return cred
104 | end
105 |
106 | --Demo for command filter
107 | local function commandFilter(context,source,command,session)
108 | --change this to implement your own filter strategy
109 | if command:match("forbidden")then
110 | return nil,{message=command.." is a forbidden command",code=1234}
111 | end
112 | return command
113 | end
114 |
115 | local function log(username,content,session)
116 | local rs={
117 | os.date("%Y.%m.%d %H:%M:%S", ngx.time()).."\t" ,
118 | session.clientIP..":"..session.clientPort.."\t",
119 | username and username.."\t" or "UNKNOWN ",
120 | content.."\r\n"
121 | }
122 | local bytes, err = logger.log(table.concat(rs))
123 | if err then
124 | ngx.log(ngx.ERR, "failed to log command: ", err)
125 | end
126 | end
127 |
128 | --Demo for login logging
129 | local function logAuth(context,source,username,session)
130 | local rs={
131 | "login with ",
132 | (session and session.client) and session.client or "unknown client",
133 | (session and session.clientVersion) and session.clientVersion or ""
134 | }
135 | log(username,table.concat(rs),session)
136 | end
137 |
138 | --Demo for connect logging
139 | local function logConnect(context,source,connInfo)
140 | local rs={"connect to ",connInfo.srvIP..":"..connInfo.srvPort}
141 | log("UNKNOWN",table.concat(rs),connInfo)
142 | end
143 |
144 | --Demo for command logging
145 | local function logCmd(context,source,command,reply,session)
146 | local username=session.username
147 | log(username,command,session)
148 | if not reply or reply=="" then return end
149 | local bytes, err = logger.log("------------------reply--------------------\r\n"
150 | ..reply:sub(1,4000)
151 | .."\r\n----------------reply end------------------\r\n\r\n")
152 | if err then
153 | ngx.log(ngx.ERR, "failed to log reply: ", err)
154 | end
155 | end
156 |
157 | --Demo for login fail logging
158 | local function logAuthFail(context,source,failInfo,session)
159 | log(failInfo.username,
160 | "login fail, fail message: "..(failInfo.message or ""),
161 | session)
162 | end
163 |
164 | --Demo for self-defined authenticator
165 | local function authenticator(context,source,credential,session)
166 | local result=credential.username=="test" and credential.password=="test"
167 | local message=(not result) and "login with "..credential.username.." failed"
168 | return result,message
169 | end
170 |
171 | --Demo for auto response ldap search command, this Demo shows how to handle parser events
172 | local function ldap_SearchRequestHandler(context,src,p)
173 | if context.command:match("pleasechangeme") then
174 | local packets=require("suproxy.ldap.ldapPackets")
175 | local response=packets.SearchResultEntry:new()
176 | local done=packets.SearchResultDone:new()
177 | response.objectName="cn=admin,dc=www,dc=test,dc=com"
178 | response.messageId=p.messageId
179 | response.attributes={
180 | {attrType="objectClass",values={"posixGroup","top"}},
181 | {attrType="cn",values={"group"}},
182 | {attrType="memberUid",values={"haha","test","test"}},
183 | {attrType="gidNumber",values={"44789"}},
184 | {attrType="description",values={"group"}}
185 | }
186 | done.resultCode=packets.ResultCode.success
187 | done.messageId=p.messageId
188 | response:pack() done:pack()
189 | context.channel:c2pSend(response.allBytes..done.allBytes)
190 | --stop forwarding
191 | p.allBytes=""
192 | end
193 | end
194 |
195 | --Demo for change the welcome info of ssh2 server
196 | local function myWelcome(context,source)
197 | local digger={"\r\n",
198 | [[ .-. ]].."\r\n",
199 | [[ / \ ]].."\r\n",
200 | [[ _____.....-----|(o) | ]].."\r\n",
201 | [[ _..--' _..--| .'' ]].."\r\n",
202 | [[ .' o _..--'' | | | ]].."\r\n",
203 | [[ / _/_..--'' | | | ]].."\r\n",
204 | [[ ________/ / / | | | ]].."\r\n",
205 | [[ | _ ____\ / / | | | ]].."\r\n",
206 | [[ _.-----._________|| || \\ / | | | ]].."\r\n",
207 | [[|=================||=||_____\\ |__|-' ]].."\r\n",
208 | [[| suproxy ||_||_____// (o\ | ]].."\r\n",
209 | [[|_________________|_________/ |-\| ]].."\r\n",
210 | [[ `-------------._______.----' / `. ]].."\r\n",
211 | [[ .,.,.,.,.,.,.,.,.,.,.,.,., / \]].."\r\n",
212 | [[ ((O) o o o o ======= o o(O)) ._.' /]].."\r\n",
213 | [[ `-.,.,.,.,.,.,.,.,.,.,.,-' `.......' ]].."\r\n",
214 | [[ scan me to login ]].."\r\n",
215 | "\r\n",
216 | }
217 | return table.concat(digger),false
218 | end
219 |
220 | local switch={}
221 | --dispatch different port to different channel
222 | --Demo for SSH2 processor
223 | switch[22]= function()
224 | local ssh=require("suproxy.ssh2"):new()
225 | ssh.AuthSuccessEvent:addHandler(ssh,logAuth)
226 | ssh.BeforeAuthEvent:addHandler(ssh,getCredential)
227 | ssh.OnAuthEvent:addHandler(ssh,authenticator)
228 | ssh.AuthFailEvent:addHandler(ssh,logAuthFail)
229 | local cmd=require("suproxy.ssh2.commandCollector"):new()
230 | cmd.CommandEnteredEvent:addHandler(ssh,commandFilter)
231 | cmd.CommandFinishedEvent:addHandler(ssh,logCmd)
232 | cmd.BeforeWelcomeEvent:addHandler(ssh,myWelcome)
233 | ssh.C2PDataEvent:addHandler(cmd,cmd.handleDataUp)
234 | ssh.S2PDataEvent:addHandler(cmd,cmd.handleDataDown)
235 | package.loaded.my_SSHB=package.loaded.my_SSHB or
236 | --change to your own upstreams
237 | require ("suproxy.balancer.balancer"):new{
238 | --{ip="127.0.0.1",port=2222,id="local",gid="linuxServer"},
239 | --{ip="192.168.46.128",port=22,id="remote",gid="linuxServer"},
240 | --{ip="192.168.1.121",port=22,id="UBUNTU14",gid="testServer"},
241 | {ip="192.168.1.152",port=22,id="UBUNTU20",gid="testServer"},
242 | --{ip="192.168.1.103",port=22,id="SUSE11",gid="testServer"},
243 | --{ip="192.168.1.186",port=22,id="OPENBSD",gid="testServer"},
244 | --{ip="192.168.1.187",port=22,id="FreeBSD",gid="testServer"},
245 | }
246 | local channel=require("suproxy.channel"):new(package.loaded.my_SSHB,ssh,{sessionMan=sessionManager})
247 | channel.OnConnectEvent:addHandler(channel,logConnect)
248 | channel:run()
249 | end
250 | --Demo for TNS processor
251 | switch[1521]=function()
252 | --server version is required for password substitution
253 | local tns=require("suproxy.tns"):new{oracleVersion=11,swapPass=false}
254 | tns.AuthSuccessEvent:addHandler(tns,logAuth)
255 | tns.CommandEnteredEvent:addHandler(tns,commandFilter)
256 | tns.CommandFinishedEvent:addHandler(tns,logCmd)
257 | tns.AuthFailEvent:addHandler(tns,logAuthFail)
258 | tns.BeforeAuthEvent:addHandler(tns,getCredential)
259 | --tns.OnAuthEvent:addHandler(tns,authenticator)
260 | package.loaded.my_OracleB=package.loaded.my_OracleB or
261 | --change to your own upstreams
262 | require ("suproxy.balancer.balancer"):new{
263 | {ip="192.168.1.96",port=1521,id="remote",gid="oracleServer"},
264 | --{ip="192.168.46.157",port=1522,id="local",gid="oracleServer"},
265 | --{ip="192.168.1.182",port=1521,id="182",gid="oracleServer"},
266 | --{ip="192.168.1.190",port=1521,id="oracle10",gid="oracleServer"},
267 | }
268 | local channel=require("suproxy.channel"):new(package.loaded.my_OracleB,tns,{sessionMan=sessionManager})
269 | channel.OnConnectEvent:addHandler(channel,logConnect)
270 | channel:run()
271 | end
272 | --Demo for LDAP processor
273 | switch[389]=function()
274 | local ldap=require("suproxy.ldap"):new()
275 | ldap.AuthSuccessEvent:addHandler(ldap,logAuth)
276 | ldap.AuthFailEvent:addHandler(ldap,logAuthFail)
277 | ldap.CommandEnteredEvent:addHandler(ldap,commandFilter)
278 | ldap.CommandFinishedEvent:addHandler(ldap,logCmd)
279 | ldap.BeforeAuthEvent:addHandler(ldap,getCredential)
280 | ldap.OnAuthEvent:addHandler(ldap,authenticator)
281 | ldap.c2pParser.events.SearchRequest:addHandler(ldap,ldap_SearchRequestHandler)
282 | --change to your own upstreams
283 | local channel=require("suproxy.channel"):new({{ip="192.168.46.128",port=389,id="ldap1",gid="ldapServer"}},ldap,{sessionMan=sessionManager})
284 | channel.OnConnectEvent:addHandler(channel,logConnect)
285 | channel:run()
286 | end
287 | --Demo for TDS processor
288 | switch[1433]=function()
289 | local tds=require("suproxy.tds"):new({disableSSL=false,catchReply=true})
290 | tds.AuthSuccessEvent:addHandler(tds,logAuth)
291 | tds.CommandEnteredEvent:addHandler(tds,commandFilter)
292 | tds.CommandFinishedEvent:addHandler(tds,logCmd)
293 | tds.BeforeAuthEvent:addHandler(tds,getCredential)
294 | tds.OnAuthEvent:addHandler(tds,authenticator)
295 | tds.AuthFailEvent:addHandler(tds,logAuthFail)
296 | package.loaded.my_SQLServerB=package.loaded.my_SQLServerB or
297 | --change to your own upstreams
298 | require ("suproxy.balancer.balancer"):new{
299 | {ip="192.168.1.135",port=1433,id="srv12",gid="sqlServer"},
300 | --{ip="192.168.1.120",port=1433,id="srv14",gid="sqlServer"}
301 | }
302 | local channel=require("suproxy.channel"):new(package.loaded.my_SQLServerB,tds,{sessionMan=sessionManager})
303 | channel.OnConnectEvent:addHandler(channel,logConnect)
304 | channel:run()
305 | end
306 |
307 | local fSwitch = switch[tonumber(ngx.var.server_port)]
308 | if fSwitch then
309 | fSwitch()
310 | end
311 |
312 |
--------------------------------------------------------------------------------
/example/session.lua:
--------------------------------------------------------------------------------
1 | --[[
This Demo shows how to manage session on http modify nginx config, To test it
change the redis ip and port setting and add following section to your config file
http {
include mime.types;
lua_code_cache off;
server {
listen 80;
server_name localhost;
default_type text/html;
location /suproxy/manage{
content_by_lua_file lualib/suproxy/example/session.lua;
}
}
]]
local utils= require "suproxy.utils.utils"
local cjson=require "cjson"
local redisIP="127.0.0.1"
local port=6379
local sessionMan=require ("suproxy.session.sessionManager"):new{ip=redisIP,port=port,expire=30}
if not sessionMan then
ngx.say("connect to Redis "..redisIP..":"..port.." failed" )
return
end
2 |
3 | if ngx.var.request_uri:match("/suproxy/manage/session/kill") then
4 | local sid=utils.getArgsFromRcequest("sid")
5 | local uid=utils.getArgsFromRequest("uid")
local result
6 | if not sid and not uid then
result=sessionMan:clear()
7 | elseif sid then
8 | result=sessionMan:kill(sid)
9 | else
10 | result=sessionMan:killSessionOfUser(uid)
11 | end
12 | ngx.say(result.." items is removed")
13 | elseif ngx.var.request_uri:match("/suproxy/manage/session/get") then
14 | local sid=utils.getArgsFromRequest("sid")
15 | local uid=utils.getArgsFromRequest("uid")
16 | if not sid and not uid then ngx.say("valid sid or uid should be provided") return end
17 | local result
18 | if sid then
19 | result=sessionMan:get(sid)
20 | else
21 | result=sessionMan:getSessionOfUser(uid)
22 | result=cjson.encode(result)
23 | end
24 | ngx.say(result)
25 | elseif ngx.var.request_uri:match("/suproxy/manage/session/all") then
26 | local result,count=sessionMan:getAll()
ngx.say("count:"..count)
27 | ngx.say(cjson.encode(result))
elseif ngx.var.request_uri:match("/suproxy/manage/session/clear") then
local result=sessionMan:clear()
ngx.say(result.." items is removed")
28 | end
29 |
30 |
31 |
--------------------------------------------------------------------------------
/http/handlers.lua:
--------------------------------------------------------------------------------
1 | -- Copyright (C) Joey Zhu
2 | local _M = {_VERSION="0.1.11"}
3 | local cjson=require("cjson")
4 | local utils=require("suproxy.utils.utils")
5 | local ssoProcessors=require("suproxy.http.ssoProcessors")
6 | local zlib = require('suproxy.utils.ffi-zlib')
7 |
8 | --加载内容
9 | function _M.loadUrl(url)
10 | res = ngx.location.capture(url,{method=ngx.HTTP_GET})
11 | return res.body
12 | end
13 |
14 | function _M.loadToolBar(localParams)
15 | local request_uri = ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port..ngx.var.request_uri
16 | local loginPath=utils.addParamToUrl(localParams.loginUrl,"appCode",localParams.appCode)
17 | if request_uri~=nil then
18 | --告诉登录处理器需要返回的地址
19 | local ctx={method=ngx.var.request_method,targetURL=request_uri}
20 | --构建其他想要登录处理器回传的参数
21 | loginPath=utils.addParamToUrl(loginPath,"context",ngx.encode_base64(cjson.encode(ctx)))
22 | end
23 | --status,body,err=utils.jget(ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port.."/auth/toolbar.html")
24 | body=_M.loadUrl("/auth/toolbar.html")
25 | local username=cjson.decode(ngx.var.userdata).user
26 | return string.gsub(string.gsub(body, "{{username}}", username),"{{loginUrl}}",loginPath)
27 | end
28 |
29 |
30 |
31 | --替换返回值内容
32 | function _M.replaceResponse(regex,replacement)
33 | _M.replaceResponseMutiple({{regex=regex,replacement=replacement}})
34 | end
35 |
36 | --替换返回值内容,同时替换多值
37 | function _M.replaceResponseMutiple(subs)
38 | local chunk, eof = ngx.arg[1], ngx.arg[2]
39 | local buffered = ngx.ctx.buffered
40 | if not buffered then
41 | buffered = {} -- XXX we can use table.new here
42 | ngx.ctx.buffered = buffered
43 | end
44 | if chunk ~= "" then
45 | buffered[#buffered + 1] = chunk
46 | ngx.arg[1] = nil
47 | end
48 | if eof then
49 | local whole = table.concat(buffered)
50 | ngx.ctx.buffered = nil
51 | -- try to unzip
52 | if ngx.var.upstreamEncoding=="gzip" then
53 | local debody = utils.unzip(whole)
54 | if debody then
55 | whole = debody
56 | end
57 | end
58 | -- try to add or replace response body
59 | -- local js_code = ...
60 | -- whole = whole .. js_code
61 | for i,v in ipairs(subs) do
62 | whole = string.gsub(whole, v.regex, v.replacement)
63 | end
64 | ngx.arg[1] = whole
65 | end
66 | end
67 |
68 | --检查权限
69 | function _M.accessCheck(localParams)
70 | local request_uri = ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port..ngx.var.request_uri
71 | ngx.var.callbackmethod=ngx.var.request_method
72 | if ngx.var.http_referer~=nil then
73 | ngx.var.referer=string.gsub(ngx.var.http_referer,ngx.var.scheme.."://"..ngx.var.server_name..":"..ngx.var.server_port, ngx.var.backaddress)
74 | end
75 |
76 | local session = require "resty.session".open()
77 | local processor=ssoProcessors.getProcessor(localParams.ssoParam)
78 | if processor==nil then
79 | utils.error("can't find processor for sso protocal "..localParams.ssoParam.ssoProtocol,nil,500)
80 | return
81 | end
82 | --检查请求是否单点登录验证
83 | local result=processor:checkRequest()
84 | -- 不是单点验证请求
85 | if result.status==ssoProcessors.CHECK_STATUS.NOT_SSO_CHECK then
86 | --如果开启登录验证,则验证session
87 | ngx.log(ngx.DEBUG,"sessionid:"..ngx.encode_base64(session.id))
88 | if localParams.checkLogin then
89 | ngx.log(ngx.DEBUG,"checkLogin:true"..request_uri)
90 | --session 已经存在
91 | if session.present then
92 |
93 | ngx.var.userdata=cjson.encode(session.data.user)
94 | return
95 | --session 不存在,直接跳转
96 | else
97 | local loginPath=utils.addParamToUrl(localParams.loginUrl,"appCode",localParams.appCode)
98 | if request_uri~=nil then
99 | --告诉登录处理器需要返回的地址
100 | local ctx={method=ngx.var.request_method,targetURL=request_uri}
101 | --构建其他想要登录处理器回传的参数
102 | loginPath=utils.addParamToUrl(loginPath,"context",ngx.encode_base64(cjson.encode(ctx)))
103 | end
104 | ---[[
105 | if processor.formatLoginPath then
106 | loginPath=processor:formatLoginPath(loginPath)
107 | end
108 | --]]
109 | ngx.log(ngx.DEBUG,"loginPath:"..loginPath)
110 | ngx.redirect(loginPath)
111 | return
112 | end
113 | --未开启登录验证,直接报错
114 | else
115 | ngx.log(ngx.DEBUG,"checkLogin:false")
116 | utils.error(result.message,nil,ngx.HTTP_UNAUTHORIZED)
117 | return
118 | end
119 | end
120 | result=processor:valiate()
121 | if result.status==ssoProcessors.CHECK_STATUS.SUCCESS then
122 | -- 验证成功,获取user并建立session
123 | ngx.var.userdata=cjson.encode(result.accountData)
124 | local context=utils.getArgsFromRequest("context")
125 | local contextJson=nil
126 | if context~=nil then
127 | contextJson=cjson.decode(ngx.decode_base64(context))
128 | else
129 | contextJson=localParams.defaultContext
130 | end
131 |
132 | if contextJson.method~=nil then
133 | ngx.var.callbackmethod=contextJson.method
134 | end
135 | session:start()
136 | session.data.user=result.accountData
137 | session:save()
138 | ngx.log(ngx.INFO, "Session Started -- " .. ngx.encode_base64(session.id))
139 | --若context中返回的最初访问的页面和当前页面不同,则直接跳转到context中指定的页面
140 | if contextJson.targetURL~=nil then
141 | --[[ --删除url中由登录处理程序附加的参数如context和xtoken参数再进行url比较
142 | if(ngx.req.get_uri_args()["context"]) then
143 | request_uri=utils.removeParamFromUrl(request_uri,"context")
144 | end
145 | if(ngx.req.get_uri_args()["x_token"]) then
146 | request_uri=utils.removeParamFromUrl(request_uri,"x_token")
147 | end
148 | if(ngx.req.get_uri_args()["ticket"]) then
149 | request_uri=utils.removeParamFromUrl(request_uri,"ticket")
150 | end
151 | if(ngx.req.get_uri_args()["params"]) then
152 | request_uri=utils.removeParamFromUrl(request_uri,"params")
153 | end
154 | --]]
155 | if contextJson.targetURL~=request_uri then
156 | return ngx.redirect(contextJson.targetURL)
157 | end
158 | end
159 | ngx.log(ngx.INFO, "Jump to targetURL -- " .. contextJson.targetURL)
160 | return
161 | elseif result.status==ssoProcessors.CHECK_STATUS.AUTH_FAIL then
162 | -- 验证失败
163 | utils.error(result.message,"Status:"..result.status.." Message:"..result.message,500)
164 | return
165 | else
166 | utils.error(result.message,"Status:"..result.status.." Message:"..result.message,500)
167 | return
168 | end
169 | end
170 |
171 | return _M
172 |
173 |
--------------------------------------------------------------------------------
/http/initParam.lua:
--------------------------------------------------------------------------------
1 | local cjson=require("cjson")
local Global_Params = {
2 | appCode ="gate",
3 | --loginUrl="/auth/ssologin.html",
4 | loginUrl="/auth/mocklogin.html",
5 | defaultContext={
6 | targetURL="/",
7 | method="GET"
8 | },
9 | ---[[
10 | ssoParam={
11 | ssoProtocol="JWT",
12 | secret="lua-resty-jwt"
13 | }
14 | --]]
15 | --[[
16 | ssoParam={
17 | ssoProtocol="CAS",
18 | validate_url="xxxxxxxxxxxxxx",
19 | service="xxxxxxxxxxxx"
20 | }
21 | --]]
22 | --[[
23 | ssoParam={
24 | ssoProtocol="OAUTH",
25 | validate_code_url="xxxxxxxxxxxxx",
26 | profile_url="xxxxxxxxxxxxxxx",
27 | callbackurl="xxxxxxxxxxx",
28 | client_secret="xxxxxxxxxxxxx"
29 | }
30 | --]]
31 | }
32 | Global_Params.ssoParam.appCode=Global_Params.appCode
33 | Global_Params.ssoParam.loginUrl=Global_Params.loginUrl
34 | return cjson.encode(Global_Params)
35 |
--------------------------------------------------------------------------------
/http/mockauth.lua:
--------------------------------------------------------------------------------
1 | local jwt = require "resty.jwt"
local utils= require "suproxy.utils.utils"
local cjson=require("cjson")
2 |
3 | --获取表单里的用户名密码,这里可以进行身份验证
local args = nil
4 | if "GET" == ngx.var.request_method then
5 | args = ngx.req.get_uri_args()
6 | elseif "POST" == ngx.var.request_method then
7 | ngx.req.read_body()
8 | args = ngx.req.get_post_args()
9 | end
local username=args["name"]
local password=args["pass"]
10 | if username == nil or password==nil then
11 | ngx.log(ngx.WARN, "Username and Password can not be empty ")
12 | ngx.exit(ngx.HTTP_UNAUTHORIZED)
13 | end
14 | ----模拟身份验证begin----
15 |
16 | if username~="admin1" or password~="aA123." then
17 | ngx.log(ngx.WARN, "Wrong Username or Password ")
18 | ngx.exit(ngx.HTTP_UNAUTHORIZED)
19 | end
20 |
21 | ----模拟身份验证end------
22 |
23 | --模拟签发jwt token
local jwt_token = jwt:sign(
24 | "lua-resty-jwt",
25 | {
26 | header={typ="JWT", alg="HS256"},
27 | payload={accountName="admin1"}
28 | }
29 | )
30 | --获取跳转变量
local redirectUrl=nil
local context=ngx.req.get_uri_args()["context"]
local jsonContext=cjson.decode(ngx.decode_base64(context))
31 | if jsonContext~=nil and jsonContext["targetURL"]~=nil then redirectUrl=jsonContext["targetURL"] else redirectUrl="/gateway/callback" end
32 |
33 | redirectUrl="/gateway/callback"
34 |
35 | --参数使用Post方式传递
local template=[[
36 |
40 |
41 | ]]
42 | ngx.say(string.format(template,redirectUrl,jwt_token,context))
43 |
44 |
45 | --[[
46 | --参数使用get redirect方式传递
47 | redirectUrl=utils.addParamToUrl(redirectUrl,"context",ngx_decode_base64(context))
48 | redirectUrl=utils.addParamToUrl(redirectUrl,"x_token",jwt_token)
49 | ngx.redirect(redirectUrl)
50 | --]]
--------------------------------------------------------------------------------
/http/ssoProcessors.lua:
--------------------------------------------------------------------------------
1 | local utils = require "suproxy.utils.utils"
local jwt = require "resty.jwt"
local cjson=require("cjson")
local _M = {_VERSION="0.1.11"}
2 |
3 | _M.CHECK_STATUS = {
4 | SUCCESS=0,
5 | NOT_SSO_CHECK=1,
6 | AUTH_FAIL=2,
7 | UNKNOWN_SSO_PROTOCAL=3
8 | }
9 |
10 | ------JWT Processor------------
local JWTProcessor={}
11 |
12 | function JWTProcessor.new(ssoParam)
13 | JWTProcessor.ssoParam=ssoParam
14 | return JWTProcessor
15 | end
16 |
17 | function JWTProcessor:checkRequest()
18 | local x_token=utils.getArgsFromRequest("x_token")
19 | --判断token是否传递 若没有传递,则不是JWT验证请求
20 | if x_token == nil then
21 | return {status=_M.CHECK_STATUS.NOT_SSO_CHECK,message="Not valid sso auth request"}
22 | end
23 | return {status=_M.CHECK_STATUS.SUCCESS}
24 | end
25 |
26 | function JWTProcessor:valiate()
27 | -- 验证token签名,如果验证错误,直接提示错误信息
28 | local x_token=utils.getArgsFromRequest("x_token")
29 | local jwt_obj = jwt:verify(self.ssoParam.secret,x_token)
30 | if jwt_obj.verified == false then
31 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message="Invalid token: ".. jwt_obj.reason}
32 | end
33 | return {status=_M.CHECK_STATUS.SUCCESS,accountData={user=jwt_obj.payload.accountName,attributes=jwt_obj.payload}}
34 | end
35 | ------JWT Processor end------------
36 | ------CAS Processor------------
local CASProcessor={}
37 |
38 | function CASProcessor.new(ssoParam)
39 | CASProcessor.ssoParam=ssoParam
40 | return CASProcessor
41 | end
42 |
43 | function CASProcessor:checkRequest()
44 | local ticket=utils.getArgsFromRequest("ticket");
45 | --判断ticket是否传递 若没有传递,则不是CAS验证请求
46 | if ticket == nil then
47 | return {status=_M.CHECK_STATUS.NOT_SSO_CHECK}
48 | end
49 | return {status=_M.CHECK_STATUS.SUCCESS}
50 | end
51 |
local function cas_ticket_verify(validate_url,service,ticket)
52 | -----网络请求验证begin------
53 | local payload = {
54 | service =service,
55 | ticket = ticket
56 | }
57 | local status, body, err = utils.jget(validate_url, ngx.encode_args(payload))
58 |
59 | if not status or status ~= 200 then
60 | return {success=false,message=err}
61 | end
62 |
63 | local decodeResponse=cjson.decode(body)
64 |
65 | if decodeResponse==nil or decodeResponse.serviceResponse==nil then
66 | return {success=false,"response format wrong, can't be parse to json"}
67 | end
68 |
69 | if decodeResponse.serviceResponse.authenticationFailure then
70 | return {success=false,
71 | string.format("ticket validation failure, code:%s description:%s",
72 | decodeResponse.serviceResponse.authenticationFailure.code,
73 | decodeResponse.serviceResponse.authenticationFailure.description)
74 | }
75 | end
76 |
77 | if decodeResponse.serviceResponse.authenticationSuccess then
78 | return {success=true,data=decodeResponse.serviceResponse.authenticationSuccess}
79 | end
80 |
81 | ------网络请求验证身份end----
82 | end
83 |
84 | function CASProcessor:valiate()
85 | local ticket=utils.getArgsFromRequest("ticket");
86 | -- 验证token签名,如果验证错误,直接提示错误信息
87 | local cas_result = cas_ticket_verify(self.ssoParam.validate_url,self.ssoParam.service,ticket)--注意如果使用ngiam sso,则需要修改秘钥与应用中的一致
88 | if not cas_result.success then
89 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message="Invalid ticket: ".. cas_result.message}
90 | end
91 |
92 | return {status=_M.CHECK_STATUS.SUCCESS,accountData={user=cas_result.data.user,attributes=cas_result.data}}
93 | end
94 | ------CAS Processor end------------
95 |
96 | ------OAUTH2.0 Processor------------
local OAUTHProcessor={}
97 |
98 | function OAUTHProcessor.new(ssoParam)
99 | OAUTHProcessor.ssoParam=ssoParam
100 | return OAUTHProcessor
101 | end
102 |
103 | function OAUTHProcessor:checkRequest(self)
104 | local code=utils.getArgsFromRequest("code")
105 | --判断code是否传递 若没有传递,则不是OAUTH验证请求
106 | if code == nil then
107 | return {status=_M.CHECK_STATUS.NOT_SSO_CHECK,message="Not valid sso auth request"}
108 | end
109 | return {status=_M.CHECK_STATUS.SUCCESS}
110 | end
111 |
112 | function OAUTHProcessor:get_token(code)
113 | local payload = {
114 | appcode = self.ssoParam.appCode,
115 | secret = self.ssoParam.client_secret,
116 | code = code
117 | }
118 | local status, body, err = utils.jpost(self.ssoParam.validate_code_url,ngx.encode_args(payload))
119 | if not status or status ~= 200 then
120 | return {success=false,message=err}
121 | else
122 | local result=cjson.decode(body)
123 | if result.errorCode then
124 | return {success=false,message=body}
125 | end
126 |
127 | return {success=true,token=result.accessToken}
128 | end
129 | end
130 |
131 | function OAUTHProcessor:get_profile(token)
132 | local status, body, err = utils.jget(
133 | self.ssoParam.profile_url,
134 | ngx.encode_args({
135 | appcode = self.ssoParam.appCode,
136 | secret = self.ssoParam.client_secret,
137 | token= token
138 | }))
139 | if not status or status ~= 200 then
140 | return {success=false,message=err}
141 | else
142 | ngx.log(ngx.DEBUG,body)
143 | local result=cjson.decode(body)
144 | if result.errorCode then
145 | return {success=false,message=body}
146 | end
147 | return {success=true,profile=result}
148 | end
149 | end
150 |
151 | function OAUTHProcessor:valiate(code)
152 | -- 验证token签名,如果验证错误,直接提示错误信息
153 | if not code then
154 | code=utils.getArgsFromRequest("code")
155 | end
156 |
157 | local result = self:get_token(code)
158 | if result.success==false then
159 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message=result.message}
160 | end
161 | local token=result.token
162 | result = self:get_profile(token)
163 | if result.success==false then
164 | return {status=_M.CHECK_STATUS.AUTH_FAIL,message=result.message}
165 | end
166 | return {status=_M.CHECK_STATUS.SUCCESS,accountData={user=result.profile.accountName,attributes=result.profile}}
167 |
168 | end
169 |
170 | function OAUTHProcessor:formatLoginPath(loginUrl)
171 | return loginUrl
172 | end
173 | ------JWT Processor end------------
174 |
175 |
local processorList = {
176 | ["JWT"] = function(param)
177 | return JWTProcessor.new(param)
178 | end,
179 | ["CAS"] = function(param)
180 | return CASProcessor.new(param)
181 | end,
182 | ["OAUTH"] =function(param)
183 | return OAUTHProcessor.new(param)
184 | end
185 | }
186 |
187 | function _M.getProcessor(ssoParam)
188 | local p= processorList[ssoParam.ssoProtocol]
189 | if p~=nil then
190 | return p(ssoParam)
191 | else
192 | return nil
193 | end
194 | end
195 |
196 | return _M
--------------------------------------------------------------------------------
/ldap.lua:
--------------------------------------------------------------------------------
1 | local asn1 = require("suproxy.utils.asn1")
local format = string.format
local ok,cjson=pcall(require,"cjson")
2 | if not ok then cjson = require("suproxy.utils.json") end
3 | local logger=require "suproxy.utils.compatibleLog"
local bunpack = asn1.bunpack
local fmt = string.format
4 | require "suproxy.utils.stringUtils"
5 | local tableUtils=require "suproxy.utils.tableUtils"
local pureluapack=require"suproxy.utils.pureluapack"
6 | local event=require "suproxy.utils.event"
7 | local ldapPackets=require "suproxy.ldap.ldapPackets"
local ResultCode=ldapPackets.ResultCode
8 | local _M = {}
9 | _M._PROTOCAL ='ldapv3'
10 | local encoder,decoder=asn1.ASN1Encoder:new(),asn1.ASN1Decoder:new()
11 | function _M:new()
12 | local o= setmetatable({},{__index=self})
13 | o.ctx={}
14 | o.AuthSuccessEvent=event:new(o,"AuthSuccessEvent")
15 | o.AuthFailEvent=event:new(o,"AuthFailEvent")
16 | o.BeforeAuthEvent=event:newReturnEvent(o,"BeforeAuthEvent")
o.OnAuthEvent=event:newReturnEvent(o,"OnAuthEvent")
17 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent")
18 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent")
19 | o.ContextUpdateEvent=event:new(o,"ContextUpdateEvent")
local ldapParser=require ("suproxy.ldap.parser"):new()
20 | o.c2pParser=ldapParser.C2PParser
21 | o.s2pParser=ldapParser.S2PParser
22 | o.c2pParser.events.SearchRequest:setHandler(o,_M.SearchRequestHandler)
23 | o.c2pParser.events.BindRequest:setHandler(o,_M.BindRequestHandler)
24 | o.c2pParser.events.UnbindRequest:setHandler(o,_M.UnbindRequestHandler)
25 | o.s2pParser.events.BindResponse:setHandler(o,_M.BindResponseHandler)
26 | o.s2pParser.events.SearchResultEntry:setHandler(o,_M.SearchResultEntryHandler)
27 | o.s2pParser.events.SearchResultDone:setHandler(o,_M.SearchResultDoneHandler)
28 | return o
29 | end
30 | ----------------parser event handlers----------------------
31 | function _M:SearchRequestHandler(src,p)
32 | local cstr=cjson.encode{baseObject=p.baseObject,scope=p.scope,filter=p.filter}
33 | if self.CommandEnteredEvent:hasHandler() then
34 | local cmd,err=self.CommandEnteredEvent:trigger(cstr,self.ctx)
35 | if err then p.allBytes=nil return end
36 | if cmd.command~=p.filter then
37 | --todo: modify the filter
38 | end
39 | end
40 | self.command=cstr
41 | end
42 |
43 | function _M:BindRequestHandler(src,p)
44 | local cred
45 | if self.BeforeAuthEvent:hasHandler() then
46 | cred=self.BeforeAuthEvent:trigger({username=p.username,password=p.password},self.ctx)
47 | end
if self.OnAuthEvent:hasHandler() then
local ok,message,cred=self.OnAuthEvent:trigger({username=p.username,password=p.password},self.ctx)
if not ok then
local resp=ldapPackets.BindResponse:new({
messageId=p.messageId,
resultCode=ResultCode.invalidCredentials
}):pack()
self.channel:c2pSend(resp.allBytes)
p.allBytes=nil
return
end
end
48 | if cred and (p.username~=cred.username or p.password~=cred.password) then
49 | p.username=cred.username
50 | p.password=cred.password
51 | p:pack()
52 | end
53 | self.ctx.username=p.username
54 | if self.ContextUpdateEvent:hasHandler() then
55 | self.ContextUpdateEvent:trigger(self.ctx)
56 | end
57 | end
58 |
59 | function _M:BindResponseHandler(src,p)
60 | if p.resultCode==ResultCode.success then
61 | if self.AuthSuccessEvent:hasHandler() then
62 | self.AuthSuccessEvent:trigger(self.ctx.username,self.ctx)
63 | end
64 | else
65 | if self.AuthFailEvent:hasHandler() then
66 | self.AuthFailEvent:trigger({username=self.ctx.username,message="fail code: "..tostring(p.resultCode)},self.ctx)
67 | end
68 | end
69 | end
70 |
71 | function _M:SearchResultEntryHandler(src,p)
72 | self.reply=(self.reply or "")..cjson.encode({p.objectName,p.attributes}).."\r\n"
73 | end
74 |
75 | function _M:SearchResultDoneHandler(src,p)
76 | if self.CommandFinishedEvent:hasHandler() then
77 | self.CommandFinishedEvent:trigger(self.command,self.reply,self.ctx)
78 | end
79 | self.reply=""
80 | end
81 |
82 | function _M:UnbindRequestHandler(src,p)
83 | ngx.exit(0)
84 | end
85 |
86 | function _M:recv(readMethod)
87 | logger.log(logger.DEBUG,"start processRequest")
88 | local lengthdata,err = readMethod(self.channel,2)
89 | if(err) then
90 | logger.log(logger.ERR,"err when reading length")
91 | return nil,err
92 | end
93 | local length=("B"):unpack(lengthdata,2)
94 | local len_len=length-128
95 | local realLengthData=""
96 | if len_len>0 then
97 | realLengthData,err=readMethod(self.channel,len_len)
98 | if(err) then
99 | logger.log(logger.ERR,"err when reading real length")
100 | return nil,err
101 | end
102 | length=(">I"..len_len):unpack(realLengthData)
103 | end
104 | local payloadBytes,err = readMethod(self.channel,length)
105 | local allBytes=lengthdata..realLengthData..payloadBytes
106 | if(err) then
107 | logger.log(logger.ERR,"err when reading packet")
108 | return nil,err
109 | end
110 | return allBytes
111 | end
----------------implement processor methods---------------------
112 | function _M.processUpRequest(self)
113 | local allBytes,err=self:recv(self.channel.c2pRead)
if err then return nil,err end
114 | local p=self.c2pParser:parse(allBytes)
115 | return p.allBytes
116 | end
117 |
118 | function _M.processDownRequest(self)
119 | local allBytes,err=self:recv(self.channel.p2sRead)
if err then return nil,err end
120 | local p=self.s2pParser:parse(allBytes)
121 | return p.allBytes
122 | end
123 |
124 | function _M:sessionInvalid(session)
125 | ngx.exit(0)
126 | end
127 |
128 | return _M;
129 |
--------------------------------------------------------------------------------
/ldap/ldapPackets.lua:
--------------------------------------------------------------------------------
1 | local asn1 = require("suproxy.utils.asn1")
2 | local format = string.format
3 | local ok,cjson=pcall(require,"cjson")
4 | if not ok then cjson = require("suproxy.utils.json") end
5 | local logger=require "suproxy.utils.compatibleLog"
6 | local bunpack = asn1.bunpack
7 | local fmt = string.format
8 | require "suproxy.utils.stringUtils"
9 | require "suproxy.utils.pureluapack"
10 | local event=require "suproxy.utils.event"
11 | local tableUtils=require "suproxy.utils.tableUtils"
12 | local extends=tableUtils.extends
13 | local encoder,decoder=asn1.ASN1Encoder:new(),asn1.ASN1Decoder:new()
14 | local orderTable=tableUtils.OrderedTable
15 | local _M={}
16 |
17 | local APPNO = {
18 | BindRequest=0, BindResponse=1, UnbindRequest=2, SearchRequest = 3,
19 | SearchResultEntry=4,SearchResultDone=5, ModifyRequest=6, ModifyResponse = 7 ,
20 | AddRequest=8, AddResponse=9, DelRequest=10, DelResponse = 11 ,
21 | ModifyDNRequest=12, ModifyDNResponse=13,CompareRequest=14, CompareResponse =15,
22 | AbandonRequest =16, ExtendedRequest=23, ExtendedResponse=24,IntermediateResponse =25
23 | }
24 | _M.APPNO=APPNO
25 |
26 | local ResultCode = {
27 | success = 0, operationsError =1, protocolError =2, timeLimitExceeded =3,
28 | sizeLimitExceeded =4, compareFalse =5, compareTrue =6, authMethodNotSupported =7,
29 | strongerAuthRequired =8, --[[ 9 reserved --]] referral =10, adminLimitExceeded =11,
30 | unavailableCriticalExtension =12, confidentialityRequired =13, saslBindInProgress =14, noSuchAttribute =16,
31 | undefinedAttributeType =17, inappropriateMatching =18, constraintViolation =19, attributeOrValueExists =20,
32 | invalidAttributeSyntax =21, noSuchObject =32, aliasProblem =33, invalidDNSyntax =34,
33 | --[[ 35 reserved for isLeaf --]] aliasDereferencingProblem =36, --[[ 37-47 unused --]] inappropriateAuthentication =48,
34 | invalidCredentials =49, insufficientAccessRights =50, busy =51, unavailable =52,
35 | unwillingToPerform =53, loopDetect =54, --[[ 55-63 unused --]] namingViolation =64,
36 | objectClassViolation =65, notAllowedOnNonLeaf =66, notAllowedOnRDN =67, entryAlreadyExists =68,
37 | objectClassModsProhibited =69, --[[ 70 reserved for CLDAP --]] affectsMultipleDSAs =71, --[[ 72-79 unused --]]
38 | other= 80
39 | }
40 | _M.ResultCode=ResultCode
41 | local function encodeLDAPOp(encoder, appno, isConstructed, data)
42 | local asn1_type = asn1.BERtoInt(asn1.BERCLASS.Application, isConstructed, appno)
43 | return encoder:encode( data,asn1_type)
44 | end
45 | --TODO???: filter type 4 and 9 not processed, switch must be used to replace if
46 | local function decodeFilter(packet,pos)
47 |
48 | local newpos=pos
49 |
50 | local filter="("
51 |
52 | --get filter type
53 | local newpos, tmp = bunpack(packet, "B", newpos)
54 |
55 | local field,condition
56 |
57 | local ftype = asn1.intToBER(tmp).number
58 |
59 | local newpos,flen=decoder.decodeLength(packet,newpos);
60 |
61 | local elenlen=newpos-pos-1
62 |
63 | logger.log(logger.DEBUG,"element:-----------------\n"..string.hex(packet,pos,pos+flen+elenlen,4,8,nil,nil,1,1))
64 |
65 | logger.log(logger.DEBUG,"filter type:"..ftype)
66 |
67 | logger.log(logger.DEBUG,"filter length:"..flen)
68 |
69 | logger.log(logger.DEBUG,"data:-----------------\n"..string.hex(packet,newpos,newpos+flen-1,4,8,nil,nil,1,1))
70 |
71 | --0 and 1 or 2 not
72 | if ftype<3 then
73 | if ftype==0 then
74 | filter=filter.."&"
75 | end
76 | if ftype==1 then
77 | filter=filter.."||"
78 | end
79 | if ftype==2 then
80 | filter=filter.."!"
81 | end
82 | local lp=newpos
83 | while(newpos-lp="
123 |
124 | newpos, condition=decoder:decode(packet,newpos)
125 |
126 | logger.log(logger.DEBUG,"condition:"..condition)
127 |
128 | filter=filter..condition
129 |
130 | logger.log(logger.DEBUG,"filter5:"..filter)
131 | end
132 |
133 | --less or equal
134 | if ftype==6 then
135 |
136 | newpos, field=decoder:decode(packet,newpos);
137 |
138 | logger.log(logger.DEBUG,"field:"..field)
139 |
140 | filter=filter..field
141 |
142 | filter=filter.."<="
143 |
144 | newpos, condition=decoder:decode(packet,newpos)
145 |
146 | logger.log(logger.DEBUG,"condition:"..condition)
147 |
148 | filter=filter..condition
149 |
150 | logger.log(logger.DEBUG,"filter6:"..filter)
151 | end
152 |
153 |
154 | --present
155 | if ftype==7 then
156 | newpos,tmp=bunpack(packet, "c" .. flen, newpos)
157 | filter=filter..tmp.."=*"
158 | logger.log(logger.DEBUG,"filter7:"..filter)
159 | end
160 |
161 | filter=filter..")"
162 | return newpos,filter
163 | end
164 |
165 |
166 | --parse and pack common header from or to bytes
167 | --common headers include :length messageId,opCode
168 | _M.Packet={
169 | desc="base",
170 | parseHeader=function(self,allBytes,pos)
171 | local _,pos = ("B"):unpack(allBytes,pos)
172 | pos,self.length= decoder.decodeLength(allBytes, pos)
173 | pos,self.messageId = decoder:decode(allBytes, pos)
174 | local pos,tmp = bunpack(allBytes, "B", pos)
175 | local pos,l= decoder.decodeLength(allBytes, pos)
176 | self.opCode = asn1.intToBER(tmp).number
177 | return pos
178 | end,
179 |
180 | parse=function(self,allBytes,pos)
181 | local pos=self.parseHeader(self,allBytes,pos)
182 | self.parsePayload(self,allBytes,pos)
183 | self.allBytes=allBytes
184 | return self
185 | end,
186 |
187 | parsePayload=function(self,allBytes,pos)
188 | end,
189 |
190 | pack=function(self)
191 | local payloadBytes=self:packPayload()
192 | local allBytes=encoder:encodeSeq(encoder:encode(self.messageId) .. encodeLDAPOp(encoder, self.opCode,true,payloadBytes))
193 | logger.logWithTitle(logger.DEBUG,"packing",allBytes:hex16F())
194 | self.allBytes=allBytes
195 | return self
196 | end,
197 |
198 | new=function(self,o)
199 | local o=o or {}
200 | return orderTable.new(self,o)
201 | end
202 | }
203 |
204 | _M.BindRequest={
205 | opCode= APPNO.BindRequest,
206 | desc="BindRequest",
207 | parsePayload=function(self,payload,pos)
208 | pos,self.version = decoder:decode(payload,pos)
209 | logger.log(logger.DEBUG,"version:"..self.version )
210 | pos,self.username = decoder:decode(payload,pos)
211 | logger.log(logger.DEBUG,"username:"..self.username )
212 | pos,self.password = decoder:decode(payload,pos)
213 | if self.username=="" then
214 | logger.log(logger.DEBUG,"anonymous login")
215 | elseif self.password=="" then
216 | logger.log(logger.DEBUG,"unauthorized login")
217 | end
218 | return self
219 | end,
220 |
221 | packPayload=function(self)
222 | local payloadBytes=encoder:encode(self.version)..encoder:encode(self.username)..encoder:encode(self.password,"simplePass")
223 | return payloadBytes
224 | end
225 | }
226 | extends(_M.BindRequest,_M.Packet)
227 |
228 | _M.BindResponse={
229 | opCode= APPNO.BindResponse,
230 | desc="BindResponse",
231 | parsePayload=function(self,payload,pos)
232 | pos,self.resultCode=decoder:decode(payload,pos)
233 | return self
234 | end,
235 | packPayload=function(self)
236 | local payloadBytes=encoder:encode(self.resultCode,"enumerated") .. encoder:encode('') .. encoder:encode('')
237 | return payloadBytes
238 | end
239 | }
240 | extends(_M.BindResponse,_M.Packet)
241 |
242 | --UnbindRequest
243 | _M.UnbindRequest=extends({opCode=APPNO.UnbindRequest,desc="UnbindRequest"},_M.Packet)
244 |
245 | _M.SearchRequest={
246 | opCode=APPNO.SearchRequest,
247 | desc="SearchRequest",
248 | parsePayload=function(self,payload,pos)
249 | logger.log(logger.DEBUG,"searchRequest:")
250 |
251 | pos,self.baseObject = decoder:decode(payload,pos)
252 |
253 | logger.log(logger.DEBUG,"baseObject:"..self.baseObject )
254 |
255 | pos,self.scope = decoder:decode(payload,pos)
256 |
257 | logger.log(logger.DEBUG,"scope:"..self.scope )
258 |
259 | pos,self.derefAlias = decoder:decode(payload,pos)
260 |
261 | logger.log(logger.DEBUG,"derefAlias:"..self.derefAlias )
262 |
263 | pos,self.sizeLimit = decoder:decode(payload,pos)
264 |
265 | logger.log(logger.DEBUG,"sizeLimit:"..self.sizeLimit )
266 |
267 | pos,self.timeLimit = decoder:decode(payload,pos)
268 |
269 | logger.log(logger.DEBUG,"timeLimit:"..self.timeLimit )
270 |
271 | pos,self.typesOnly = decoder:decode(payload,pos)
272 |
273 | logger.log(logger.DEBUG,"typesOnly:"..(self.typesOnly and "true" or "false"))
274 |
275 | pos,self.filter=decodeFilter(payload,pos)
276 |
277 | logger.log(logger.DEBUG,"self.filter:"..self.filter)
278 |
279 | pos,self.attributes=decoder:decode(payload,pos)
280 |
281 | logger.log(logger.DEBUG,"self.attributes:\n"..cjson.encode(self.attributes))
282 |
283 | logger.log(logger.DEBUG,"searchRequest finish")
284 | return self
285 | end
286 | }
287 | extends(_M.SearchRequest,_M.Packet)
288 |
289 | _M.SearchResultEntry={
290 | opCode= APPNO.SearchResultEntry,
291 | desc="SearchResponseEntry",
292 | parsePayload=function(self,payload,pos)
293 | pos,self.objectName=decoder:decode(payload,pos)
294 | local pos,attr=decoder:decode(payload,pos)
295 | print(tableUtils.printTableF(attr))
296 | self.attributes={}
297 | for i,v in ipairs(attr) do
298 | local t=v[1]
299 | table.remove(v,1)
300 | table.insert(self.attributes,{attrType=t,values=v})
301 | end
302 | return self
303 | end,
304 | packPayload=function(self)
305 | local resultObjectName = encoder:encode(self.objectName)
306 | local tmp=""
307 | for i,v in ipairs(self.attributes) do
308 | local attrValues=""
309 | for _,val in ipairs(v.values) do
310 | attrValues=attrValues..encoder:encode(val)
311 | end
312 | tmp=tmp..encoder:encodeSeq(encoder:encode(v.attrType)..encoder:encodeSet(attrValues))
313 | end
314 | local resultAttributes = encoder:encodeSeq(tmp)
315 | local payloadBytes = resultObjectName..resultAttributes
316 | return payloadBytes
317 | end,
318 | }
319 | extends(_M.SearchResultEntry,_M.Packet)
320 |
321 | _M.SearchResultDone={
322 | opCode= APPNO.SearchResultDone,
323 | desc="SearchResponseDone",
324 | parsePayload=function(self,payload,pos)
325 | pos,self.resultCode=decoder:decode(payload,pos)
326 | return self
327 | end,
328 | packPayload=function(self)
329 | local payloadBytes=encoder:encode(self.resultCode,"enumerated") .. encoder:encode('') .. encoder:encode('')
330 | return payloadBytes
331 | end
332 | }
333 | extends(_M.SearchResultDone,_M.Packet)
334 |
335 | ---------------------test starts here -----------------------
336 | _M.unitTest={}
337 | function _M.unitTest.testBindRequest()
338 | local bytes=string.fromhex("30840000003402010a60840000002b020103041e636e3d61646d696e2c64633d7777772c64633d746573742c64633d636f6d800661413132332e")
339 | local p=require("suproxy.ldap.parser"):new().C2PParser:parse(bytes)
340 | assert(p.version==3,p.version)
341 | assert(p.username=="cn=admin,dc=www,dc=test,dc=com",p.username)
342 | assert(p.password=="aA123.",p.password)
343 | p.username="cn=admin,dc=www,dc=test,dc=cn"
344 | p.password="Aa123."
345 | p:pack(encoder)
346 | bytes=p.allBytes
347 | local p=require("suproxy.ldap.parser"):new().C2PParser:parse(bytes)
348 | assert(p.username=="cn=admin,dc=www,dc=test,dc=cn",p.username)
349 | assert(p.password=="Aa123.",p.password)
350 | end
351 |
352 | function _M.unitTest.testSearchRequest()
353 | local bytes=string.fromhex("30840000008302010c638400000046041564633d7777772c64633d746573742c64633d636f6d0a01010a010002010002013c010100870b6f626a656374636c61737330840000000d040b6f626a656374636c617373a0840000002e3084000000280416312e322e3834302e3131333535362e312e342e3331390101ff040b3084000000050201640400")
354 | local p=require("suproxy.ldap.parser"):new().C2PParser:parse(bytes)
355 | assert(p.baseObject=="dc=www,dc=test,dc=com",p.baseObject)
356 | assert(p.scope==1,p.scope)
357 | assert(p.derefAlias==0,p.derefAlias)
358 | assert(p.timeLimit==60,p.timeLimit)
359 | assert(p.sizeLimit==0,p.sizeLimit)
360 | assert(p.typesOnly==false,p.typesOnly)
361 | assert(p.filter=="(objectclass=*)",p.filter)
362 | end
363 |
364 |
365 | function _M.unitTest.testSearchResultEntry()
366 | local bytes=string.fromhex("306502016a6460041564633d7777772c64633d746573742c64633d636f6d3047302c040b6f626a656374436c617373311d0403746f70040864634f626a656374040c6f7267616e697a6174696f6e300a04016f31050403646576300b0402646331050403777777")
367 | local p=require("suproxy.ldap.parser"):new().S2PParser:parse(bytes)
368 | assert(p.objectName=="dc=www,dc=test,dc=com",p.objectName)
369 | assert(#(p.attributes)==3,#(p.attributes))
370 | p:pack()
371 | bytes=p.allBytes
372 | local p=require("suproxy.ldap.parser"):new().S2PParser:parse(bytes)
373 | assert(p.objectName=="dc=www,dc=test,dc=com",p.objectName)
374 | assert(#(p.attributes)==3,#(p.attributes))
375 | end
376 |
377 | function _M.test()
378 | for k,v in pairs(_M.unitTest) do
379 | print("------------running "..k)
380 | v()
381 | print("------------"..k.." finished")
382 | end
383 | end
384 |
385 | return _M
--------------------------------------------------------------------------------
/ldap/parser.lua:
--------------------------------------------------------------------------------
1 | local P=require "suproxy.ldap.ldapPackets"
2 | local parser=require("suproxy.parser")
3 | local _M={}
4 |
5 | local conf={
6 | {key=P.APPNO.BindRequest, parser=P.BindRequest, eventName="BindRequest"},
7 | {key=P.APPNO.UnbindRequest, parser=P.UnbindRequest, eventName="UnbindRequest"},
8 | {key=P.APPNO.SearchRequest, parser=P.SearchRequest, eventName="SearchRequest"},
9 | {key=P.APPNO.BindResponse, parser=P.BindResponse, eventName="BindResponse"},
10 | {key=P.APPNO.SearchResultEntry, parser=P.SearchResultEntry,eventName="SearchResultEntry"},
11 | {key=P.APPNO.SearchResultDone, parser=P.SearchResultDone, eventName="SearchResultDone"},
12 | }
13 |
14 | local function keyG(allBytes,pos)
15 | local p=P.Packet:new() p:parseHeader(allBytes,pos) return p.opCode
16 | end
17 |
18 | function _M:new()
19 | local o= setmetatable({},{__index=self})
20 | local C2PParser=parser:new()
21 | C2PParser.keyGenerator=keyG
22 | C2PParser:registerMulti(conf)
23 | C2PParser:registerDefaultParser(P.Packet)
24 | o.C2PParser=C2PParser
25 |
26 | local S2PParser=parser:new()
27 | S2PParser.keyGenerator=keyG
28 | S2PParser:registerMulti(conf)
29 | S2PParser:registerDefaultParser(P.Packet)
30 | o.S2PParser=S2PParser
31 | return o
32 | end
33 | return _M
34 |
35 |
--------------------------------------------------------------------------------
/parser.lua:
--------------------------------------------------------------------------------
1 | require "suproxy.utils.stringUtils"
2 | local event=require "suproxy.utils.event"
3 | local logger=require "suproxy.utils.compatibleLog"
4 | local tableUtils=require "suproxy.utils.tableUtils"
5 | local _M={}
6 | function _M:new()
7 | local o=setmetatable({},{__index=self})
8 | o.events={}
9 | o.parserList={}
10 | o.defaultParseEvent=event:newReturnEvent(nil,"defaultParseEvent")
11 | return o
12 | end
13 |
14 | function _M:register(key,parserName,parser,eventName,e)
15 | assert(not self.parserList[key],string.format("unable to register parser %s, key already registered",key))
16 | if eventName then
17 | assert(not self.events[eventName],string.format("unable to register event %s, event already registered",eventName))
18 | e=e or event:new(nil,eventName)
19 | self.events[eventName]=e
20 | end
21 | self.parserList[key]={parser=parser,parserName=parserName,event=e}
22 | end
23 |
24 | function _M:unregister(key,eventName)
25 | self.parserList[key]=nil
26 | if eventName then
27 | self.events[eventName]=nil
28 | end
29 | end
30 |
31 | function _M:getParser(key)
32 | return self.parserList[key]
33 | end
34 |
35 | function _M:registerMulti(t)
36 | for i,v in ipairs(t) do
37 | local parserName
38 | if v.parser then parserName=v.parserName or v.parser.desc or tostring(v.parser) end
39 | self:register(v.key,parserName,v.parser,v.eventName,v.e)
40 | end
41 | end
42 |
43 | function _M:registerDefaultParser(parser)
44 | assert(parser,"default parser can not be null")
45 | self.defaultParser=parser
46 | end
47 |
48 | function _M.printPacket(packet,allBytes,key,parserName,...)
49 | local args={...}
50 | if not parserName then
51 | logger.logWithTitle(logger.DEBUG,string.format("packet with key %s doesn't have parser",key),(allBytes and allBytes:hex16F() or ""))
52 | else
53 | logger.logWithTitle(logger.DEBUG,string.format("packet with key %s will be parsed by parser %s ",key,parserName or "Unknown"),(allBytes and allBytes:hex16F() or ""))
54 | end
55 | for i,v in ipairs(args) do
56 | logger.log(logger.DEBUG,"\r\noptions"..i..":"..tableUtils.printTableF(v,{inline=true,printIndex=true}))
57 | end
58 | logger.log(logger.DEBUG,"\r\npacket:"..tableUtils.printTableF(packet,{ascii=true,excepts={"allBytes"}}))
59 | end
60 |
61 | --static method to parse all kinds of packets
62 | function _M:parse(allBytes,pos,key,...)
63 | pos=pos or 1
64 | assert(allBytes,"bytes stream can not be null")
65 | if not key then key=self.keyGenerator end
66 | if type(key)=="function" then
67 | key=key(allBytes,pos,...)
68 | end
69 | assert(key,"key can not be null")
70 | local packet={}
71 | packet.allBytes=allBytes
72 | local parser,event,newBytes,parserName
73 | if self.parserList[key] then
74 | parser=self.parserList[key].parser
75 | parserName=self.parserList[key].parserName
76 | if self.parserList[key].event then
77 | event=self.parserList[key].event
78 | end
79 | end
80 | if not parser and self.defaultParser then
81 | parser=self.defaultParser
82 | parserName="Default Parser"
83 | end
84 | event=event or self.defaultParseEvent
85 | local args={...}
86 | local ok=true
87 | local ret
88 | if parser then
89 | ok,ret=xpcall(function() return parser:new(nil,unpack(args)):parse(allBytes,pos,unpack(args))end,function(err) logger.log(logger.ERR,err) logger.log(logger.ERR,debug.traceback()) end,"error when parsing ")
90 | if ok then packet= ret end
91 | end
92 | if logger.getLogLevel().code>=logger.DEBUG.code then
93 | _M.printPacket(packet,allBytes,key,parserName,...)
94 | end
95 | packet.__key=packet.__key or key
96 | if ok and event and event:hasHandler() then
97 | xpcall(function() return event:trigger(packet,allBytes,key,unpack(args)) end,function(err) logger.log(logger.ERR,err) logger.log(logger.ERR,debug.traceback()) end,"error when exe parser handler " )
98 | end
99 | return packet
100 | end
101 | _M.doParse=doParse
102 | return _M
--------------------------------------------------------------------------------
/session/session.lua:
--------------------------------------------------------------------------------
1 | require "suproxy.utils.stringUtils"
2 | require "suproxy.utils.pureluapack"
3 | local _M={}
4 | function _M:new(stype,manager)
5 | assert(manager,"manager can not be null")
6 | local now=ngx.time()
7 | local session={sid=string.random(4):hex(),uptime=now,ctime=now,stype=stype,uid="_SUPROXY_UNKNOWN"}
8 | setmetatable(session,{__index=self})
9 | local sessionMeta={
10 | __index=session,
11 | __newindex=function(t,k,v)
12 | local now=ngx.time()
13 | manager:setProperty(session.sid,k,v,"uptime",now)
14 | t.__data[k]=v
15 | t.__data.uptime=now
16 | end
17 | }
18 | manager:create(session)
19 | local proxy={__data=session,__manager=manager}
20 | return setmetatable(proxy,sessionMeta),nil
21 | end
22 |
23 | function _M:kill()
24 | return self.__manager:kill(self.sid)
25 | end
26 |
27 | function _M:valid()
28 | return self.__manager:valid(self.sid)
29 | end
30 |
31 | function _M.newDoNothing(self)
32 | return {
33 | create=function() return end,
34 | setProperty=function() return 1 end,
35 | valid=function() return true end,
36 | kill=function() return true end
37 | }
38 | end
39 |
40 | return _M
--------------------------------------------------------------------------------
/session/sessionManager.lua:
--------------------------------------------------------------------------------
1 | local cjson=require("cjson")
2 | local logger=require("suproxy.utils.compatibleLog")
local redis = require "resty.redis"
local _M = {}
3 | ---------------required method for implements----------------------
4 | -- options {ip=ip,port=port,sock=sock,timeout=timeout,expire=expire,extend=true}
5 | function _M.new(self,options)
6 | local o=setmetatable({},{__index=self})
7 | local red = redis:new()
8 | local ip=options.ip or "127.0.0.1"
9 | local port=options.port or 6379
10 | local timeout=options.timeout or 5000
11 | o.expire=options.expire or 3600
12 | o.extend=(options.extend==nil) and true or options.extend
13 | local sock=options.sock
14 | red:set_timeout(timeout)
15 | local ok, err
16 | if sock then
17 | ok,err=red:connect("unix:/path/to/redis.sock")
18 | else
19 | ok, err = red:connect(ip, port)
20 | end
21 | if not ok then
22 | logger.log(logger.ERR,"failed to connect: ", err)
23 | return
24 | end
25 | o.redis=red
26 | return o
27 | end
local function getKey(sid) return "gateway_session_"..sid end
function _M:create(session)
assert(session,"session can not be null")
assert(session.sid,"session.sid can not be null")
local sessions=self.redis--ngx.shared.sessions
28 | local k=getKey(session.sid)
29 | local ok,err=sessions:set(k,cjson.encode(session))
30 | if ok and self.expire>=0 then ok,err=sessions:expire(k,self.expire) end
if not ok then return false,err end
return true
31 | end
function _M:setProperty(sid,...)
32 | local args={...}
33 | assert(#args%2==0,"key value count should be even")
34 | local s=self:get(sid)
35 | local result=0
if not s then return result end
36 | for i=1,#args,2 do
s[args[i]]=args[i+1]
local ok,err=self:update(sid,s)
37 | if ok then
38 | result=result+1
39 | else
40 | logger.log(logger.ERR,"failed to update ", args[i]," to" ,args[i+1]," with error message: ",err)
41 | end
end
42 | return result,err
end
function _M:update(sid,session)
assert(session,"session can not be null")
assert(sid,"sid can not be null")
local sessions=self.redis--ngx.shared.sessions
43 | local k=getKey(sid)
local ok,err=sessions:set(k,cjson.encode(session))
44 | if ok then
45 | if self.extend then ok,err=sessions:expire(k,self.expire)
46 | elseif(self.expire>=0) then
47 | local currentTime=ngx.time()
48 | local createTime=session.ctime
49 | local elapsedTime=currentTime-createTime
50 | ok,err=sessions:expire(k,self.expire-elapsedTime)
51 | if not ok then sessions:expire(k,0) end
52 | end
53 | end
return ok,err
end
54 |
55 | function _M:get(sid)
assert(sid,"sid can not be null")
56 | local sessions=self.redis--ngx.shared.sessions
57 | local result,err=sessions:get(getKey(sid))
58 | if err or not result or result==ngx.null then
59 | return nil,err
60 | end
61 | return cjson.decode(result),nil
62 | end
function _M:valid(sid)
assert(sid,"sid can not be null")
local s,err=self:get(sid)
return s and true or false
end
63 |
64 | function _M:kill(sid)
assert(sid,"sid can not be null")
65 | local sessions=self.redis--ngx.shared.sessions
66 | return sessions:del(getKey(sid))
67 | end
68 | ---------------------manage method-----------------------
69 | function _M:getSessionOfUser(uid)
70 | assert(uid,"uid can not be null")
71 | local sessions=self.redis--ngx.shared.sessions
72 | local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys()
73 | if err then return nil,err end
74 | local result={}
75 | for i=1,#keys,1 do
76 | if cjson.decode(sessions:get(keys[i])).ctx.uid==uid then
77 | result[keys[i]]=sessions:get(keys[i])
78 | end
79 | end
80 | return result,nil
81 | end
82 |
83 | function _M:killSessionOfUser(uid)
assert(uid,"uid can not be null")
local sessions=self.redis--ngx.shared.sessions
84 | local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys()
if err then return 0,err end
85 | local result=0
86 | for i=1,#keys,1 do
87 | if cjson.decode(sessions:get(keys[i])).ctx.uid==uid then
88 | sessions:del(keys[i])
89 | result=result+1
90 | end
91 | end
92 | return result,nil
93 | end
94 |
95 | function _M:getAll()
96 | local sessions=self.redis--ngx.shared.sessions
97 | local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys()
local result={}
if err then return result,0,err end
local count=0
98 | for i=1,#keys,1 do
99 | result[keys[i]]=cjson.decode(sessions:get(keys[i]))
count=count+1
100 | end
101 | return result,count,nil
102 | end
function _M:clear()
local sessions=self.redis--ngx.shared.sessions
local keys,err=sessions:keys("gateway_session_*")--sessions:get_keys()
local result=0
if err then return result,err end
for i=1,#keys,1 do
sessions:del(keys[i])
result=result+1
end
return result
end
103 |
104 | return _M
--------------------------------------------------------------------------------
/ssh2/commandCollector.lua:
--------------------------------------------------------------------------------
1 | local ssh2Packet=require "suproxy.ssh2.ssh2Packets"
local sc=require"suproxy.ssh2.shellCommand"
local event=require "suproxy.utils.event"
2 | local logger=require "suproxy.utils.compatibleLog"
3 |
local _M={}
4 |
5 | function _M:new()
6 | local o=setmetatable({}, {__index=self})
o.reply=""
7 | o.commandReply=""
o.welcome=""
8 | o.command=sc:new()
o.firstReply=true
o.BeforeWelcomeEvent=event:newReturnEvent(o,"BeforeWelcomeEvent")
9 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent")
10 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent")
11 | return o
12 | end
13 |
local function removeANSIEscape(str)
14 | return str:gsub(string.char(0x1b).."[%[%]%(][0-9%:%;%<%=%>%?]*".."[@A-Z%[%]%^_`a-z%{%|%}%~]","")
15 | end
16 |
local function removeUnprintableAscii(str)
17 | return str:gsub(".",
18 | function(x)
19 | if (string.byte(x)<=31 or string.byte(x)==127) and (string.byte(x)~=0x0d) then return "" end
20 | end
21 | )
22 | end
23 | function _M:handleDataUp(processor,packet,ctx)
24 | self.waitForWelcome=false
if self.waitForReply then
self.waitForReply=false
25 | if self.commandReply and self.commandReply ~="" then
26 | logger.log(logger.DEBUG,"--------------\r\n",self.commandReply)
27 | local reply=removeANSIEscape(self.commandReply)
28 | self.CommandFinishedEvent:trigger(self.lastCommand,reply,ctx)
29 | end
end
30 | self.reply=""
self.commandReply=""
31 | local channel=packet.channel
32 | local letter=packet.data
33 | logger.log(logger.DEBUG,"-------------letter---------------",letter:hex())
34 | --up down arrow
35 | if letter==string.char(0x1b,0x5b,0x41) or letter==string.char(0x1b,0x5b,0x42) then
36 | self.upArrowClicked=true
37 | self.command:clear()
38 | --ctrl+u
39 | elseif letter==string.char(0x15) then
40 | self.command:removeBefore(nil,all)
41 | --left arrow or ctrl+b
42 | elseif letter==string.char(0x1b,0x5b,0x44) or letter==string.char(2) then
43 | self.command:moveCursor(-1)
44 | --right arrow or ctrl+f
45 | elseif letter==string.char(0x1b,0x5b,0x43) or letter==string.char(6) then
46 | self.command:moveCursor(1)
47 | --home or ctrl+a
48 | elseif letter==string.char(0x1b,0x5b,0x31,0x7e) or letter==string.char(1) then
49 | self.command:home()
50 | --end or ctrl+e
51 | elseif letter==string.char(0x1b,0x5b,0x34,0x7e) or letter==string.char(5) then
52 | self.command:toEnd()
53 | --delete or control+d
54 | elseif letter==string.char(0x1b,0x5b,0x33,0x7e) or letter==string.char(4) then
55 | self.command:removeAfter()
56 | --tab
57 | elseif letter==string.char(0x09) then
58 | self.tabClicked=true
59 | --backspace
60 | elseif letter==string.char(0x7f) or letter==string.char(8) then
61 | self.command:removeBefore()
62 | --ctrl+c
63 | elseif letter==string.char(0x03) then
64 | self.command:clear()
65 | --ctrl+? still needs further process
66 | elseif letter==string.char(0x1f) then
67 | self.tabClicked=true
68 | --enter
69 | elseif letter==string.char(0x0d) then
70 | if(self.command:getLength()>0) then
71 | local cstr=self.command:toString()
self.command:clear()
self.lastCommand=cstr
self.waitForReply=true
72 | local newcmd,err=self.CommandEnteredEvent:trigger(cstr,ctx)
73 | if err then
local toSend=ssh2Packet.ChannelData:new{
channel=256,
data=table.concat{"\r\n",err.message,"\r\n"}
}:pack().allBytes
processor:sendDown(toSend)
--0x05 0x15 move cursor to the end and delete all
74 | packet.data=string.char(5,0x15,0x0d)
75 | packet:pack()
76 | elseif newcmd~=cstr then
77 | --0x05 0x15 for move cursor to the end and delete all
78 | packet.data=string.char(5,0x15)..newcmd.."\n"
79 | packet:pack()
80 | end
81 | end
82 | elseif ((string.byte(letter,1)>31 and string.byte(letter,1)<127)) or string.byte(letter,1)>=128
83 | then
84 | self.command:append(letter)
85 | end
86 | return packet
87 | end
88 |
local function processReply(self,reply)
if not reply then return end
--found OSC command ESC]0; means new prompt should be display
local endPos=reply:find(string.char(0x1b,0x5d,0x30,0x3b))
if not endPos then endPos=reply:find("[.*@.*:.*]?[%s]?%[?.*@.*%]?[\\$#][%s]?") end
if not endPos then endPos=reply:find("mysql>[%s]?") end
if endPos then
self.lastPrompt=reply:sub(endPos)
self.commandReply=self.commandReply..reply:sub(1,endPos-1)
return self.lastPrompt,self.commandReply
else
self.commandReply=self.commandReply..reply
end
end
89 | function _M:handleDataDown(processor,packet,ctx)
90 | local reply=packet.data
91 | --up arrow
92 | if self.upArrowClicked then
93 | --command may have leading 0x08 bytes, trim it
94 | self.command:append(removeUnprintableAscii(removeANSIEscape(reply)))
95 | self.upArrowClicked=false
96 | --tab
97 | elseif self.tabClicked then
98 | self.command:append(removeUnprintableAscii(reply),self.commandPtr)
99 | self.tabClicked=false
100 | --prompt received
101 | elseif self.waitForReply and reply then
102 | processReply(self,reply)
--welcome screen
103 | elseif self.firstReply and self.BeforeWelcomeEvent:hasHandler() then
self.firstReply=false
local welcome,prepend=self.BeforeWelcomeEvent:trigger(ctx)
if welcome then
local prompt,orignalWelcome=processReply(self,reply)
if not prompt then self.waitForWelcome=true end
self.prepend=prepend
local data={
welcome,
prepend and self.commandReply or "",
(not prepend) and (prompt or ">") or ""
}
packet.data=table.concat(data)
packet:pack()
end
elseif self.waitForWelcome then
local prompt,orignalWelcome=processReply(self,reply)
if not self.prepend then
if prompt then
packet.data=prompt
packet:pack()
else
packet.allBytes=nil
end
end
end
104 | return packet
105 | end
106 |
107 |
108 | return _M
--------------------------------------------------------------------------------
/ssh2/parser.lua:
--------------------------------------------------------------------------------
1 | --ssh2.0 protocol parser
local P=require "suproxy.ssh2.ssh2Packets"
local parser=require("suproxy.parser")
local _M={}
local conf={
{key=P.PktType.KeyXInit, parserName="KeyXInit", parser=P.KeyXInit, eventName="KeyXInitEvent"},
{key=P.PktType.DHKeyXInit, parserName="DHKeyXInit", parser=P.DHKeyXInit, eventName="DHKeyXInitEvent"},
{key=P.PktType.DHKeyXReply, parserName="DHKeyXReply", parser=P.DHKeyXReply, eventName="DHKeyXReplyEvent"},
{key=P.PktType.AuthReq, parserName="AuthReq", parser=P.AuthReq, eventName="AuthReqEvent"},
{key=P.PktType.AuthFail, parserName="AuthFail", parser=P.AuthFail, eventName="AuthFailEvent"},
{key=P.PktType.ChannelData, parserName="ChannelData", parser=P.ChannelData, eventName="ChannelDataEvent"},
{key=P.PktType.Disconnect, parserName="Disconnect", parser=P.Disconnect, eventName="DisconnectEvent"},
{key=P.PktType.NewKeys, eventName="NewKeysEvent"},
{key=P.PktType.AuthSuccess, eventName="AuthSuccessEvent"}
}
local keyG=function(allBytes) return allBytes:byte(6) end
function _M:new()
local o= setmetatable({},{__index=self})
local C2PParser=parser:new()
C2PParser.keyGenerator=keyG
C2PParser:registerMulti(conf)
C2PParser:registerDefaultParser(P.Base)
o.C2PParser=C2PParser
local S2PParser=parser:new()
S2PParser.keyGenerator=keyG
S2PParser:registerMulti(conf)
S2PParser:registerDefaultParser(P.Base)
o.S2PParser=S2PParser
return o
end
return _M
--------------------------------------------------------------------------------
/ssh2/shellCommand.lua:
--------------------------------------------------------------------------------
1 | require "suproxy.utils.stringUtils"
2 | require "suproxy.utils.pureluapack"
3 |
local _M = {}
4 | function _M.new(self)
5 | local o={}
6 | o.cursor=0
7 | o.chars={}
8 | return setmetatable(o, {__index=self})
9 | end
10 |
local function rshiftArray(tab,cursor,count)
11 | for i=#tab,cursor,-1 do
12 | tab[i+count]=tab[i]
13 | end
14 | return tab
15 | end
16 |
17 | function _M.append(self,str)
18 | local i=1
19 | while i<=#str do
20 | local c=string.byte(str,i)
21 | self.chars=rshiftArray(self.chars,self.cursor,1)
22 | if c>0xF0 then
23 | self.chars[self.cursor+1]=str:sub(i,i+3)
24 | i=i+4
25 | elseif c>0xE0 then
26 | --unicode
27 | self.chars[self.cursor+1]=str:sub(i,i+2)
28 | i=i+3
29 | elseif c>0xC0 then
30 | self.chars[self.cursor+1]=str:sub(i,i+1)
31 | i=i+2
32 | else
33 | self.chars[self.cursor+1]=str:sub(i,i)
34 | i=i+1
35 | end
36 | self.cursor=self.cursor+1
37 | end
38 | end
39 |
40 | function _M.removeBefore(self,count,all)
41 | if all then count=self.cursor end
42 | if not count then count=1 end
43 | if self.cursor<1 then return end
44 | if self.cursor-count<0 then count= self.cursor end
45 | for i=1, count, -1 do
46 | table.remove(self.chars,self.cursor)
47 | end
48 | self.cursor=self.cursor-count
49 | end
50 |
51 | function _M.clear(self)
52 | self.cursor=0
53 | self.chars={}
54 | end
55 |
56 | function _M.removeAfter(self,count,all)
57 | if all then count= #(self.chars)-self.cursor end
58 | if not count then count=1 end
59 | if self.cursor==#(self.chars) then return end
60 | if self.cursor+count>#(self.chars) then count= #(self.chars)-self.cursor end
61 | for i=1,count,1 do
62 | table.remove(self.chars,self.cursor)
63 | end
64 | end
65 |
66 | function _M.home(self)
67 | self.cursor=0
68 | end
69 |
70 | function _M.toEnd(self)
71 | self.cursor=#(self.chars)
72 | end
73 |
74 | function _M.moveCursor(self,step)
75 | if self.cursor+step>#(self.chars) then
76 | self.cursor=#(self.chars) return
77 | elseif self.cursor+step<0 then
78 | self.cursor=0 return
79 | end
80 | self.cursor=self.cursor+step
81 | end
82 |
83 | function _M.getLength(self)
84 | return #(self.chars)
85 | end
86 |
87 | function _M.get(self,pos)
88 | if pos>#(self.chars) or pos<1 then return nil end
89 | return self.chars[pos]
90 | end
91 |
92 | function _M.toString(self)
93 | -- local result=""
94 | -- for k, v in pairs(self.chars) do
95 | -- result=result..v
96 | -- end
97 | -- return result
98 | return table.concat(self.chars)
99 | end
100 |
101 | function _M.test()
102 | local str="中华A已经Bあまり哈哈哈1234567"
103 | print(str:hex())
104 | local shellCommand=_M
105 | local command=shellCommand:new()
106 | command:append(str)
107 | assert(command:getLength()==19,command:getLength())
108 | assert(command:get(1)=="中",command:get(1))
109 | assert(command:get(19)=="7",command:get(19))
110 | assert(command.cursor==19,command.cursor)
111 | command:moveCursor(-100)
112 | command:moveCursor(100)
113 | command:moveCursor(-2)
114 | assert(command.cursor==17,command.cursor)
115 | command:removeAfter()
116 | command:removeAfter()
117 | command:removeAfter()
118 | command:removeAfter()
119 | command:removeAfter()
120 | command:removeBefore()
121 | assert(command.cursor==16,command.cursor)
122 | assert(command:toString()=="中华A已经Bあまり哈哈哈1234",command:toString():hex())
123 | command:clear()
124 | command:append("34")
125 | command:moveCursor(-2)
126 | command:append("12")
127 | assert(command:toString()=="1234",command:toString())
128 | command:home()
129 | assert(command.cursor==0,command.cursor)
130 | command:toEnd()
131 | assert(command.cursor==command:getLength(),command.cursor)
132 | print(command:toString())
133 | end
134 |
135 | return _M
136 |
137 |
--------------------------------------------------------------------------------
/ssh2/ssh2CipherConf.lua:
--------------------------------------------------------------------------------
1 | local bn = require "resty.openssl.bn"
local rand= require "resty.openssl.rand"
local tableUtils = require "suproxy.utils.tableUtils"
local _M={}
_M.EncAlg={
{name="aes128-ctr",cipherStr="aes-128-ctr"},
{name="aes128-cbc",cipherStr="aes-128-cbc"}
}
function _M.EncAlg:getList()
local rs={}
for i,v in ipairs(self) do
rs[#rs+1]=v.name
end
return table.concat(rs,",")
end
_M.DHAlg={}
--http://ietf.org/rfc/rfc3526.txt
_M.DHAlg[#(_M.DHAlg)+1]={
name="diffie-hellman-group14-sha256",
p=bn.from_binary(string.fromhex([[
FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
15728E5A 8AACAA68 FFFFFFFF FFFFFFFF
]])),
shaAlg="sha256"}
_M.DHAlg[#(_M.DHAlg)+1]={
name="diffie-hellman-group14-sha1",
p=bn.from_binary(string.fromhex([[
FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
15728E5A 8AACAA68 FFFFFFFF FFFFFFFF
]])),
shaAlg="sha1"
}
_M.DHAlg[#(_M.DHAlg)+1]={
name="diffie-hellman-group1-sha1",
p=bn.from_binary(string.fromhex([[
FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE65381
FFFFFFFF FFFFFFFF
]])),
shaAlg="sha1"
}
function _M.DHAlg:getList()
local rs={}
for i,v in ipairs(self) do
rs[#rs+1]=v.name
end
return table.concat(rs,",")
end
2 | --todo should be random bytes
_M.y=bn.from_binary(rand.bytes(1024))
_M.x=bn.from_binary(rand.bytes(1024))
--key pair used for KEX
_M.pubkey=
3 | [[
4 | -----BEGIN RSA PUBLIC KEY-----
5 | MIIBCgKCAQEAyfPdItqWAL0kLjr4C9FJUm1nyRqNePUfAEHZqH+zQDnUmRUnJc/t
6 | YvViQwoBS4O21LbEJJJyA2UQ3LsiCj6l511uTJKjs43jS8uufLamnZkovfnj766V
7 | AQuGLb/LL28kbDNrjEBILG7Z1SjKOMcj8ltt5Jno3hy8QbufK+9nk1AyjvJy2xxg
8 | mAUYOXxI8hYOmIybdL06sKmnqn3CcBjHm5al426f91BgZk0uiaK+8Tq3fVi36fss
9 | o5ZGI3V64zRF+FCE80RvGW3S4ErUm95+SwLRjVav6keCQXYfVHiQ9sacLxjuVve4
10 | /UKjlFztG8+U/ZrIO5GgHEEc8px2s5mqMwIDAQAB
11 | -----END RSA PUBLIC KEY-----
12 | ]]
13 |
_M.privkey=
14 | [[
15 | -----BEGIN RSA PRIVATE KEY-----
16 | MIIEowIBAAKCAQEAyfPdItqWAL0kLjr4C9FJUm1nyRqNePUfAEHZqH+zQDnUmRUn
17 | Jc/tYvViQwoBS4O21LbEJJJyA2UQ3LsiCj6l511uTJKjs43jS8uufLamnZkovfnj
18 | 766VAQuGLb/LL28kbDNrjEBILG7Z1SjKOMcj8ltt5Jno3hy8QbufK+9nk1AyjvJy
19 | 2xxgmAUYOXxI8hYOmIybdL06sKmnqn3CcBjHm5al426f91BgZk0uiaK+8Tq3fVi3
20 | 6fsso5ZGI3V64zRF+FCE80RvGW3S4ErUm95+SwLRjVav6keCQXYfVHiQ9sacLxju
21 | Vve4/UKjlFztG8+U/ZrIO5GgHEEc8px2s5mqMwIDAQABAoIBAFsWEZxhyJxGsuXj
22 | FPOHjrGNxOzQfBSdQkFEch5sknWaX8g34TNNx/0FPi+MeK8Nlk30rRztrFzZnbRg
23 | 9uZ2ATAMVO5WiV031tfd4zI+04FrjhO5fNQjAvO4tek2gzc+wsfGnXBhoevgh4F7
24 | 51GaiB0MndEolf5wKXzgWddgIHAxQ3pgqTqBhCvr0h/U0VxGkntqqEDIzRKohB0D
25 | hd5MXP9hCdTTOud9Kfy/2DKl0a8UWC6N5oyT1EhmGI011Fpc3J+svIbd9fPnRo4B
26 | RoAaiKWezYOja6ruRYZo9+GBtjznV1IGlK9EttMv9W5MbCDyFGU4/MoNmal9pAUz
27 | +HrX/aECgYEA7fG5QWYY6YLCs5UEpwC6jxD5uTXR8LA5JP8VxNLHA4/HCY3nXThC
28 | 800iEWfdgLtdN3H95KslW+E7WLoT7hdONZKKcGq6zBywa4JawUXt8jMravhN17os
29 | 6DTEOHtUE6WETUslJhK0o3232h8wo2dZ99lPiA00Uk8nKUku0CMVULECgYEA2Ub4
30 | iUZKpCHGJ+HjnKINbua36TATBEVY3myr8XBwHDJIAB+LyK+DwJ9QiXJDKnlDGsa8
31 | XEdYaRkUYNIcZgVnTF4s3O1BuGFXBbc7Av+Z85zLtSo/1kDn9YI92Los9APPffdu
32 | UMtbIj9eXqtzVg1DjWcbhZQZAi8uONGEvh+CQiMCgYBIlayFnreKxDDQx2yb5UUD
33 | z5HeReS9H4TPHGFvoTzEgV+eMoOZlEgYIDd8R8ryMjXFbCifUPYciSCpeFoMD1/0
34 | R7ejg2toSHgo06MLwmFLuQBNqWFVpZ19WFtjP3vuYldxnLLAYoRoOzmSeGFF94ki
35 | alAwmJaVZT/1ADYfmBQwgQKBgFRokd0ihZTF2ilcRARxoC5ZS1E37+tU1XVzWkjt
36 | mWAa2IXTu4Y3SUPnoG4FCbrSaRNZ6Ysf3GTX7Wa/uXCY4Mx2OY+KTGHIzvnVeQNt
37 | MO3HGAxFYY9mn7Zs5oHvsc8KO+1/1kdk+P6RB6RXjvL7LCceyz5VjnGeyqIgIyWJ
38 | MB1pAoGBAN2ztGq/yOpNwiCrxeeDpoFe7iQHznaXIgK7yA+uFFUUT58/bl87kq9/
39 | huKG4OgDnZ+iHeSk9CFNseGZHIdcliTDGnud7ipOCFJCzjtlCcM/oHl1hL7oCJUL
40 | GrUlhvtkRKRV96mCDIEGxKCB0xzqjXKAgzTGApCQ9RUKG5EbNSnC
41 | -----END RSA PRIVATE KEY-----
42 | ]]
43 | return _M
--------------------------------------------------------------------------------
/ssh2/ssh2Packets.lua:
--------------------------------------------------------------------------------
1 | --ssh2.0 protocol parser and encoder
--Packet parser follows rfc rfc4251,4252,4253
require "suproxy.utils.stringUtils"
require "suproxy.utils.pureluapack"
local tableUtils=require "suproxy.utils.tableUtils"
local ok,cjson=pcall(require,"cjson")
cjson = ok and cjson or require("suproxy.utils.json")
local extends=tableUtils.extends
local orderTable=tableUtils.OrderedTable
local asn1 = require "suproxy.utils.asn1"
local event=require "suproxy.utils.event"
local logger=require "suproxy.utils.compatibleLog"
local _M={}
--Packet type defines, only the type that have been implemented are listed
_M.PktType={
KeyXInit=0x14, DHKeyXInit=0x1e, DHKeyXReply=0x1f,
AuthReq=0x32, AuthFail=0x33, ChannelData=0x5e,
Disconnect=0x01, NewKeys=0x15, AuthSuccess=0x34,
}
--Tool for mpint format padding (rfc4251 section 5)
local function paddingInt(n)
if(n:byte(1)>=128)then
return string.char(0)..n
end
return n
end
2 |
local function packSSHData(data,padding)
3 | local paddingLength=16-(#data+5)%16
4 | if paddingLength<4 then paddingLength=paddingLength+16 end
5 | local padding=padding or string.random(paddingLength)
6 | return string.pack(">I4B",#data+1+#padding,#padding)..data..padding
7 | end
--Base Packet implements header parser and pack
-- uint32 packet_length
-- byte padding_length
-- byte[n1] payload; n1 = packet_length - padding_length - 1
-- byte[n2] random padding; n2 = padding_length
-- byte[m] mac (Message Authentication Code - MAC); m = mac_length
_M.Base={
parse=function(self,allBytes)
local pos
self.dataLength,self.paddingLength,self.code,pos=string.unpack(">I4BB",allBytes)
self.allBytes=allBytes
self:parsePayload(allBytes,pos)
return self
end,
parsePayload=function(self,allbytes,pos) return self end,
pack=function(self)
self.allBytes=packSSHData(string.char(self.code)..self:packPayload())
logger.logWithTitle(logger.DEBUG,"packing",self.allBytes:hex16F())
return self
end,
packPayload=function(self) return "" end,
new=function(self,o)
local o=o or {}
return orderTable.new(self,o)
end
}
--Key Exchange Init Packet
-- byte SSH_MSG_KEXINIT 0x14
-- byte[16] cookie (random bytes)
-- name-list kex_algorithms
-- name-list server_host_key_algorithms
-- name-list encryption_algorithms_client_to_server
-- name-list encryption_algorithms_server_to_client
-- name-list mac_algorithms_client_to_server
-- name-list mac_algorithms_server_to_client
-- name-list compression_algorithms_client_to_server
-- name-list compression_algorithms_server_to_client
-- name-list languages_client_to_server
-- name-list languages_server_to_client
-- boolean first_kex_packet_follows
-- uint32 0 (reserved for future extension)
_M.KeyXInit={
code=_M.PktType.KeyXInit,
parsePayload=function(self,allBytes,pos)
self.cookie,self.kex_alg,
self.key_alg,self.enc_alg_c2s,
self.enc_alg_s2c,self.mac_alg_c2s,
self.mac_alg_s2c,self.comp_alg_c2s,
self.comp_alg_s2c,self.lan_c2s,
self.lan_s2c,self.kex_follows,
self.reserved=string.unpack(">c16s4s4s4s4s4s4s4s4s4s4BI4",allBytes,pos)
self.payloadBytes=allBytes:sub(6,5+self.dataLength-self.paddingLength-1)
return self
end,
packPayload=function(self)
local rs=string.pack(">c16s4s4s4s4s4s4s4s4s4s4BI4",self.cookie,
self.kex_alg,self.key_alg,
self.enc_alg_c2s,self.enc_alg_s2c,
self.mac_alg_c2s,self.mac_alg_s2c,
self.comp_alg_c2s,self.comp_alg_s2c,
self.lan_c2s,self.lan_s2c,self.kex_follows,self.reserved
)
self.payloadBytes=string.char(self.code)..rs
return rs
end
}
extends(_M.KeyXInit,_M.Base)
8 |
9 |
10 | -- byte SSH_MSG_KEXDH_INIT 0x1e
11 | -- mpint e
12 | _M.DHKeyXInit={
13 | code=_M.PktType.DHKeyXInit,
14 | parsePayload=function(self,payload,pos)
15 | self.e=string.unpack(">s4",payload,pos)
16 | return self
17 | end,
18 | packPayload=function(self)
19 | return string.pack(">s4",paddingInt(self.e))
20 | end
21 | }
extends(_M.DHKeyXInit,_M.Base)
22 |
23 | -- byte SSH_MSG_KEXDH_REPLY 0x1f
24 | -- string server public host key and certificates (K_S)
25 | -- mpint f
26 | -- string signature of H
27 | _M.DHKeyXReply={
28 | code=_M.PktType.DHKeyXReply,
29 | parsePayload=function(self,payload,pos)
local hh
30 | self.K_S,
31 | self.f,
32 | hh=string.unpack(">s4s4s4",payload,pos)
33 | self.key_alg,
34 | self.signH=string.unpack(">s4s4",hh)
35 | return self
36 | end,
37 | packPayload=function(self)
38 | return string.pack(">s4s4s4",self.K_S,paddingInt(self.f),string.pack(">s4s4",self.key_alg,self.signH))
39 | end
40 | }
extends(_M.DHKeyXReply,_M.Base)
41 | --process user authenticate, request format in rfc4252 section 8
42 | -- byte SSH_MSG_USERAUTH_REQUEST 0x32
43 | -- string user name
44 | -- string service name
45 | -- string method
46 | -- below are optional, if method is "none" ,following field woundn't appear
47 | -- boolean FALSE
48 | -- string plaintext password in ISO-10646 UTF-8 encoding [RFC3629]
49 | _M.AuthReq={
50 | code=_M.PktType.AuthReq,
51 | parsePayload=function(self,payload,pos)
52 | local passStartPos
53 | self.username,
54 | self.serviceName,
55 | self.method,passStartPos=string.unpack(">s4s4s4",payload,pos)
56 | if self.method=="password" then
57 | self.password=string.unpack(">s1",payload,passStartPos+4)
58 | end
59 | return self
60 | end,
61 | packPayload=function(self)
62 | local req=string.pack(">s4s4s4",self.username,self.serviceName,self.method)
63 | if self.method=="password" then
64 | req=req..string.pack(">s4s1","",self.password)
65 | end
66 | return req
67 | end
68 | }
extends(_M.AuthReq,_M.Base)
69 |
--SSH_MSG_USERAUTH_FAILURE 0x33
70 | _M.AuthFail={
code=_M.PktType.AuthFail,
parsePayload=function(self,payload,pos)
self.methods,pos=string.unpack(">s4",payload,pos)
if pos < #payload then self.partialSuccess=string.unpack(">I4",payload,pos)>0 end
return self
end,
packPayload=function(self)
return string.pack(">s4I4",self.methods,self.partialSuccess and 0 or 1)
end
}
extends(_M.AuthFail,_M.Base)
71 | -- byte SSH_MSG_CHANNEL_DATA 0x5e
72 | -- uint32 recipient channel
73 | -- string data
74 | _M.ChannelData={
75 | code=_M.PktType.ChannelData,
76 | parsePayload=function(self,payload,pos)
77 | self.channel,self.data=string.unpack(">I4s4",payload,pos)
78 | return self
79 | end,
80 | packPayload=function(self)
81 | return string.pack(">I4s4",self.channel,self.data)
82 | end
83 | }
extends(_M.ChannelData,_M.Base)
84 | -- byte SSH_MSG_DISCONNECT 0x01
85 | -- uint32 reason code
86 | -- string description in ISO-10646 UTF-8 encoding [RFC3629]
87 | -- string language tag [RFC3066]
88 | _M.Disconnect={
89 | code=_M.PktType.Disconnect,
--todo: lang not parsed yet
90 | parsePayload=function(self,payload,pos)
91 | self.reasonCode,self.message=string.unpack(">I4s4",payload,pos)
92 | return self
93 | end,
94 | packPayload=function(self)
95 | return string.pack(">I4s4",self.reasonCode,self.message)
96 | end
97 | }
extends(_M.Disconnect,_M.Base)
98 | return _M
--------------------------------------------------------------------------------
/suproxy-v0.6.0-1.rockspec:
--------------------------------------------------------------------------------
1 | package = "suproxy"
2 | version = "v0.6.0-1"
3 | source = {
4 | url = "git+https://github.com/yizhu2000/suproxy.git",
5 | tag = "v0.6.0"
6 | }
7 | description = {
8 | detailed = "Lua SSH2,LDAP,TNS,TDS proxy and mim library for OpenResty",
9 | homepage = "https://github.com/yizhu2000/suproxy",
10 | license = "BSD"
11 | }
12 | dependencies = {
13 | "lua >= 5.1",
14 | "lua-resty-openssl >= 0.6",
15 | "lua-resty-logger-socket >= v0.1"
16 | }
17 | build = {
18 | type = "builtin",
19 | modules = {
20 | ["suproxy.channel"]="channel.lua",
21 | ["suproxy.ldap"]="ldap.lua",
22 | ["suproxy.parser"]="parser.lua",
23 | ["suproxy.ssh2"]="ssh2.lua",
24 | ["suproxy.tds"]="tds.lua",
25 | ["suproxy.test"]="test.lua",
26 | ["suproxy.tns"]="tns.lua",
27 | ["suproxy.balancer.balancer"]="balancer/balancer.lua",
28 | ["suproxy.example.gateway"]="example/gateway.lua",
29 | ["suproxy.example.session"]="example/session.lua",
30 | ["suproxy.http.handlers"]="http/handlers.lua",
31 | ["suproxy.http.initParam"]="http/initParam.lua",
32 | ["suproxy.http.mockauth"]="http/mockauth.lua",
33 | ["suproxy.http.ssoProcessors"]="http/ssoProcessors.lua",
34 | ["suproxy.ldap.ldapPackets"]="ldap/ldapPackets.lua",
35 | ["suproxy.ldap.parser"]="ldap/parser.lua",
36 | ["suproxy.session.session"]="session/session.lua",
37 | ["suproxy.session.sessionManager"]="session/sessionManager.lua",
38 | ["suproxy.ssh2.commandCollector"]="ssh2/commandCollector.lua",
39 | ["suproxy.ssh2.parser"]="ssh2/parser.lua",
40 | ["suproxy.ssh2.shellCommand"]="ssh2/shellCommand.lua",
41 | ["suproxy.ssh2.ssh2CipherConf"]="ssh2/ssh2CipherConf.lua",
42 | ["suproxy.ssh2.ssh2Packets"]="ssh2/ssh2Packets.lua",
43 | ["suproxy.tds.datetime"]="tds/datetime.lua",
44 | ["suproxy.tds.parser"]="tds/parser.lua",
45 | ["suproxy.tds.tdsPackets"]="tds/tdsPackets.lua",
46 | ["suproxy.tds.token"]="tds/token.lua",
47 | ["suproxy.tds.version"]="tds/version.lua",
48 | ["suproxy.tns.crypt"]="tns/crypt.lua",
49 | ["suproxy.tns.parser"]="tns/parser.lua",
50 | ["suproxy.tns.tnsPackets"]="tns/tnsPackets.lua",
51 | ["suproxy.utils.asn1"]="utils/asn1.lua",
52 | ["suproxy.utils.compatibleLog"]="utils/compatibleLog.lua",
53 | ["suproxy.utils.datetime"]="utils/datetime.lua",
54 | ["suproxy.utils.event"]="utils/event.lua",
55 | ["suproxy.utils.ffi-zlib"]="utils/ffi-zlib.lua",
56 | ["suproxy.utils.json"]="utils/json.lua",
57 | ["suproxy.utils.pureluapack"]="utils/pureluapack.lua",
58 | ["suproxy.utils.stringUtils"]="utils/stringUtils.lua",
59 | ["suproxy.utils.tableUtils"]="utils/tableUtils.lua",
60 | ["suproxy.utils.unicode"]="utils/unicode.lua",
61 | ["suproxy.utils.utils"]="utils/utils.lua"
62 | }
63 | }
--------------------------------------------------------------------------------
/tds.lua:
--------------------------------------------------------------------------------
1 | require "suproxy.utils.stringUtils"
2 | require "suproxy.utils.pureluapack"
3 | local event=require "suproxy.utils.event"
4 | local ok,cjson=pcall(require,"cjson")
5 | if not ok then cjson = require("suproxy.utils.json") end
6 | local logger=require "suproxy.utils.compatibleLog"
7 | local tdsPacket=require "suproxy.tds.tdsPackets"
8 | local tableUtils=require "suproxy.utils.tableUtils"
local _M = {}
9 | _M._PROTOCAL ='tds'
10 |
11 | function _M.new(self,options)
12 | local o= setmetatable({},{__index=self})
13 | options=options or {}
14 | o.disableSSL=true
15 | if options.disableSSL~=nil then o.disableSSL=options.disableSSL end
16 | o.catchReply=false
17 | if options.catchReply~=nil then o.catchReply=options.catchReply end
18 | o.BeforeAuthEvent=event:newReturnEvent(o,"BeforeAuthEvent")
19 | o.OnAuthEvent=event:newReturnEvent(o,"OnAuthEvent")
20 | o.AuthSuccessEvent=event:new(o,"AuthSuccessEvent")
21 | o.AuthFailEvent=event:new(o,"AuthFailEvent")
22 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent")
23 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent")
24 | o.ContextUpdateEvent=event:new(o,"ContextUpdateEvent")
25 | o.ctx={}
26 | local tdsParser=require ("suproxy.tds.parser"):new(o.catchReply)
27 | o.C2PParser=tdsParser.C2PParser
28 | o.S2PParser=tdsParser.S2PParser
29 | o.C2PParser.events.SQLBatch:addHandler(o,_M.SQLBatchHandler)
30 | o.C2PParser.events.Prelogin:addHandler(o,_M.PreloginHandler)
31 | o.C2PParser.events.Login7:addHandler(o,_M.Login7Handler)
32 | o.S2PParser.events.LoginResponse:addHandler(o,_M.LoginResponseHandler)
33 | o.S2PParser.events.SSLLoginResponse:addHandler(o,_M.LoginResponseHandler)
34 | o.S2PParser.events.SQLResponse:addHandler(o,_M.SQLResponseHandler)
35 | return o
36 | end
37 | ----------------parser event handlers----------------------
38 | function _M:SQLBatchHandler(src,p)
39 | if self.CommandEnteredEvent:hasHandler() then
40 | local cmd,err=self.CommandEnteredEvent:trigger(p.sql,self.ctx)
41 | if err then
42 | self.channel:c2pSend(tdsPacket.packErrorResponse(err.message,err.code))
43 | p.allBytes=nil
44 | return
45 | end
46 | end
47 | ngx.ctx.sql=p.sql
48 | end
49 |
50 | function _M:PreloginHandler(src,p)
51 | if p.options.Encryption and self.disableSSL then
52 | p.options.Encryption=2
53 | --self.ctx.serverVer=p.options.Version.versionNumber
54 | p:pack()
55 | end
56 | end
57 |
58 | function _M:Login7Handler(src,p)
59 | local cred
60 | if self.BeforeAuthEvent:hasHandler() then
61 | cred=self.BeforeAuthEvent:trigger({username=p.username,password=p.password},self.ctx)
62 | end
63 | if self.OnAuthEvent:hasHandler() then
64 | local ok,message,cred=self.OnAuthEvent:trigger({username=p.username,password=p.password},self.ctx)
65 | if not ok then
66 | self.channel:c2pSend(tdsPacket.packErrorResponse(message or "login with "..p.username.." failed",18456))
67 | p.allBytes=nil
68 | return
69 | end
70 | end
71 | if cred and (p.username~=cred.username or p.password~=cred.password) then
72 | print(p.username,cred.username,p.password,cred.password)
73 | p.username=cred.username
74 | p.password=cred.password
75 | p:pack()
76 | end
77 | self.ctx.username=p.username
78 | self.ctx.client=p.appName
79 | self.ctx.clientVer=p.ClientProgVer:hex()
80 | self.ctx.libName=p.libName
81 | self.ctx.tdsVer=p.TDSVersion:hex()
82 | if self.ContextUpdateEvent:hasHandler() then
83 | self.ContextUpdateEvent:trigger(self.ctx)
84 | end
85 | end
86 |
87 | function _M:LoginResponseHandler(src,p)
88 | if p.success then
89 | if self.AuthSuccessEvent:hasHandler() then
90 | self.AuthSuccessEvent:trigger(self.ctx.username,self.ctx)
91 | end
92 | self.ctx.serverVer=p.serverVersion.versionNumber
93 | self.ctx.tdsVer=p.TDSVersion:hex()
94 | if self.ContextUpdateEvent:hasHandler() then
95 | self.ContextUpdateEvent:trigger(self.ctx)
96 | end
97 | else
98 | if self.AuthFailEvent:hasHandler() then
99 | self.AuthFailEvent:trigger({username=self.ctx.username,message="["..p.errNo.."]"..p.message},self.ctx)
100 | end
101 | end
102 | end
103 |
104 | function _M:SQLResponseHandler(src,p)
105 | if self.CommandFinishedEvent:hasHandler() then
106 | local reply=p.tostring and p:tostring() or ""
107 | reply=reply
108 | self.CommandFinishedEvent:trigger(ngx.ctx.sql,reply,self.ctx)
109 | end
110 | end
111 |
112 | ----------------implement processor methods---------------------
113 | local function recv(self,readMethod)
114 | local headerBytes,err,partial=readMethod(self.channel,8)
115 | if(err) then
116 | logger.log(logger.ERR,"err when reading header",err)
117 | return partial,err
118 | end
119 | local packet=tdsPacket.Packet:new()
120 | local pos=packet:parseHeader(headerBytes)
121 | local payloadBytes,err,allBytes
122 | if(packet.code==0x17) then
123 | local _,_,_,dataLength=string.unpack(">BBBI2",headerBytes)
124 | payloadBytes,err=readMethod(self.channel,dataLength-3)
125 | allBytes=headerBytes..payloadBytes
126 | else
127 | local dataLength=packet.dataLength
128 | payloadBytes,err=readMethod(self.channel,dataLength-8)
129 | allBytes=headerBytes..payloadBytes
130 | end
131 | return allBytes
132 | end
133 |
134 | function _M.processUpRequest(self)
135 | local readMethod=self.channel.c2pRead
136 | local allBytes,err=recv(self,readMethod)
137 | if err then return nil,err end
138 | local p=self.C2PParser:parse(allBytes)
139 | ngx.ctx.upPacket=p.code
140 | return p.allBytes
141 | end
142 |
143 | function _M.processDownRequest(self)
144 | local readMethod=self.channel.p2sRead
145 | local allBytes,err=recv(self,readMethod)
146 | if err then return nil,err end
147 | local p =self.S2PParser:parse(allBytes,nil,ngx.ctx.upPacket)
148 | return p.allBytes
149 | end
150 |
151 | function _M:sessionInvalid(session)
152 | self.channel:c2pSend(tdsPacket.packErrorResponse("you are not allowed to connect, please contact the admin"))
153 | ngx.exit(0)
154 | end
155 |
156 | return _M
157 |
--------------------------------------------------------------------------------
/tds/datetime.lua:
--------------------------------------------------------------------------------
1 |
2 | --[[
3 | 参数说明:
4 | srcDateTime 原始时间字符串,要求格式%Y%m%d%H%M%S,这个时间格式字符串表示4位年份、月份、day、小时、分钟、秒都是2位数字
5 | interval 对该时间进行加或减具体值,>0表示加 <0表示减
6 | dateUnit 时间单位,支持DAY、HOUR、SECOND、MINUTE 4种时间单位操作,根据interval具体值对原始时间按指定的单位进行加或减
7 | 例如,
8 | interval=10,unit='DAY',表示对原始时间加10天
9 | interval=-1,unit='HOUR',表示对原始时间减1小时
10 |
11 | 返回结果是一个os.date,他是一个table结构,里面包含了year,month,day,hour,minute,second 6个属性,跟据需要从结果里面取出需要的属性然后根据需要产生相应的新的日期格式即可。
12 | ]]
13 |
14 | _M={}
15 |
16 | function _M.getNewDate(srcDateTime,interval ,dateUnit)
17 | --从日期字符串中截取出年月日时分秒
18 | local Y = string.sub(srcDateTime,1,4)
19 | local M = string.sub(srcDateTime,5,6)
20 | local D = string.sub(srcDateTime,7,8)
21 | local H = string.sub(srcDateTime,9,10)
22 | local MM = string.sub(srcDateTime,11,12)
23 | local SS = string.sub(srcDateTime,13,14)
24 |
25 | --把日期时间字符串转换成对应的日期时间
26 | local dt1 = os.time{year=Y, month=M, day=D, hour=H,min=MM,sec=SS}
27 |
28 | --根据时间单位和偏移量得到具体的偏移数据
29 | local ofset=0
30 |
31 | if dateUnit =='DAY' then
32 | ofset = 60 *60 * 24 * interval
33 |
34 | elseif dateUnit == 'HOUR' then
35 | ofset = 60 *60 * interval
36 |
37 | elseif dateUnit == 'MINUTE' then
38 | ofset = 60 * interval
39 |
40 | elseif dateUnit == 'SECOND' then
41 | ofset = interval
42 | end
43 |
44 | --指定的时间+时间偏移量
45 | local newTime = os.date("*t", dt1 + tonumber(ofset))
46 | return newTime
47 | end
48 |
--------------------------------------------------------------------------------
/tds/parser.lua:
--------------------------------------------------------------------------------
1 | --tds protocol parser and encoder
local P=require "suproxy.tds.tdsPackets"
local parser=require("suproxy.parser")
local _M={}
----------------------build parser---------------------------
--config c2p parsers
local c2pConf={
{key=P.Login7.code, parser=P.Login7, eventName="Login7"},
{key=P.SQLBatch.code, parser=P.SQLBatch, eventName="SQLBatch"},
{key=P.Prelogin.code, parser=P.Prelogin, eventName="Prelogin"},
}
--config s2p parsers
local s2pConf={
{key=P.Prelogin.code, parser=P.PreloginResponse, eventName="PreloginResponse"},
{key=P.Login7.code, parser=P.LoginResponse, eventName="LoginResponse"},
--ssl login response
{key=0x17, parser=P.LoginResponse, eventName="SSLLoginResponse"},
{key=P.SQLBatch.code, parser=P.SQLResponse, eventName="SQLResponse"}
}
function _M:new(catchReply)
local o= setmetatable({},{__index=self})
local C2PParser=parser:new()
C2PParser.keyGenerator=function(allBytes) return allBytes:byte(1) end
C2PParser:registerMulti(c2pConf)
C2PParser:registerDefaultParser(P.Packet)
o.C2PParser=C2PParser
local S2PParser=parser:new()
S2PParser:registerMulti(s2pConf)
if not catchReply then
S2PParser:unregister(P.SQLBatch.code,"SQLResponse")
S2PParser:register(P.SQLBatch.code,nil,nil,"SQLResponse")
end
S2PParser:registerDefaultParser(P.Packet)
o.S2PParser=S2PParser
return o
end
return _M
--------------------------------------------------------------------------------
/tds/version.lua:
--------------------------------------------------------------------------------
1 | local _M={}
2 | --- SqlServerVersionInfo class
3 |
4 | _M={
5 | versionNumber = "", -- The full version string (e.g. "9.00.2047.00")
6 | major = nil, -- The major version (e.g. 9)
7 | minor = nil, -- The minor version (e.g. 0)
8 | build = nil, -- The build number (e.g. 2047)
9 | subBuild = nil, -- The sub-build number (e.g. 0)
10 | productName = nil, -- The product name (e.g. "SQL Server 2005")
11 | brandedVersion = nil, -- The branded version of the product (e.g. "2005")
12 | servicePackLevel = nil, -- The service pack level (e.g. "SP1")
13 | patched = nil, -- Whether patches have been applied since SP installation (true/false/nil)
14 | source = nil, -- The source of the version info (e.g. "SSRP", "SSNetLib")
15 |
16 | new = function(self,o)
17 | o = o or {}
18 | setmetatable(o, self)
19 | self.__index = self
20 | return o
21 | end,
22 |
23 | --- Sets the version using a version number string.
24 | --
25 | -- @param versionNumber a version number string (e.g. "9.00.1399.00")
26 | -- @param source a string indicating the source of the version info (e.g. "SSRP", "SSNetLib")
27 | SetVersionNumber = function(self, versionNumber, source)
28 | local major, minor, revision, subBuild
29 | if versionNumber:match( "^%d+%.%d+%.%d+.%d+" ) then
30 | major, minor, revision, subBuild = versionNumber:match( "^(%d+)%.(%d+)%.(%d+)" )
31 | elseif versionNumber:match( "^%d+%.%d+%.%d+" ) then
32 | major, minor, revision = versionNumber:match( "^(%d+)%.(%d+)%.(%d+)" )
33 | else
34 | print("%s: SetVersionNumber: versionNumber is not in correct format: %s", "MSSQL", versionNumber or "nil" )
35 | end
36 |
37 | self:SetVersion( major, minor, revision, subBuild, source )
38 | end,
39 |
40 | --- Sets the version using the individual numeric components of the version
41 | -- number.
42 | --
43 | -- @param source a string indicating the source of the version info (e.g. "SSRP", "SSNetLib")
44 | SetVersion = function(self, major, minor, build, subBuild, source)
45 | self.source = source
46 | -- make sure our version numbers all end up as valid numbers
47 | self.major, self.minor, self.build, self.subBuild =
48 | tonumber( major or 0 ), tonumber( minor or 0 ), tonumber( build or 0 ), tonumber( subBuild or 0 )
49 |
50 | self.versionNumber = string.format( "%u.%02u.%u.%02u", self.major, self.minor, self.build, self.subBuild )
51 |
52 | self:_ParseVersionInfo()
53 | end,
54 |
55 | --- Using the version number, determines the product version
56 | _InferProductVersion = function(self)
57 |
58 | local VERSION_LOOKUP_TABLE = {
59 | ["^6%.0"] = "6.0", ["^6%.5"] = "6.5", ["^7%.0"] = "7.0",
60 | ["^8%.0"] = "2000", ["^9%.0"] = "2005", ["^10%.0"] = "2008",
61 | ["^10%.50"] = "2008 R2", ["^11%.0"] = "2012", ["^12%.0"] = "2014",
62 | ["^13%.0"] = "2016", ["^14%.0"] = "2017", ["^15%.0"] = "2019"
63 | }
64 |
65 | local product = ""
66 |
67 | for m, v in pairs(VERSION_LOOKUP_TABLE) do
68 | if ( self.versionNumber:match(m) ) then
69 | product = v
70 | self.brandedVersion = product
71 | break
72 | end
73 | end
74 |
75 | self.productName = ("Microsoft SQL Server %s"):format(product)
76 |
77 | end,
78 |
79 |
80 | --- Returns a lookup table that maps revision numbers to service pack levels for
81 | -- the applicable SQL Server version (e.g. { {1600, "RTM"}, {2531, "SP1"} }).
82 | _GetSpLookupTable = function(self)
83 |
84 | -- Service pack lookup tables:
85 | -- For instances where a revised service pack was released (e.g. 2000 SP3a), we will include the
86 | -- build number for the original SP and the build number for the revision. However, leaving it
87 | -- like this would make it appear that subsequent builds were a patched version of the revision
88 | -- (e.g. a patch applied to 2000 SP3 that increased the build number to 780 would get displayed
89 | -- as "SP3a+", when it was actually SP3+). To avoid this, we will include an additional fake build
90 | -- number that combines the two.
91 | local SP_LOOKUP_TABLE = {
92 | ["6.5"] = {
93 | {201, "RTM"},
94 | {213, "SP1"},
95 | {240, "SP2"},
96 | {258, "SP3"},
97 | {281, "SP4"},
98 | {415, "SP5"},
99 | {416, "SP5a"},
100 | {417, "SP5/SP5a"},
101 | },
102 |
103 | ["7.0"] = {
104 | {623, "RTM"},
105 | {699, "SP1"},
106 | {842, "SP2"},
107 | {961, "SP3"},
108 | {1063, "SP4"},
109 | },
110 |
111 | ["2000"] = {
112 | {194, "RTM"},
113 | {384, "SP1"},
114 | {532, "SP2"},
115 | {534, "SP2"},
116 | {760, "SP3"},
117 | {766, "SP3a"},
118 | {767, "SP3/SP3a"},
119 | {2039, "SP4"},
120 | },
121 |
122 | ["2005"] = {
123 | {1399, "RTM"},
124 | {2047, "SP1"},
125 | {3042, "SP2"},
126 | {4035, "SP3"},
127 | {5000, "SP4"},
128 | },
129 |
130 | ["2008"] = {
131 | {1600, "RTM"},
132 | {2531, "SP1"},
133 | {4000, "SP2"},
134 | {5500, "SP3"},
135 | {6000, "SP4"},
136 | },
137 |
138 | ["2008 R2"] = {
139 | {1600, "RTM"},
140 | {2500, "SP1"},
141 | {4000, "SP2"},
142 | {6000, "SP3"},
143 | },
144 |
145 | ["2012"] = {
146 | {2100, "RTM"},
147 | {3000, "SP1"},
148 | {5058, "SP2"},
149 | {6020, "SP3"},
150 | {7001, "SP4"},
151 | },
152 |
153 | ["2014"] = {
154 | {2000, "RTM"},
155 | {4100, "SP1"},
156 | {5000, "SP2"},
157 | {6024, "SP3"},
158 | },
159 |
160 | ["2016"] = {
161 | {1601, "RTM"},
162 | {4001, "SP1"},
163 | {5026, "SP2"},
164 | },
165 |
166 | ["2017"] = {
167 | {1000, "RTM"},
168 | {3257, "CU18"},
169 | },
170 |
171 | ["2019"] = {
172 | {2000, "RTM"},
173 | },
174 | }
175 |
176 |
177 | if ( not self.brandedVersion ) then
178 | self:_InferProductVersion()
179 | end
180 |
181 | local spLookupTable = SP_LOOKUP_TABLE[self.brandedVersion]
182 | print("brandedVersion: %s, #lookup: %d", self.brandedVersion, spLookupTable and #spLookupTable or 0)
183 |
184 | return spLookupTable
185 |
186 | end,
187 |
188 |
189 | --- Processes version data to determine (if possible) the product version,
190 | -- service pack level and patch status.
191 | _ParseVersionInfo = function(self)
192 |
193 | local spLookupTable = self:_GetSpLookupTable()
194 |
195 | if spLookupTable then
196 |
197 | local spLookupItr = 0
198 | -- Loop through the service pack levels until we find one whose revision
199 | -- number is the same as or lower than our revision number.
200 | while spLookupItr < #spLookupTable do
201 | spLookupItr = spLookupItr + 1
202 |
203 | if (spLookupTable[ spLookupItr ][1] == self.build ) then
204 | spLookupItr = spLookupItr
205 | break
206 | elseif (spLookupTable[ spLookupItr ][1] > self.build ) then
207 | -- The target revision number is lower than the first release
208 | if spLookupItr == 1 then
209 | self.servicePackLevel = "Pre-RTM"
210 | else
211 | -- we went too far - it's the previous SP, but with patches applied
212 | spLookupItr = spLookupItr - 1
213 | end
214 | break
215 | end
216 | end
217 |
218 | -- Now that we've identified the proper service pack level:
219 | if self.servicePackLevel ~= "Pre-RTM" then
220 | self.servicePackLevel = spLookupTable[ spLookupItr ][2]
221 |
222 | if ( spLookupTable[ spLookupItr ][1] == self.build ) then
223 | self.patched = false
224 | else
225 | self.patched = true
226 | end
227 | end
228 |
229 | -- Clean up some of our inferences. If the source of our revision number
230 | -- was the SSRP (SQL Server Browser) response, we need to recognize its
231 | -- limitations:
232 | -- * Versions of SQL Server prior to 2005 are reported with the RTM build
233 | -- number, regardless of the actual version (e.g. SQL Server 2000 is
234 | -- always 8.00.194).
235 | -- * Versions of SQL Server starting with 2005 (and going through at least
236 | -- 2008) do better but are still only reported with the build number as
237 | -- of the last service pack (e.g. SQL Server 2005 SP3 with patches is
238 | -- still reported as 9.00.4035.00).
239 | if ( self.source == "SSRP" ) then
240 | self.patched = nil
241 |
242 | if ( self.major <= 8 ) then
243 | self.servicePackLevel = nil
244 | end
245 | end
246 | end
247 |
248 | return true
249 | end,
250 |
251 | ---
252 | ToString = function(self)
253 | local rs = {}
254 | if self.productName then
255 | rs[#rs+1]= self.productName
256 | if self.servicePackLevel then
257 | rs[#rs+1]= " "
258 | rs[#rs+1]= self.servicePackLevel
259 | end
260 | if self.patched then
261 | rs[#rs+1]= "+"
262 | end
263 | end
264 |
265 | return table.concat(rs)
266 | end,
267 |
268 |
269 | }
270 |
271 | return _M
--------------------------------------------------------------------------------
/test.lua:
--------------------------------------------------------------------------------
1 | --function scandir(directory)
2 | -- local i, t, popen = 0, {}, io.popen
3 | -- local pfile = popen('dir "'..directory..'*.lua" /b')
4 | -- for filename in pfile:lines() do
5 | -- i = i + 1
6 | -- t[i] = filename
7 | -- end
8 | -- pfile:close()
9 | -- return t
10 | --end
11 | --local currentDir=debug.getinfo(1).source:sub(1,#debug.getinfo(1).source-8)
12 | --print(currentDir)
13 | --local files=scandir("C:\\env\\openresty-1.15.8.3-win64\\lualib\\gateway\\")
14 | --for i,v in ipairs (files) do
15 | -- print(v)
16 | -- local mod=require("suproxy."..v:sub(1,#v-4))
17 | -- if mod.test then
18 | -- mod.test()
19 | -- end
20 | --end
21 | print("start testing")
22 | local m=require "suproxy.utils.compatibleLog"
23 | m.test()
24 | m=require "suproxy.utils.datetime"
25 | m.test()
26 | m=require "suproxy.utils.event"
27 | m.test()
28 | m=require "suproxy.utils.pureluapack"
29 | m.test()
30 | m=require "suproxy.utils.tableUtils"
31 | m.test()
32 | m=require "suproxy.utils.unicode"
33 | m.test()
34 | m=require "suproxy.tns.tnsPackets"
35 | m.test()
36 | m=require "suproxy.tds.tdsPackets"
37 | m.test()
38 | m=require "suproxy.ssh2.shellCommand"
39 | m.test()
40 | m=require "suproxy.ldap.ldapPackets"
41 | m.test()
42 | m=require "suproxy.balancer.balancer"
43 | m.test()
44 |
45 | print("All test finished without error")
--------------------------------------------------------------------------------
/tns.lua:
--------------------------------------------------------------------------------
1 | require "suproxy.utils.stringUtils"
2 | require "suproxy.utils.pureluapack"
3 | local event=require "suproxy.utils.event"
4 | local logger=require "suproxy.utils.compatibleLog"
5 | local tnsPackets=require "suproxy.tns.tnsPackets"
6 | local tableUtils=require "suproxy.utils.tableUtils"
7 | local crypt= require "suproxy.tns.crypt"
local _M = {}
8 | _M._PROTOCAL ='tns'
9 |
10 | function _M.new(self,options)
11 | options=options or {}
12 | local o= setmetatable({},{__index=self})
13 | o.AuthSuccessEvent=event:new(o,"AuthSuccessEvent")
14 | o.AuthFailEvent=event:new(o,"AuthFailEvent")
15 | o.BeforeAuthEvent=event:newReturnEvent(o,"BeforeAuthEvent")
16 | o.OnAuthEvent=event:newReturnEvent(o,"OnAuthEvent")
17 | o.CommandEnteredEvent=event:newReturnEvent(o,"CommandEnteredEvent")
18 | o.CommandFinishedEvent=event:new(o,"CommandFinishedEvent")
19 | o.ContextUpdateEvent=event:new(o,"ContextUpdateEvent")
20 | o.options=tnsPackets.Options:new()
21 | o.options.oracleVersion.major=options.oracleVersion or o.options.oracleVersion.major
22 | o.swapPass=options.swapPass or false
23 | o.ctx={}
24 | local tnsParser=require ("suproxy.tns.parser"):new()
25 | o.C2PParser=tnsParser.C2PParser
26 | o.C2PParser.events.ConnectEvent:setHandler(o,_M.ConnectHandler)
27 | o.C2PParser.events.AuthRequestEvent:setHandler(o,_M.AuthRequestHandler)
28 | o.C2PParser.events.SessionRequestEvent:setHandler(o,_M.SessionRequestHandler)
29 | o.C2PParser.events.SetProtocolEvent:setHandler(o,_M.SetProtocolRequestHandler)
30 | o.C2PParser.events.SQLRequestEvent:setHandler(o,_M.SQLRequestHandler)
31 | o.C2PParser.events.Piggyback1169:setHandler(o,_M.PiggbackHandler)
32 | o.C2PParser.events.Piggyback116b:setHandler(o,_M.PiggbackHandler)
33 | o.C2PParser.events.MarkerEvent:setHandler(o,_M.MarkerHandler)
34 | o.S2PParser=tnsParser.S2PParser
35 | o.S2PParser.events.SessionResponseEvent:setHandler(o,_M.SessionResponseHandler)
36 | o.S2PParser.events.VersionResponseEvent:setHandler(o,_M.VersionResponseHandler)
37 | o.S2PParser.events.SetProtocolEvent:setHandler(o,_M.SetProtocolResponseHandler)
38 | o.S2PParser.events.AcceptEvent:setHandler(o,_M.AcceptHandler)
39 | o.S2PParser.events.AuthErrorEvent:setHandler(o,_M.AuthErrorHandler)
40 | return o
41 | end
42 |
43 | ----------------parser event handlers----------------------
44 | function _M:ConnectHandler(src,p)
45 | p:setTnsVersion(314)
46 | self.ctx.connStr=p:getConnStr()
47 | p:pack()
48 | end
49 |
50 | function _M:AcceptHandler(src,p)
51 | self.options.tnsVersion=p:getTnsVersion()
52 | self.options.headerCheckSum=p:checkHeader()
53 | self.options.packetCheckSum=p:checkPacket()
54 | self.ctx.tnsVer=p:getTnsVersion()
55 | end
56 |
57 | function _M:AuthRequestHandler(src,p)
58 | local ckey=p:getAuthKey()
59 | local skey=self.serverKey
60 | local tmpKey=self.tmpKey
61 | local salt=self.salt
62 | local pass=p:getPassword()
63 | local username=p:getUsername()
64 | if username ~= self.username then
65 | p:setUsername(self.username)
66 | end
67 | local passChanged=false
68 | if self.swapPass and self.tempPass ~= self.realPass then
69 | passChanged=true
70 | local ck
71 | local realpass=self.realPass
72 | if self.options.oracleVersion.major==11 then
73 | ck,pass=crypt:Decrypt11g(ckey,tmpKey,pass,self.tempPass,salt)
74 | elseif self.options.oracleVersion.major==10 then
75 | --ck,pass=crypt:Decrypt10g(ckey,tmpKey,pass,self.tempPass,salt)
76 | end
77 | if self.options.oracleVersion.major==11 or self.options.oracleVersion.major==10 then
78 | ckey,pass=crypt:Encrypt11g(realpass,ck,skey,salt)
79 | p:setAuthKey(ckey)
80 | p:setPassword(pass)
81 | end
82 | end
83 | if self.OnAuthEvent:hasHandler() then
84 | local ok,message=self.OnAuthEvent:trigger({username=username,password=self.tempPass},self.ctx)
85 | if not ok then
86 | --return marker1
87 | self.channel:c2pSend(tnsPackets.Marker1:pack())
88 | --return marker2
89 | self.channel:c2pSend(tnsPackets.Marker2:pack())
90 | self.responseError=true
91 | p.allBytes=nil
92 | return
93 | end
94 | end
95 | if username ~= self.username or passChanged then
96 | p:pack()
97 | end
98 | end
99 |
100 | function _M:SessionRequestHandler(src,p)
101 | self.options.program=p:getProgram()
102 | self.options.is64Bit=p:is64Bit()
103 | self.ctx.client=p:getProgram()
104 | self.username=p:getUsername()
105 | if self.BeforeAuthEvent:hasHandler() then
106 | local cred=self.BeforeAuthEvent:trigger({username=p:getUsername()},self.ctx)
107 | self.tempPass=cred.temppass
108 | self.realPass=cred.password
109 | self.username=cred.username
110 | self.tempUsername=p:getUsername()
111 | end
112 | self.ctx.username=self.username
113 | if self.ContextUpdateEvent:hasHandler() then
114 | self.ContextUpdateEvent:trigger(self.ctx)
115 | end
116 | p:setUsername(self.username)
117 | p:pack()
118 | end
119 |
120 | function _M:SetProtocolRequestHandler(src,p)
121 | self.options.platform=p:getClientPlatform()
122 | self.ctx.clientPlatform=p:getClientPlatform()
123 | end
124 |
125 | function _M:SetProtocolResponseHandler(src,p)
126 | self.options.srvPlatform=p:getClientPlatform()
127 | self.ctx.srvPlatform=p:getClientPlatform()
128 | end
129 |
130 | function _M:VersionResponseHandler(src,p)
131 | --todo find a better chance to trigger authSuccess
132 | if self.AuthSuccessEvent:hasHandler() then
133 | self.AuthSuccessEvent:trigger(self.ctx.username,self.ctx)
134 | end
135 | self.options.oracleVersion.major=p:getMajor()
136 | self.options.oracleVersion.minor=p:getMinor()
137 | self.options.oracleVersion.build=p:getBuild()
138 | self.options.oracleVersion.subbuild=p:getSub()
139 | self.options.oracleVersion.fix=p:getFix()
140 | self.ctx.serverVer=p:getVersion()
141 | if self.ContextUpdateEvent:hasHandler() then
142 | self.ContextUpdateEvent:trigger(self.ctx)
143 | end
144 | end
145 |
146 | function _M:SessionResponseHandler(src,p)
147 | --if temp pass equals real pass then do nothing
148 | if self.tempPass==self.realPass or not self.swapPass then return end
149 | if self.options.oracleVersion.major==11 or self.options.oracleVersion.major==10 then
150 | self.serverKey=p:getAuthKey()
151 | self.salt=p:getSalt()
152 | local tmpKey=crypt:getServerKey(self.tempPass,self.realPass,self.serverKey,self.salt)
153 | self.tmpKey=tmpKey
154 | p:setAuthKey(tmpKey)
155 | p:pack()
156 | end
157 | end
158 |
159 | function _M:PiggbackHandler(src,p)
160 | if not p.__key then return end
161 | local entry=self.C2PParser.parserList[p.__key]
162 | if entry and entry.event and entry.event:hasHandler() then
163 | entry.event:trigger(p)
164 | end
165 | end
166 |
167 | function _M:MarkerHandler(src,p)
168 | --process req marker, if flag true then return error
169 | if self.responseError then
170 | self.channel:c2pSend(tnsPackets.NoPermissionError:new(self.options):pack())
171 | self.responseError=false
172 | end
173 | if self.sessionStop then
174 | self.channel:c2pSend(tnsPackets.NoPermissionError:new(self.options):pack())
175 | ngx.exit(0)
176 | return
177 | end
178 | end
179 |
180 | function _M:AuthErrorHandler(src,p)
181 | if self.AuthFailEvent:hasHandler() then
182 | self.AuthFailEvent:trigger({username=self.username},self.ctx)
183 | end
184 | end
185 |
186 | function _M:SQLRequestHandler(src,p)
187 | local command=p:getCommand()
188 | --if command=="altersession" then replace end
189 | if command and command:len()>0 then
190 | local allBytes
191 | if self.CommandEnteredEvent:hasHandler() then
192 | local cmd,err=self.CommandEnteredEvent:trigger(command,self.ctx)
193 | if err then
194 | --set a flag indicate error happen
195 | self.responseError=true
196 | --return marker1
197 | self.channel:c2pSend(tnsPackets.Marker1:pack())
198 | --return marker2
199 | self.channel:c2pSend(tnsPackets.Marker2:pack())
200 | p.allBytes=nil
201 | return
202 | end
203 | if cmd and cmd~=command then
204 | command=cmd
205 | p:setCommand(cmd)
206 | p:pack()
207 | end
208 | end
209 | --if username was changed during login, alter session sql sent by client should be update to real username
210 | if self.tempUsername ~= self.username then
211 | if command:match("ALTER SESSION SET CURRENT_SCHEMA") then
212 | command=command:gsub("%= .*","%= "..self.username:literalize())
213 | p:setCommand(command)
214 | p:pack()
215 | end
216 | end
217 |
218 | if self.CommandFinishedEvent:hasHandler() then
219 | self.CommandFinishedEvent:trigger(command,"",self.ctx)
220 | end
221 | end
222 | end
223 |
224 | -------------implement processor methods---------------
225 | function _M.processUpRequest(self)
226 | local readMethod=self.channel.c2pRead
227 | local allBytes,err=self:recv(readMethod)
228 | if err then return nil,err end
229 | local p=self.C2PParser:parse(allBytes,nil,nil,self.options)
230 | self.request=p.__key
231 | return p.allBytes
232 | end
233 |
234 | function _M.processDownRequest(self)
235 | local readMethod=self.channel.p2sRead
236 | local allBytes,err= self:recv(readMethod)
237 | if err then return nil,err end
238 | --if request is oci function call , then set options and wait for respond
239 | local ociCall
240 | if self.request then
241 | ociCall= self.request:match("callId") and self.request or nil
242 | end
243 | local p=self.S2PParser:parse(allBytes,nil,nil,self.options,ociCall)
244 | return p.allBytes
245 | end
246 |
247 | function _M:recv(readMethod)
248 | local lengthdata,err=readMethod(self.channel,2)
249 | if(err) then
250 | logger.log(logger.ERR,"err when reading length")
251 | return nil,err
252 | end
253 | local pktLen=string.unpack(">I2",lengthdata)
254 | local data,err=readMethod(self.channel,pktLen-2)
255 | if(err) then
256 | logger.log(logger.ERR,"err when reading packet")
257 | return nil,err
258 | end
259 | local allBytes=lengthdata..data
260 | return allBytes
261 | end
262 |
263 | function _M:sessionInvalid(session)
264 | --return marker1
265 | self.channel:c2pSend(tnsPackets.Marker1:pack())
266 | --return marker2
267 | self.channel:c2pSend(tnsPackets.Marker2:pack())
268 | self.sessionStop=true
269 | end
270 |
271 | return _M
272 |
--------------------------------------------------------------------------------
/tns/crypt.lua:
--------------------------------------------------------------------------------
1 | --- Class that handles all Oracle encryption
2 | local cipher=require ("resty.openssl.cipher")
3 | local rand=require("resty.openssl.rand")
4 | require "suproxy.utils.stringUtils"
5 | require("suproxy.utils.pureluapack")
6 | local aes = require "resty.aes"
7 | local _M = {
8 |
9 | getServerKey=function(self,pass,realpass,s_sesskey,auth_vrfy_data)
10 | local pw_hash = ngx.sha1_bin(realpass .. auth_vrfy_data) .. "\0\0\0\0"
11 | local srv_sesskey=cipher.new("aes-192-cbc"):decrypt(pw_hash,pw_hash:sub(1,16),s_sesskey,true)
12 | --srv_sesskey= rand.bytes(40) .. string.fromhex("0808080808080808")
13 | local pw_hash = ngx.sha1_bin(pass .. auth_vrfy_data) .. "\0\0\0\0"
14 | local result=cipher.new("aes-192-cbc"):encrypt(pw_hash,pw_hash:sub(1,16),srv_sesskey,true)
15 | return result,srv_sesskey
16 | end,
17 |
18 | Decrypt11g = function(self, c_sesskey, s_sesskey, auth_password, pass, salt )
19 | local sha1 = ngx.sha1_bin(pass .. salt) .. "\0\0\0\0"
20 | local server_sesskey =cipher.new("aes-192-cbc"):decrypt(sha1,sha1:sub(1,16),s_sesskey,true)
21 | local client_sesskey = cipher.new("aes-192-cbc"):decrypt(sha1,sha1:sub(1,16),c_sesskey,true)
22 | local combined_sesskey = {}
23 | for i=17, 40 do
24 | combined_sesskey[#combined_sesskey+1] = string.char( bit.bxor(string.byte(server_sesskey, i) , string.byte(client_sesskey,i)) )
25 | end
26 | combined_sesskey = table.concat(combined_sesskey)
27 | print("combined_sesskey",combined_sesskey:hex())
28 | combined_sesskey = ( ngx.md5_bin( combined_sesskey:sub(1,16) ) .. ngx.md5_bin(combined_sesskey:sub(17) ) ):sub(1, 24)
29 | local p,err= cipher.new("aes-192-cbc"):decrypt(combined_sesskey,combined_sesskey:sub(1,16),auth_password)
30 | return client_sesskey,p:sub(17)
31 | end,
32 |
33 | -- -- - Creates an Oracle 10G password hash
34 |
35 | -- -- @param username containing the Oracle user name
36 | -- -- @param password containing the Oracle user password
37 | -- -- @return hash containing the Oracle hash
38 | -- HashPassword10g = function( self, username, password )
39 | -- local uspw = (username .. password):upper():gsub(".", "\0%1")
40 | -- local key = stdnse.fromhex("0123456789abcdef")
41 |
42 | -- -- do padding
43 | -- uspw = uspw .. string.rep('\0', (8 - (#uspw % 8)) % 8)
44 |
45 | -- local iv2 = openssl.encrypt( "DES-CBC", key, nil, uspw, false ):sub(-8)
46 | -- local enc = openssl.encrypt( "DES-CBC", iv2, nil, uspw, false ):sub(-8)
47 | -- return enc
48 | -- end,
49 |
50 | -- -- Test function, not currently in use
51 | -- Decrypt10g = function(self, user, pass, srv_sesskey_enc )
52 | -- local pwhash = self:HashPassword10g( user, pass ) .. "\0\0\0\0\0\0\0\0"
53 | -- local cli_sesskey_enc = stdnse.fromhex("7B244D7A1DB5ABE553FB9B7325110024911FCBE95EF99E7965A754BC41CF31C0")
54 | -- local srv_sesskey = openssl.decrypt( "AES-128-CBC", pwhash, nil, srv_sesskey_enc )
55 | -- local cli_sesskey = openssl.decrypt( "AES-128-CBC", pwhash, nil, cli_sesskey_enc )
56 | -- local auth_pass = stdnse.fromhex("4C5E28E66B6382117F9D41B08957A3B9E363B42760C33B44CA5D53EA90204ABE")
57 | -- local pass
58 |
59 | -- local combined_sesskey = {}
60 | -- for i=17, 32 do
61 | -- combined_sesskey[#combined_sesskey+1] = string.char( string.byte(srv_sesskey, i) ~ string.byte(cli_sesskey, i) )
62 | -- end
63 | -- combined_sesskey = openssl.md5( table.concat(combined_sesskey) )
64 |
65 | -- pass = openssl.decrypt( "AES-128-CBC", combined_sesskey, nil, auth_pass ):sub(17)
66 |
67 | -- print( stdnse.tohex( srv_sesskey ))
68 | -- print( stdnse.tohex( cli_sesskey ))
69 | -- print( stdnse.tohex( combined_sesskey ))
70 | -- print( "pass=" .. pass )
71 | -- end,
72 |
73 | -- -- - Performs the relevant encryption needed for the Oracle 10g response
74 |
75 | -- -- @param user containing the Oracle user name
76 | -- -- @param pass containing the Oracle user password
77 | -- -- @param srv_sesskey_enc containing the encrypted server session key as
78 | -- -- received from the PreAuth packet
79 | -- -- @return cli_sesskey_enc the encrypted client session key
80 | -- -- -- @return auth_pass the encrypted Oracle password
81 | -- Encrypt10g = function( self, user, pass, srv_sesskey_enc )
82 |
83 | -- local pwhash = self:HashPassword10g( user, pass ) .. "\0\0\0\0\0\0\0\0"
84 | -- -- We're currently using a static client session key, this should
85 | -- -- probably be changed to a random value in the future
86 | -- local cli_sesskey = stdnse.fromhex("FAF5034314546426F329B1DAB1CDC5B8FF94349E0875623160350B0E13A0DA36")
87 | -- local srv_sesskey = openssl.decrypt( "AES-128-CBC", pwhash, nil, srv_sesskey_enc )
88 | -- local cli_sesskey_enc = openssl.encrypt( "AES-128-CBC", pwhash, nil, cli_sesskey )
89 | -- -- This value should really be random, not this static cruft
90 | -- local rnd = stdnse.fromhex("4C31AFE05F3B012C0AE9AB0CDFF0C508")
91 | -- local auth_pass
92 |
93 | -- local combined_sesskey = {}
94 | -- for i=17, 32 do
95 | -- combined_sesskey[#combined_sesskey+1] = string.char( string.byte(srv_sesskey, i) ~ string.byte(cli_sesskey, i) )
96 | -- end
97 | -- combined_sesskey = openssl.md5( table.concat(combined_sesskey) )
98 | -- auth_pass = openssl.encrypt("AES-128-CBC", combined_sesskey, nil, rnd .. pass, true )
99 | -- auth_pass = stdnse.tohex(auth_pass)
100 | -- cli_sesskey_enc = stdnse.tohex(cli_sesskey_enc)
101 | -- return cli_sesskey_enc, auth_pass
102 | -- end,
103 |
104 | -- - Performs the relevant encryption needed for the Oracle 11g response
105 |
106 | -- @param pass containing the Oracle user password
107 | -- @param cli_sesskey unencrypted client key
108 | -- @param srv_sesskey_enc containing the encrypted server session key as
109 | -- received from the PreAuth packet
110 | -- @param auth_vrfy_data containing the password salt as received from the
111 | -- PreAuth packet
112 | -- @return cli_sesskey_enc the encrypted client session key
113 | -- @return auth_pass the encrypted Oracle password
114 | Encrypt11g = function( self, pass,cli_sesskey, srv_sesskey_enc, auth_vrfy_data )
115 | local rnd = rand.bytes(16)
116 | --local cli_sesskey = rand.bytes(40) .. string.fromhex("0808080808080808")
117 | local pw_hash = ngx.sha1_bin(pass .. auth_vrfy_data) .. "\0\0\0\0"
118 | local srv_sesskey=cipher.new("aes-192-cbc"):decrypt(pw_hash, pw_hash:sub(1,16),srv_sesskey_enc)
119 | local auth_password
120 | local cli_sesskey_enc
121 | local combined_sesskey = {}
122 | for i=17, 40 do
123 | combined_sesskey[#combined_sesskey+1] = string.char( bit.bxor(string.byte(srv_sesskey, i) ,string.byte(cli_sesskey, i) ))
124 | end
125 | combined_sesskey = table.concat(combined_sesskey)
126 | combined_sesskey = ( ngx.md5_bin( combined_sesskey:sub(1,16) ) .. ngx.md5_bin( combined_sesskey:sub(17) ) ):sub(1, 24)
127 | local cli_sesskey_enc=cipher.new("aes-192-cbc"):encrypt(pw_hash,pw_hash:sub(1,16),cli_sesskey,true)
128 | auth_password=cipher.new("aes-192-cbc"):encrypt(combined_sesskey,combined_sesskey:sub(1,16),rnd .. pass)
129 | return cli_sesskey_enc, auth_password
130 | end,
131 |
132 | }
133 | return _M
--------------------------------------------------------------------------------
/tns/parser.lua:
--------------------------------------------------------------------------------
1 | --tns protocol parser
2 | local P=require "suproxy.tns.tnsPackets"
3 | local K=P.getKey
4 | local parser=require("suproxy.parser")
5 | local _M={}
6 |
7 | local conf={
8 | {key=K({code=1}), parser=P.Connect, eventName="ConnectEvent"},
9 | {key=K({callId=0x76}), parser=P.SessionRequest, eventName="SessionRequestEvent"},
10 | {key=K({callId=0x73}), parser=P.AuthRequest, eventName="AuthRequestEvent"},
11 | {key=K({dataId=1}), parser=P.SetProtocolRequest, eventName="SetProtocolEvent"},
12 | {key=K({callId=0x69}), parser=P.Piggyback, eventName="Piggyback1169"},
13 | {key=K({callId=0x6b}), parser=P.Piggyback, eventName="Piggyback116b"},
14 | {key=K({callId=0x5e}), parser=P.SQLRequest, eventName="SQLRequestEvent"},
15 | {key=K({code=2}), parser=P.Accept, eventName="AcceptEvent"},
16 | {key=K({code=12}), eventName="MarkerEvent"},
17 | {key=K({dataId=8,req=K({callId=0x76})}), parser=P.SessionResponse, eventName="SessionResponseEvent"},
18 | {key=K({dataId=8,req=K({callId=0x3b})}), parser=P.VersionResponse, eventName="VersionResponseEvent"},
19 | {key=K({code=12,req=K({callId=0x73})}), eventName="AuthErrorEvent"},
20 | }
21 |
22 | local keyG=function(allBytes,pos,options,request)
23 | --if options is null use default value
24 | options=P.Options:new(options)
25 | local pktChkLen=(options:pktChk() and 2 or 0)
26 | local hdrChkLen=(options:hdrChk() and 2 or 0)
27 | local pktType=allBytes:byte(3+pktChkLen)
28 | local dataId,callId,key,keyStr
29 | if pktType==P.PacketType.DATA.code then dataId=allBytes:byte(7+pktChkLen+hdrChkLen) end
30 | if dataId==P.DataID.USER_OCI_FUNC.code or dataId==P.DataID.PIGGYBACK_FUNC.code then callId=allBytes:byte(8+pktChkLen+hdrChkLen) end
31 | return K({callId=callId,dataId=dataId,code=pktType,req=request})
32 | end
33 |
34 | function _M:new()
35 | local o= setmetatable({},{__index=self})
36 | local C2PParser=parser:new()
37 | C2PParser.keyGenerator=keyG
38 | C2PParser:registerMulti(conf)
39 | C2PParser:registerDefaultParser(P.Packet)
40 | o.C2PParser=C2PParser
41 |
42 | local S2PParser=parser:new()
43 | S2PParser.keyGenerator=keyG
44 | S2PParser:registerMulti(conf)
45 | S2PParser:registerDefaultParser(P.Packet)
46 | o.S2PParser=S2PParser
47 | return o
48 | end
49 | return _M
50 |
51 |
52 |
--------------------------------------------------------------------------------
/utils/compatibleLog.lua:
--------------------------------------------------------------------------------
1 | local tableUtils=require "suproxy.utils.tableUtils"
2 | local _M={}
3 |
4 | _M.STDERR ={code=0x00,ngxCode=0x00,desc="STDERR"}
5 | _M.EMERG ={code=0x01,ngxCode=0x01,desc="EMERG" }
6 | _M.ALERT ={code=0x02,ngxCode=0x02,desc="ALERT" }
7 | _M.CRIT ={code=0x03,ngxCode=0x03,desc="CRIT" }
8 | _M.ERR ={code=0x04,ngxCode=0x04,desc="ERR" }
9 | _M.WARN ={code=0x05,ngxCode=0x05,desc="WARN" }
10 | _M.NOTICE ={code=0x06,ngxCode=0x06,desc="NOTICE"}
11 | _M.INFO ={code=0x07,ngxCode=0x07,desc="INFO" }
12 | _M.DEBUG ={code=0x08,ngxCode=0x08,desc="DEBUG" }
13 |
14 | local _ngxMapping={
15 | [0x00]=_M.STDERR,[0x01]=_M.EMERG ,[0x02]=_M.ALERT ,
16 | [0x03]=_M.CRIT ,[0x04]=_M.ERR ,[0x05]=_M.WARN ,
17 | [0x06]=_M.NOTICE,[0x07]=_M.INFO ,[0x08]=_M.DEBUG ,
18 | }
19 |
20 | local _logLevel=_M.DEBUG
21 |
22 | function _M.logInner(level,stackUpLevel,...)
23 | level=level or _M.NOTICE
24 | stackUpLevel=stackUpLevel or 2
25 | local args={...}
26 | local func=debug.getinfo(stackUpLevel).short_src ..":"..debug.getinfo(stackUpLevel).currentline
27 | local ok,ngxLog=pcall(require,"ngx.errlog")
28 | if ok and ngxLog then
29 | ngxLog.raw_log(level.ngxCode,func..": "..tableUtils.concat(args))
30 | elseif level.code<=_logLevel.code then
31 | print(level.desc,func,":",tableUtils.concat(args))
32 | end
33 | end
34 |
35 | function _M.log(level,...)
36 | _M.logInner(level,3,...)
37 | end
38 |
39 | function _M.logWithTitle(level,title,...)
40 | local l=(80-#title)/2
41 | local l=l >= 0 and l or 0
42 | _M.logInner(level,3,"\r\n"..string.rep("-",l)..title..string.rep("-",l).."\r\n",...)
43 | end
44 |
45 | function _M.getLogLevel()
46 | local ok,ret=pcall(require,"ngx.errlog")
47 | if ok then _logLevel= _ngxMapping[ret.get_sys_filter_level()] end
48 | return _logLevel
49 | end
50 |
51 | function _M.setLogLevel(level)
52 | _logLevel=level
53 | local ok,ret=pcall(require,"ngx.errlog")
54 | if ok then
55 | status,err=ret.set_filter_level(level.ngxCode)
56 | if not status then
57 | ngx.log(ngx.ERR, err)
58 | end
59 | end
60 | end
61 |
62 | _M.unitTest={}
63 |
64 | function _M.test()
65 | _M.setLogLevel(_M.ERR)
66 | _M.log(_M.DEBUG,"abc",1,nil,{})
67 | _M.logWithTitle(_M.ERR,"abc",1,nil,{},6)
68 | _M.setLogLevel(_M.DEBUG)
69 | _M.log(_M.DEBUG,"abc",1,nil,{})
70 | _M.logWithTitle(_M.ERR,"abc",1,nil,{},6)
71 | end
72 |
73 | return _M
74 |
--------------------------------------------------------------------------------
/utils/datetime.lua:
--------------------------------------------------------------------------------
1 | --from https://www.cnblogs.com/wangzhitie/p/5209985.html
2 | --[[
3 | 参数说明:
4 | srcDateTime 原始时间字符串,要求格式%Y%m%d%H%M%S,这个时间格式字符串表示4位年份、月份、day、小时、分钟、秒都是2位数字
5 | interval 对该时间进行加或减具体值,>0表示加 <0表示减
6 | dateUnit 时间单位,支持DAY、HOUR、SECOND、MINUTE 4种时间单位操作,根据interval具体值对原始时间按指定的单位进行加或减
7 | 例如,
8 | interval=10,unit='DAY',表示对原始时间加10天
9 | interval=-1,unit='HOUR',表示对原始时间减1小时
10 |
11 | 返回结果是一个os.date,他是一个table结构,里面包含了year,month,day,hour,minute,second 6个属性,跟据需要从结果里面取出需要的属性然后根据需要产生相应的新的日期格式即可。
12 | ]]
13 |
14 | local _M={}
15 |
16 | function _M.getNewDate(srcDateTime,interval ,dateUnit)
17 | --从日期字符串中截取出年月日时分秒
18 | local Y = string.sub(srcDateTime,1,4)
19 | local M = string.sub(srcDateTime,5,6)
20 | local D = string.sub(srcDateTime,7,8)
21 | local H = string.sub(srcDateTime,9,10)
22 | local MM = string.sub(srcDateTime,11,12)
23 | local SS = string.sub(srcDateTime,13,14)
24 |
25 | --把日期时间字符串转换成对应的日期时间
26 | local dt1 = os.time{year=Y, month=M, day=D, hour=H,min=MM,sec=SS}
27 |
28 | --根据时间单位和偏移量得到具体的偏移数据
29 | local ofset=0
30 |
31 | if dateUnit =='DAY' then
32 | ofset = 60 *60 * 24 * interval
33 |
34 | elseif dateUnit == 'HOUR' then
35 | ofset = 60 *60 * interval
36 |
37 | elseif dateUnit == 'MINUTE' then
38 | ofset = 60 * interval
39 |
40 | elseif dateUnit == 'SECOND' then
41 | ofset = interval
42 | end
43 |
44 | --指定的时间+时间偏移量
45 | local newTime = os.date("*t", dt1 + tonumber(ofset))
46 | return newTime
47 | end
48 |
49 | function _M.test()
50 | local oldTime="20130908232828"
51 | --把指定的时间加3小时
52 | local newTime=_M.getNewDate(oldTime,3,'HOUR')
53 | local t1 = string.format('%d-%02d-%02d %02d:%02d:%02d',newTime.year,newTime.month,newTime.day,newTime.hour,newTime.min,newTime.sec)
54 | print('t1='..t1)
55 |
56 | --把指定的时间加1天
57 | local newTime=_M.getNewDate(oldTime,1,'DAY')
58 |
59 | local t2 = string.format('%d%02d%02d%02d%02d%02d',newTime.year,newTime.month,newTime.day,newTime.hour,newTime.min,newTime.sec)
60 |
61 | print('t2='..t2)
62 | end
63 |
64 | return _M
65 |
--------------------------------------------------------------------------------
/utils/event.lua:
--------------------------------------------------------------------------------
1 | local _M={}
2 | local logger=require "suproxy.utils.compatibleLog"
3 | local function addHandler(self,context,...)
4 | local handlers={...}
5 | if #handlers ==0 then print("no handler was added to event") return error("no handler was added to event") end
6 | for k,v in ipairs(handlers) do
7 | table.insert(self.chain,{context=context,handler=v})
8 | end
9 | end
10 |
11 | function _M:new(source,name)
12 | local o={
13 | source=source,
14 | name=name,
15 | chain={},
16 | }
17 | o.addHandler=addHandler
18 | o.trigger=function(self,...)
19 | logger.logWithTitle(logger.DEBUG,string.format("event %s triggered",self.name),"")
20 | local args={...}
21 | for k,v in ipairs(self.chain) do
22 | v.handler(v.context,self.source,unpack(args))
23 | end
24 | end
25 | return setmetatable(o, {__index=self})
26 | end
27 |
28 | function _M:newReturnEvent(source,name)
29 | local o=_M:new(source,name)
30 | o.addHandler=function(self,context,...)
31 | assert(#(self.chain)==0,"returnEvent cannot has more than one handler")
32 | addHandler(self,context,...)
33 | end
34 | o.trigger=function(self,...)
35 | logger.logWithTitle(logger.DEBUG,string.format("event %s triggered",self.name),"")
36 | local args={...}
37 | for k,v in ipairs(self.chain) do
38 | return unpack{v.handler(v.context,self.source,unpack(args))}
39 | end
40 | end
41 | return o
42 | end
43 |
44 |
45 | function _M:setHandler(context,...)
46 | self.chain={}
47 | self:addHandler(context,...)
48 | end
49 |
50 | function _M:hasHandler()
51 | return #(self.chain)>0
52 | end
53 |
54 | _M.unitTest={}
55 | function _M.test()
56 | local src={}
57 | src.eventA=_M:new(src)
58 | src.eventB=_M:newReturnEvent(src)
59 | local handlers={
60 | handle1=function(self,source,params)
61 | print("1 executed"..tostring(params))
62 | source.name=params
63 | end,
64 | handle2=function(self,source,params)
65 | print("2 executed"..tostring(params))
66 | self.name=params
67 | end,
68 | handle3=function(self,source,params)
69 | print("3 executed"..tostring(params))
70 | end,
71 | }
72 | src.eventA:addHandler(handlers,handlers.handle1,handlers.handle2,handlers.handle3)
73 | local result=src.eventA:trigger("0","lala")
74 | assert(src.name=="0",src.name)
75 | assert(handlers.name=="0",handlers.name)
76 | src.eventB:addHandler(src,function(self,source,params) source.name=params return "2",true end)
77 | local result1=src.eventB:trigger("0","la")
78 | assert(src.name=="0",src.name)
79 | assert(result1=="2",result)
80 | ok=pcall(src.eventB.addHandler,src.eventB,nil,function()end)
81 | assert(not ok)
82 | end
83 |
84 | return _M
--------------------------------------------------------------------------------
/utils/ffi-zlib.lua:
--------------------------------------------------------------------------------
1 | local ffi = require "ffi"
2 | local ffi_new = ffi.new
3 | local ffi_str = ffi.string
4 | local ffi_sizeof = ffi.sizeof
5 | local ffi_copy = ffi.copy
6 | local tonumber = tonumber
7 |
8 | local _M = {
9 | _VERSION = '0.01',
10 | }
11 |
12 | local mt = { __index = _M }
13 |
14 |
15 | ffi.cdef([[
16 | enum {
17 | Z_NO_FLUSH = 0,
18 | Z_PARTIAL_FLUSH = 1,
19 | Z_SYNC_FLUSH = 2,
20 | Z_FULL_FLUSH = 3,
21 | Z_FINISH = 4,
22 | Z_BLOCK = 5,
23 | Z_TREES = 6,
24 | /* Allowed flush values; see deflate() and inflate() below for details */
25 | Z_OK = 0,
26 | Z_STREAM_END = 1,
27 | Z_NEED_DICT = 2,
28 | Z_ERRNO = -1,
29 | Z_STREAM_ERROR = -2,
30 | Z_DATA_ERROR = -3,
31 | Z_MEM_ERROR = -4,
32 | Z_BUF_ERROR = -5,
33 | Z_VERSION_ERROR = -6,
34 | /* Return codes for the compression/decompression functions. Negative values
35 | * are errors, positive values are used for special but normal events.
36 | */
37 | Z_NO_COMPRESSION = 0,
38 | Z_BEST_SPEED = 1,
39 | Z_BEST_COMPRESSION = 9,
40 | Z_DEFAULT_COMPRESSION = -1,
41 | /* compression levels */
42 | Z_FILTERED = 1,
43 | Z_HUFFMAN_ONLY = 2,
44 | Z_RLE = 3,
45 | Z_FIXED = 4,
46 | Z_DEFAULT_STRATEGY = 0,
47 | /* compression strategy; see deflateInit2() below for details */
48 | Z_BINARY = 0,
49 | Z_TEXT = 1,
50 | Z_ASCII = Z_TEXT, /* for compatibility with 1.2.2 and earlier */
51 | Z_UNKNOWN = 2,
52 | /* Possible values of the data_type field (though see inflate()) */
53 | Z_DEFLATED = 8,
54 | /* The deflate compression method (the only one supported in this version) */
55 | Z_NULL = 0, /* for initializing zalloc, zfree, opaque */
56 | };
57 |
58 |
59 | typedef void* (* z_alloc_func)( void* opaque, unsigned items, unsigned size );
60 | typedef void (* z_free_func) ( void* opaque, void* address );
61 |
62 | typedef struct z_stream_s {
63 | char* next_in;
64 | unsigned avail_in;
65 | unsigned long total_in;
66 | char* next_out;
67 | unsigned avail_out;
68 | unsigned long total_out;
69 | char* msg;
70 | void* state;
71 | z_alloc_func zalloc;
72 | z_free_func zfree;
73 | void* opaque;
74 | int data_type;
75 | unsigned long adler;
76 | unsigned long reserved;
77 | } z_stream;
78 |
79 |
80 | const char* zlibVersion();
81 | const char* zError(int);
82 |
83 | int inflate(z_stream*, int flush);
84 | int inflateEnd(z_stream*);
85 | int inflateInit2_(z_stream*, int windowBits, const char* version, int stream_size);
86 |
87 | int deflate(z_stream*, int flush);
88 | int deflateEnd(z_stream* );
89 | int deflateInit2_(z_stream*, int level, int method, int windowBits, int memLevel,int strategy, const char *version, int stream_size);
90 |
91 | unsigned long adler32(unsigned long adler, const char *buf, unsigned len);
92 | unsigned long crc32(unsigned long crc, const char *buf, unsigned len);
93 | unsigned long adler32_combine(unsigned long, unsigned long, long);
94 | unsigned long crc32_combine(unsigned long, unsigned long, long);
95 |
96 | ]])
97 |
98 | local zlib = ffi.load(ffi.os == "Windows" and "zlib1" or "z")
99 | _M.zlib = zlib
100 |
101 | -- Default to 16k output buffer
102 | local DEFAULT_CHUNK = 16384
103 |
104 | local Z_OK = zlib.Z_OK
105 | local Z_NO_FLUSH = zlib.Z_NO_FLUSH
106 | local Z_STREAM_END = zlib.Z_STREAM_END
107 | local Z_FINISH = zlib.Z_FINISH
108 |
109 | local function zlib_err(err)
110 | return ffi_str(zlib.zError(err))
111 | end
112 | _M.zlib_err = zlib_err
113 |
114 | local function createStream(bufsize)
115 | -- Setup Stream
116 | local stream = ffi_new("z_stream")
117 |
118 | -- Create input buffer var
119 | local inbuf = ffi_new('char[?]', bufsize+1)
120 | stream.next_in, stream.avail_in = inbuf, 0
121 |
122 | -- create the output buffer
123 | local outbuf = ffi_new('char[?]', bufsize)
124 | stream.next_out, stream.avail_out = outbuf, 0
125 |
126 | return stream, inbuf, outbuf
127 | end
128 | _M.createStream = createStream
129 |
130 | local function initInflate(stream, windowBits)
131 | -- Setup inflate process
132 | local windowBits = windowBits or (15 + 32) -- +32 sets automatic header detection
133 | local version = ffi_str(zlib.zlibVersion())
134 |
135 | return zlib.inflateInit2_(stream, windowBits, version, ffi_sizeof(stream))
136 | end
137 | _M.initInflate = initInflate
138 |
139 | local function initDeflate(stream, options)
140 | -- Setup deflate process
141 | local method = zlib.Z_DEFLATED
142 | local level = options.level or zlib.Z_DEFAULT_COMPRESSION
143 | local memLevel = options.memLevel or 8
144 | local strategy = options.strategy or zlib.Z_DEFAULT_STRATEGY
145 | local windowBits = options.windowBits or (15 + 16) -- +16 sets gzip wrapper not zlib
146 | local version = ffi_str(zlib.zlibVersion())
147 |
148 | return zlib.deflateInit2_(stream, level, method, windowBits, memLevel, strategy, version, ffi_sizeof(stream))
149 | end
150 | _M.initDeflate = initDeflate
151 |
152 | local function flushOutput(stream, bufsize, output, outbuf)
153 | -- Calculate available output bytes
154 | local out_sz = bufsize - stream.avail_out
155 | if out_sz == 0 then
156 | return
157 | end
158 | -- Read bytes from output buffer and pass to output function
159 | output(ffi_str(outbuf, out_sz))
160 | end
161 |
162 | local function flate(zlib_flate, zlib_flateEnd, input, output, bufsize, stream, inbuf, outbuf)
163 | -- Inflate or Deflate a stream
164 | local err = 0
165 | local mode = Z_NO_FLUSH
166 | repeat
167 | -- Read some input
168 | local data = input(bufsize)
169 | if data ~= nil then
170 | ffi_copy(inbuf, data)
171 | stream.next_in, stream.avail_in = inbuf, #data
172 | else
173 | -- EOF, try and finish up
174 | mode = Z_FINISH
175 | stream.avail_in = 0
176 | end
177 |
178 | -- While the output buffer is being filled completely just keep going
179 | repeat
180 | stream.next_out = outbuf
181 | stream.avail_out = bufsize
182 | -- Process the stream
183 | err = zlib_flate(stream, mode)
184 | if err < Z_OK then
185 | -- Error, clean up and return
186 | zlib_flateEnd(stream)
187 | return false, "FLATE: "..zlib_err(err), stream
188 | end
189 | -- Write the data out
190 | flushOutput(stream, bufsize, output, outbuf)
191 | until stream.avail_out ~= 0
192 |
193 | until err == Z_STREAM_END
194 |
195 | -- Stream finished, clean up and return
196 | zlib_flateEnd(stream)
197 | return true, zlib_err(err)
198 | end
199 | _M.flate = flate
200 |
201 | local function adler(str, chksum)
202 | local chksum = chksum or 0
203 | local str = str or ""
204 | return zlib.adler32(chksum, str, #str)
205 | end
206 | _M.adler = adler
207 |
208 | local function crc(str, chksum)
209 | local chksum = chksum or 0
210 | local str = str or ""
211 | return zlib.crc32(chksum, str, #str)
212 | end
213 | _M.crc = crc
214 |
215 | function _M.inflateGzip(input, output, bufsize, windowBits)
216 | local bufsize = bufsize or DEFAULT_CHUNK
217 |
218 | -- Takes 2 functions that provide input data from a gzip stream and receives output data
219 | -- Returns uncompressed string
220 | local stream, inbuf, outbuf = createStream(bufsize)
221 |
222 | local init = initInflate(stream, windowBits)
223 | if init == Z_OK then
224 | local ok, err = flate(zlib.inflate, zlib.inflateEnd, input, output, bufsize, stream, inbuf, outbuf)
225 | return ok,err
226 | else
227 | -- Init error
228 | zlib.inflateEnd(stream)
229 | return false, "INIT: "..zlib_err(init)
230 | end
231 | end
232 |
233 | function _M.deflateGzip(input, output, bufsize, options)
234 | local bufsize = bufsize or DEFAULT_CHUNK
235 | options = options or {}
236 |
237 | -- Takes 2 functions that provide plain input data and receives output data
238 | -- Returns gzip compressed string
239 | local stream, inbuf, outbuf = createStream(bufsize)
240 |
241 | local init = initDeflate(stream, options)
242 | if init == Z_OK then
243 | local ok, err = flate(zlib.deflate, zlib.deflateEnd, input, output, bufsize, stream, inbuf, outbuf)
244 | return ok,err
245 | else
246 | -- Init error
247 | zlib.deflateEnd(stream)
248 | return false, "INIT: "..zlib_err(init)
249 | end
250 | end
251 |
252 | function _M.version()
253 | return ffi_str(zlib.zlibVersion())
254 | end
255 |
256 | return _M
--------------------------------------------------------------------------------
/utils/json.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- json.lua
3 | --
4 | -- Copyright (c) 2020 rxi
5 | --
6 | -- Permission is hereby granted, free of charge, to any person obtaining a copy of
7 | -- this software and associated documentation files (the "Software"), to deal in
8 | -- the Software without restriction, including without limitation the rights to
9 | -- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
10 | -- of the Software, and to permit persons to whom the Software is furnished to do
11 | -- so, subject to the following conditions:
12 | --
13 | -- The above copyright notice and this permission notice shall be included in all
14 | -- copies or substantial portions of the Software.
15 | --
16 | -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | -- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | -- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | -- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | -- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | -- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | -- SOFTWARE.
23 | --
24 |
25 | local json = { _version = "0.1.2" }
26 |
27 | -------------------------------------------------------------------------------
28 | -- Encode
29 | -------------------------------------------------------------------------------
30 |
31 | local encode
32 |
33 | local escape_char_map = {
34 | [ "\\" ] = "\\",
35 | [ "\"" ] = "\"",
36 | [ "\b" ] = "b",
37 | [ "\f" ] = "f",
38 | [ "\n" ] = "n",
39 | [ "\r" ] = "r",
40 | [ "\t" ] = "t",
41 | }
42 |
43 | local escape_char_map_inv = { [ "/" ] = "/" }
44 | for k, v in pairs(escape_char_map) do
45 | escape_char_map_inv[v] = k
46 | end
47 |
48 |
49 | local function escape_char(c)
50 | return "\\" .. (escape_char_map[c] or string.format("u%04x", c:byte()))
51 | end
52 |
53 |
54 | local function encode_nil(val)
55 | return "null"
56 | end
57 |
58 |
59 | local function encode_table(val, stack)
60 | local res = {}
61 | stack = stack or {}
62 |
63 | -- Circular reference?
64 | if stack[val] then error("circular reference") end
65 |
66 | stack[val] = true
67 |
68 | if rawget(val, 1) ~= nil or next(val) == nil then
69 | -- Treat as array -- check keys are valid and it is not sparse
70 | local n = 0
71 | for k in pairs(val) do
72 | if type(k) ~= "number" then
73 | error("invalid table: mixed or invalid key types")
74 | end
75 | n = n + 1
76 | end
77 | if n ~= #val then
78 | error("invalid table: sparse array")
79 | end
80 | -- Encode
81 | for i, v in ipairs(val) do
82 | table.insert(res, encode(v, stack))
83 | end
84 | stack[val] = nil
85 | return "[" .. table.concat(res, ",") .. "]"
86 |
87 | else
88 | -- Treat as an object
89 | for k, v in pairs(val) do
90 | if type(k) ~= "string" then
91 | error("invalid table: mixed or invalid key types")
92 | end
93 | table.insert(res, encode(k, stack) .. ":" .. encode(v, stack))
94 | end
95 | stack[val] = nil
96 | return "{" .. table.concat(res, ",") .. "}"
97 | end
98 | end
99 |
100 |
101 | local function encode_string(val)
102 | return '"' .. val:gsub('[%z\1-\31\\"]', escape_char) .. '"'
103 | end
104 |
105 |
106 | local function encode_number(val)
107 | -- Check for NaN, -inf and inf
108 | if val ~= val or val <= -math.huge or val >= math.huge then
109 | error("unexpected number value '" .. tostring(val) .. "'")
110 | end
111 | return string.format("%.14g", val)
112 | end
113 |
114 |
115 | local type_func_map = {
116 | [ "nil" ] = encode_nil,
117 | [ "table" ] = encode_table,
118 | [ "string" ] = encode_string,
119 | [ "number" ] = encode_number,
120 | [ "boolean" ] = tostring,
121 | }
122 |
123 |
124 | encode = function(val, stack)
125 | local t = type(val)
126 | local f = type_func_map[t]
127 | if f then
128 | return f(val, stack)
129 | end
130 | error("unexpected type '" .. t .. "'")
131 | end
132 |
133 |
134 | function json.encode(val)
135 | return ( encode(val) )
136 | end
137 |
138 |
139 | -------------------------------------------------------------------------------
140 | -- Decode
141 | -------------------------------------------------------------------------------
142 |
143 | local parse
144 |
145 | local function create_set(...)
146 | local res = {}
147 | for i = 1, select("#", ...) do
148 | res[ select(i, ...) ] = true
149 | end
150 | return res
151 | end
152 |
153 | local space_chars = create_set(" ", "\t", "\r", "\n")
154 | local delim_chars = create_set(" ", "\t", "\r", "\n", "]", "}", ",")
155 | local escape_chars = create_set("\\", "/", '"', "b", "f", "n", "r", "t", "u")
156 | local literals = create_set("true", "false", "null")
157 |
158 | local literal_map = {
159 | [ "true" ] = true,
160 | [ "false" ] = false,
161 | [ "null" ] = nil,
162 | }
163 |
164 |
165 | local function next_char(str, idx, set, negate)
166 | for i = idx, #str do
167 | if set[str:sub(i, i)] ~= negate then
168 | return i
169 | end
170 | end
171 | return #str + 1
172 | end
173 |
174 |
175 | local function decode_error(str, idx, msg)
176 | local line_count = 1
177 | local col_count = 1
178 | for i = 1, idx - 1 do
179 | col_count = col_count + 1
180 | if str:sub(i, i) == "\n" then
181 | line_count = line_count + 1
182 | col_count = 1
183 | end
184 | end
185 | error( string.format("%s at line %d col %d", msg, line_count, col_count) )
186 | end
187 |
188 |
189 | local function codepoint_to_utf8(n)
190 | -- http://scripts.sil.org/cms/scripts/page.php?site_id=nrsi&id=iws-appendixa
191 | local f = math.floor
192 | if n <= 0x7f then
193 | return string.char(n)
194 | elseif n <= 0x7ff then
195 | return string.char(f(n / 64) + 192, n % 64 + 128)
196 | elseif n <= 0xffff then
197 | return string.char(f(n / 4096) + 224, f(n % 4096 / 64) + 128, n % 64 + 128)
198 | elseif n <= 0x10ffff then
199 | return string.char(f(n / 262144) + 240, f(n % 262144 / 4096) + 128,
200 | f(n % 4096 / 64) + 128, n % 64 + 128)
201 | end
202 | error( string.format("invalid unicode codepoint '%x'", n) )
203 | end
204 |
205 |
206 | local function parse_unicode_escape(s)
207 | local n1 = tonumber( s:sub(1, 4), 16 )
208 | local n2 = tonumber( s:sub(7, 10), 16 )
209 | -- Surrogate pair?
210 | if n2 then
211 | return codepoint_to_utf8((n1 - 0xd800) * 0x400 + (n2 - 0xdc00) + 0x10000)
212 | else
213 | return codepoint_to_utf8(n1)
214 | end
215 | end
216 |
217 |
218 | local function parse_string(str, i)
219 | local res = ""
220 | local j = i + 1
221 | local k = j
222 |
223 | while j <= #str do
224 | local x = str:byte(j)
225 |
226 | if x < 32 then
227 | decode_error(str, j, "control character in string")
228 |
229 | elseif x == 92 then -- `\`: Escape
230 | res = res .. str:sub(k, j - 1)
231 | j = j + 1
232 | local c = str:sub(j, j)
233 | if c == "u" then
234 | local hex = str:match("^[dD][89aAbB]%x%x\\u%x%x%x%x", j + 1)
235 | or str:match("^%x%x%x%x", j + 1)
236 | or decode_error(str, j - 1, "invalid unicode escape in string")
237 | res = res .. parse_unicode_escape(hex)
238 | j = j + #hex
239 | else
240 | if not escape_chars[c] then
241 | decode_error(str, j - 1, "invalid escape char '" .. c .. "' in string")
242 | end
243 | res = res .. escape_char_map_inv[c]
244 | end
245 | k = j + 1
246 |
247 | elseif x == 34 then -- `"`: End of string
248 | res = res .. str:sub(k, j - 1)
249 | return res, j + 1
250 | end
251 |
252 | j = j + 1
253 | end
254 |
255 | decode_error(str, i, "expected closing quote for string")
256 | end
257 |
258 |
259 | local function parse_number(str, i)
260 | local x = next_char(str, i, delim_chars)
261 | local s = str:sub(i, x - 1)
262 | local n = tonumber(s)
263 | if not n then
264 | decode_error(str, i, "invalid number '" .. s .. "'")
265 | end
266 | return n, x
267 | end
268 |
269 |
270 | local function parse_literal(str, i)
271 | local x = next_char(str, i, delim_chars)
272 | local word = str:sub(i, x - 1)
273 | if not literals[word] then
274 | decode_error(str, i, "invalid literal '" .. word .. "'")
275 | end
276 | return literal_map[word], x
277 | end
278 |
279 |
280 | local function parse_array(str, i)
281 | local res = {}
282 | local n = 1
283 | i = i + 1
284 | while 1 do
285 | local x
286 | i = next_char(str, i, space_chars, true)
287 | -- Empty / end of array?
288 | if str:sub(i, i) == "]" then
289 | i = i + 1
290 | break
291 | end
292 | -- Read token
293 | x, i = parse(str, i)
294 | res[n] = x
295 | n = n + 1
296 | -- Next token
297 | i = next_char(str, i, space_chars, true)
298 | local chr = str:sub(i, i)
299 | i = i + 1
300 | if chr == "]" then break end
301 | if chr ~= "," then decode_error(str, i, "expected ']' or ','") end
302 | end
303 | return res, i
304 | end
305 |
306 |
307 | local function parse_object(str, i)
308 | local res = {}
309 | i = i + 1
310 | while 1 do
311 | local key, val
312 | i = next_char(str, i, space_chars, true)
313 | -- Empty / end of object?
314 | if str:sub(i, i) == "}" then
315 | i = i + 1
316 | break
317 | end
318 | -- Read key
319 | if str:sub(i, i) ~= '"' then
320 | decode_error(str, i, "expected string for key")
321 | end
322 | key, i = parse(str, i)
323 | -- Read ':' delimiter
324 | i = next_char(str, i, space_chars, true)
325 | if str:sub(i, i) ~= ":" then
326 | decode_error(str, i, "expected ':' after key")
327 | end
328 | i = next_char(str, i + 1, space_chars, true)
329 | -- Read value
330 | val, i = parse(str, i)
331 | -- Set
332 | res[key] = val
333 | -- Next token
334 | i = next_char(str, i, space_chars, true)
335 | local chr = str:sub(i, i)
336 | i = i + 1
337 | if chr == "}" then break end
338 | if chr ~= "," then decode_error(str, i, "expected '}' or ','") end
339 | end
340 | return res, i
341 | end
342 |
343 |
344 | local char_func_map = {
345 | [ '"' ] = parse_string,
346 | [ "0" ] = parse_number,
347 | [ "1" ] = parse_number,
348 | [ "2" ] = parse_number,
349 | [ "3" ] = parse_number,
350 | [ "4" ] = parse_number,
351 | [ "5" ] = parse_number,
352 | [ "6" ] = parse_number,
353 | [ "7" ] = parse_number,
354 | [ "8" ] = parse_number,
355 | [ "9" ] = parse_number,
356 | [ "-" ] = parse_number,
357 | [ "t" ] = parse_literal,
358 | [ "f" ] = parse_literal,
359 | [ "n" ] = parse_literal,
360 | [ "[" ] = parse_array,
361 | [ "{" ] = parse_object,
362 | }
363 |
364 |
365 | parse = function(str, idx)
366 | local chr = str:sub(idx, idx)
367 | local f = char_func_map[chr]
368 | if f then
369 | return f(str, idx)
370 | end
371 | decode_error(str, idx, "unexpected character '" .. chr .. "'")
372 | end
373 |
374 |
375 | function json.decode(str)
376 | if type(str) ~= "string" then
377 | error("expected argument of type string, got " .. type(str))
378 | end
379 | local res, idx = parse(str, next_char(str, 1, space_chars, true))
380 | idx = next_char(str, idx, space_chars, true)
381 | if idx <= #str then
382 | decode_error(str, idx, "trailing garbage")
383 | end
384 | return res
385 | end
386 |
387 |
388 | return json
--------------------------------------------------------------------------------
/utils/pureluapack.lua:
--------------------------------------------------------------------------------
1 | local bit=require("bit")
require "suproxy.utils.stringUtils"
local _INTLEN=4
local _DEFAULT_ENDIAN="<"
2 |
local function packDouble (e,l,n)
local result
local sign = 0
if n < 0.0 then
sign = 0x80
n = -n
end
local mant, expo = math.frexp(n)
if mant ~= mant then
result = string.char( 0xFF, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
elseif mant == math.huge then
if sign == 0 then
result = string.char( 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
else
result = string.char(0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
end
elseif mant == 0.0 and expo == 0 then
result = string.char(sign, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
else
expo = expo + 0x3FE
mant = (mant * 2.0 - 1.0) * math.ldexp(0.5, 53)
local x=(expo % 0x10) * 0x10 + math.floor(mant / 0x1000000000000)
print(tostring(x))
result = string.char(
sign + math.floor(expo / 0x10),
(expo % 0x10) * 0x10 + math.floor(mant / 0x1000000000000),
math.floor(mant / 0x10000000000) % 0x100,
math.floor(mant / 0x100000000) % 0x100,
math.floor(mant / 0x1000000) % 0x100,
math.floor(mant / 0x10000) % 0x100,
math.floor(mant / 0x100) % 0x100,
mant % 0x100)
end
if e=="<" then
result=result:reverse()
end
return result
end
local function unpackDouble (e,l,s)
local rs=s:sub(1,l)
if e=="<" then rs=rs:reverse() end
local b1, b2, b3, b4, b5, b6, b7, b8 = rs:byte(1, 8)
local sign = b1 > 0x7F
local expo = (b1 % 0x80) * 0x10 + math.floor(b2 / 0x10)
local mant = ((((((b2 % 0x10) * 0x100 + b3) * 0x100 + b4) * 0x100 + b5) * 0x100 + b6) * 0x100 + b7) * 0x100 + b8
if sign then
sign = -1
else
sign = 1
end
local n
if mant == 0 and expo == 0 then
n = sign * 0.0
elseif expo == 0x7FF then
if mant == 0 then
n = sign * huge
else
n = 0.0/0.0
end
else
n = sign * math.ldexp(1.0 + mant / 0x10000000000000, expo - 0x3FF)
end
return n,9
end
local function packNumber(e,length,k)
3 | if(k<0) then k=256^length+k end
4 | local rs=""
5 | local i=0
6 | while (k>=1) do
7 | local t=k%256
8 | if(e==">" ) then rs=string.char(t)..rs end
9 | if(e=="<" or e==nil or e=="=") then rs=rs..string.char(t) end
10 | k=bit.rshift(k,8)
11 | i=i+1
12 | end
13 | if i>length then return nil end
14 | while i" ) then rs=string.char(0)..rs end
16 | if(e=="<" or e==nil or e=="=") then rs=rs..string.char(0) end
17 | i=i+1
18 | end
19 | return rs
20 | end
21 |
22 |
local function packLengthPreStr(e,l,value)
23 |
24 | local strLen=string.len(value)
25 |
26 | return packNumber(e,l,strLen)..value
27 | end
28 |
local function packLengthStr(e,l,value)
29 |
30 | local strLen=string.len(value)
31 |
32 | for i=strLen,l-1,1 do
33 |
34 | value=value..string.char(0)
35 | end
36 |
37 | return value
38 | end
39 |
local function packZeroEndStr(e,l,value)
40 |
41 | return value..string.char(0)
42 | end
43 |
local function unpackNumber(endian,length,Str)
44 | local rs=Str:sub(1,length)
45 | if(endian==">")then
46 | rs=rs:reverse()
47 | end
48 | local i=1
49 | local result=string.byte(rs,1)
50 | while i+1<=length do
51 | result=result+string.byte(rs,i+1)*(256^i)
52 | i=i+1
53 | end
54 |
55 | return math.floor(result),i+1
56 | end
57 |
local function unpackSignedNumber(endian,length,Str)
58 | local result,pos=unpackNumber(endian,length,Str)
59 | --minus value
60 | if result >= (256^length)/2 then
61 | result = result - 256^length
62 | end
63 |
64 | return result,pos
65 | end
66 |
local function unpackLengthPreStr(e,l,value)
67 |
68 | local strLen=unpackNumber(e,l,value)
69 |
70 | return value:sub(l+1,l+strLen),l+strLen+1
71 | end
72 |
local function unpackLengthStr(e,l,value)
73 | return value:sub(1,l),l+1
74 | end
75 |
local function unpackZeroEndStr(e,l,value)
76 | local i=1;
77 | while(i<=#value and string.byte(value,i)~=0) do i=i+1 end
78 | return value:sub(1,i-1),i+1
79 | end
80 |
local function getLen(t,l)
local _t={
["B"]=1,
["b"]=1,
["H"]=2,
["h"]=2,
["d"]=8,
["f"]=4,
["s"]=1,
["I"]=_INTLEN,
["i"]=_INTLEN
}
if(not l ) then l=_t[t] end
return l
end
81 | function string.pack(fmt,...)
82 | local arg={...}
83 | assert(type(fmt)=="string","bad argument #1 to 'pack' (string expected, got "..type(fmt)..")")
84 | local rs=""
85 | local i=1
86 | local nativeEndian=_DEFAULT_ENDIAN
87 | for w,e,t,l in fmt:gmatch("(([<>=]?)([bBhH1LjJTiIfdnczsxX])([%d]*))") do
88 | l=tonumber(l)
89 | if(e:len()~=0) then nativeEndian=e end
90 | if(t=="I" or t=="B" or t=="H") then
91 | l=getLen(t,l)
92 | assert(type(arg[i]) == "number", "bad argument #"..(i+1).." to 'pack' (number expected, got "..type(arg[i])..")")
93 | assert(arg[i]<=256^l-1,"bad argument #"..(i+1).." to 'pack' (unsign integer overflow)")
94 | rs=rs..packNumber(nativeEndian,l,arg[i])
95 | elseif(t=="i" or t=="b" or t=="h") then
96 | l=getLen(t,l)
97 | assert(type(arg[i]) == "number", "bad argument #"..(i+1).." to 'pack' (number expected, got "..type(arg[i])..")")
98 | assert(arg[i]<256^l/2 and arg[i]>=-256^l/2,"bad argument #"..(i+1).." to 'pack' (signed interger overflow)")
99 | rs=rs..packNumber(nativeEndian,l,arg[i])
elseif(t=="d") then
l=getLen(t,l)
assert(type(arg[i]) == "number", "bad argument #"..(i+1).." to 'pack' (number expected, got "..type(arg[i])..")")
rs=rs..packDouble(nativeEndian,l,arg[i])
100 | elseif(t=="s") then
101 | assert(type(arg[i]) == "string", "bad argument #"..(i+1).." to 'pack' (string expected, got "..type(arg[i])..")")
102 | l=getLen(t,l)
103 | rs=rs..packLengthPreStr(nativeEndian,l,arg[i])
104 | elseif(t=="c") then
105 | assert(type(arg[i]) == "string", "bad argument #"..(i+1).." to 'pack' (string expected, got "..type(arg[i])..")")
106 | assert(l,"missing size for format option 'c'")
107 | rs=rs..packLengthStr(nativeEndian,l,arg[i])
108 | elseif(t=="z") then
109 | assert(type(arg[i]) == "string", "bad argument #"..(i+1).." to 'pack' (string expected, got "..type(arg[i])..")")
110 | rs=rs..packZeroEndStr(nativeEndian,l,arg[i])
111 | else
112 | error("invalid format option '"..t)
113 | end
114 | i=i+1
115 | end
116 |
117 | return rs
118 | end
119 |
120 |
121 | function string.unpack(fmt,value,pos)
122 |
123 | assert(type(fmt)=="string","bad argument #1 to 'unpack' (string expected, got "..type(fmt)..")")
124 | assert(type(value)=="string","bad argument #2 to 'unpack' (string expected, got "..type(value)..")")
125 | if(pos) then assert(pos>=1 and pos<=value:len()+1,"pos invalid") end
126 | local rs={}
127 | local i=1
128 | if(pos)then i=pos end
129 | local nativeEndian=_DEFAULT_ENDIAN
130 | for w,e,t,l in fmt:gmatch("(([<>=]?)([bBhH1LjJTiIfdnczsxX])([%d]*))") do
131 | l=tonumber(l)
132 | if(e:len()~=0) then nativeEndian=e end
133 | local segment=value:sub(i)
134 | local ps,index
135 | if(t=="I" or t=="B" or t=="H") then
136 | l=getLen(t,l)
assert(l>=1,"size out of limit")
137 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)")
138 | ps,index=unpackNumber(nativeEndian,l,segment)
139 | elseif(t=="i" or t=="b" or t=="h") then
140 | l=getLen(t,l)
141 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)")
142 | ps,index=unpackSignedNumber(nativeEndian,l,segment)
elseif(t=="d") then
l=getLen(t,l)
assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)")
ps,index=unpackDouble(nativeEndian,l,segment)
143 | elseif(t=="s") then
144 | l=getLen(t,l)
145 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)")
146 | ps,index=unpackLengthPreStr(nativeEndian,l,segment)
147 | elseif(t=="c") then
148 | ps,index=unpackLengthStr(nativeEndian,l,segment)
149 | assert(segment:len()>=l,"bad argument #2 to 'unpack' (data string too short)")
150 | elseif(t=="z") then
151 | ps,index=unpackZeroEndStr(nativeEndian,l,segment)
152 | else
153 | error("invalid format option '"..t)
154 | end
155 |
156 | table.insert(rs,ps)
157 | i=i+index-1
158 | end
159 | table.insert(rs,i)
160 | return unpack(rs)
161 | end
162 |
local _M={}
163 | function _M.test()
164 | local k=string.pack(">I4BHs1I4BHs1s2",673845,2,-200,"123")
182 | print(string.hex(m))
183 | --add some useless byte on start to test pos params
184 | m=string.char(0x2,0x3,0xff,0x00)..m
185 | local v1,v2,v3,v4,v5=string.unpack("I4Bh>s2",m,5)
186 | print(v1,v2,v3,v4,v5)
187 | assert(v1==673845)
188 | assert(v2==2)
189 | assert(v3==-200)
190 | assert(v4=="123")
191 |
192 | p=string.fromhex([[
193 | FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
194 | 29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
195 | EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
196 | E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
197 | EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE65381
198 | FFFFFFFF FFFFFFFF
199 | ]])
200 | q="FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF"
201 | assert(p:hex()==q,"fromhex err")
202 | print(p:hex(1,nil,4,8," "," ",1,1,12))
203 | p=string.fromhex
204 |
205 | print(p:hex(1,nil,8,16," "," ",1,1,12))
p=string.pack("31 and string.byte(x)<127 then
43 | asciiStr=asciiStr..string.char(string.byte(x))
44 | else
45 | asciiStr=asciiStr.."."
46 | end
47 | if linewidth and i%linewidth==0 then
48 | elseif columnwidth and i%columnwidth==0 then
49 | asciiStr=asciiStr..columnspan
50 | end
51 | end
52 | if linewidth and (i%linewidth==0 or i==#self) then
53 | d=d..columnspan..asciiStr.."\n"
54 | asciiStr=""
55 | elseif columnwidth and i%columnwidth==0 then
56 | d=d..columnspan
57 | else
58 | d=d..bytespan
59 | end
60 | return d
61 | end)
62 | return s
63 | end
64 |
65 | function string.random(byteNum)
66 | local ok,rand=pcall(require,"resty.openssl.rand")
67 | if ok then return rand.bytes(byteNum) end
68 | local byteNum=byteNum or 4
69 | local result=""
70 | math.randomseed(ngx and ngx.time() or os.time())
71 | for i=1,byteNum,1 do
72 | result=result..string.pack("I1",math.random(0x7f))
73 | end
74 | return result
75 | end
76 |
77 | function string.split(self,splitter)
78 | local nFindStartIndex = 1
79 | local nSplitIndex = 1
80 | local nSplitArray = {}
81 | while true do
82 | local nFindLastIndex = string.find(self, splitter, nFindStartIndex)
83 | if not nFindLastIndex then
84 | nSplitArray[nSplitIndex] = string.sub(self, nFindStartIndex, string.len(self))
85 | break
86 | end
87 | nSplitArray[nSplitIndex] = string.sub(self, nFindStartIndex, nFindLastIndex - 1)
88 | nFindStartIndex = nFindLastIndex + string.len(splitter)
89 | nSplitIndex = nSplitIndex + 1
90 | end
91 | return nSplitArray
92 | end
93 |
94 | function string.ascii(self,noFormat)
95 | return self:gsub("." ,function(x)
96 | if (string.byte(x)<31 or string.byte(x)>127) then
97 | if not noFormat and string.byte(x) ~= 0x09 and string.byte(x) ~= 0x0a and string.byte(x) ~= 0x0d then
98 | return x
99 | end
100 | return "."
101 | end
102 | end)
103 | end
104 | --same as hex(1,nil,4,8," "," ",1,1,...)
105 | function string.hexF(self,pos,endpos,...)
106 | pos=pos or 1
107 | return self:hex(pos,endpos,4,8," "," ",1,1,...)
108 | end
109 | --same as hex(1,nil,8,16," "," ",1,1,...)
110 | function string.hex16F(self,pos,endpos,...)
111 | pos=pos or 1
112 | return self:hex(pos,endpos,8,16," "," ",1,1,...)
113 | end
114 | --same as hex(1,nil,8,32," "," ",1,1,...)
115 | function string.hex32F(self,pos,endpos,...)
116 | pos=pos or 1
117 | return self:hex(pos,endpos,8,32," "," ",1,1,...)
118 | end
119 | --decimal number to hex string
120 | function string.dec2hex(input)
121 | assert(type(input)=="number","input must be a number")
122 | return string.format("0x%02X",input)
123 | end
124 | --decimal number to hex string and format as decimal[0x hex]
125 | function string.dec2hexF(input)
126 | return tostring(input).."["..string.dec2hex(input).."]"
127 | end
128 |
129 | function string.literalize(str)
130 | local result=str:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", function(c) return "%" .. c end)
131 | return result
132 | end
133 |
134 | function string.subPlain(str,plainPattern,repl)
135 | local p=plainPattern:literalize()
136 | local r=repl:literalize()
137 | return str:gsub(p,r)
138 | end
139 |
140 | function string.fromhex(value)
141 | local newValue=value:gsub("[^0-9a-fA-F]",function(x) return"" end)
142 | assert(#newValue%2 == 0,"value length % 2 must be 0")
143 | local rs=newValue:gsub("..",function(x) return string.char(tonumber(x,16)) end)
144 | return rs
145 | end
146 |
147 | function string.append(self,element,i)
148 | if not i then i=#self end
149 | if i<0 then i=0 end
150 | if i>#self then i=#self end
151 | if i==0 then return element..self,i+#element end
152 | if i==#self then return self..element,i+#element end
153 | return self:sub(1,i)..element..self:sub(i+1),i+#element
154 | end
155 |
156 | function string.trim(self,element)
157 | if #self<#element or not element then return self end
158 | local result=self;
159 | while result:sub(#result-#element+1)==element do
160 | result=result:sub(1,#result-#element)
161 | end
162 | while result:sub(1,#element)==element do
163 | result=result:sub(#element+1)
164 | end
165 | return result
166 | end
167 |
168 | function string.remove(self,i)
169 | if i<1 then i=1 end
170 | if i>#self then i=#self end
171 | return self:sub(1,i-1)..self:sub(i+1),i-1
172 | end
173 |
174 | function string.compare(value,value1)
175 | local i=1
176 | local result={}
177 | value:gsub("(.)",function(x) if x:byte()~=value1:byte(i) then table.insert(result,i) end i=i+1 end)
178 | return unpack(result)
179 | end
180 |
181 | function string.compareF(value,value1,pos,endpos)
182 | return "---------------------1-------------------------\r\n"
183 | ..value:hexF(pos,endpos,value:compare(value1))
184 | .."\r\n----------------------2-----------------------\r\n"
185 | ..value1:hexF(pos,endpos,value1:compare(value))
186 | end
187 |
188 | function string.compare16F(value,value1,pos,endpos)
189 | return "--------------------------------------1---------------------------------------\r\n"
190 | ..value:hex16F(pos,endpos,value:compare(value1))
191 | .."\r\n--------------------------------------2---------------------------------------\r\n"
192 | ..value1:hex16F(pos,endpos,value1:compare(value))
193 | end
194 |
195 | local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/' -- You will need this for encoding/decoding
196 | -- encoding
197 | function string.base64Encode(data)
198 | return ((data:gsub('.', function(x)
199 | local r,b='',x:byte()
200 | for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end
201 | return r;
202 | end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x)
203 | if (#x < 6) then return '' end
204 | local c=0
205 | for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end
206 | return b:sub(c+1,c+1)
207 | end)..({ '', '==', '=' })[#data%3+1])
208 | end
209 |
210 | -- decoding
211 | function string.base64Decode(data)
212 | data = string.gsub(data, '[^'..b..'=]', '')
213 | return (data:gsub('.', function(x)
214 | if (x == '=') then return '' end
215 | local r,f='',(b:find(x)-1)
216 | for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end
217 | return r;
218 | end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x)
219 | if (#x ~= 8) then return '' end
220 | local c=0
221 | for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end
222 | return string.char(c)
223 | end))
224 | end
225 |
--------------------------------------------------------------------------------
/utils/tableUtils.lua:
--------------------------------------------------------------------------------
1 | require "suproxy.utils.stringUtils"
2 | local _M={}
3 | local printFlag="___printted"
4 | --format and print table in json style
5 | function _M.printTableF(tab,options,layer,map)
6 | options=options or {}
7 | local printIndex=options.printIndex or false
8 | local layer=layer or 1
9 | local stopLayer=options.stopLayer or 0xffff
10 | local inline=options.inline or false
11 | local wrap= (not inline) and "\r\n" or ""
12 | local tabb= (not inline) and "\t" or ""
13 | local justLen=options.justLen or false
14 | local ascii=options.ascii or true
15 | local excepts=options.excepts or {}
16 | local map=map or {}
17 | local includes=options.includes
18 | local logStr=""
19 | if layer>stopLayer then return logStr end
20 | if type(tab)=="table" and (not map[tostring(tab)]) then
21 | map[tostring(tab)]=""
22 | local i=1
23 | local isList=true
24 | local shortList=true
25 | for k, v in pairs(tab) do
26 | if k~=i then isList=false break end
27 | if type(v)=="table" then shortList=false end
28 | i=i+1
29 | end
30 | --no item in table
31 | if i==1 then isList,shortList =false,false end
32 | if not tab.__orderred then
33 | for k, v in pairs(tab) do
34 | local skip=false
35 | k=tostring(k)
36 | for _,e in ipairs(excepts) do
37 | if k:match(e) then
38 | skip=true
39 | break
40 | end
41 | end
42 | if includes then
43 | local inWhiteList=false
44 | for _,e in ipairs(includes) do
45 | if k:match(e) then
46 | inWhiteList=true
47 | end
48 | end
49 | skip=not inWhiteList
50 | end
51 | if not skip then
52 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer) end
53 | if not isList then logStr=logStr.."\""..k.."\":" end
54 | --print(string.rep(" ",layer),k,":",tostring(v))
55 | logStr=logStr.._M.printTableF(v,options,layer+1,map)
56 | logStr=logStr..","
57 | end
58 | end
59 | else
60 | for k, v in ipairs(tab) do
61 | local skip=false
62 | for _,e in ipairs(excepts) do
63 | if v.key:match(e) then
64 | skip=true
65 | break
66 | end
67 | end
68 | if includes then
69 | local inWhiteList=false
70 | for _,e in ipairs(includes) do
71 | if v.key:match(e) then
72 | inWhiteList=true
73 | end
74 | end
75 | skip=not inWhiteList
76 | end
77 | if not skip then
78 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer) end
79 | if not isList then logStr=logStr.."\""..v.key.."\":" end
80 | logStr=logStr.._M.printTableF(v.value,options,layer+1,map)
81 | logStr=logStr..","
82 | end
83 | end
84 | end
85 | if printIndex and getmetatable(tab) and getmetatable(tab).__index then
86 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer) end
87 | if not isList then logStr=logStr.."\"__index\":" end
88 | logStr=logStr.._M.printTableF(getmetatable(tab).__index,options,layer+1,map)..","
89 | end
90 | if #logStr >0 then logStr=logStr:sub(1,#logStr-1) end
91 | logStr=(isList and "[" or "{")..logStr
92 | if not isList or not shortList then logStr=logStr..wrap..string.rep(tabb,layer-1) end
93 | logStr=logStr..(isList and "]" or "}")..tostring(tab)
94 | elseif type(tab)=="string" then
95 | logStr="[Len:"..tab:len().."]"..logStr
96 | if not justLen then
97 | if #tab<40 then
98 | logStr=logStr.."\""..(ascii and tab:ascii(true) or tostring(tab)).."\"".."["..tab:hex().."]"
99 | else
100 | logStr=logStr..wrap..string.rep(tabb,layer-1).."\""..(ascii and tab:ascii(true) or tostring(tab)).."\""..wrap..string.rep(tabb,layer-1).."["..tab:hex().."]"
101 | end
102 | end
103 | elseif type(tab)=="number" then
104 | logStr=logStr..string.dec2hexF(tab)
105 | else
106 | logStr=logStr..tostring(tab)
107 | end
108 | return logStr
109 | end
110 | --nil and table compatible concat
111 | function _M.concat(tab,splitter)
112 | local rs={}
113 | for i=1,#tab do
114 | local v=tab[i]
115 | v = v or "nil"
116 | v=(type(v)=="table") and tostring(tab) or v
117 | rs[#rs +1]=v
118 | end
119 | return table.concat(rs,splitter)
120 | end
121 |
122 | --add index to table, then you can use the index to get item
123 | function _M.addIndex(tab,key)
124 | local lookUps={}
125 | for k,v in pairs(tab) do lookUps[v[key]]=v end
126 | local index
127 | local mt
128 | repeat
129 | mt=getmetatable(tab)
130 | if mt then index=mt.__index tab=index end
131 | until not index
132 | setmetatable(tab,{__index=lookUps})
133 | end
134 |
135 | --imitate extends keyword in java
136 | --copy parents method into subclass if not exists in subclass
137 | --set __base as the base class
138 | --limitation : any method add after extends method can't call sub class's override method
139 | function _M.extends(o,parent)
140 | assert(o,"object can not be null")
141 | assert(parent,"parent can not be null")
142 | for k,v in pairs(parent) do
143 | if not o[k] then o[k]=v end
144 | end
145 | -- if not o.orderred then
146 | -- setmetatable(o,{__index=parent})
147 | -- else
148 | -- local index=getmetatable(o).__index
149 | -- setmetatable(index,{__index=parent})
150 | -- end
151 | o.__base=parent
152 | return o
153 | end
154 |
155 | --order table: item key can not be Number
156 | _M.OrderedTable={
157 | new=function(self,o)
158 | local o=o or {}
159 | local k_i={}
160 | local meta={
161 | __index=self,
162 | __newindex=function(t,k,v)
163 | assert(type(k)~="number")
164 | rawset(k_i,k,#t+1)
165 | rawset(t,k,v)
166 | rawset(t,#t+1,{key=k,value=v})
167 | end,
168 | __k_i=k_i
169 | }
170 | o.__orderred=true
171 | return setmetatable(o,meta)
172 | end,
173 | getIndex=function(self,k)
174 | assert(type(k)~="number")
175 | local k_i=getmetatable(self).__k_i
176 | return k_i[k]
177 | end,
178 | getKVTable=function(self)
179 | local rs
180 | for i,v in ipairs(self) do
181 | rs[v.key]=v.value
182 | end
183 | return rs
184 | end,
185 | remove=function(self,k)
186 | assert(type(k)~="number")
187 | local k_i=getmetatable(self).__k_i
188 | local removeIndex=k_i[k]
189 | table.remove(self,removeIndex)
190 | rawset(k_i,k,nil)
191 | rawset(self,k,nil)
192 | for i=removeIndex,#self do
193 | k_i[self[i].key]=i
194 | end
195 | return
196 | end
197 | }
198 |
199 |
200 |
201 | _M.unitTest={}
202 | function _M.unitTest.OrderedTable()
203 | local t=_M.OrderedTable:new()
204 | t.A=1
205 | t.B=2
206 | t.C=3
207 | t.D=4
208 | t.E=5
209 | print(_M.printTableF(t))
210 | assert(t[1].key=="A",t[1].key)
211 | assert(t[2].key=="B",t[2].key)
212 | assert(t[3].key=="C",t[3].key)
213 | assert(#t==5,#t)
214 | t:remove("B")
215 | print(_M.printTableF(t))
216 | assert(t[1].key=="A",t[1].key)
217 | assert(t[2].key=="C",t[2].key)
218 | assert(t[3].key=="D",t[3].key)
219 | assert(#t==4)
220 | end
221 |
222 | function _M.unitTest.concat()
223 | local t={[1]=1,[2]="",[3]=nil,[4]="abc",[5]=4524354326}
224 | assert(_M.concat(t,";")=="1;;nil;abc;4524354326",_M.concat(t,";"))
225 | end
226 |
227 | function _M.test()
228 | for k,v in pairs(_M.unitTest) do
229 | print("------------running "..k)
230 | v()
231 | print("------------"..k.." finished")
232 | end
233 | end
234 |
235 | return _M
--------------------------------------------------------------------------------
/utils/unicode.lua:
--------------------------------------------------------------------------------
1 | -- Localize a few functions for a tiny speed boost, since these will be looped
2 | -- over every char of a string
3 | require "suproxy.utils.stringUtils"
4 | require "suproxy.utils.pureluapack"
5 | local byte = string.byte
6 | local char = string.char
7 | local pack = string.pack
8 | local unpack = string.unpack
9 | local concat = table.concat
10 |
11 | local _M={}
12 |
13 | ---Decode a buffer containing Unicode data.
14 | --@param buf The string/buffer to be decoded
15 | --@param decoder A Unicode decoder function (such as utf8_dec)
16 | --@param bigendian For encodings that care about byte-order (such as UTF-16),
17 | -- set this to true to force big-endian byte order. Default:
18 | -- false (little-endian)
19 | --@return A list-table containing the code points as numbers
20 | function _M.decode(buf, decoder, bigendian)
21 | local cp = {}
22 | local pos = 1
23 | while pos <= #buf do
24 | pos, cp[#cp+1] = decoder(buf, pos, bigendian)
25 | end
26 | return cp
27 | end
28 |
29 | ---Encode a list of Unicode code points
30 | --@param list A list-table of code points as numbers
31 | --@param encoder A Unicode encoder function (such as utf8_enc)
32 | --@param bigendian For encodings that care about byte-order (such as UTF-16),
33 | -- set this to true to force big-endian byte order. Default:
34 | -- false (little-endian)
35 | --@return An encoded string
36 | function _M.encode(list, encoder, bigendian)
37 | local buf = {}
38 | for i, cp in ipairs(list) do
39 | buf[i] = encoder(cp, bigendian)
40 | end
41 | return table.concat(buf, "")
42 | end
43 |
44 | ---Transcode a string from one format to another
45 | --
46 | --The string will be decoded and re-encoded in one pass. This saves some
47 | --overhead vs simply passing the output of unicode.encode
to
48 | --unicode.decode
.
49 | --@param buf The string/buffer to be transcoded
50 | --@param decoder A Unicode decoder function (such as utf16_dec)
51 | --@param encoder A Unicode encoder function (such as utf8_enc)
52 | --@param bigendian_dec Set this to true to force big-endian decoding.
53 | --@param bigendian_enc Set this to true to force big-endian encoding.
54 | --@return An encoded string
55 | function _M.transcode(buf, decoder, encoder, bigendian_dec, bigendian_enc)
56 | local out = {}
57 | local cp
58 | local pos = 1
59 | while pos <= #buf do
60 | pos, cp = decoder(buf, pos, bigendian_dec)
61 | out[#out+1] = encoder(cp, bigendian_enc)
62 | end
63 | return table.concat(out)
64 | end
65 |
66 | --- Determine (poorly) the character encoding of a string
67 | --
68 | -- First, the string is checked for a Byte-order Mark (BOM). This can be
69 | -- examined to determine UTF-16 with endianness or UTF-8. If no BOM is found,
70 | -- the string is examined.
71 | --
72 | -- If null bytes are encountered, UTF-16 is assumed. Endianness is determined
73 | -- by byte position, assuming the null is the high-order byte. Otherwise, if
74 | -- byte values over 127 are found, UTF-8 decoding is attempted. If this fails,
75 | -- the result is 'other', otherwise it is 'utf-8'. If no high bytes are found,
76 | -- the result is 'ascii'.
77 | --
78 | --@param buf The string/buffer to be identified
79 | --@param len The number of bytes to inspect in order to identify the string.
80 | -- Default: 100
81 | --@return A string describing the encoding: 'ascii', 'utf-8', 'utf-16be',
82 | -- 'utf-16le', or 'other' meaning some unidentified 8-bit encoding
83 | function _M.chardet(buf, len)
84 | local limit = len or 100
85 | if limit > #buf then
86 | limit = #buf
87 | end
88 | -- Check BOM
89 | if limit >= 2 then
90 | local bom1, bom2 = byte(buf, 1, 2)
91 | if bom1 == 0xff and bom2 == 0xfe then
92 | return 'utf-16le'
93 | elseif bom1 == 0xfe and bom2 == 0xff then
94 | return 'utf-16be'
95 | elseif limit >= 3 then
96 | local bom3 = byte(buf, 3)
97 | if bom1 == 0xef and bom2 == 0xbb and bom3 == 0xbf then
98 | return 'utf-8'
99 | end
100 | end
101 | end
102 | -- Try bytes
103 | local pos = 1
104 | local high = false
105 | local utf8 = true
106 | while pos < limit do
107 | local c = byte(buf, pos)
108 | if c == 0 then
109 | if pos % 2 == 0 then
110 | return 'utf-16le'
111 | else
112 | return 'utf-16be'
113 | end
114 | utf8 = false
115 | pos = pos + 1
116 | elseif c > 127 then
117 | if not high then
118 | high = true
119 | end
120 | if utf8 then
121 | local p, cp = utf8_dec(buf, pos)
122 | if not p then
123 | utf8 = false
124 | else
125 | pos = p
126 | end
127 | end
128 | if not utf8 then
129 | pos = pos + 1
130 | end
131 | else
132 | pos = pos + 1
133 | end
134 | end
135 | if high then
136 | if utf8 then
137 | return 'utf-8'
138 | else
139 | return 'other'
140 | end
141 | else
142 | return 'ascii'
143 | end
144 | end
145 |
146 | ---Encode a Unicode code point to UTF-16. See RFC 2781.
147 | --
148 | -- Windows OS prior to Windows 2000 only supports UCS-2, so beware using this
149 | -- function to encode code points above 0xFFFF.
150 | --@param cp The Unicode code point as a number
151 | --@param bigendian Set this to true to encode big-endian UTF-16. Default is
152 | -- false (little-endian)
153 | --@return A string containing the code point in UTF-16 encoding.
154 | function _M.utf16_enc(cp, bigendian)
155 | local fmt = "= 0xD800 and cp <= 0xDFFF then
193 | local high = bit.lshift((cp - 0xD800) ,10)
194 | cp, pos = unpack(fmt, buf, pos)
195 | cp = 0x10000 + high + cp - 0xDC00
196 | end
197 | return pos, cp
198 | end
199 |
200 | ---Encode a Unicode code point to UTF-8. See RFC 3629.
201 | --
202 | -- Does not check that cp is a real character; that is, doesn't exclude the
203 | -- surrogate range U+D800 - U+DFFF and a handful of others.
204 | --@param cp The Unicode code point as a number
205 | --@return A string containing the code point in UTF-8 encoding.
206 | function _M.utf8_enc(cp)
207 | local bytes = {}
208 | local n, mask
209 | if cp % 1.0 ~= 0.0 or cp < 0 then
210 | -- Only defined for nonnegative integers.
211 | return nil
212 | elseif cp <= 0x7F then
213 | -- Special case of one-byte encoding.
214 | return char(cp)
215 | elseif cp <= 0x7FF then
216 | n = 2
217 | mask = 0xC0
218 | elseif cp <= 0xFFFF then
219 | n = 3
220 | mask = 0xE0
221 | elseif cp <= 0x10FFFF then
222 | n = 4
223 | mask = 0xF0
224 | else
225 | return nil
226 | end
227 |
228 | while n > 1 do
229 | bytes[n] = char(0x80 + bit.band(cp, 0x3F))
230 | cp = bit.rshift(cp, 6)
231 | n = n - 1
232 | end
233 | bytes[1] = char(mask + cp)
234 |
235 | return table.concat(bytes)
236 | end
237 |
238 | ---Decodes a UTF-8 character.
239 | --
240 | -- Does not check that the returned code point is a real character.
241 | --@param buf A string containing the character
242 | --@param pos The index in the string where the character begins
243 | --@return pos The index in the string where the character ended or nil on error
244 | --@return cp The code point of the character as a number, or an error string
245 | function _M.utf8_dec(buf, pos)
246 | pos = pos or 1
247 | local n, mask
248 | local bv = byte(buf, pos)
249 | if bv <= 0x7F then
250 | return pos+1, bv
251 | elseif bv <= 0xDF then
252 | --110xxxxx 10xxxxxx
253 | n = 1
254 | mask = 0xC0
255 | elseif bv <= 0xEF then
256 | --1110xxxx 10xxxxxx 10xxxxxx
257 | n = 2
258 | mask = 0xE0
259 | elseif bv <= 0xF7 then
260 | --11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
261 | n = 3
262 | mask = 0xF0
263 | else
264 | return nil, string.format("Invalid UTF-8 byte at %d", pos)
265 | end
266 |
267 | local cp = bv - mask
268 |
269 | if pos + n > #buf then
270 | return nil, string.format("Incomplete UTF-8 sequence at %d", pos)
271 | end
272 | for i = 1, n do
273 | bv = byte(buf, pos + i)
274 | if bv < 0x80 or bv > 0xBF then
275 | return nil, string.format("Invalid UTF-8 sequence at %d", pos + i)
276 | end
277 | cp = bit.lshift(cp ,6) + bit.band(bv , 0x3F)
278 | end
279 |
280 | return pos + 1 + n, cp
281 | end
282 |
283 | ---Helper function for the common case of UTF-16 to UTF-8 transcoding, such as
284 | --from a Windows/SMB unicode string to a printable ASCII (subset of UTF-8)
285 | --string.
286 | --@param from A string in UTF-16, little-endian
287 | --@return The string in UTF-8
288 | function _M.utf16to8(from)
289 | return _M.transcode(from, _M.utf16_dec, _M.utf8_enc, false, nil)
290 | end
291 |
292 | ---Helper function for the common case of UTF-8 to UTF-16 transcoding, such as
293 | --from a printable ASCII (subset of UTF-8) string to a Windows/SMB unicode
294 | --string.
295 | --@param from A string in UTF-8
296 | --@return The string in UTF-16, little-endian
297 | function _M.utf8to16(from)
298 | return _M.transcode(from, _M.utf8_dec, _M.utf16_enc, nil, false)
299 | end
300 |
301 | function _M.test()
302 | local str="中华A已经Bあまり哈哈哈1234567"
303 | local b=_M.decode(str,_M.utf8_dec,false)
304 | print(require("suproxy.utils.json").encode(b))
305 | print(_M.encode(b,_M.utf8_enc,false):hexF())
306 | print(str:hexF())
307 | local b=_M.decode(str,_M.utf8_dec,false)
308 | print(require("suproxy.utils.json").encode(b))
309 |
310 | local str="12345abcd"
311 | print(_M.utf8to16(str):hex())
312 |
313 | end
314 |
315 | return _M
--------------------------------------------------------------------------------
/utils/utils.lua:
--------------------------------------------------------------------------------
1 | local _M = {_VERSION="0.1.11"}
local table_insert = table.insert
local table_concat = table.concat
2 |
3 | function _M.getTime() return ngx and ngx.time() or os.time() end
4 | function _M.addParamToUrl(urlString, paramName,paramValue)
5 |
6 | if urlString==nil then urlString="" end
7 |
8 | if paramValue==nill then paramValue="" end
9 |
10 | if paramName==nil then return urlString end
11 |
12 | if string.find(urlString,"?") then urlString=urlString.."&" else urlString=urlString.."?" end
13 |
14 | urlString=urlString..paramName.."="..paramValue
15 |
16 | return urlString
17 | end
18 |
19 | function _M._86_64()
20 | return 0xfffffffff==0xffffffff and 32 or 64
21 | end
22 |
23 | function _M.removeParamFromUrl(urlString, paramName)
24 |
25 | if urlString==nil then urlString="" end
26 |
27 | if paramName==nil then return urlString end
28 |
29 | urlString=string.gsub (urlString,"[\\?\\&]"..paramName.."=?[^&$]*", "")
30 |
31 | ngx.log(ngx.DEBUG,"urlString:"+urlString)
32 |
33 | local qmarkindex=string.find(urlString,"\\?")
34 | local andmarkindex=string.find(urlString,"\\&")
35 | if qmarkindex==-1 and andmarkindex>0 then
36 | urlString=string.gsub (urlString,"\\&", "?")
37 | end
38 |
39 |
40 | return urlString
41 | end
42 |
43 |
44 | --get url or post arguments from request
45 | function _M.getArgsFromRequest(argName)
46 | local args=ngx.req.get_uri_args()
47 | local result=args[argName]
48 | if result==nil and "POST" == ngx.var.request_method then
49 | ngx.req.read_body()
50 | args = ngx.req.get_post_args()
51 | result=args[argName]
52 | end
53 | return result
54 | end
55 |
56 |
57 | function _M.error(msg, detail, status)
local cjson=require("cjson")
58 | if status then ngx.status = status end
59 | ngx.say(cjson.encode({ msg = msg, detail = detail }))
60 | ngx.log(ngx.ERR,cjson.encode({ msg = msg, detail = detail }))
61 | if status then ngx.exit(status) end
62 | end
63 |
64 |
local errors = {
65 | UNAVAILABLE = 'upstream-unavailable',
66 | QUERY_ERROR = 'query-failed'
67 | }
68 |
69 | _M.errors = errors
70 |
local function request(method)
71 | return function(url, payload, headers)
72 | headers = headers or {}
73 | headers['Content-Type'] = 'application/x-www-form-urlencoded'
74 | local httpc = require( "resty.http" ).new()
75 | local params = {headers = headers, method = method }
76 | if string.sub(string.lower(url),1,5)=='https' then params.ssl_verify = true end
77 | if method == 'GET' then params.query = payload
78 | else params.body = payload end
79 | local res, err = httpc:request_uri(url, params)
80 | if err then
81 | ngx.log(ngx.ERR, table.concat(
82 | {method .. ' fail', url, payload}, '|'
83 | ))
84 | return nil, nil, errors.UNAVAILABLE
85 | else
86 | if res.status >= 400 then
87 | ngx.log(ngx.ERR, table.concat({
88 | method .. ' fail code', url, res.status, res.body,
89 | }, '|'))
90 | return res.status, res.body, errors.QUERY_ERROR
91 | else
92 | return res.status, res.body, nil
93 | end
94 | end
95 | end
96 | end
97 |
98 | _M.jget = request('GET')
99 | _M.jput = request('PUT')
100 | _M.jpost = request('POST')
101 |
102 |
103 | function _M.unzip(inputString)
local zlib=require('suproxy.utils.ffi-zlib')
104 | -- Reset vars
105 | local chunk = 16384
106 | local output_table = {}
107 | local count = 0
108 | local input = function(bufsize)
109 | ngx.log(ngx.DEBUG,"count:"..count)
110 | local start = count > 0 and bufsize*count or 1
111 | local data = inputString:sub(start, (bufsize*(count+1)-1) )
112 | count = count + 1
113 | ngx.log(ngx.INFO,"--------data-----------")
114 | ngx.log(ngx.INFO,data)
115 | return data
116 | end
117 |
118 | local output = function(data)
119 | table_insert(output_table, data)
120 | end
121 |
122 | local ok, err = zlib.inflateGzip(input, output, chunk)
123 | if not ok then
124 | ngx.log(ngx.ERR,"unzip error")
125 | end
126 | local output_data = table_concat(output_table,'')
127 |
128 |
129 | return output_data
130 | end
131 |
132 |
133 |
134 | return _M
--------------------------------------------------------------------------------