├── .gitignore
├── README.md
├── config
├── application.example.yml
├── application.yml
└── logging.yml
├── dev
├── build-package.sh
└── start.sh
├── pom.xml
└── src
├── main
├── java
│ ├── com
│ │ └── intigua
│ │ │ └── antlr4
│ │ │ └── autosuggest
│ │ │ ├── AutoSuggester.java
│ │ │ ├── CasePreference.java
│ │ │ ├── DefaultToCharStream.scala
│ │ │ ├── LexerAndParserFactory.java
│ │ │ ├── LexerFactory.java
│ │ │ ├── LexerWrapper.java
│ │ │ ├── ParserFactory.java
│ │ │ ├── ParserWrapper.java
│ │ │ ├── ReflectionLexerAndParserFactory.java
│ │ │ ├── ToCharStream.java
│ │ │ ├── TokenSuggester.java
│ │ │ └── TransitionWrapper.java
│ └── tech
│ │ └── mlsql
│ │ └── autosuggest
│ │ ├── AttributeExtractor.scala
│ │ ├── AutoSuggestContext.scala
│ │ ├── AutoSuggester.scala
│ │ ├── FuncReg.scala
│ │ ├── FunctionUtils.scala
│ │ ├── POrCLiterals.scala
│ │ ├── SpecialTableConst.scala
│ │ ├── TokenPos.scala
│ │ ├── app
│ │ ├── Constants.scala
│ │ ├── MLSQLAutoSuggestApp.scala
│ │ ├── MysqlType.java
│ │ ├── RDSchema.scala
│ │ ├── RawDBTypeToJavaType.scala
│ │ ├── SchemaRegistry.scala
│ │ └── Standalone.scala
│ │ ├── ast
│ │ ├── NoneToken.scala
│ │ └── TableTree.scala
│ │ ├── dsl
│ │ └── TokenMatcher.scala
│ │ ├── funcs
│ │ ├── Count.scala
│ │ └── Splitter.scala
│ │ ├── meta
│ │ ├── LayeredMetaProvider.scala
│ │ ├── MLSQLEngineMetaProvider.scala
│ │ ├── MemoryMetaProvider.scala
│ │ ├── MetaProvider.scala
│ │ ├── RestMetaProvider.scala
│ │ ├── StatementTempTableProvider.scala
│ │ └── meta_protocal.scala
│ │ ├── preprocess
│ │ └── TablePreprocessor.scala
│ │ ├── statement
│ │ ├── LexerUtils.scala
│ │ ├── LoadSuggester.scala
│ │ ├── MLSQLStatementSplitter.scala
│ │ ├── MatchAndExtractor.scala
│ │ ├── PreProcessStatement.scala
│ │ ├── RegisterSuggester.scala
│ │ ├── SelectStatementUtils.scala
│ │ ├── SelectSuggester.scala
│ │ ├── StatementSplitter.scala
│ │ ├── StatementSuggester.scala
│ │ ├── StatementUtils.scala
│ │ ├── SuggesterRegister.scala
│ │ ├── TableExtractor.scala
│ │ ├── TemplateSuggester.scala
│ │ └── single_statement.scala
│ │ └── utils
│ │ └── SchemaUtils.scala
└── resources
│ └── log4j.properties
└── test
└── java
└── com
└── intigua
└── antlr4
└── autosuggest
├── AutoSuggestContextTest.scala
├── BaseTest.scala
├── BaseTestWithoutSparkSession.scala
├── LexerUtilsTest.scala
├── LoadSuggesterTest.scala
├── MatchTokenTest.scala
├── SelectSuggesterTest.scala
├── TablePreprocessorTest.scala
└── TableStructureTest.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | .idea/
3 | spark-binlog.iml
4 | release.sh
5 | build
6 | logs
7 | dependency-reduced-pom.xml
8 | sql-code-intelligence.iml
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SQL Code Intelligence
2 | SQL Code Intelligence 是一个代码补全后端引擎。既可以作为MLSQL语法补全,
3 | 也可以为标准Spark SQL(Select语句)做补全。`
4 |
5 | ## 当前状态
6 | 【积极开发中,还未发布稳定版本】
7 |
8 |
9 | ## 发行方式
10 |
11 | ### maven 依赖
12 |
13 | ```xml
14 |
15 | tech.mlsql
16 | sql-code-intelligence
17 | 0.1.0
18 |
19 | ```
20 |
21 | 使用该依赖,用户可以很好的将功能嵌入到自己的web应用中,比如使用spring之类的web框架。
22 |
23 | ### 预编译包
24 |
25 |
26 | 下载地址:[sql-code-intelligence](http://download.mlsql.tech/sql-code-intelligence/).
27 |
28 | 下载后 `tar xvf sql-code-intelligence-0.1.0.tar` 解压,执行如下指令即可运行:
29 |
30 | ```
31 | ./start.sh
32 | ```
33 |
34 | 用户也可以直接使用Java命令启动:
35 |
36 | ```
37 | java -cp .:sql-code-intelligence-0.1.0.jar tech.mlsql.autosuggest.app.Standalone -config ./config/application.yml
38 | ```
39 | application.yml 可以参考mlsql-autosuggest/config的示例。默认端口是9004.
40 |
41 |
42 |
43 | ## 目标
44 | 【SQL Code Intelligence】目标分成两个,第一个是标准SQL补全:
45 |
46 | 1. SQL关键字补全
47 | 2. 表/字段属性/函数补全
48 | 3. 可二次开发自定义对接任何Schema Provider
49 |
50 | 第二个是MLSQL语法补全(依托于标准SQL的补全功能之上):
51 |
52 | 1. 支持各种数据源提示
53 | 2. 支持临时表提示(临时表字段补全等等)
54 | 3. 支持各种ET组件参数提示以及名称提示
55 |
56 | 对于表和字段补,函数补全,相比其他一些SQL代码提示工具,该插件可根据当前已有的信息精确推断。
57 |
58 | ## 效果展示
59 | 【点击下面动图可看到标准SQL补全视频】
60 | [](https://www.bilibili.com/video/BV1xk4y1z7tV)
61 |
62 | ### 标准的SQL语法提示
63 |
64 | ```sql
65 | select no_result_type, keywords, search_num, rank
66 | from(
67 | select [鼠标位置] row_number() over (PARTITION BY no_result_type order by search_num desc) as rank
68 | from(
69 | select jack1.*,no_result_type, keywords, sum(search_num) AS search_num
70 | from jack.drugs_bad_case_di as jack1,jack.abc jack2
71 | where hp_stat_date >= date_sub(current_date,30)
72 | and action_dt >= date_sub(current_date,30)
73 | and action_type = 'search'
74 | and length(keywords) > 1
75 | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9))
76 | --and no_result_type = 'indication'
77 | group by no_result_type, keywords
78 | )a
79 | )b
80 | where rank <=
81 | ```
82 |
83 | 鼠标在第三行第十列,此时系统会自动提示:
84 | 1. a [表名]
85 | 2. jack1展开的所有列
86 | 3. no_result_type
87 | 4. keywords
88 | 5. search_num
89 |
90 |
91 | ### 多行MLSQL的提示
92 |
93 | ```sql
94 | load hive.`db.table1` as table2;
95 | select * from table2 as table3;
96 | select [鼠标位置] from table3
97 | ```
98 |
99 | 假设db.table1 表的字段为a,b,c,d
100 | 其中鼠标在低3行第七列,在此位置,会提示:
101 |
102 | 1. table3
103 | 2. a
104 | 3. b
105 | 4. c
106 | 5. d
107 |
108 | 可以看到,系统具有非常强的跨语句能力,会自动展开*,并且推测出每个表的schema信息从而进行补全。
109 |
110 | ## MLSQL 数据源/ET组件参数提示
111 |
112 | ```sql
113 | select spl from jack.drugs_bad_case_di as a;
114 | load csv.`/tmp.csv` where [鼠标位置]
115 | ```
116 |
117 | 通常加载csv我们需要设定下csv是不是包含header, 分割符是什么。不过一般我们需要去查文档才能知道这些参数。 现在,【SQL Code Intelligence】会给出提示:
118 |
119 | ```json
120 | [
121 | {"name":"codec","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
122 | {"name":"dateFormat","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
123 | {"name":"delimiter","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
124 | {"name":"emptyValue","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
125 | {"name":"escape","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
126 | {"name":"header","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
127 | {"name":"inferSchema","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}},
128 | {"name":"quote","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}]
129 | ```
130 |
131 |
132 |
133 | ## 用户指南
134 |
135 | ### 部署
136 |
137 | 如果作为MLSQL插件运行, 请参考部署文档 [MLSQL部署](http://docs.mlsql.tech/zh/installation/)
138 | 该插件作为MLSQ默认插件,所以开箱即用
139 |
140 | 如果用户希望作为独立应用运行, 下载:[sql-code-intelligence](http://download.mlsql.tech/sql-code-intelligence/).
141 | 下载后 `tar xvf sql-code-intelligence-0.1.0.tar` 解压,执行如下指令即可运行:
142 |
143 | ```
144 | java -cp .:sql-code-intelligence-0.1.0.jar tech.mlsql.autosuggest.app.Standalone -config ./config/application.yml
145 | ```
146 | application.yml 可以参考mlsql-autosuggest/config的示例。默认端口是9004.
147 |
148 | ### 编辑器接入 (只要三分钟)
149 |
150 | 写一个http请求类:
151 |
152 | 
153 |
154 | 在编辑器设置autocompleter回调:
155 |
156 | 
157 |
158 | Done。 这里我使用的是reactjs, 编辑器是aceditor.
159 |
160 | ### Schema信息
161 |
162 | 【SQL Code Intelligence】 需要基础表的schema信息,目前用户有三种可选方式:
163 |
164 | 1. 主动注册schema信息 (适合体验和调试)
165 | 2. 提供符合规范的Rest接口,系统会自动调用该接口获取schema信息 (推荐,对本项目无需任何修改)
166 | 3. 扩展【SQL Code Intelligence】的 MetaProvider,使得系统可以获取shcema信息。 (启动本项目时需要注册该类)
167 |
168 | 最简单的是方式1. 通过http接口注册表信息,我下面是使用scala代码完成,用户也可以使用POSTMan之类的工具完成注册。
169 |
170 | ```scala
171 | def registerTable(port: Int = 9003) = {
172 | val time = System.currentTimeMillis()
173 | val response = Request.Post(s"http://127.0.0.1:${port}/run/script").bodyForm(
174 | Form.form().add("executeMode", "registerTable").add("schema",
175 | """
176 | |CREATE TABLE emps(
177 | | empid INT NOT NULL,
178 | | deptno INT NOT NULL,
179 | | locationid INT NOT NULL,
180 | | empname VARCHAR(20) NOT NULL,
181 | | salary DECIMAL (18, 2),
182 | | PRIMARY KEY (empid),
183 | | FOREIGN KEY (deptno) REFERENCES depts(deptno),
184 | | FOREIGN KEY (locationid) REFERENCES locations(locationid)
185 | |);
186 | |""".stripMargin).add("db", "db1").add("table", "emps").
187 | add("isDebug", "true").build()
188 | ).execute().returnContent().asString()
189 | println(response)
190 | }
191 | ```
192 |
193 | 创建表的语句类型支持三种:db,hive,json。 分别对应MySQL语法,Hive语法,Spark SQL schema json格式。默认是MySQL的语法。
194 |
195 | 接着就系统就能够提示了:
196 |
197 | ```scala
198 | def testSuggest(port: Int = 9003) = {
199 | val time = System.currentTimeMillis()
200 | val response = Request.Post(s"http://127.0.0.1:${port}/run/script").bodyForm(
201 | Form.form().add("executeMode", "autoSuggest").add("sql",
202 | """
203 | |select emp from db1.emps as a;
204 | |-- load csv.`/tmp.csv` where
205 | |""".stripMargin).add("lineNum", "2").add("columnNum", "10").
206 | add("isDebug", "true").build()
207 | ).execute().returnContent().asString()
208 | println(response)
209 | }
210 | ```
211 |
212 | 第二种在请求参数里传递searchUrl和listUrl,要求接口的输入输出需要符合`tech.mlsql.autosuggest.meta.RestMetaProvider`
213 | 中的定义。
214 |
215 |
216 | 第三种模式是用户实现一个自定义的MetaProvider,就可以充分利用自己的schema系统
217 |
218 | ```scala
219 | trait MetaProvider {
220 | def search(key: MetaTableKey): Option[MetaTable]
221 |
222 | def list: List[MetaTable]
223 | }
224 | ```
225 |
226 | 使用时,在AutoSuggestContext设置下使其生效:
227 |
228 | ```
229 | context.setUserDefinedMetaProvider(有的实现类的实例)
230 | ```
231 |
232 | MetaTableKey 的定义如下:
233 |
234 | ```scala
235 | case class MetaTableKey(prefix: Option[String], db: Option[String], table: String)
236 | ```
237 |
238 | prefix是方便定义数据源的。比如同样一个表,可能是hive表,也可能是mysql表。如果你只有一个数仓,不访问其他数据源,那么设置为None就好。对于下面的句子:
239 |
240 | ```sql
241 | load hive.`db.table1` as table2;
242 | ```
243 | 【SQL Code Intelligence】 会发送如下的MetaTableKey给你的MetaProvider.search方法:
244 |
245 | ```scala
246 | MetaTableKey(Option(hive),Option("db"),Option("table2"))
247 | ```
248 |
249 | 如果是一个普通的SQL语句,而非MLSQL 语句,比如:
250 |
251 | ```sql
252 | select * from db.table1
253 | ```
254 |
255 | 则发送给MetaProvider.search方法的MetaTableKey是这个样子的:
256 |
257 | ```scala
258 | MetaTableKey(None,Option("db"),Option("table2"))
259 | ```
260 |
261 | ### 接口使用
262 | 访问接口: http://127.0.0.1:9003/run/script?executeMode=autoSuggest
263 |
264 | 参数1: sql SQL脚本
265 | 参数2: lineNum 光标所在的行号 从1开始计数
266 | 参数3: columnNum 光标所在的列号,从1开始计数
267 |
268 | 比如我用Scala写一个client:
269 |
270 | ```
271 | object Test {
272 | def main(args: Array[String]): Unit = {
273 | val time = System.currentTimeMillis()
274 | val response = Request.Post("http://127.0.0.1:9003/run/script").bodyForm(
275 | Form.form().add("executeMode", "autoSuggest").add("sql",
276 | """
277 | |select spl from jack.drugs_bad_case_di as a
278 | |""".stripMargin).add("lineNum", "2").add("columnNum", "10").build()
279 | ).execute().returnContent().asString()
280 | println(System.currentTimeMillis() - time)
281 | println(response)
282 | }
283 |
284 | }
285 | ```
286 |
287 | 最后结果如下:
288 |
289 | ```json
290 | [{"name":"split",
291 | "metaTable":{"key":{"db":"__FUNC__","table":"split"},
292 | "columns":[
293 | {"name":null,"dataType":"array","isNull":true,"extra":{"zhDoc":"\nsplit函数。用于切割字符串,返回字符串数组\n"}},{"name":"str","dataType":"string","isNull":false,"extra":{"zhDoc":"待切割字符"}},
294 | {"name":"pattern","dataType":"string","isNull":false,"extra":{"zhDoc":"分隔符"}}]},
295 | "extra":{}}]
296 | ```
297 | 可以知道提示了split,并且这是一个函数,函数的参数以及返回值都有定义。
298 |
299 | ### 编程使用
300 |
301 | 创建AutoSuggestContext即可,然后用buildFromString处理字符串,使用suggest方法
302 | 进行推荐。
303 |
304 | ```scala
305 |
306 | val sql = params("sql")
307 | val lineNum = params("lineNum").toInt
308 | val columnNum = params("columnNum").toInt
309 |
310 | val sparkSession = SparkSession.builder().appName("local").master("local[*]").getOrCreate()
311 | val context = new AutoSuggestContext(sparkSession,
312 | AutoSuggestController.mlsqlLexer,
313 | AutoSuggestController.sqlLexer)
314 |
315 | JSONTool.toJsonStr(context.buildFromString(sql).suggest(lineNum,columnNum))
316 | ```
317 |
318 | sparkSession也可以设置为null,但是会缺失一些功能,比如数据源提示等等。
319 |
320 |
321 | ## 开发者指南
322 |
323 | ### 解析流程
324 | 【SQL Code Intelligence】复用了MLSQL/Spark SQL的lexer,重写了parser部分。因为代码提示有其自身特点,就是句法在书写过程中,大部分情况下都是错误的,无法使用严格的parser来进行解析。
325 |
326 | 使用两个Lexer的原因是因为,MLSQL Lexer主要用来解析整个MLSQL脚本,Spark SQL Lexer主要用来解决标准SQL中的select语句。但是因为该项目高度可扩展,用户也可以自行扩展到其他标准SQL的语句中。
327 |
328 | 以select语句里的代码提示为例,整个解析流程为:
329 |
330 | 1. 使用MLSQL Lexer 将脚本切分成多个statement
331 | 2. 每个statement 会使用不同的Suggester进行下一步解析
332 | 3. 使用SelectSuggester 对select statement进行解析
333 | 4. 首先对select语句构建一个非常粗粒度的AST,节点为每个子查询,同时构建一个表结构层级缓存信息TABLE_INFO
334 | 5. 将光标位置转化为全局TokenPos
335 | 6. 将全局TokenPos转化select语句相对TokenPos
336 | 7. 根据TokenPos遍历Select AST树,定位到简单子语句
337 | 8. 使用project/where/groupby/on/having子suggester进行匹配,匹配的suggester最后完成提示逻辑
338 |
339 | 在AST树种,每个子语句都可以是不完整的。由上面流程可知,我们会以statement为粗粒度工作context,然后对于复杂的select语句,最后我们会进一步细化到每个子查询为工作context。这样为我们编码带来了非常大的便利。
340 |
341 |
342 | ### 快速参与贡献该项目
343 | 【SQL Code Intelligence】 需要大量函数的定义,方便在用户使用时给予提示。下面是我实现的`split` 函数的代码:
344 |
345 | ```scala
346 | class Splitter extends FuncReg {
347 |
348 | override def register = {
349 | val func = MLSQLSQLFunction.apply("split").
350 | funcParam.
351 | param("str", DataType.STRING, false, Map("zhDoc" -> "待切割字符")).
352 | param("pattern", DataType.STRING, false, Map("zhDoc" -> "分隔符")).
353 | func.
354 | returnParam(DataType.ARRAY, true, Map(
355 | "zhDoc" ->
356 | """
357 | |split函数。用于切割字符串,返回字符串数组
358 | |""".stripMargin
359 | )).
360 | build
361 | func
362 | }
363 |
364 | }
365 | ```
366 |
367 | 用户只要用FunctionBuilder去构建函数签名即可。这样用户在使用该函数的时候就能得到非常详尽的使用说明和参数说明。同时,我们也可以通过该函数签名获取嵌套函数处理后的字段的类型信息。
368 |
369 | 用户只要按上面的方式添加更多函数到tech.mlsql.autosuggest.funcs包下即可。系统会自动扫描该包里的实现并且注册。
370 |
371 | ### TokenMatcher工具类
372 |
373 | 在【SQL Code Intelligence】中,最主要的工作是做token匹配。我们提供了TokenMatcher来完成token的匹配。TokenMatcher支持前向和后向匹配。如下token序列:
374 |
375 | ```
376 | select a , b , c from jack
377 | ```
378 |
379 | 假设我想以token index 3(b) 为起始点,前向匹配一个逗号,identify 可以使用如下语法:
380 |
381 | ```scala
382 | val tokenMatcher = TokenMatcher(tokens,4).forward.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build
383 | ```
384 |
385 | 接着你可以调用 tokenMatcher.isSuccess来判断是否匹配成功,可以调用tokenMatcher.get 获取匹配得到匹配成功后的index,通过tokenMatcher.getMatchTokens 获取匹配成功的token集合。
386 |
387 | 注意,TokenMatcher起始位置是包含的,也就是他会将起始位置的token也加入到匹配token里去。所以在上面的例子中,start 是4而不是3. 更多例子可以查看源码。
388 |
389 | ### 子查询层级结构
390 |
391 | 对于语句:
392 |
393 | ```sql
394 | select no_result_type, keywords, search_num, rank
395 | from(
396 | select keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank
397 | from(
398 | select *,no_result_type, keywords, sum(search_num) AS search_num
399 | from jack.drugs_bad_case_di,jack.abc jack
400 | where hp_stat_date >= date_sub(current_date,30)
401 | and action_dt >= date_sub(current_date,30)
402 | and action_type = 'search'
403 | and length(keywords) > 1
404 | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9))
405 | --and no_result_type = 'indication'
406 | group by no_result_type, keywords
407 | )a
408 | )b
409 | where rank <=
410 | ```
411 |
412 | 形成的AST结构树如下:
413 |
414 | ```sql
415 | select no_result_type , keywords , search_num , rank from ( select keywords , search_num , row_number ( ) over
416 | ( PARTITION BY no_result_type order by search_num desc ) as rank from ( select * , no_result_type , keywords ,
417 | sum ( search_num ) AS search_num from jack . drugs_bad_case_di , jack . abc jack where hp_stat_date >= date_sub (
418 | current_date , 30 ) and action_dt >= date_sub ( current_date , 30 ) and action_type = 'search' and length (
419 | keywords ) > 1 and ( split ( av , '\\.' ) [ 0 ] >= 11 OR ( split
420 | ( av , '\\.' ) [ 0 ] = 10 AND split ( av , '\\.' ) [ 1 ]
421 | = 9 ) ) group by no_result_type , keywords ) a ) b where rank <=
422 |
423 |
424 | =>select keywords , search_num , row_number ( ) over ( PARTITION BY no_result_type order by search_num desc ) as
425 | rank from ( select * , no_result_type , keywords , sum ( search_num ) AS search_num from jack . drugs_bad_case_di
426 | , jack . abc jack where hp_stat_date >= date_sub ( current_date , 30 ) and action_dt >= date_sub ( current_date
427 | , 30 ) and action_type = 'search' and length ( keywords ) > 1 and ( split ( av ,
428 | '\\.' ) [ 0 ] >= 11 OR ( split ( av , '\\.' ) [ 0 ] = 10
429 | AND split ( av , '\\.' ) [ 1 ] = 9 ) ) group by no_result_type , keywords )
430 | a ) b
431 |
432 |
433 | ==>select * , no_result_type , keywords , sum ( search_num ) AS search_num from jack . drugs_bad_case_di , jack
434 | . abc jack where hp_stat_date >= date_sub ( current_date , 30 ) and action_dt >= date_sub ( current_date , 30
435 | ) and action_type = 'search' and length ( keywords ) > 1 and ( split ( av , '\\.' )
436 | [ 0 ] >= 11 OR ( split ( av , '\\.' ) [ 0 ] = 10 AND split
437 | ( av , '\\.' ) [ 1 ] = 9 ) ) group by no_result_type , keywords ) a
438 | ```
439 |
440 | 我们可以看到一共嵌套了两层,每层都有一个子查询。
441 |
442 | 对此形成的TABLE_INFO结构如下:
443 |
444 | ```
445 | 2:
446 | List(
447 | MetaTableKeyWrapper(MetaTableKey(None,Some(jack),drugs_bad_case_di),None),
448 | MetaTableKeyWrapper(MetaTableKey(None,None,null),Some(a)),
449 | MetaTableKeyWrapper(MetaTableKey(None,Some(jack),abc),Some(jack)))
450 | 1:
451 | List(MetaTableKeyWrapper(MetaTableKey(None,None,null),Some(b)))
452 | 0:
453 | List()
454 | ```
455 |
456 | 0层级为最外层语句;1层级为第一个子查询;2层级为第二个子查询,他包含了子查询别名以及该子查询里所有的实体表信息。
457 |
458 | 上面只是为了显示,实际上还包含了所有列的信息。这意味着,如果我要补全0层记得 project,那我只需要获取1层级的信息,可以补全b表名称或者b表对应的字段。同理类推。
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
--------------------------------------------------------------------------------
/config/application.example.yml:
--------------------------------------------------------------------------------
1 | #mode
2 | mode:
3 | development
4 | #mode=production
5 |
6 | path:
7 | conf: /Users/allwefantasy/CSDNWorkSpace/streamingpro-spark-2.4.x/external/mlsql-autosuggest/config/
8 |
9 | ###############datasource config##################
10 | #mysql,mongodb,redis等数据源配置方式
11 | development:
12 | datasources:
13 | mysql:
14 | host: 127.0.0.1
15 | port: 3306
16 | database: mlsql_console
17 | username: xxx
18 | password: xxxx
19 | initialSize: 8
20 | disable: true
21 | removeAbandoned: true
22 | testWhileIdle: true
23 | removeAbandonedTimeout: 30
24 | maxWait: 100
25 | filters: stat,log4j
26 | mongodb:
27 | disable: true
28 | redis:
29 | disable: true
30 | test:
31 | datasources:
32 | mysql:
33 | host: 127.0.0.1
34 | port: 3306
35 | database: wow
36 | username: root
37 | password: mlsql
38 | disable: true
39 |
40 | production:
41 | datasources:
42 | mysql:
43 | host: 127.0.0.1
44 | port: 3306
45 | database: wow
46 | username: root
47 | password: mlsql
48 | disable: false
49 |
50 | ###############application config##################
51 | #'model' for relational database like MySQL
52 | #'document' for NoSQL database model configuration, MongoDB
53 | auth_secret: "mlsql"
54 | application:
55 | controller: tech.mlsql.autosuggest.app
56 | model: tech.mlsql.model
57 | test: test.com.example
58 | static:
59 | enable: false
60 | template:
61 | engine:
62 | enable: false
63 |
64 | serviceframework:
65 | template:
66 | loader:
67 | classpath:
68 | enable: true
69 | static:
70 | loader:
71 | classpath:
72 | enable: true
73 | dir: "streamingpro/assets"
74 | ###############http config##################
75 | http:
76 | port: 9004
77 | disable: false
78 | host: 0.0.0.0
79 | server:
80 | idleTimeout: 6000000
81 | client:
82 | accept:
83 | timeout: 43200000
84 |
85 | #thrift:
86 | # disable: true
87 | # services:
88 | # net_csdn_controller_thrift_impl_CBayesianQueryServiceImpl:
89 | # port: 9001
90 | # min_threads: 100
91 | # max_threads: 1000
92 | #
93 | # servers:
94 | # spam_bayes: ["127.0.0.1:9001"]
95 |
96 |
97 |
98 | ###############validator config##################
99 | #如果需要添加验证器,只要配置好类全名即可
100 | #替换验证器实现,则替换相应的类名即可
101 | #warning: 自定义验证器实现需要线程安全
102 |
103 | validator:
104 | format: net.csdn.validate.impl.Format
105 | numericality: net.csdn.validate.impl.Numericality
106 | presence: net.csdn.validate.impl.Presence
107 | uniqueness: net.csdn.validate.impl.Uniqueness
108 | length: net.csdn.validate.impl.Length
109 | associated: net.csdn.validate.impl.Associated
110 |
111 | mongo_validator:
112 | format: net.csdn.mongo.validate.impl.Format
113 | numericality: net.csdn.mongo.validate.impl.Numericality
114 | presence: net.csdn.mongo.validate.impl.Presence
115 | uniqueness: net.csdn.mongo.validate.impl.Uniqueness
116 | length: net.csdn.mongo.validate.impl.Length
117 | associated: net.csdn.mongo.validate.impl.Associated
118 |
119 | ################ 数据库类型映射 ####################
120 | type_mapping: net.csdn.jpa.type.impl.MysqlType
121 |
122 | qps:
123 | /say/hello: 10
124 |
125 | qpslimit:
126 | enable: true
127 | dubbo:
128 | disable: true
129 | server: true
--------------------------------------------------------------------------------
/config/application.yml:
--------------------------------------------------------------------------------
1 | #mode
2 | mode:
3 | development
4 | #mode=production
5 |
6 | path:
7 | conf: /Users/allwefantasy/CSDNWorkSpace/streamingpro-spark-2.4.x/external/mlsql-autosuggest/config/
8 |
9 | ###############datasource config##################
10 | #mysql,mongodb,redis等数据源配置方式
11 | development:
12 | datasources:
13 | mysql:
14 | host: 127.0.0.1
15 | port: 3306
16 | database: mlsql_console
17 | username: xxx
18 | password: xxxx
19 | initialSize: 8
20 | disable: true
21 | removeAbandoned: true
22 | testWhileIdle: true
23 | removeAbandonedTimeout: 30
24 | maxWait: 100
25 | filters: stat,log4j
26 | mongodb:
27 | disable: true
28 | redis:
29 | disable: true
30 | test:
31 | datasources:
32 | mysql:
33 | host: 127.0.0.1
34 | port: 3306
35 | database: wow
36 | username: root
37 | password: mlsql
38 | disable: true
39 |
40 | production:
41 | datasources:
42 | mysql:
43 | host: 127.0.0.1
44 | port: 3306
45 | database: wow
46 | username: root
47 | password: mlsql
48 | disable: false
49 |
50 | ###############application config##################
51 | #'model' for relational database like MySQL
52 | #'document' for NoSQL database model configuration, MongoDB
53 | auth_secret: "mlsql"
54 | application:
55 | controller: tech.mlsql.autosuggest.app
56 | model: tech.mlsql.model
57 | test: test.com.example
58 | static:
59 | enable: false
60 | template:
61 | engine:
62 | enable: false
63 |
64 | serviceframework:
65 | template:
66 | loader:
67 | classpath:
68 | enable: true
69 | static:
70 | loader:
71 | classpath:
72 | enable: true
73 | dir: "streamingpro/assets"
74 | ###############http config##################
75 | http:
76 | port: 9004
77 | disable: false
78 | host: 0.0.0.0
79 | server:
80 | idleTimeout: 6000000
81 | client:
82 | accept:
83 | timeout: 43200000
84 |
85 | #thrift:
86 | # disable: true
87 | # services:
88 | # net_csdn_controller_thrift_impl_CBayesianQueryServiceImpl:
89 | # port: 9001
90 | # min_threads: 100
91 | # max_threads: 1000
92 | #
93 | # servers:
94 | # spam_bayes: ["127.0.0.1:9001"]
95 |
96 |
97 |
98 | ###############validator config##################
99 | #如果需要添加验证器,只要配置好类全名即可
100 | #替换验证器实现,则替换相应的类名即可
101 | #warning: 自定义验证器实现需要线程安全
102 |
103 | validator:
104 | format: net.csdn.validate.impl.Format
105 | numericality: net.csdn.validate.impl.Numericality
106 | presence: net.csdn.validate.impl.Presence
107 | uniqueness: net.csdn.validate.impl.Uniqueness
108 | length: net.csdn.validate.impl.Length
109 | associated: net.csdn.validate.impl.Associated
110 |
111 | mongo_validator:
112 | format: net.csdn.mongo.validate.impl.Format
113 | numericality: net.csdn.mongo.validate.impl.Numericality
114 | presence: net.csdn.mongo.validate.impl.Presence
115 | uniqueness: net.csdn.mongo.validate.impl.Uniqueness
116 | length: net.csdn.mongo.validate.impl.Length
117 | associated: net.csdn.mongo.validate.impl.Associated
118 |
119 | ################ 数据库类型映射 ####################
120 | type_mapping: net.csdn.jpa.type.impl.MysqlType
121 |
122 | qps:
123 | /say/hello: 10
124 |
125 | qpslimit:
126 | enable: true
127 | dubbo:
128 | disable: true
129 | server: true
--------------------------------------------------------------------------------
/config/logging.yml:
--------------------------------------------------------------------------------
1 | rootLogger: INFO,console
2 |
3 |
4 | appender:
5 | console:
6 | type: console
7 | threshold: INFO
8 | layout:
9 | type: consolePattern
10 | conversionPattern: "[%d{ISO8601}][%-5p][%-25c] %m%n"
11 |
12 | file:
13 | type: dailyRollingFile
14 | file: ${path.logs}/${cluster.name}.log
15 | datePattern: "'.'yyyy-MM-dd"
16 | layout:
17 | type: pattern
18 | conversionPattern: "[%d{ISO8601}][%-5p][%-25c] %m%n"
--------------------------------------------------------------------------------
/dev/build-package.sh:
--------------------------------------------------------------------------------
1 | mvn package -DskipTests -Pshade
2 | rm -rf build
3 | mkdir -p build/sql-code-intelligence-0.1.0
4 | cp target/sql-code-intelligence-0.1.0.jar build/sql-code-intelligence-0.1.0
5 | cp -r config build/sql-code-intelligence-0.1.0
6 | cp dev/start.sh build/sql-code-intelligence-0.1.0
7 | cd build
8 | tar cvf sql-code-intelligence-0.1.0.tar sql-code-intelligence-0.1.0
9 | scp sql-code-intelligence-0.1.0.tar mlsql2:/data/mlsql/releases/sql-code-intelligence/
--------------------------------------------------------------------------------
/dev/start.sh:
--------------------------------------------------------------------------------
1 | java -cp .:sql-code-intelligence-0.1.0.jar tech.mlsql.autosuggest.app.Standalone
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | tech.mlsql
8 | sql-code-intelligence
9 | 0.1.0
10 | SQL Code Intelligence
11 | https://github.com/allwefantasy/sql-code-intelligence
12 |
13 | SQL code autocomplete engine
14 |
15 |
16 |
17 | Apache 2.0 License
18 | http://www.apache.org/licenses/LICENSE-2.0.html
19 | repo
20 |
21 |
22 |
23 |
24 | allwefantasy
25 | ZhuHaiLin
26 | allwefantasy@gmail.com
27 |
28 |
29 |
30 |
31 | scm:git:git@github.com:allwefantasy/sql-code-intelligence.git
32 |
33 |
34 | scm:git:git@github.com:allwefantasy/sql-code-intelligence.git
35 |
36 | https://github.com/allwefantasy/sql-code-intelligence
37 |
38 |
39 | https://github.com/allwefantasy/sql-code-intelligence/issues
40 |
41 |
42 |
43 | UTF-8
44 | 2.11.8
45 | 2.11
46 | 2.11.0-M3
47 |
48 | 2.4.3
49 | 2.4
50 |
51 | 16.0
52 | 4.5.3
53 | 0.3.5
54 | 1.6.0
55 |
56 | provided
57 | 1.6.0
58 |
59 |
60 |
61 | tech.mlsql
62 | common-utils_${scala.binary.version}
63 | ${common-utils-version}
64 |
65 |
66 | org.python
67 | jython-standalone
68 |
69 |
70 | org.fusesource.leveldbjni
71 | leveldbjni-all
72 |
73 |
74 |
75 |
76 | tech.mlsql
77 | streamingpro-dsl-${spark.bigversion}_${scala.binary.version}
78 | ${mlsql.version}
79 |
80 |
81 |
82 | tech.mlsql
83 | streamingpro-spark-${spark.bigversion}.0-adaptor_${scala.binary.version}
84 | ${mlsql.version}
85 |
86 |
87 | tech.mlsql
88 | spark-adhoc-kafka_2.11
89 |
90 |
91 | tech.mlsql
92 | mysql-binlog_2.11
93 |
94 |
95 |
96 |
97 |
98 | tech.mlsql
99 | streamingpro-common-${spark.bigversion}_${scala.binary.version}
100 | ${mlsql.version}
101 |
102 |
103 |
104 | org.pegdown
105 | pegdown
106 | ${pegdown.version}
107 | test
108 |
109 |
110 |
111 | tech.mlsql
112 | streamingpro-core-${spark.bigversion}_${scala.binary.version}
113 | 1.6.0-SNAPSHOT
114 |
115 |
116 | cn.edu.hfut.dmic.webcollector
117 | WebCollector
118 |
119 |
120 | tech.mlsql
121 | pyjava-2.4_2.11
122 |
123 |
124 | tech.mlsql
125 | mlsql-scheduler
126 |
127 |
128 | org.apache.spark
129 | spark-streaming-kafka-0-8_2.11
130 |
131 |
132 |
133 |
134 |
135 | com.alibaba
136 | druid
137 | 1.1.16
138 |
139 |
140 | org.scalactic
141 | scalactic_${scala.binary.version}
142 | 3.0.0
143 | test
144 |
145 |
146 |
147 | org.scalatest
148 | scalatest_${scala.binary.version}
149 | 3.0.0
150 | test
151 |
152 |
153 |
154 | org.apache.spark
155 | spark-sql_${scala.binary.version}
156 | ${spark.version}
157 |
158 |
159 | org.apache.parquet
160 | parquet-hadoop
161 |
162 |
163 | org.apache.parquet
164 | parquet-column
165 |
166 |
167 | org.apache.arrow
168 | arrow-vector
169 |
170 |
171 | com.github.luben
172 | zstd-jni
173 |
174 |
175 | org.apache.orc
176 | orc-core
177 |
178 |
179 | org.apache.hadoop
180 | hadoop-client
181 |
182 |
183 |
184 |
185 |
186 | org.apache.spark
187 | spark-catalyst_${scala.binary.version}
188 | ${spark.version}
189 | tests
190 | test
191 |
192 |
193 |
194 | org.apache.spark
195 | spark-core_${scala.binary.version}
196 | ${spark.version}
197 | tests
198 | test
199 |
200 |
201 |
202 | org.apache.spark
203 | spark-sql_${scala.binary.version}
204 | ${spark.version}
205 | tests
206 | test
207 |
208 |
209 |
210 |
211 | shade
212 |
213 |
214 |
215 | org.apache.maven.plugins
216 | maven-shade-plugin
217 | 2.4.3
218 |
219 |
220 |
221 | *:*
222 |
223 | META-INF/*.SF
224 | META-INF/*.DSA
225 | META-INF/*.RSA
226 |
227 |
228 |
229 |
230 |
231 |
232 | package
233 |
234 | shade
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 | disable-java8-doclint
244 |
245 | [1.8,)
246 |
247 |
248 | -Xdoclint:none
249 | none
250 |
251 |
252 |
253 | release-sign-artifacts
254 |
255 |
256 | performRelease
257 | true
258 |
259 |
260 |
261 |
262 |
263 | org.apache.maven.plugins
264 | maven-gpg-plugin
265 | 1.1
266 |
267 |
268 | sign-artifacts
269 | verify
270 |
271 | sign
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 | src/main/resources
285 |
286 |
287 |
288 |
289 | org.apache.maven.plugins
290 | maven-surefire-plugin
291 | 3.0.0-M1
292 |
293 | 1
294 | true
295 | -Xmx4024m
296 |
297 | **/*.java
298 | **/*.scala
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 | org.scala-tools
307 | maven-scala-plugin
308 | 2.15.2
309 |
310 |
311 |
312 | -g:vars
313 |
314 |
315 | true
316 |
317 |
318 |
319 | compile
320 |
321 | compile
322 |
323 | compile
324 |
325 |
326 | testCompile
327 |
328 | testCompile
329 |
330 | test
331 |
332 |
333 | process-resources
334 |
335 | compile
336 |
337 |
338 |
339 |
340 |
341 |
342 | org.apache.maven.plugins
343 | maven-compiler-plugin
344 | 2.3.2
345 |
346 |
347 | -g
348 | true
349 | 1.8
350 | 1.8
351 |
352 |
353 |
354 |
355 |
356 |
357 | maven-source-plugin
358 | 2.1
359 |
360 | true
361 |
362 |
363 |
364 | compile
365 |
366 | jar
367 |
368 |
369 |
370 |
371 |
372 | org.apache.maven.plugins
373 | maven-javadoc-plugin
374 |
375 |
376 | attach-javadocs
377 |
378 | jar
379 |
380 |
381 |
382 |
383 |
384 | org.sonatype.plugins
385 | nexus-staging-maven-plugin
386 | 1.6.7
387 | true
388 |
389 | sonatype-nexus-staging
390 | https://oss.sonatype.org/
391 | true
392 |
393 |
394 |
395 |
396 | org.scalatest
397 | scalatest-maven-plugin
398 | 2.0.0
399 |
400 | streaming.core.NotToRunTag
401 | ${project.build.directory}/surefire-reports
402 | .
403 | WDF TestSuite.txt
404 | ${project.build.directory}/html/scalatest
405 | false
406 |
407 |
408 |
409 | test
410 |
411 | test
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 | sonatype-nexus-snapshots
421 | https://oss.sonatype.org/content/repositories/snapshots
422 |
423 |
424 | sonatype-nexus-staging
425 | https://oss.sonatype.org/service/local/staging/deploy/maven2/
426 |
427 |
428 |
429 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/AutoSuggester.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import com.intigua.antlr4.autosuggest.LexerWrapper.TokenizationResult;
4 | import org.antlr.v4.runtime.Token;
5 | import org.antlr.v4.runtime.atn.ATNState;
6 | import org.antlr.v4.runtime.atn.AtomTransition;
7 | import org.antlr.v4.runtime.atn.SetTransition;
8 | import org.antlr.v4.runtime.atn.Transition;
9 | import org.antlr.v4.runtime.misc.Interval;
10 | import org.slf4j.Logger;
11 | import org.slf4j.LoggerFactory;
12 |
13 | import java.util.*;
14 |
15 | /**
16 | * Suggests completions for given text, using a given ANTLR4 grammar.
17 | */
18 | public class AutoSuggester {
19 | private static final Logger logger = LoggerFactory.getLogger(AutoSuggester.class);
20 |
21 | private final ParserWrapper parserWrapper;
22 | private final LexerWrapper lexerWrapper;
23 | private final String input;
24 | private final Set collectedSuggestions = new HashSet<>();
25 |
26 | private List extends Token> inputTokens;
27 | private String untokenizedText = "";
28 | private String indent = "";
29 | private CasePreference casePreference = CasePreference.BOTH;
30 |
31 | private Map parserStateToTokenListIndexWhereLastVisited = new HashMap<>();
32 |
33 | public AutoSuggester(LexerAndParserFactory lexerAndParserFactory, String input) {
34 | this.lexerWrapper = new LexerWrapper(lexerAndParserFactory, new DefaultToCharStream());
35 | this.parserWrapper = new ParserWrapper(lexerAndParserFactory, lexerWrapper.getVocabulary());
36 | this.input = input;
37 | }
38 |
39 | public void setCasePreference(CasePreference casePreference) {
40 | this.casePreference = casePreference;
41 | }
42 |
43 | public Collection suggestCompletions() {
44 | tokenizeInput();
45 | runParserAtnAndCollectSuggestions();
46 | return collectedSuggestions;
47 | }
48 |
49 | private void tokenizeInput() {
50 | TokenizationResult tokenizationResult = lexerWrapper.tokenizeNonDefaultChannel(this.input);
51 | this.inputTokens = tokenizationResult.tokens;
52 | this.untokenizedText = tokenizationResult.untokenizedText;
53 | if (logger.isDebugEnabled()) {
54 | logger.debug("TOKENS FOUND IN FIRST PASS:");
55 | for (Token token : this.inputTokens) {
56 | logger.debug(token.toString());
57 | }
58 | }
59 | }
60 |
61 | private void runParserAtnAndCollectSuggestions() {
62 | ATNState initialState = this.parserWrapper.getAtnState(0);
63 | logger.debug("Parser initial state: " + initialState);
64 | parseAndCollectTokenSuggestions(initialState, 0);
65 | }
66 |
67 | /**
68 | * Recursive through the parser ATN to process all tokens. When successful (out of tokens) - collect completion
69 | * suggestions.
70 | */
71 | private void parseAndCollectTokenSuggestions(ATNState parserState, int tokenListIndex) {
72 | indent = indent + " ";
73 | if (didVisitParserStateOnThisTokenIndex(parserState, tokenListIndex)) {
74 | logger.debug(indent + "State " + parserState + " had already been visited while processing token "
75 | + tokenListIndex + ", backtracking to avoid infinite loop.");
76 | return;
77 | }
78 | Integer previousTokenListIndexForThisState = setParserStateLastVisitedOnThisTokenIndex(parserState, tokenListIndex);
79 | try {
80 | if (logger.isDebugEnabled()) {
81 | logger.debug(indent + "State: " + parserWrapper.toString(parserState));
82 | logger.debug(indent + "State available transitions: " + parserWrapper.transitionsStr(parserState));
83 | }
84 |
85 | if (!haveMoreTokens(tokenListIndex)) { // stop condition for recursion
86 | suggestNextTokensForParserState(parserState);
87 | return;
88 | }
89 | for (Transition trans : parserState.getTransitions()) {
90 | if (trans.isEpsilon()) {
91 | handleEpsilonTransition(trans, tokenListIndex);
92 | } else if (trans instanceof AtomTransition) {
93 | handleAtomicTransition((AtomTransition) trans, tokenListIndex);
94 | } else {
95 | handleSetTransition((SetTransition) trans, tokenListIndex);
96 | }
97 | }
98 | } finally {
99 | indent = indent.substring(2);
100 | setParserStateLastVisitedOnThisTokenIndex(parserState, previousTokenListIndexForThisState);
101 | }
102 | }
103 |
104 | private boolean didVisitParserStateOnThisTokenIndex(ATNState parserState, Integer currentTokenListIndex) {
105 | Integer lastVisitedThisStateAtTokenListIndex = parserStateToTokenListIndexWhereLastVisited.get(parserState);
106 | return currentTokenListIndex.equals(lastVisitedThisStateAtTokenListIndex);
107 | }
108 |
109 | private Integer setParserStateLastVisitedOnThisTokenIndex(ATNState parserState, Integer tokenListIndex) {
110 | if (tokenListIndex == null) {
111 | return parserStateToTokenListIndexWhereLastVisited.remove(parserState);
112 | } else {
113 | return parserStateToTokenListIndexWhereLastVisited.put(parserState, tokenListIndex);
114 | }
115 | }
116 |
117 | private boolean haveMoreTokens(int tokenListIndex) {
118 | return tokenListIndex < inputTokens.size();
119 | }
120 |
121 | private void handleEpsilonTransition(Transition trans, int tokenListIndex) {
122 | // Epsilon transitions don't consume a token, so don't move the index
123 | parseAndCollectTokenSuggestions(trans.target, tokenListIndex);
124 | }
125 |
126 | private void handleAtomicTransition(AtomTransition trans, int tokenListIndex) {
127 | Token nextToken = inputTokens.get(tokenListIndex);
128 | int nextTokenType = inputTokens.get(tokenListIndex).getType();
129 | boolean nextTokenMatchesTransition = (trans.label == nextTokenType);
130 | if (nextTokenMatchesTransition) {
131 | logger.debug(indent + "Token " + nextToken + " following transition: " + parserWrapper.toString(trans));
132 | parseAndCollectTokenSuggestions(trans.target, tokenListIndex + 1);
133 | } else {
134 | logger.debug(indent + "Token " + nextToken + " NOT following transition: " + parserWrapper.toString(trans));
135 | }
136 | }
137 |
138 | private void handleSetTransition(SetTransition trans, int tokenListIndex) {
139 | Token nextToken = inputTokens.get(tokenListIndex);
140 | int nextTokenType = nextToken.getType();
141 | for (int transitionTokenType : trans.label().toList()) {
142 | boolean nextTokenMatchesTransition = (transitionTokenType == nextTokenType);
143 | if (nextTokenMatchesTransition) {
144 | logger.debug(indent + "Token " + nextToken + " following transition: " + parserWrapper.toString(trans) + " to " + transitionTokenType);
145 | parseAndCollectTokenSuggestions(trans.target, tokenListIndex + 1);
146 | } else {
147 | logger.debug(indent + "Token " + nextToken + " NOT following transition: " + parserWrapper.toString(trans) + " to " + transitionTokenType);
148 | }
149 | }
150 | }
151 |
152 | private void suggestNextTokensForParserState(ATNState parserState) {
153 | Set transitionLabels = new HashSet<>();
154 | fillParserTransitionLabels(parserState, transitionLabels, new HashSet<>());
155 | TokenSuggester tokenSuggester = new TokenSuggester(this.untokenizedText, lexerWrapper, this.casePreference);
156 | Collection suggestions = tokenSuggester.suggest(transitionLabels);
157 | parseSuggestionsAndAddValidOnes(parserState, suggestions);
158 | logger.debug(indent + "WILL SUGGEST TOKENS FOR STATE: " + parserState);
159 | }
160 |
161 | private void fillParserTransitionLabels(ATNState parserState, Collection result, Set visitedTransitions) {
162 | for (Transition trans : parserState.getTransitions()) {
163 | TransitionWrapper transWrapper = new TransitionWrapper(parserState, trans);
164 | if (visitedTransitions.contains(transWrapper)) {
165 | logger.debug(indent + "Not following visited " + transWrapper);
166 | continue;
167 | }
168 | if (trans.isEpsilon()) {
169 | try {
170 | visitedTransitions.add(transWrapper);
171 | fillParserTransitionLabels(trans.target, result, visitedTransitions);
172 | } finally {
173 | visitedTransitions.remove(transWrapper);
174 | }
175 | } else if (trans instanceof AtomTransition) {
176 | int label = ((AtomTransition) trans).label;
177 | if (label >= 1) { // EOF would be -1
178 | result.add(label);
179 | }
180 | } else if (trans instanceof SetTransition) {
181 | for (Interval interval : ((SetTransition) trans).label().getIntervals()) {
182 | for (int i = interval.a; i <= interval.b; ++i) {
183 | result.add(i);
184 | }
185 | }
186 | }
187 | }
188 | }
189 |
190 | private void parseSuggestionsAndAddValidOnes(ATNState parserState, Collection suggestions) {
191 | for (String suggestion : suggestions) {
192 | logger.debug("CHECKING suggestion: " + suggestion);
193 | Token addedToken = getAddedToken(suggestion);
194 | if (isParseableWithAddedToken(parserState, addedToken, new HashSet())) {
195 | collectedSuggestions.add(suggestion);
196 | } else {
197 | logger.debug("DROPPING non-parseable suggestion: " + suggestion);
198 | }
199 | }
200 | }
201 |
202 | private Token getAddedToken(String suggestedCompletion) {
203 | String completedText = this.input + suggestedCompletion;
204 | List extends Token> completedTextTokens = this.lexerWrapper.tokenizeNonDefaultChannel(completedText).tokens;
205 | if (completedTextTokens.size() <= inputTokens.size()) {
206 | return null; // Completion didn't yield whole token, could be just a token fragment
207 | }
208 | logger.debug("TOKENS IN COMPLETED TEXT: " + completedTextTokens);
209 | Token newToken = completedTextTokens.get(completedTextTokens.size() - 1);
210 | return newToken;
211 | }
212 |
213 | private boolean isParseableWithAddedToken(ATNState parserState, Token newToken, Set visitedTransitions) {
214 | if (newToken == null) {
215 | return false;
216 | }
217 | for (Transition parserTransition : parserState.getTransitions()) {
218 | if (parserTransition.isEpsilon()) { // Recurse through any epsilon transitionsStr
219 | TransitionWrapper transWrapper = new TransitionWrapper(parserState, parserTransition);
220 | if (visitedTransitions.contains(transWrapper)) {
221 | continue;
222 | }
223 | visitedTransitions.add(transWrapper);
224 | try {
225 | if (isParseableWithAddedToken(parserTransition.target, newToken, visitedTransitions)) {
226 | return true;
227 | }
228 | } finally {
229 | visitedTransitions.remove(transWrapper);
230 | }
231 | } else if (parserTransition instanceof AtomTransition) {
232 | AtomTransition parserAtomTransition = (AtomTransition) parserTransition;
233 | int transitionTokenType = parserAtomTransition.label;
234 | if (transitionTokenType == newToken.getType()) {
235 | return true;
236 | }
237 | } else if (parserTransition instanceof SetTransition) {
238 | SetTransition parserSetTransition = (SetTransition) parserTransition;
239 | for (int transitionTokenType : parserSetTransition.label().toList()) {
240 | if (transitionTokenType == newToken.getType()) {
241 | return true;
242 | }
243 | }
244 | } else {
245 | throw new IllegalStateException("Unexpected: " + parserWrapper.toString(parserTransition));
246 | }
247 | }
248 | return false;
249 | }
250 |
251 |
252 | }
253 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/CasePreference.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 |
4 | public enum CasePreference {
5 | /**
6 | * Suggest both uppercase and lowercase completions
7 | */
8 | BOTH,
9 |
10 | /**
11 | * In case both uppercase and lowercase are supported by the grammar, suggest only lowercase alternative
12 | */
13 | LOWER,
14 |
15 | /**
16 | * In case both uppercase and lowercase are supported by the grammar, suggest only uppercase alternative
17 | */
18 | UPPER
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/DefaultToCharStream.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import java.io.StringReader
4 |
5 | import org.antlr.v4.runtime.{CharStream, CharStreams}
6 | import tech.mlsql.autosuggest
7 |
8 | /**
9 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com)
10 | */
11 | class DefaultToCharStream extends ToCharStream {
12 | override def toCharStream(text: String): CharStream = {
13 | CharStreams.fromReader(new StringReader(text))
14 |
15 | }
16 | }
17 |
18 | class RawSQLToCharStream extends ToCharStream {
19 | override def toCharStream(text: String): CharStream = {
20 | new autosuggest.UpperCaseCharStream(CharStreams.fromString(text))
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/LexerAndParserFactory.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | public interface LexerAndParserFactory extends LexerFactory, ParserFactory {
4 | }
5 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/LexerFactory.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.CharStream;
4 | import org.antlr.v4.runtime.Lexer;
5 |
6 | public interface LexerFactory {
7 |
8 | Lexer createLexer(CharStream input);
9 | }
10 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/LexerWrapper.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.*;
4 | import org.antlr.v4.runtime.atn.ATNState;
5 | import org.antlr.v4.runtime.misc.ParseCancellationException;
6 |
7 | import java.util.List;
8 | import java.util.stream.Collectors;
9 |
10 | public class LexerWrapper {
11 | private final LexerFactory lexerFactory;
12 | private Lexer cachedLexer;
13 | private ToCharStream toCharStream;
14 |
15 | public static class TokenizationResult {
16 | public List extends Token> tokens;
17 | public String untokenizedText = "";
18 | }
19 |
20 | public LexerWrapper(LexerFactory lexerFactory, ToCharStream toCharStream) {
21 | super();
22 | this.lexerFactory = lexerFactory;
23 | this.toCharStream = toCharStream;
24 | }
25 |
26 | public TokenizationResult tokenizeNonDefaultChannel(String input) {
27 | TokenizationResult result = this.tokenize(input);
28 | result.tokens = result.tokens.stream().filter(t -> t.getChannel() == 0).collect(Collectors.toList());
29 | return result;
30 | }
31 |
32 | public String[] getRuleNames() {
33 | return getCachedLexer().getRuleNames();
34 | }
35 |
36 | public ATNState findStateByRuleNumber(int ruleNumber) {
37 | return getCachedLexer().getATN().ruleToStartState[ruleNumber];
38 | }
39 |
40 | public Vocabulary getVocabulary() {
41 | return getCachedLexer().getVocabulary();
42 | }
43 |
44 | private Lexer getCachedLexer() {
45 | if (cachedLexer == null) {
46 | cachedLexer = createLexer("");
47 | }
48 | return cachedLexer;
49 | }
50 |
51 | private TokenizationResult tokenize(String input) {
52 | Lexer lexer = this.createLexer(input);
53 | lexer.removeErrorListeners();
54 | final TokenizationResult result = new TokenizationResult();
55 | ANTLRErrorListener newErrorListener = new BaseErrorListener() {
56 | @Override
57 | public void syntaxError(Recognizer, ?> recognizer, Object offendingSymbol, int line,
58 | int charPositionInLine, String msg, RecognitionException e) throws ParseCancellationException {
59 | result.untokenizedText = input.substring(charPositionInLine); // intended side effect
60 | }
61 | };
62 | lexer.addErrorListener(newErrorListener);
63 | result.tokens = lexer.getAllTokens();
64 | return result;
65 | }
66 |
67 | private Lexer createLexer(CharStream input) {
68 | return this.lexerFactory.createLexer(input);
69 | }
70 |
71 | private Lexer createLexer(String lexerInput) {
72 | return this.createLexer(toCharStream.toCharStream(lexerInput));
73 | }
74 |
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/ParserFactory.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.Parser;
4 | import org.antlr.v4.runtime.TokenStream;
5 |
6 | public interface ParserFactory {
7 |
8 | Parser createParser(TokenStream tokenStream);
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/ParserWrapper.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.Parser;
4 | import org.antlr.v4.runtime.Vocabulary;
5 | import org.antlr.v4.runtime.atn.ATN;
6 | import org.antlr.v4.runtime.atn.ATNState;
7 | import org.antlr.v4.runtime.atn.AtomTransition;
8 | import org.antlr.v4.runtime.atn.Transition;
9 | import org.apache.commons.lang3.StringUtils;
10 | import org.slf4j.Logger;
11 | import org.slf4j.LoggerFactory;
12 |
13 | import java.util.Arrays;
14 | import java.util.List;
15 | import java.util.stream.Collectors;
16 | import java.util.stream.Stream;
17 |
18 | class ParserWrapper {
19 | private static final Logger logger = LoggerFactory.getLogger(ParserWrapper.class);
20 | private final Vocabulary lexerVocabulary;
21 |
22 | private final ATN parserAtn;
23 | private final String[] parserRuleNames;
24 |
25 | public ParserWrapper(ParserFactory parserFactory, Vocabulary lexerVocabulary) {
26 | this.lexerVocabulary = lexerVocabulary;
27 |
28 | Parser parserForAtnOnly = parserFactory.createParser(null);
29 | this.parserAtn = parserForAtnOnly.getATN();
30 | this.parserRuleNames = parserForAtnOnly.getRuleNames();
31 | logger.debug("Parser rule names: " + StringUtils.join(parserForAtnOnly.getRuleNames(), ", "));
32 | }
33 |
34 | public String toString(ATNState parserState) {
35 | String ruleName = this.parserRuleNames[parserState.ruleIndex];
36 | return "*" + ruleName + "* " + parserState.getClass().getSimpleName() + " " + parserState;
37 | }
38 |
39 | public String toString(Transition t) {
40 | String nameOrLabel = t.getClass().getSimpleName();
41 | if (t instanceof AtomTransition) {
42 | nameOrLabel += ' ' + this.lexerVocabulary.getDisplayName(((AtomTransition) t).label);
43 | }
44 | return nameOrLabel + " -> " + toString(t.target);
45 | }
46 |
47 | public String transitionsStr(ATNState state) {
48 | Stream transitionsStream = Arrays.asList(state.getTransitions()).stream();
49 | List transitionStrings = transitionsStream.map(this::toString).collect(Collectors.toList());
50 | return StringUtils.join(transitionStrings, ", ");
51 | }
52 |
53 | public ATNState getAtnState(int stateNumber) {
54 | return parserAtn.states.get(stateNumber);
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/ReflectionLexerAndParserFactory.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.CharStream;
4 | import org.antlr.v4.runtime.Lexer;
5 | import org.antlr.v4.runtime.Parser;
6 | import org.antlr.v4.runtime.TokenStream;
7 |
8 | import java.lang.reflect.Constructor;
9 | import java.lang.reflect.InvocationTargetException;
10 |
11 | public class ReflectionLexerAndParserFactory implements LexerAndParserFactory {
12 |
13 | private final Constructor extends Lexer> lexerCtr;
14 | private final Constructor extends Parser> parserCtr;
15 |
16 | public ReflectionLexerAndParserFactory(Class extends Lexer> lexerClass, Class extends Parser> parserClass) {
17 | lexerCtr = getConstructor(lexerClass, CharStream.class);
18 | parserCtr = getConstructor(parserClass, TokenStream.class);
19 | }
20 |
21 | @Override
22 | public Lexer createLexer(CharStream input) {
23 | return create(lexerCtr, input);
24 | }
25 |
26 | @Override
27 | public Parser createParser(TokenStream tokenStream) {
28 | return create(parserCtr, tokenStream);
29 | }
30 |
31 | private static Constructor extends T> getConstructor(Class extends T> givenClass, Class> argClass) {
32 | try {
33 | return givenClass.getConstructor(argClass);
34 | } catch (NoSuchMethodException | SecurityException e) {
35 | throw new IllegalArgumentException(
36 | givenClass.getSimpleName() + " must have constructor from " + argClass.getSimpleName() + ".");
37 | }
38 | }
39 |
40 | private T create(Constructor extends T> contructor, Object arg) {
41 | try {
42 | return contructor.newInstance(arg);
43 | } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
44 | throw new IllegalArgumentException(e);
45 | }
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/ToCharStream.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.CharStream;
4 |
5 | /**
6 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | public interface ToCharStream {
9 | public CharStream toCharStream(String text);
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/TokenSuggester.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.atn.ATNState;
4 | import org.antlr.v4.runtime.atn.AtomTransition;
5 | import org.antlr.v4.runtime.atn.SetTransition;
6 | import org.antlr.v4.runtime.atn.Transition;
7 | import org.slf4j.Logger;
8 | import org.slf4j.LoggerFactory;
9 |
10 | import java.util.*;
11 | import java.util.stream.Collectors;
12 |
13 | /**
14 | * Given an ATN state and the lexer ATN, suggests auto-completion texts.
15 | */
16 | class TokenSuggester {
17 | private static final Logger logger = LoggerFactory.getLogger(TokenSuggester.class);
18 |
19 | private final LexerWrapper lexerWrapper;
20 | private final CasePreference casePreference;
21 |
22 | private final Set suggestions = new TreeSet();
23 | private final List visitedLexerStates = new ArrayList<>();
24 | private String origPartialToken;
25 |
26 | public TokenSuggester(LexerWrapper lexerWrapper, String input) {
27 | this(input, lexerWrapper, CasePreference.BOTH);
28 | }
29 |
30 | public TokenSuggester(String origPartialToken, LexerWrapper lexerWrapper, CasePreference casePreference) {
31 | this.origPartialToken = origPartialToken;
32 | this.lexerWrapper = lexerWrapper;
33 | this.casePreference = casePreference;
34 | }
35 |
36 | public Collection suggest(Collection nextParserTransitionLabels) {
37 | logTokensUsedForSuggestion(nextParserTransitionLabels);
38 | for (int nextParserTransitionLabel : nextParserTransitionLabels) {
39 | int nextTokenRuleNumber = nextParserTransitionLabel - 1; // Count from 0 not from 1
40 | ATNState lexerState = this.lexerWrapper.findStateByRuleNumber(nextTokenRuleNumber);
41 | suggest("", lexerState, origPartialToken);
42 | }
43 | return suggestions;
44 | // return suggestions.stream().filter(s -> this.lexerWrapper.isValidSuggestion(input, s)).collect(Collectors.toList());
45 | }
46 |
47 | private void logTokensUsedForSuggestion(Collection ruleIndices) {
48 | if (!logger.isDebugEnabled()) {
49 | return;
50 | }
51 | String ruleNames = ruleIndices.stream().map(r -> lexerWrapper.getRuleNames()[r-1]).collect(Collectors.joining(" "));
52 | logger.debug("Suggesting tokens for lexer rules: " + ruleNames, " ");
53 | }
54 |
55 |
56 | private void suggest(String tokenSoFar, ATNState lexerState, String remainingText) {
57 | logger.debug(
58 | "SUGGEST: tokenSoFar=" + tokenSoFar + " remainingText=" + remainingText + " lexerState=" + toString(lexerState));
59 | if (visitedLexerStates.contains(lexerState.stateNumber)) {
60 | return; // avoid infinite loop and stack overflow
61 | }
62 | visitedLexerStates.add(lexerState.stateNumber);
63 | try {
64 | Transition[] transitions = lexerState.getTransitions();
65 | boolean tokenNotEmpty = tokenSoFar.length() > 0;
66 | boolean noMoreCharactersInToken = (transitions.length == 0);
67 | if (tokenNotEmpty && noMoreCharactersInToken) {
68 | addSuggestedToken(tokenSoFar);
69 | return;
70 | }
71 | for (Transition trans : transitions) {
72 | suggestViaLexerTransition(tokenSoFar, remainingText, trans);
73 | }
74 | } finally {
75 | visitedLexerStates.remove(visitedLexerStates.size() - 1);
76 | }
77 | }
78 |
79 | private String toString(ATNState lexerState) {
80 | String ruleName = this.lexerWrapper.getRuleNames()[lexerState.ruleIndex];
81 | return ruleName + " " + lexerState.getClass().getSimpleName() + " " + lexerState;
82 | }
83 |
84 | private void suggestViaLexerTransition(String tokenSoFar, String remainingText, Transition trans) {
85 | if (trans.isEpsilon()) {
86 | suggest(tokenSoFar, trans.target, remainingText);
87 | } else if (trans instanceof AtomTransition) {
88 | String newTokenChar = getAddedTextFor((AtomTransition) trans);
89 | if (remainingText.isEmpty() || remainingText.startsWith(newTokenChar)) {
90 | logger.debug("LEXER TOKEN: " + newTokenChar + " remaining=" + remainingText);
91 | suggestViaNonEpsilonLexerTransition(tokenSoFar, remainingText, newTokenChar, trans.target);
92 | } else {
93 | logger.debug("NONMATCHING LEXER TOKEN: " + newTokenChar + " remaining=" + remainingText);
94 | }
95 | } else if (trans instanceof SetTransition) {
96 | List symbols = ((SetTransition) trans).label().toList();
97 | for (Integer symbol : symbols) {
98 | char[] charArr = Character.toChars(symbol);
99 | String charStr = new String(charArr);
100 | boolean shouldIgnoreCase = shouldIgnoreThisCase(charArr[0], symbols); // TODO: check for non-BMP
101 | if (!shouldIgnoreCase && (remainingText.isEmpty() || remainingText.startsWith(charStr))) {
102 | suggestViaNonEpsilonLexerTransition(tokenSoFar, remainingText, charStr, trans.target);
103 | }
104 | }
105 | }
106 | }
107 |
108 | private void suggestViaNonEpsilonLexerTransition(String tokenSoFar, String remainingText,
109 | String newTokenChar, ATNState targetState) {
110 | String newRemainingText = (remainingText.length() > 0) ? remainingText.substring(1) : remainingText;
111 | suggest(tokenSoFar + newTokenChar, targetState, newRemainingText);
112 | }
113 |
114 | private void addSuggestedToken(String tokenToAdd) {
115 | String justTheCompletionPart = chopOffCommonStart(tokenToAdd, this.origPartialToken);
116 | suggestions.add(justTheCompletionPart);
117 | }
118 |
119 | private String chopOffCommonStart(String a, String b) {
120 | int charsToChopOff = Math.min(b.length(), a.length());
121 | return a.substring(charsToChopOff);
122 | }
123 |
124 | private String getAddedTextFor(AtomTransition transition) {
125 | return new String(Character.toChars(transition.label));
126 | }
127 |
128 | private boolean shouldIgnoreThisCase(char transChar, List allTransChars) {
129 | if (this.casePreference == null) {
130 | return false;
131 | }
132 | switch(this.casePreference) {
133 | case BOTH:
134 | return false;
135 | case LOWER:
136 | return Character.isUpperCase(transChar) && allTransChars.contains((int) Character.toLowerCase(transChar));
137 | case UPPER:
138 | return Character.isLowerCase(transChar) && allTransChars.contains((int) Character.toUpperCase(transChar));
139 | default:
140 | return false;
141 | }
142 | }
143 |
144 | }
145 |
--------------------------------------------------------------------------------
/src/main/java/com/intigua/antlr4/autosuggest/TransitionWrapper.java:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest;
2 |
3 | import org.antlr.v4.runtime.atn.ATNState;
4 | import org.antlr.v4.runtime.atn.Transition;
5 |
6 | public class TransitionWrapper {
7 | private final ATNState source;
8 | private final Transition transition;
9 |
10 | public TransitionWrapper(ATNState source, Transition transition) {
11 | super();
12 | this.source = source;
13 | this.transition = transition;
14 | }
15 |
16 | @Override
17 | public int hashCode() {
18 | final int prime = 31;
19 | int result = 1;
20 | result = prime * result + ((source == null) ? 0 : source.hashCode());
21 | result = prime * result + ((transition == null) ? 0 : transition.hashCode());
22 | return result;
23 | }
24 |
25 | @Override
26 | public boolean equals(Object obj) {
27 | if (this == obj)
28 | return true;
29 | if (obj == null)
30 | return false;
31 | if (getClass() != obj.getClass())
32 | return false;
33 | TransitionWrapper other = (TransitionWrapper) obj;
34 | if (source == null) {
35 | if (other.source != null)
36 | return false;
37 | } else if (!source.equals(other.source))
38 | return false;
39 | if (transition == null) {
40 | if (other.transition != null)
41 | return false;
42 | } else if (!transition.equals(other.transition))
43 | return false;
44 | return true;
45 | }
46 |
47 | @Override
48 | public String toString() {
49 | return transition.getClass().getSimpleName() + " from " + source + " to " + transition.target;
50 | }
51 |
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/AttributeExtractor.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
5 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher}
6 | import tech.mlsql.autosuggest.meta.MetaTableKey
7 | import tech.mlsql.autosuggest.statement.{MatchAndExtractor, MetaTableKeyWrapper, SingleStatementAST}
8 |
9 | import scala.collection.mutable.ArrayBuffer
10 |
11 | /**
12 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class AttributeExtractor(autoSuggestContext: AutoSuggestContext, ast: SingleStatementAST, tokens: List[Token]) extends MatchAndExtractor[String] {
15 |
16 | override def matcher(start: Int): TokenMatcher = {
17 | return asterriskM(start)
18 | }
19 |
20 | private def attributeM(start: Int): TokenMatcher = {
21 | val temp = TokenMatcher(tokens, start).
22 | eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, SqlBaseLexer.T__3)).optional.
23 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).
24 | eat(Food(None, SqlBaseLexer.AS)).optional.
25 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).optional.
26 | build
27 | temp.isSuccess match {
28 | case true => temp
29 | case false =>
30 | TokenMatcher(tokens, start).
31 | eatOneAny.
32 | eat(Food(None, SqlBaseLexer.AS)).
33 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).
34 | build
35 | }
36 | }
37 |
38 | private def funcitonM(start: Int): TokenMatcher = {
39 | // deal with something like: sum(a[1],fun(b)) as a
40 | val temp = TokenMatcher(tokens, start).eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, SqlBaseLexer.T__0)).build
41 | // function match
42 | if (temp.isSuccess) {
43 | // try to find AS
44 | // we need to take care of situation like this: cast(a as int) as b
45 | // In future, we should get first function and get the return type so we can get the b type.
46 | val index = TokenMatcher(tokens, start).index(Array(Food(None, SqlBaseLexer.T__1), Food(None, SqlBaseLexer.AS), Food(None, SqlBaseLexer.IDENTIFIER)))
47 | if (index != -1) {
48 | //index + 1 to skip )
49 | val aliasName = TokenMatcher(tokens, index + 1).eat(Food(None, SqlBaseLexer.AS)).
50 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).build
51 | if (aliasName.isSuccess) {
52 | return TokenMatcher.resultMatcher(tokens, start, aliasName.get)
53 | }
54 |
55 | }
56 | // if no AS, do nothing
57 | null
58 | }
59 | return attributeM(start)
60 | }
61 |
62 | private def asterriskM(start: Int): TokenMatcher = {
63 | val temp = TokenMatcher(tokens, start).
64 | eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, SqlBaseLexer.T__3)).optional.
65 | eat(Food(None, SqlBaseLexer.ASTERISK)).build
66 | if (temp.isSuccess) {
67 | return TokenMatcher.resultMatcher(tokens, start, temp.get)
68 | }
69 | return funcitonM(start)
70 | }
71 |
72 | override def extractor(start: Int, end: Int): List[String] = {
73 |
74 |
75 | val attrTokens = tokens.slice(start, end)
76 | val token = attrTokens.last
77 | if (token.getType == SqlBaseLexer.ASTERISK) {
78 | return attrTokens match {
79 | case List(tableName, _, _) =>
80 | //expand output
81 | ast.selectSuggester.table_info(ast.level).
82 | get(MetaTableKeyWrapper(MetaTableKey(None, None, null), Option(tableName.getText))).orElse{
83 | ast.selectSuggester.table_info(ast.level).
84 | get(MetaTableKeyWrapper(MetaTableKey(None, None, tableName.getText), None))
85 | } match {
86 | case Some(table) =>
87 | //如果是临时表,那么需要进一步展开
88 | val columns = if (table.key.db == Option(SpecialTableConst.TEMP_TABLE_DB_KEY)) {
89 | autoSuggestContext.metaProvider.search(table.key) match {
90 | case Some(item) => item.columns
91 | case None => List()
92 | }
93 | } else table.columns
94 | columns.map(_.name).toList
95 | case None => List()
96 | }
97 | case List(starAttr) =>
98 | val table = ast.tables(tokens).head
99 | ast.selectSuggester.table_info(ast.level).
100 | get(table) match {
101 | case Some(table) =>
102 | //如果是临时表,那么需要进一步展开
103 | val columns = if (table.key.db == Option(SpecialTableConst.TEMP_TABLE_DB_KEY)) {
104 | autoSuggestContext.metaProvider.search(table.key) match {
105 | case Some(item) => item.columns
106 | case None => List()
107 | }
108 | } else table.columns
109 | columns.map(_.name).toList
110 | case None => List()
111 | }
112 | }
113 | }
114 | List(token.getText)
115 | }
116 |
117 | override def iterate(start: Int, end: Int, limit: Int): List[String] = {
118 | val attributes = ArrayBuffer[String]()
119 | var matchRes = matcher(start)
120 | var whileLimit = 1000
121 | while (matchRes.isSuccess && whileLimit > 0) {
122 | attributes ++= extractor(matchRes.start, matchRes.get)
123 | whileLimit -= 1
124 | val temp = TokenMatcher(tokens, matchRes.get).eat(Food(None, SqlBaseLexer.T__2)).build
125 | if (temp.isSuccess) {
126 | matchRes = matcher(temp.get)
127 | } else whileLimit = 0
128 | }
129 | attributes.toList
130 | }
131 | }
132 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/AutoSuggestContext.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | import com.intigua.antlr4.autosuggest.LexerWrapper
4 | import org.antlr.v4.runtime.misc.Interval
5 | import org.antlr.v4.runtime.{CharStream, CodePointCharStream, IntStream, Token}
6 | import org.apache.spark.sql.SparkSession
7 | import tech.mlsql.autosuggest.meta._
8 | import tech.mlsql.autosuggest.preprocess.TablePreprocessor
9 | import tech.mlsql.autosuggest.statement._
10 | import tech.mlsql.common.utils.log.Logging
11 | import tech.mlsql.common.utils.reflect.ClassPath
12 | import tech.mlsql.common.utils.serder.json.JSONTool
13 |
14 | import scala.collection.JavaConverters._
15 | import scala.collection.mutable.ArrayBuffer
16 |
17 | object AutoSuggestContext {
18 | private[this] val autoSuggestContext: ThreadLocal[AutoSuggestContext] = new ThreadLocal[AutoSuggestContext]
19 |
20 | def context(): AutoSuggestContext = autoSuggestContext.get
21 | def setContext(ec: AutoSuggestContext): Unit = autoSuggestContext.set(ec)
22 |
23 | val memoryMetaProvider = new MemoryMetaProvider()
24 | var isInit = false
25 |
26 | def init: Unit = {
27 | val funcRegs = ClassPath.from(classOf[AutoSuggestContext].getClassLoader).getTopLevelClasses(POrCLiterals.FUNCTION_PACKAGE).iterator()
28 | while (funcRegs.hasNext) {
29 | val wow = funcRegs.next()
30 | val funcMetaTable = wow.load().newInstance().asInstanceOf[FuncReg].register
31 | MLSQLSQLFunction.funcMetaProvider.register(funcMetaTable)
32 | }
33 | isInit = true
34 | }
35 |
36 | if (!isInit) {
37 | init
38 | }
39 |
40 | }
41 |
42 | /**
43 | * 每个请求都需要实例化一个
44 | */
45 | class AutoSuggestContext(val session: SparkSession,
46 | val lexer: LexerWrapper,
47 | val rawSQLLexer: LexerWrapper, val options: Map[String, String] = Map()) extends Logging {
48 |
49 | AutoSuggestContext.setContext(this)
50 | private var _debugMode = false
51 |
52 | private var _rawTokens: List[Token] = List()
53 | private var _statements = List[List[Token]]()
54 | private val _tempTableProvider: StatementTempTableProvider = new StatementTempTableProvider()
55 | private var _rawLineNum = 0
56 | private var _rawColumnNum = 0
57 | private var userDefinedProvider: MetaProvider = new MetaProvider {
58 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = None
59 |
60 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List()
61 | }
62 | private var _metaProvider: MetaProvider = new LayeredMetaProvider(tempTableProvider, userDefinedProvider)
63 |
64 | private val _statementProcessors = ArrayBuffer[PreProcessStatement]()
65 | addStatementProcessor(new TablePreprocessor(this))
66 |
67 | private var _statementSplitter: StatementSplitter = new MLSQLStatementSplitter()
68 |
69 | def setDebugMode(isDebug: Boolean) = {
70 | this._debugMode = isDebug
71 | }
72 |
73 | def isSchemaInferEnabled = {
74 | !options.getOrElse("schemaInferUrl","").isEmpty && session != null
75 | }
76 |
77 | def isInDebugMode = _debugMode
78 |
79 | def metaProvider = _metaProvider
80 |
81 | def statements = {
82 | _statements
83 | }
84 |
85 | def reqParams = {
86 | JSONTool.parseJson[Map[String,String]](AutoSuggestContext.context().options("params"))
87 | }
88 |
89 | def rawTokens = _rawTokens
90 |
91 | def tempTableProvider = {
92 | _tempTableProvider
93 | }
94 |
95 | def setStatementSplitter(_statementSplitter: StatementSplitter) = {
96 | this._statementSplitter = _statementSplitter
97 | this
98 | }
99 |
100 | def setUserDefinedMetaProvider(_metaProvider: MetaProvider) = {
101 | userDefinedProvider = _metaProvider
102 | this._metaProvider = new LayeredMetaProvider(tempTableProvider, userDefinedProvider)
103 | this
104 | }
105 |
106 | def setRootMetaProvider(_metaProvider: MetaProvider) = {
107 | this._metaProvider = _metaProvider
108 | this
109 | }
110 |
111 | def addStatementProcessor(item: PreProcessStatement) = {
112 | this._statementProcessors += item
113 | this
114 | }
115 |
116 | def buildFromString(str: String): AutoSuggestContext = {
117 | build(lexer.tokenizeNonDefaultChannel(str).tokens.asScala.toList)
118 | this
119 | }
120 |
121 | def build(_tokens: List[Token]): AutoSuggestContext = {
122 | _rawTokens = _tokens
123 | _statements = _statementSplitter.split(rawTokens)
124 | // preprocess
125 | _statementProcessors.foreach { sta =>
126 | _statements.foreach(sta.process(_))
127 | }
128 | return this
129 | }
130 |
131 | def toRelativePos(tokenPos: TokenPos): (TokenPos, Int) = {
132 | var skipSize = 0
133 | var targetIndex = 0
134 | var targetPos: TokenPos = null
135 | var targetStaIndex = 0
136 | _statements.zipWithIndex.foreach { case (sta, index) =>
137 | val relativePos = tokenPos.pos - skipSize
138 | if (relativePos >= 0 && relativePos < sta.size) {
139 | targetPos = tokenPos.copy(pos = tokenPos.pos - skipSize)
140 | targetStaIndex = index
141 | }
142 | skipSize += sta.size
143 | targetIndex += 1
144 | }
145 | return (targetPos, targetStaIndex)
146 | }
147 |
148 | def suggest(lineNum: Int, columnNum: Int): List[SuggestItem] = {
149 | if (isInDebugMode) {
150 | logInfo(s"lineNum:${lineNum} columnNum:${columnNum}")
151 | }
152 | _rawLineNum = lineNum
153 | _rawColumnNum = columnNum
154 | val tokenPos = LexerUtils.toTokenPos(rawTokens, lineNum, columnNum)
155 | _suggest(tokenPos)
156 | }
157 |
158 | /**
159 | * Notice that the pos in tokenPos is in whole script.
160 | * We need to convert it to the relative pos in every statement
161 | */
162 | private[autosuggest] def _suggest(tokenPos: TokenPos): List[SuggestItem] = {
163 | assert(_rawColumnNum != 0 || _rawColumnNum != 0, "lineNum and columnNum should be set")
164 | if (isInDebugMode) {
165 | logInfo("Global Pos::" + tokenPos.str + s"::${rawTokens(tokenPos.pos)}")
166 | }
167 | if (tokenPos.pos == -1) {
168 | return firstWords
169 | }
170 | val (relativeTokenPos, index) = toRelativePos(tokenPos)
171 |
172 | if (isInDebugMode) {
173 | logInfo(s"Relative Pos in ${index}-statement ::" + relativeTokenPos.str)
174 | logInfo(s"${index}-statement:\n${_statements(index).map(_.getText).mkString(" ")}")
175 | }
176 | val items = _statements(index).headOption.map(_.getText) match {
177 | case Some("load") =>
178 | val suggester = new LoadSuggester(this, _statements(index), relativeTokenPos)
179 | suggester.suggest()
180 | case Some("select") =>
181 | // we should recompute the token pos since they use spark sql lexer instead of
182 | val selectTokens = _statements(index)
183 | val startLineNum = selectTokens.head.getLine
184 | val relativeLineNum = _rawLineNum - startLineNum + 1 // startLineNum start from 1
185 | val relativeColumnNum = _rawColumnNum - selectTokens.head.getCharPositionInLine // charPos is start from 0
186 |
187 | if (isInDebugMode) {
188 | logInfo(s"select[${index}] relativeLineNum:${relativeLineNum} relativeColumnNum:${relativeColumnNum}")
189 | }
190 | val relativeTokenPos = LexerUtils.toTokenPosForSparkSQL(LexerUtils.toRawSQLTokens(this, selectTokens), relativeLineNum, relativeColumnNum)
191 | if (isInDebugMode) {
192 | logInfo(s"select[${index}] relativeTokenPos:${relativeTokenPos.str}")
193 | }
194 | val suggester = new SelectSuggester(this, selectTokens, relativeTokenPos)
195 | suggester.suggest()
196 | case Some(value) => firstWords.filter(_.name.startsWith(value))
197 | case None => firstWords
198 | }
199 | items.distinct
200 | }
201 |
202 | private val firstWords = List("load", "select", "include", "register", "run", "train", "save", "set").map(SuggestItem(_, SpecialTableConst.KEY_WORD_TABLE, Map())).toList
203 | }
204 |
205 | class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
206 | override def consume(): Unit = wrapped.consume
207 |
208 | override def getSourceName(): String = wrapped.getSourceName
209 |
210 | override def index(): Int = wrapped.index
211 |
212 | override def mark(): Int = wrapped.mark
213 |
214 | override def release(marker: Int): Unit = wrapped.release(marker)
215 |
216 | override def seek(where: Int): Unit = wrapped.seek(where)
217 |
218 | override def size(): Int = wrapped.size
219 |
220 | override def getText(interval: Interval): String = {
221 | // ANTLR 4.7's CodePointCharStream implementations have bugs when
222 | // getText() is called with an empty stream, or intervals where
223 | // the start > end. See
224 | // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix
225 | // that is not yet in a released ANTLR artifact.
226 | if (size() > 0 && (interval.b - interval.a >= 0)) {
227 | wrapped.getText(interval)
228 | } else {
229 | ""
230 | }
231 | }
232 |
233 | override def LA(i: Int): Int = {
234 | val la = wrapped.LA(i)
235 | if (la == 0 || la == IntStream.EOF) la
236 | else Character.toUpperCase(la)
237 | }
238 | }
239 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/AutoSuggester.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | /**
4 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | class AutoSuggester {
7 |
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/FuncReg.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | import tech.mlsql.autosuggest.meta.MetaTable
4 |
5 | /**
6 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | trait FuncReg {
9 | val DOC = "doc"
10 | val COLUMN = "column"
11 | val IS_AGG = "agg"
12 | val YES = "yes"
13 | val NO = "no"
14 | val DEFAULT_VALUE = "default"
15 |
16 | def register: MetaTable
17 | }
18 |
19 |
20 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/FunctionUtils.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | import tech.mlsql.autosuggest.MLSQLSQLFunction.DB_KEY
4 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey}
5 |
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 |
9 | object MLSQLSQLFunction {
10 | val DB_KEY = "__FUNC__"
11 | val RETURN_KEY = "__RETURN__"
12 | val funcMetaProvider = new FuncMetaProvider()
13 |
14 | def apply(name: String): MLSQLSQLFunction = new MLSQLSQLFunction(name)
15 | }
16 |
17 | class MLSQLSQLFunction(name: String) {
18 |
19 | private val _params = ArrayBuffer[MetaTableColumn]()
20 | private val _returnParam = ArrayBuffer[MetaTableColumn]()
21 | private var _tableKey: MetaTableKey = MetaTableKey(None, Option(DB_KEY), name)
22 | private val _funcDescParam = ArrayBuffer[MetaTableColumn]()
23 |
24 | def funcName(name: String) = {
25 | _tableKey = MetaTableKey(None, Option(DB_KEY), name)
26 | this
27 | }
28 |
29 | def funcParam = {
30 | new MLSQLFuncParam(this)
31 | }
32 |
33 | def desc(extra: Map[String, String]) = {
34 | assert(_funcDescParam.size == 0, "desc can only invoke once")
35 | _funcDescParam += MetaTableColumn(MLSQLSQLFunction.DB_KEY, "", false, extra)
36 | this
37 | }
38 |
39 | def returnParam(dataType: String, isNull: Boolean, extra: Map[String, String]) = {
40 | assert(_returnParam.size == 0, "returnParam can only invoke once")
41 | _returnParam += MetaTableColumn(MLSQLSQLFunction.RETURN_KEY, dataType, isNull, extra)
42 | this
43 | }
44 |
45 | def addColumn(column: MetaTableColumn) = {
46 | _params += column
47 | this
48 | }
49 |
50 | def build = {
51 | MetaTable(_tableKey, (_funcDescParam ++ _returnParam ++ _params).toList)
52 | }
53 |
54 | }
55 |
56 | class MLSQLFuncParam(_func: MLSQLSQLFunction) {
57 | def param(name: String, dataType: String) = {
58 | _func.addColumn(MetaTableColumn(name, dataType, true, Map()))
59 | this
60 | }
61 |
62 | def param(name: String, dataType: String, isNull: Boolean) = {
63 | _func.addColumn(MetaTableColumn(name, dataType, isNull, Map()))
64 | this
65 | }
66 |
67 | def param(name: String, dataType: String, isNull: Boolean, extra: Map[String, String]) = {
68 | _func.addColumn(MetaTableColumn(name, dataType, isNull, extra))
69 | this
70 | }
71 |
72 | def func = {
73 | _func
74 | }
75 | }
76 |
77 | class FuncMetaProvider extends MetaProvider {
78 | private val funcs = scala.collection.mutable.HashMap[MetaTableKey, MetaTable]()
79 |
80 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
81 | funcs.get(key)
82 | }
83 |
84 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = {
85 | funcs.map(_._2).toList
86 | }
87 |
88 | def register(func: MetaTable) = {
89 | this.funcs.put(func.key, func)
90 | this
91 | }
92 |
93 | // for test
94 | def clear = {
95 | funcs.clear()
96 | }
97 | }
98 |
99 |
100 | object DataType {
101 | val STRING = "string"
102 | val INT = "integer"
103 | val LONG = "long"
104 | val DOUBLE = "double"
105 | val NUMBER = "number"
106 | val DATE = "date"
107 | val DATE_TIMESTAMP = "date_timestamp"
108 | val ARRAY = "array"
109 | val MAP = "map"
110 | }
111 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/POrCLiterals.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | /**
4 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | object POrCLiterals {
7 | val FUNCTION_PACKAGE = "tech.mlsql.autosuggest.funcs"
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/SpecialTableConst.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableKey}
4 |
5 | /**
6 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | object SpecialTableConst {
9 | val KEY_WORD = "__KEY__WORD__"
10 | val DATA_SOURCE_KEY = "__DATA__SOURCE__"
11 | val OPTION_KEY = "__OPTION__"
12 | val TEMP_TABLE_DB_KEY = "__TEMP_TABLE__"
13 |
14 | val OTHER_TABLE_KEY = "__OTHER__TABLE__"
15 |
16 | val TOP_LEVEL_KEY = "__TOP_LEVEL__"
17 |
18 | def KEY_WORD_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.KEY_WORD), List())
19 |
20 | def DATA_SOURCE_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.DATA_SOURCE_KEY), List())
21 |
22 | def OPTION_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.OPTION_KEY), List())
23 |
24 | def OTHER_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.OTHER_TABLE_KEY), List())
25 |
26 | def tempTable(name: String) = MetaTable(MetaTableKey(None, Option(TEMP_TABLE_DB_KEY), name), List())
27 |
28 | def subQueryAliasTable = {
29 | MetaTableKey(None, None, null)
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/TokenPos.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest
2 |
3 | /**
4 | * TokenPos mark the cursor position
5 | */
6 | object TokenPosType {
7 | val END = -1
8 | val CURRENT = -2
9 | val NEXT = -3
10 | }
11 |
12 | /**
13 | * input "load" the cursor can be (situation 1) in the end of load
14 | * or (situation 2) with white space
15 | *
16 | * Looks like this:
17 | * 1. loa[cursor] => tab hit => loa[d] => pos:0 currentOrNext:TokenPosType.CURRENT offsetInToken:3
18 | * 2. load [cursor] => tab hit => load [DataSource list] => pos:0 currentOrNext:TokenPosType.NEXT offsetInToken:0
19 | * 2.1 load hi[cursor]ve.`db.table` => tab hit => load hi[ve] => pos:1 currentOrNext:TokenPosType.CURRENT offsetInToken:2
20 | *
21 | * the second situation can also show like this:
22 | *
23 | * load [cursor]hive.`db.table`
24 | *
25 | * we should still show DataSource List. This means we don't care the token after the cursor when we provide
26 | * suggestion list.
27 | *
28 | */
29 | case class TokenPos(pos: Int, currentOrNext: Int, offsetInToken: Int = -1) {
30 | def str = {
31 | val posType = currentOrNext match {
32 | case TokenPosType.NEXT => "next"
33 | case TokenPosType.CURRENT => "current"
34 | }
35 | s"TokenPos: Index(${pos}) currentOrNext(${posType}) offsetInToken(${offsetInToken})"
36 | }
37 | }
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/app/Constants.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.app
2 |
3 | /**
4 | * 2019-07-21 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | object Constants {
7 | val DB = "db"
8 | val HIVE = "hive"
9 | val JSON = "json"
10 | val TEXT = "text"
11 | }
12 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/app/MLSQLAutoSuggestApp.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.app
2 |
3 | import com.intigua.antlr4.autosuggest.{DefaultToCharStream, LexerWrapper, RawSQLToCharStream, ReflectionLexerAndParserFactory}
4 | import org.apache.spark.sql.catalyst.parser.{SqlBaseLexer, SqlBaseParser}
5 | import streaming.dsl.ScriptSQLExec
6 | import streaming.dsl.parser.{DSLSQLLexer, DSLSQLParser}
7 | import tech.mlsql.app.CustomController
8 | import tech.mlsql.autosuggest.meta.{MLSQLEngineMetaProvider, RestMetaProvider}
9 | import tech.mlsql.autosuggest.{AutoSuggestContext, MLSQLSQLFunction}
10 | import tech.mlsql.common.utils.serder.json.JSONTool
11 | import tech.mlsql.runtime.AppRuntimeStore
12 | import tech.mlsql.version.VersionCompatibility
13 |
14 | /**
15 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com)
16 | */
17 | class MLSQLAutoSuggestApp extends tech.mlsql.app.App with VersionCompatibility {
18 | override def run(args: Seq[String]): Unit = {
19 | AutoSuggestContext.init
20 | AppRuntimeStore.store.registerController("autoSuggest", classOf[AutoSuggestController].getName)
21 | AppRuntimeStore.store.registerController("registerTable", classOf[RegisterTableController].getName)
22 | AppRuntimeStore.store.registerController("sqlFunctions", classOf[SQLFunctionController].getName)
23 | }
24 |
25 | override def supportedVersions: Seq[String] = {
26 | Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0")
27 | }
28 | }
29 |
30 | object AutoSuggestController {
31 | val lexerAndParserfactory = new ReflectionLexerAndParserFactory(classOf[DSLSQLLexer], classOf[DSLSQLParser]);
32 | val mlsqlLexer = new LexerWrapper(lexerAndParserfactory, new DefaultToCharStream)
33 |
34 | val lexerAndParserfactory2 = new ReflectionLexerAndParserFactory(classOf[SqlBaseLexer], classOf[SqlBaseParser]);
35 | val sqlLexer = new LexerWrapper(lexerAndParserfactory2, new RawSQLToCharStream)
36 |
37 | def getSchemaRegistry = {
38 |
39 | new SchemaRegistry(getSession)
40 | }
41 |
42 | def getSession = {
43 | val session = if (ScriptSQLExec.context() != null) ScriptSQLExec.context().execListener.sparkSession else Standalone.sparkSession
44 | session
45 | }
46 | }
47 |
48 | class SQLFunctionController extends CustomController {
49 | override def run(params: Map[String, String]): String = {
50 | JSONTool.toJsonStr(MLSQLSQLFunction.funcMetaProvider.list(Map()))
51 | }
52 | }
53 |
54 | class RegisterTableController extends CustomController {
55 | override def run(params: Map[String, String]): String = {
56 | def hasParam(str: String) = params.contains(str)
57 |
58 | def paramOpt(name: String) = {
59 | if (hasParam(name)) Option(params(name)) else None
60 | }
61 |
62 | val prefix = paramOpt("prefix")
63 | val db = paramOpt("db")
64 | require(hasParam("table"), "table is required")
65 | require(hasParam("schema"), "schema is required")
66 | val table = params("table")
67 |
68 | val session = AutoSuggestController.getSchemaRegistry
69 | params.getOrElse("schemaType", "") match {
70 | case Constants.DB => session.createTableFromDBSQL(prefix, db, table, params("schema"))
71 | case Constants.HIVE => session.createTableFromHiveSQL(prefix, db, table, params("schema"))
72 | case Constants.JSON => session.createTableFromJson(prefix, db, table, params("schema"))
73 | case _ => session.createTableFromDBSQL(prefix, db, table, params("schema"))
74 | }
75 | JSONTool.toJsonStr(Map("success" -> true))
76 | }
77 | }
78 |
79 | class AutoSuggestController extends CustomController {
80 | override def run(params: Map[String, String]): String = {
81 | val sql = params("sql")
82 | val lineNum = params("lineNum").toInt
83 | val columnNum = params("columnNum").toInt
84 | val isDebug = params.getOrElse("isDebug", "false").toBoolean
85 | val size = params.getOrElse("size", "30").toInt
86 | val includeTableMeta = params.getOrElse("includeTableMeta", "false").toBoolean
87 |
88 | val schemaInferUrl = params.getOrElse("schemaInferUrl", "")
89 |
90 | val enableMemoryProvider = params.getOrElse("enableMemoryProvider", "true").toBoolean
91 | val session = AutoSuggestController.getSession
92 | val context = new AutoSuggestContext(session,
93 | AutoSuggestController.mlsqlLexer,
94 | AutoSuggestController.sqlLexer,Map("schemaInferUrl"->schemaInferUrl,"params"->JSONTool.toJsonStr(params)))
95 | context.setDebugMode(isDebug)
96 | if (enableMemoryProvider) {
97 | context.setUserDefinedMetaProvider(AutoSuggestContext.memoryMetaProvider)
98 | }
99 |
100 | if(!schemaInferUrl.isEmpty){
101 | context.setUserDefinedMetaProvider(new MLSQLEngineMetaProvider())
102 | }
103 |
104 | val searchUrl = params.get("searchUrl")
105 | val listUrl = params.get("listUrl")
106 | (searchUrl, listUrl) match {
107 | case (Some(searchUrl), Some(listUrl)) =>
108 | context.setUserDefinedMetaProvider(new RestMetaProvider(searchUrl, listUrl))
109 | case (None, None) =>
110 | }
111 |
112 |
113 |
114 | var resItems = context.buildFromString(sql).suggest(lineNum, columnNum).take(size)
115 | if (!includeTableMeta) {
116 | resItems = resItems.map { item =>
117 | item.copy(metaTable = null)
118 | }.take(size)
119 | }
120 | JSONTool.toJsonStr(resItems)
121 | }
122 | }
123 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/app/RDSchema.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.app
2 |
3 | import java.sql.{JDBCType, SQLException}
4 |
5 | import com.alibaba.druid.sql.SQLUtils
6 | import com.alibaba.druid.sql.ast.SQLDataType
7 | import com.alibaba.druid.sql.ast.statement.{SQLColumnDefinition, SQLCreateTableStatement}
8 | import com.alibaba.druid.sql.repository.SchemaRepository
9 | import org.apache.spark.sql.jdbc.JdbcDialects
10 | import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE}
11 | import org.apache.spark.sql.types._
12 |
13 | import scala.collection.JavaConverters._
14 | import scala.math.min
15 |
16 | /**
17 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com)
18 | */
19 | class RDSchema(dbType: String) {
20 |
21 | private val repository = new SchemaRepository(dbType)
22 |
23 | def createTable(sql: String) = {
24 | repository.console(sql)
25 | val tableName = SQLUtils.parseStatements(sql, dbType).get(0).asInstanceOf[SQLCreateTableStatement].
26 | getTableSource.getName.getSimpleName
27 | SQLUtils.normalize(tableName)
28 | }
29 |
30 | def getTableSchema(table: String) = {
31 | val dialect = JdbcDialects.get(s"jdbc:${dbType}")
32 |
33 |
34 | def extractfieldSize = (dataType: SQLDataType) => {
35 | dataType.getArguments.asScala.map { f =>
36 | try {
37 | f.toString.toInt
38 | } catch {
39 | case e: Exception => 0
40 | }
41 |
42 | }.headOption
43 | }
44 |
45 | val fields = repository.findTable(table).getStatement.asInstanceOf[SQLCreateTableStatement].
46 | getTableElementList().asScala.filter(f => f.isInstanceOf[SQLColumnDefinition]).
47 | map {
48 | _.asInstanceOf[SQLColumnDefinition]
49 | }.map { column =>
50 |
51 | val columnName = column.getName.getSimpleName
52 | val dataType = RawDBTypeToJavaType.convert(dbType, column.getDataType.getName)
53 | val isNullable = !column.containsNotNullConstaint()
54 |
55 | val fieldSize = extractfieldSize(column.getDataType) match {
56 | case Some(i) => i
57 | case None => 0
58 | }
59 | val fieldScale = 0
60 |
61 | val columnType = dialect.getCatalystType(dataType, column.getDataType.getName, fieldSize, new MetadataBuilder()).
62 | getOrElse(
63 | getCatalystType(dataType, fieldSize, fieldScale, false))
64 |
65 | StructField(columnName, columnType, isNullable)
66 | }
67 | new StructType(fields.toArray)
68 |
69 | }
70 |
71 | private def getCatalystType(
72 | sqlType: Int,
73 | precision: Int,
74 | scale: Int,
75 | signed: Boolean): DataType = {
76 | val answer = sqlType match {
77 | // scalastyle:off
78 | case java.sql.Types.ARRAY => null
79 | case java.sql.Types.BIGINT => if (signed) {
80 | LongType
81 | } else {
82 | DecimalType(20, 0)
83 | }
84 | case java.sql.Types.BINARY => BinaryType
85 | case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
86 | case java.sql.Types.BLOB => BinaryType
87 | case java.sql.Types.BOOLEAN => BooleanType
88 | case java.sql.Types.CHAR => StringType
89 | case java.sql.Types.CLOB => StringType
90 | case java.sql.Types.DATALINK => null
91 | case java.sql.Types.DATE => DateType
92 | case java.sql.Types.DECIMAL
93 | if precision != 0 || scale != 0 => DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
94 | case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
95 | case java.sql.Types.DISTINCT => null
96 | case java.sql.Types.DOUBLE => DoubleType
97 | case java.sql.Types.FLOAT => FloatType
98 | case java.sql.Types.INTEGER => if (signed) {
99 | IntegerType
100 | } else {
101 | LongType
102 | }
103 | case java.sql.Types.JAVA_OBJECT => null
104 | case java.sql.Types.LONGNVARCHAR => StringType
105 | case java.sql.Types.LONGVARBINARY => BinaryType
106 | case java.sql.Types.LONGVARCHAR => StringType
107 | case java.sql.Types.NCHAR => StringType
108 | case java.sql.Types.NCLOB => StringType
109 | case java.sql.Types.NULL => null
110 | case java.sql.Types.NUMERIC
111 | if precision != 0 || scale != 0 => DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
112 | case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
113 | case java.sql.Types.NVARCHAR => StringType
114 | case java.sql.Types.OTHER => null
115 | case java.sql.Types.REAL => DoubleType
116 | case java.sql.Types.REF => StringType
117 | case java.sql.Types.REF_CURSOR => null
118 | case java.sql.Types.ROWID => LongType
119 | case java.sql.Types.SMALLINT => IntegerType
120 | case java.sql.Types.SQLXML => StringType
121 | case java.sql.Types.STRUCT => StringType
122 | case java.sql.Types.TIME => TimestampType
123 | case java.sql.Types.TIME_WITH_TIMEZONE
124 | => null
125 | case java.sql.Types.TIMESTAMP => TimestampType
126 | case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
127 | => null
128 | case java.sql.Types.TINYINT => IntegerType
129 | case java.sql.Types.VARBINARY => BinaryType
130 | case java.sql.Types.VARCHAR => StringType
131 | case _ =>
132 | throw new SQLException("Unrecognized SQL type " + sqlType)
133 | // scalastyle:on
134 | }
135 |
136 | if (answer == null) {
137 | throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName)
138 | }
139 | answer
140 | }
141 | }
142 |
143 |
144 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/app/RawDBTypeToJavaType.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.app
2 |
3 | import com.alibaba.druid.util.JdbcConstants
4 |
5 |
6 | /**
7 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com)
8 | */
9 | object RawDBTypeToJavaType {
10 | def convert(dbType: String, name: String) = {
11 | dbType match {
12 | case JdbcConstants.MYSQL => MysqlType.valueOf(name.toUpperCase()).getJdbcType
13 | case _ => throw new RuntimeException(s"dbType ${dbType} is not supported yet")
14 | }
15 |
16 | }
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/app/SchemaRegistry.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.app
2 |
3 | import com.alibaba.druid.util.JdbcConstants
4 | import org.apache.spark.sql.SparkSession
5 | import org.apache.spark.sql.types.{DataType, StructType}
6 | import tech.mlsql.autosuggest.AutoSuggestContext
7 | import tech.mlsql.autosuggest.meta.MetaTableKey
8 | import tech.mlsql.autosuggest.utils.SchemaUtils
9 |
10 | /**
11 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com)
12 | */
13 | class SchemaRegistry(_spark: SparkSession) {
14 | val spark = _spark
15 |
16 | def createTableFromDBSQL(prefix: Option[String], db: Option[String], tableName: String, createSQL: String) = {
17 | val rd = new RDSchema(JdbcConstants.MYSQL)
18 | val tableName = rd.createTable(createSQL)
19 | val schema = rd.getTableSchema(tableName)
20 | val table = MetaTableKey(prefix, db, tableName)
21 | AutoSuggestContext.memoryMetaProvider.register(table, SchemaUtils.toMetaTable(table, schema));
22 | }
23 |
24 | def createTableFromHiveSQL(prefix: Option[String], db: Option[String], tableName: String, createSQL: String) = {
25 | spark.sql(createSQL)
26 | val schema = spark.table(tableName).schema
27 | val table = MetaTableKey(prefix, db, tableName)
28 | AutoSuggestContext.memoryMetaProvider.register(table, SchemaUtils.toMetaTable(table, schema));
29 |
30 | }
31 |
32 |
33 | def createTableFromJson(prefix: Option[String], db: Option[String], tableName: String, schemaJson: String) = {
34 | val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]
35 | val table = MetaTableKey(prefix, db, tableName)
36 | AutoSuggestContext.memoryMetaProvider.register(table, SchemaUtils.toMetaTable(table, schema));
37 | }
38 |
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/app/Standalone.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.app
2 |
3 | import net.csdn.ServiceFramwork
4 | import net.csdn.annotation.rest.At
5 | import net.csdn.bootstrap.Application
6 | import net.csdn.modules.http.RestRequest.Method
7 | import net.csdn.modules.http.{ApplicationController, ViewType}
8 | import org.apache.spark.sql.SparkSession
9 | import streaming.dsl.ScriptSQLExec
10 | import tech.mlsql.autosuggest.AutoSuggestContext
11 | import tech.mlsql.common.utils.shell.command.ParamsUtil
12 |
13 | import scala.collection.JavaConverters._
14 |
15 | /**
16 | * 11/6/2020 WilliamZhu(allwefantasy@gmail.com)
17 | */
18 | object Standalone extends {
19 | var sparkSession: SparkSession = null
20 |
21 | def main(args: Array[String]): Unit = {
22 | AutoSuggestContext.init
23 | val params = new ParamsUtil(args)
24 | val enableMLSQL = params.getParam("enableMLSQL", "false").toBoolean
25 | if (enableMLSQL) {
26 | sparkSession = SparkSession.builder().appName("local").master("local[*]").getOrCreate()
27 | }
28 | val applicationYamlName = params.getParam("config", "application.yml")
29 |
30 | ServiceFramwork.applicaionYamlName(applicationYamlName)
31 | ServiceFramwork.scanService.setLoader(classOf[Standalone])
32 | Application.main(args)
33 |
34 | }
35 | }
36 |
37 | class Standalone
38 |
39 | class SuggestController extends ApplicationController {
40 | @At(path = Array("/run/script"), types = Array(Method.POST))
41 | def runScript = {
42 |
43 | if (param("executeMode", "") == "registerTable") {
44 | val respStr = new RegisterTableController().run(params().asScala.toMap)
45 | render(200, respStr, ViewType.json)
46 |
47 | }
48 | if (param("executeMode", "") == "sqlFunctions") {
49 | val respStr = new SQLFunctionController().run(params().asScala.toMap)
50 | render(200, respStr, ViewType.json)
51 |
52 | }
53 | val respStr = new AutoSuggestController().run(params.asScala.toMap)
54 | render(200, respStr, ViewType.json)
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/ast/NoneToken.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.ast
2 |
3 | import org.antlr.v4.runtime.{CharStream, Token, TokenSource}
4 |
5 | /**
6 | * 24/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | class NoneToken(token: Token) extends Token {
9 | override def getText: String = NoneToken.TEXT
10 |
11 | override def getType: Int = NoneToken.TYPE
12 |
13 | override def getLine: Int = token.getLine
14 |
15 | override def getCharPositionInLine: Int = token.getCharPositionInLine
16 |
17 | override def getChannel: Int = token.getChannel
18 |
19 | override def getTokenIndex: Int = token.getTokenIndex
20 |
21 | override def getStartIndex: Int = token.getStartIndex
22 |
23 | override def getStopIndex: Int = token.getStopIndex
24 |
25 | override def getTokenSource: TokenSource = token.getTokenSource
26 |
27 | override def getInputStream: CharStream = token.getInputStream
28 | }
29 |
30 | object NoneToken {
31 | val TEXT = "__NONE__"
32 | val TYPE = -3
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/ast/TableTree.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.ast
2 |
3 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableKey}
4 | import tech.mlsql.autosuggest.statement.MetaTableKeyWrapper
5 |
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | object TableTree {
9 | def ROOT = {
10 | TableTree(-1, MetaTableKeyWrapper(MetaTableKey(None, None, null), None), None, ArrayBuffer())
11 | }
12 |
13 | }
14 |
15 | case class TableTree(level: Int, key: MetaTableKeyWrapper, table: Option[MetaTable], subNodes: ArrayBuffer[TableTree]) {
16 |
17 |
18 | def children = subNodes
19 |
20 | def collectByLevel(level: Int) = {
21 | val buffer = ArrayBuffer[TableTree]()
22 | visitDown(0) { case (table, _level) =>
23 | if (_level == level) {
24 | buffer += table
25 | }
26 | }
27 | buffer.toList
28 | }
29 |
30 | def visitDown(level: Int)(rule: PartialFunction[(TableTree, Int), Unit]): Unit = {
31 | rule.apply((this, level))
32 | this.children.map(_.visitDown(level + 1)(rule))
33 | }
34 |
35 | def visitUp(level: Int)(rule: PartialFunction[(TableTree, Int), Unit]): Unit = {
36 | this.children.map(_.visitUp(level + 1)(rule))
37 | rule.apply((this, level))
38 | }
39 |
40 | def fastEquals(other: TableTree): Boolean = {
41 | this.eq(other) || this == other
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/dsl/TokenMatcher.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.dsl
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
5 | import streaming.dsl.parser.DSLSQLLexer
6 |
7 | import scala.collection.mutable.ArrayBuffer
8 |
9 | /**
10 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com)
11 | *
12 | */
13 | class TokenMatcher(tokens: List[Token], val start: Int) {
14 | val foods = ArrayBuffer[FoodWrapper]()
15 | var cacheResult = -2
16 | private var direction: String = MatcherDirection.FORWARD
17 |
18 | def forward = {
19 | assert(foods.size == 0, "this function should be invoke before eat")
20 | direction = MatcherDirection.FORWARD
21 | this
22 | }
23 |
24 | def back = {
25 | assert(foods.size == 0, "this function should be invoke before eat")
26 | direction = MatcherDirection.BACK
27 | this
28 | }
29 |
30 | def eat(food: Food*) = {
31 | foods += FoodWrapper(AndOrFood(food.toList, true), false)
32 | this
33 | }
34 |
35 | def eatOneAny = {
36 | foods += FoodWrapper(AndOrFood(List(Food(None, -2)), true), false)
37 | this
38 | }
39 |
40 | /**
41 | *
42 | * 一直前进 直到遇到我们需要的,成功返回最后的index值,否则返回-1
43 | */
44 | def orIndex(_foods: Array[Food], upperBound: Int = tokens.size) = {
45 | if (foods.size != 0) {
46 | throw new RuntimeException("eat/optional/asStart should not before index")
47 | }
48 | direction match {
49 | case MatcherDirection.FORWARD =>
50 | var targetIndex = -1
51 | (start until upperBound).foreach { idx =>
52 | if (targetIndex == -1) {
53 | // step by step until success
54 | var matchValue = -1
55 | _foods.zipWithIndex.foreach { case (food, _) =>
56 | if (matchValue == -1 && matchToken(food, idx) != -1) {
57 | matchValue = 0
58 | }
59 | }
60 | if (matchValue != -1) {
61 | targetIndex = idx
62 | }
63 | }
64 |
65 | }
66 | targetIndex
67 | case MatcherDirection.BACK =>
68 | var _start = start
69 | var targetIndex = -1
70 | while (_start >= 0) {
71 | if (targetIndex == -1) {
72 | // step by step until success
73 | var matchValue = -1
74 | _foods.zipWithIndex.foreach { case (food, _) =>
75 | if (matchValue == -1 && matchToken(food, _start) != -1) {
76 | matchValue = 0
77 | }
78 | }
79 | if (matchValue != -1) {
80 | targetIndex = _start
81 | }
82 | }
83 | _start = _start - 1
84 | }
85 | targetIndex
86 | }
87 |
88 | }
89 |
90 | // find the first match
91 | def index(_foods: Array[Food], upperBound: Int = tokens.size) = {
92 | if (foods.size != 0) {
93 | throw new RuntimeException("eat/optional/asStart should not before index")
94 | }
95 | assert(direction == MatcherDirection.FORWARD, "index only support forward")
96 | var targetIndex = -1
97 | (start until upperBound).foreach { idx =>
98 | if (targetIndex == -1) {
99 | // step by step until success
100 | var matchValue = 0
101 | _foods.zipWithIndex.foreach { case (food, idx2) =>
102 | if (matchValue == 0 && matchToken(food, idx + idx2) == -1) {
103 | matchValue = -1
104 | }
105 | }
106 | if (matchValue != -1) {
107 | targetIndex = idx
108 | }
109 | }
110 |
111 | }
112 | targetIndex
113 |
114 | }
115 |
116 | def asStart(food: Food, offset: Int = 0) = {
117 | if (foods.size != 0) {
118 | throw new RuntimeException("eat/optional should not before asStart")
119 | }
120 | var targetIndex = -1
121 | (start until tokens.size).foreach { idx =>
122 | if (targetIndex == -1) {
123 | val index = matchToken(food, idx)
124 | if (index != -1) {
125 | targetIndex = index
126 | }
127 | }
128 |
129 | }
130 | TokenMatcher(tokens, targetIndex + offset)
131 | }
132 |
133 | def optional = {
134 | foods.lastOption.foreach(_.optional = true)
135 | this
136 | }
137 |
138 | private def matchToken(food: Food, currentIndex: Int): Int = {
139 | if (currentIndex < 0) return -1
140 | if (currentIndex >= tokens.size) return -1
141 | if (food.tp == -2) {
142 | return currentIndex
143 | }
144 | food.name match {
145 | case Some(name) => if (tokens(currentIndex).getType == food.tp && tokens(currentIndex).getText == name) {
146 | currentIndex
147 | } else -1
148 | case None =>
149 | if (tokens(currentIndex).getType == food.tp) {
150 | currentIndex
151 | } else -1
152 | }
153 | }
154 |
155 | private def forwardBuild: TokenMatcher = {
156 | var currentIndex = start
157 | var isFail = false
158 |
159 |
160 | foods.foreach { foodw =>
161 |
162 | if (currentIndex >= tokens.size && !foodw.optional) {
163 | isFail = true
164 | } else {
165 | val stepSize = foodw.foods.count
166 | var matchValue = 0
167 | foodw.foods.foods.zipWithIndex.foreach { case (food, idx) =>
168 | if (matchValue == 0 && matchToken(food, currentIndex + idx) == -1) {
169 | matchValue = -1
170 | }
171 | }
172 | if (foodw.optional) {
173 | if (matchValue != -1) {
174 | currentIndex = currentIndex + stepSize
175 | }
176 | } else {
177 | if (matchValue != -1) {
178 | currentIndex = currentIndex + stepSize
179 |
180 | } else {
181 | //mark fail
182 | isFail = true
183 | }
184 | }
185 | }
186 | }
187 |
188 | val targetIndex = if (isFail) -1 else currentIndex
189 | cacheResult = targetIndex
190 | this
191 | }
192 |
193 | private def backBuild: TokenMatcher = {
194 | var currentIndex = start
195 | var isFail = false
196 |
197 |
198 | foods.foreach { foodw =>
199 | // if out of bound then mark fail
200 | if (currentIndex <= -1 && !foodw.optional) {
201 | isFail = true
202 | } else {
203 | val stepSize = foodw.foods.count
204 | var matchValue = 0
205 | foodw.foods.foods.zipWithIndex.foreach { case (food, idx) =>
206 | if (matchValue == 0 && matchToken(food, currentIndex - idx) == -1) {
207 | matchValue = -1
208 | }
209 | }
210 | if (foodw.optional) {
211 | if (matchValue != -1) {
212 | currentIndex = currentIndex - stepSize
213 | }
214 | } else {
215 | if (matchValue != -1) {
216 | currentIndex = currentIndex - stepSize
217 |
218 | } else {
219 | //mark fail
220 | isFail = true
221 | }
222 | }
223 | }
224 |
225 | }
226 |
227 | if (!isFail && currentIndex == -1) {
228 | currentIndex = 0
229 | }
230 | val targetIndex = if (isFail) -1 else currentIndex
231 | cacheResult = targetIndex
232 | this
233 | }
234 |
235 | def build: TokenMatcher = {
236 | direction match {
237 | case MatcherDirection.FORWARD =>
238 | forwardBuild
239 | case MatcherDirection.BACK =>
240 | backBuild
241 | }
242 | }
243 |
244 | def get = {
245 | if (this.cacheResult == -2) this.build
246 | this.cacheResult
247 | }
248 |
249 | def isSuccess = {
250 | if (this.cacheResult == -2) this.build
251 | this.cacheResult != -1
252 | }
253 |
254 | def getMatchTokens = {
255 | direction match {
256 | case MatcherDirection.BACK => tokens.slice(get + 1, start + 1)
257 | case MatcherDirection.FORWARD => tokens.slice(start, get)
258 | }
259 |
260 | }
261 | }
262 |
263 | object MatcherDirection {
264 | val FORWARD = "forward"
265 | val BACK = "back"
266 | }
267 |
268 | object TokenTypeWrapper {
269 | val LEFT_BRACKET = SqlBaseLexer.T__0 //(
270 | val RIGHT_BRACKET = SqlBaseLexer.T__1 //)
271 | val COMMA = SqlBaseLexer.T__2 //,
272 | val DOT = SqlBaseLexer.T__3 //.
273 | val LEFT_SQUARE_BRACKET = SqlBaseLexer.T__7 //[
274 | val RIGHT_SQUARE_BRACKET = SqlBaseLexer.T__8 //]
275 | val COLON = SqlBaseLexer.T__9 //:
276 |
277 | val LIST = List(LEFT_BRACKET, RIGHT_BRACKET, COMMA, DOT, LEFT_SQUARE_BRACKET, RIGHT_SQUARE_BRACKET, COLON)
278 | val MAP = LIST.map((_, 1)).toMap
279 | }
280 |
281 | object MLSQLTokenTypeWrapper {
282 | val DOT = DSLSQLLexer.T__0
283 | }
284 |
285 | object TokenMatcher {
286 | def apply(tokens: List[Token], start: Int): TokenMatcher = new TokenMatcher(tokens, start)
287 |
288 | def resultMatcher(tokens: List[Token], start: Int, stop: Int) = {
289 | val temp = new TokenMatcher(tokens, start)
290 | temp.cacheResult = stop
291 | temp
292 | }
293 |
294 | def SQL_SPLITTER_KEY_WORDS = List(
295 | SqlBaseLexer.SELECT,
296 | SqlBaseLexer.FROM,
297 | SqlBaseLexer.JOIN,
298 | SqlBaseLexer.WHERE,
299 | SqlBaseLexer.GROUP,
300 | SqlBaseLexer.ON,
301 | SqlBaseLexer.BY,
302 | SqlBaseLexer.LIMIT,
303 | SqlBaseLexer.ORDER
304 | )
305 | }
306 |
307 | case class Food(name: Option[String], tp: Int)
308 |
309 | case class FoodWrapper(foods: AndOrFood, var optional: Boolean)
310 |
311 | case class AndOrFood(foods: List[Food], var and: Boolean) {
312 | def count = {
313 | if (and) foods.size
314 | else 1
315 | }
316 | }
317 |
318 | object TokenWalker {
319 | def apply(tokens: List[Token], start: Int): TokenWalker = new TokenWalker(tokens, start)
320 | }
321 |
322 | class TokenWalker(tokens: List[Token], start: Int) {
323 |
324 | var currentToken: Option[Token] = Option(tokens(start))
325 | var currentIndex = start
326 |
327 | def nextSafe: TokenWalker = {
328 | val token = if ((currentIndex + 1) < tokens.size) {
329 | currentIndex += 1
330 | Option(tokens(currentIndex))
331 | } else None
332 | currentToken = token
333 | this
334 | }
335 |
336 | def range: TokenCharRange = {
337 | if (currentToken.isEmpty) return TokenCharRange(-1, -1)
338 | val start = currentToken.get.getCharPositionInLine
339 | val end = currentToken.get.getCharPositionInLine + currentToken.get.getText.size
340 | TokenCharRange(start, end)
341 | }
342 | }
343 |
344 | case class TokenCharRange(start: Int, end: Int)
345 |
346 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/funcs/Count.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.funcs
2 |
3 | import tech.mlsql.autosuggest.{DataType, FuncReg, MLSQLSQLFunction}
4 |
5 | /**
6 | * 14/7/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | class Count extends FuncReg {
9 | override def register = {
10 |
11 | val func = MLSQLSQLFunction.apply("count").desc(Map(
12 | "zhDoc" ->
13 | """
14 | |count,统计数目,可单独或者配合group by 使用
15 | |""".stripMargin,
16 | IS_AGG -> YES
17 | )).
18 | funcParam.
19 | param("column", DataType.STRING, false, Map("zhDoc" -> "列名或者*或者常熟",COLUMN->YES)).
20 | func.
21 | returnParam(DataType.NUMBER, true, Map(
22 | "zhDoc" ->
23 | """
24 | |long类型数字
25 | |""".stripMargin
26 | )).
27 | build
28 | func
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/funcs/Splitter.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.funcs
2 |
3 | import tech.mlsql.autosuggest.{DataType, FuncReg, MLSQLSQLFunction}
4 |
5 | /**
6 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | class Splitter extends FuncReg {
9 |
10 | override def register = {
11 | val func = MLSQLSQLFunction.apply("split").desc(Map(
12 | "zhDoc" ->
13 | """
14 | |split函数。用于切割字符串,返回字符串数组.
15 | |示例: split("a,b",",") == [a,b]
16 | |""".stripMargin
17 | )).
18 | funcParam.
19 | param("str", DataType.STRING, false, Map("zhDoc" -> "待切割字符", COLUMN -> YES)).
20 | param("pattern", DataType.STRING, false, Map("zhDoc" -> "分隔符", DEFAULT_VALUE -> ",")).
21 | func.
22 | returnParam(DataType.ARRAY, true, Map(
23 | "zhDoc" ->
24 | """
25 | |返回值是数组类型
26 | |""".stripMargin
27 | )).
28 | build
29 | func
30 | }
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/LayeredMetaProvider.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | /**
4 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | class LayeredMetaProvider(tempTableProvider: StatementTempTableProvider, userDefinedProvider: MetaProvider) extends MetaProvider {
7 | def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = {
8 | tempTableProvider.search(key).orElse {
9 | userDefinedProvider.search(key)
10 | }
11 | }
12 |
13 | def list(extra: Map[String, String] = Map()): List[MetaTable] = {
14 | tempTableProvider.list(extra) ++ userDefinedProvider.list(extra)
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/MLSQLEngineMetaProvider.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | import java.util.UUID
4 |
5 | import org.apache.http.client.fluent.{Form, Request}
6 | import org.apache.http.util.EntityUtils
7 | import tech.mlsql.autosuggest.AutoSuggestContext
8 | import tech.mlsql.common.utils.serder.json.JSONTool
9 |
10 |
11 | /**
12 | * 26/7/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class MLSQLEngineMetaProvider() extends MetaProvider {
15 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
16 | val form = Form.form()
17 |
18 | if (key.prefix.isEmpty) return None
19 | var path = ""
20 |
21 | if (key.db.isDefined) {
22 | path += key.db.get + "."
23 | }
24 |
25 | path += key.table
26 | val tableName = UUID.randomUUID().toString.replaceAll("-", "")
27 |
28 | val sql =
29 | s"""
30 | |load ${key.prefix.get}.`${path}` where header="true" as ${tableName};!desc ${tableName};
31 | |""".stripMargin
32 | val params = JSONTool.parseJson[Map[String, String]](AutoSuggestContext.context().options("params"))
33 | params.foreach { case (k, v) =>
34 | if (k != "sql" && k!= "executeMode") {
35 | form.add(k, v)
36 | }
37 | }
38 | form.add("sql", sql)
39 | val resp = Request.Post(params("schemaInferUrl")).bodyForm(form.build()).execute().returnResponse()
40 | if (resp.getStatusLine.getStatusCode == 200) {
41 | val str = EntityUtils.toString(resp.getEntity)
42 | val columns = JSONTool.parseJson[List[TableSchemaColumn]](str)
43 | val table = MetaTable(key, columns.map { item =>
44 | MetaTableColumn(item.col_name, item.data_type, true, Map())
45 | })
46 | Option(table)
47 | } else None
48 | }
49 |
50 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = {
51 | List()
52 | }
53 | }
54 |
55 | case class TableSchemaColumn(col_name: String, data_type: String, comment: Option[String])
56 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/MemoryMetaProvider.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | import scala.collection.JavaConverters._
4 |
5 | /**
6 | * 15/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | class MemoryMetaProvider extends MetaProvider {
9 | private val cache = new java.util.concurrent.ConcurrentHashMap[MetaTableKey, MetaTable]()
10 |
11 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = {
12 | if (cache.containsKey(key)) Option(cache.get(key)) else None
13 | }
14 |
15 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = {
16 | cache.values().asScala.toList
17 | }
18 |
19 | def register(key: MetaTableKey, value: MetaTable) = {
20 | cache.put(key, value)
21 | this
22 | }
23 |
24 | def unRegister(key: MetaTableKey) = {
25 | cache.remove(key)
26 | this
27 | }
28 |
29 | def clear = {
30 | cache.clear()
31 | this
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/MetaProvider.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | /**
4 | * Function should also be table and at the same time
5 | * the columns are treated as parameters.
6 | * Without MetaProvider supporting, we can not
7 | * suggest column and functions .
8 | *
9 | * If the search returns None, this means it's not a table
10 | * or the table we are searching is not exists.
11 | */
12 | trait MetaProvider {
13 | def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable]
14 |
15 | def list(extra: Map[String, String] = Map()): List[MetaTable]
16 | }
17 |
18 |
19 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/RestMetaProvider.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | import org.apache.http.Header
4 | import org.apache.http.client.fluent.{Form, Request}
5 | import org.apache.http.util.EntityUtils
6 | import tech.mlsql.common.utils.serder.json.JSONTool
7 |
8 | /**
9 | * 15/6/2020 WilliamZhu(allwefantasy@gmail.com)
10 | */
11 | class RestMetaProvider(searchUrl: String, listUrl: String) extends MetaProvider {
12 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = {
13 | val form = Form.form()
14 | if (key.prefix.isDefined) {
15 | form.add("prefix", key.prefix.get)
16 | }
17 | if (key.db.isDefined) {
18 | form.add("db", key.db.get)
19 | }
20 | form.add("table", key.table)
21 | val resp = Request.Post(searchUrl).bodyForm(form.build()).execute().returnResponse()
22 | if (resp.getStatusLine.getStatusCode == 200) {
23 | val metaTable = JSONTool.parseJson[MetaTable](EntityUtils.toString(resp.getEntity))
24 | Option(metaTable)
25 | } else None
26 | }
27 |
28 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = {
29 | val form = Form.form()
30 | extra.foreach { case (k, v) =>
31 | form.add(k, v)
32 | }
33 | val resp = Request.Post(listUrl).addHeader("","").bodyForm(form.build()).execute().returnResponse()
34 | if (resp.getStatusLine.getStatusCode == 200) {
35 | val metaTables = JSONTool.parseJson[List[MetaTable]](EntityUtils.toString(resp.getEntity))
36 | metaTables
37 | } else List()
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/StatementTempTableProvider.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | /**
4 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | class StatementTempTableProvider extends MetaProvider {
7 | private val cache = scala.collection.mutable.HashMap[String, MetaTable]()
8 |
9 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = {
10 | cache.get(key.table)
11 | }
12 |
13 | def register(name: String, metaTable: MetaTable) = {
14 | cache += (name -> metaTable)
15 | this
16 | }
17 |
18 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = cache.values.toList
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/meta/meta_protocal.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.meta
2 |
3 | case class MetaTableKey(prefix: Option[String], db: Option[String], table: String)
4 |
5 | case class MetaTable(key: MetaTableKey, columns: List[MetaTableColumn])
6 |
7 | case class MetaTableColumn(name: String, dataType: String, isNull: Boolean, extra: Map[String, String])
8 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/preprocess/TablePreprocessor.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.preprocess
2 |
3 | import org.antlr.v4.runtime.Token
4 | import streaming.dsl.parser.DSLSQLLexer
5 | import tech.mlsql.autosuggest.SpecialTableConst.TEMP_TABLE_DB_KEY
6 | import tech.mlsql.autosuggest.dsl.{Food, MLSQLTokenTypeWrapper, TokenMatcher}
7 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableColumn, MetaTableKey}
8 | import tech.mlsql.autosuggest.statement.{PreProcessStatement, SelectSuggester}
9 | import tech.mlsql.autosuggest.{AutoSuggestContext, SpecialTableConst, TokenPos, TokenPosType}
10 |
11 | /**
12 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class TablePreprocessor(context: AutoSuggestContext) extends PreProcessStatement {
15 |
16 | def cleanStr(str: String) = {
17 | if (str.startsWith("`") || str.startsWith("\"") || (str.startsWith("'") && !str.startsWith("'''")))
18 | str.substring(1, str.length - 1)
19 | else str
20 | }
21 |
22 | /**
23 | *
24 | * load 语句和select语句比较特殊
25 | * Load语句要获取 真实表
26 | * select 语句要获取最后的select 语句
27 | *
28 | * load语句获取真实表的时候会加一个prefix前缀,该值等于load语句里的format
29 | */
30 | def process(statement: List[Token]): Unit = {
31 | val tempTableProvider = context.tempTableProvider
32 | val tempMatcher = TokenMatcher(statement, statement.size - 2).back.eat(Food(None, DSLSQLLexer.IDENTIFIER), Food(None, DSLSQLLexer.AS)).build
33 |
34 | if (tempMatcher.isSuccess) {
35 | val tableName = tempMatcher.getMatchTokens.last.getText
36 | val defaultTable = SpecialTableConst.tempTable(tableName)
37 | val table = statement(0).getText.toLowerCase match {
38 | case "load" =>
39 | val formatMatcher = TokenMatcher(statement, 1).
40 | eat(Food(None, DSLSQLLexer.IDENTIFIER),
41 | Food(None, MLSQLTokenTypeWrapper.DOT),
42 | Food(None, DSLSQLLexer.BACKQUOTED_IDENTIFIER)).build
43 | if (formatMatcher.isSuccess) {
44 |
45 | formatMatcher.getMatchTokens.map(_.getText) match {
46 | case List(format, _, path) =>
47 | cleanStr(path).split("\\.", 2) match {
48 | case Array(db, table) =>
49 | // if(context.isSchemaInferEnabled){
50 | //
51 | // }
52 | context.metaProvider.search(MetaTableKey(Option(format), Option(db), table)).getOrElse(defaultTable)
53 | case Array(table) =>
54 | context.metaProvider.search(MetaTableKey(Option(format), None, table)).getOrElse(defaultTable)
55 | }
56 | }
57 | } else {
58 | defaultTable
59 | }
60 | case "select" =>
61 | //statement.size - 3 是为了移除 最后的as table语句
62 | val selectSuggester = new SelectSuggester(context, statement.slice(0, statement.size - 3), TokenPos(0, TokenPosType.NEXT, -1))
63 | val columns = selectSuggester.sqlAST.output(selectSuggester.tokens).map { name =>
64 | MetaTableColumn(name, null, true, Map())
65 | }
66 | MetaTable(MetaTableKey(None, Option(TEMP_TABLE_DB_KEY), tableName), columns)
67 | case _ => defaultTable
68 | }
69 |
70 | tempTableProvider.register(tableName, table)
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/LexerUtils.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.antlr.v4.runtime.misc.Interval
5 | import streaming.dsl.parser.DSLSQLLexer
6 | import tech.mlsql.autosuggest.dsl.{MLSQLTokenTypeWrapper, TokenTypeWrapper}
7 | import tech.mlsql.autosuggest.{AutoSuggestContext, TokenPos, TokenPosType}
8 |
9 | import scala.collection.JavaConverters._
10 |
11 | /**
12 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | object LexerUtils {
15 |
16 | def toRawSQLTokens(autoSuggestContext: AutoSuggestContext, wow: List[Token]): List[Token] = {
17 | val originalText = toRawSQLStr(autoSuggestContext, wow)
18 | val newTokens = autoSuggestContext.rawSQLLexer.tokenizeNonDefaultChannel(originalText).tokens.asScala.toList
19 | return newTokens
20 | }
21 |
22 | def toRawSQLStr(autoSuggestContext: AutoSuggestContext, wow: List[Token]): String = {
23 | val start = wow.head.getStartIndex
24 | val stop = wow.last.getStopIndex
25 |
26 | val input = wow.head.getTokenSource.asInstanceOf[DSLSQLLexer]._input
27 | val interval = new Interval(start, stop)
28 | val originalText = input.getText(interval)
29 | originalText
30 | }
31 |
32 | def filterPrefixIfNeeded(candidates: List[SuggestItem], tokens: List[Token], tokenPos: TokenPos) = {
33 | if (tokenPos.offsetInToken != 0) {
34 | candidates.filter(s => s.name.startsWith(tokens(tokenPos.pos).getText.substring(0, tokenPos.offsetInToken)))
35 | } else candidates
36 | }
37 |
38 | def tableTokenPrefix(tokens: List[Token], tokenPos: TokenPos): String = {
39 | var temp = tokens(tokenPos.pos).getText.substring(0, tokenPos.offsetInToken)
40 | if (tokenPos.pos > 1 && tokens(tokenPos.pos - 1).getType == TokenTypeWrapper.DOT) {
41 | temp = tokens(tokenPos.pos - 2).getText + "." + temp
42 | }
43 | temp
44 | }
45 |
46 |
47 | /**
48 | *
49 | * @param tokens
50 | * @param lineNum 行号,从1开始计数
51 | * @param colNum 列号,从1开始计数
52 | * @return TokenPos 中的pos则是从0开始计数
53 | */
54 | def toTokenPos(tokens: List[Token], lineNum: Int, colNum: Int): TokenPos = {
55 | /**
56 | * load hi[cursor]... in token
57 | * load [cursor] out token
58 | * load[cursor] in token
59 | */
60 |
61 | if (tokens.size == 0) {
62 | return TokenPos(-1, TokenPosType.NEXT, -1)
63 | }
64 |
65 | val oneLineTokens = tokens.zipWithIndex.filter { case (token, index) =>
66 | token.getLine == lineNum
67 | }
68 |
69 | val firstToken = oneLineTokens.headOption match {
70 | case Some(head) => head
71 | case None =>
72 | tokens.zipWithIndex.filter { case (token, index) =>
73 | token.getLine == lineNum - 1
74 | }.head
75 | }
76 | val lastToken = oneLineTokens.lastOption match {
77 | case Some(last) => last
78 | case None =>
79 | tokens.zipWithIndex.filter { case (token, index) =>
80 | token.getLine == lineNum + 1
81 | }.last
82 | }
83 |
84 | if (colNum < firstToken._1.getCharPositionInLine) {
85 | return TokenPos(firstToken._2 - 1, TokenPosType.NEXT, 0)
86 | }
87 |
88 | if (colNum > lastToken._1.getCharPositionInLine + lastToken._1.getText.size) {
89 | return TokenPos(lastToken._2, TokenPosType.NEXT, 0)
90 | }
91 |
92 | if (colNum > lastToken._1.getCharPositionInLine
93 | && colNum <= lastToken._1.getCharPositionInLine + lastToken._1.getText.size
94 | &&
95 | (lastToken._1.getType != DSLSQLLexer.UNRECOGNIZED
96 | && lastToken._1.getType != MLSQLTokenTypeWrapper.DOT)
97 | ) {
98 | return TokenPos(lastToken._2, TokenPosType.CURRENT, colNum - lastToken._1.getCharPositionInLine)
99 | }
100 | oneLineTokens.map { case (token, index) =>
101 | val start = token.getCharPositionInLine
102 | val end = token.getCharPositionInLine + token.getText.size
103 | //紧邻一个token的后面,没有空格,一般情况下是当做前一个token的一部分,用户还没写完,但是如果
104 | //这个token是 [(,).]等,则不算
105 | if (colNum == end && (1 <= token.getType)
106 | && (
107 | token.getType == DSLSQLLexer.UNRECOGNIZED
108 | || token.getType == MLSQLTokenTypeWrapper.DOT
109 | )) {
110 | TokenPos(index, TokenPosType.NEXT, 0)
111 | } else if (start < colNum && colNum <= end) {
112 | // in token
113 | TokenPos(index, TokenPosType.CURRENT, colNum - start)
114 | } else if (colNum <= start) {
115 | TokenPos(index - 1, TokenPosType.NEXT, 0)
116 | } else {
117 | TokenPos(-2, -2, -2)
118 | }
119 |
120 |
121 | }.filterNot(_.pos == -2).head
122 | }
123 |
124 | def toTokenPosForSparkSQL(tokens: List[Token], lineNum: Int, colNum: Int): TokenPos = {
125 | /**
126 | * load hi[cursor]... in token
127 | * load [cursor] out token
128 | * load[cursor] in token
129 | */
130 |
131 | if (tokens.size == 0) {
132 | return TokenPos(-1, TokenPosType.NEXT, -1)
133 | }
134 |
135 | val oneLineTokens = tokens.zipWithIndex.filter { case (token, index) =>
136 | token.getLine == lineNum
137 | }
138 |
139 | val firstToken = oneLineTokens.headOption match {
140 | case Some(head) => head
141 | case None =>
142 | tokens.zipWithIndex.filter { case (token, index) =>
143 | token.getLine == lineNum - 1
144 | }.head
145 | }
146 | val lastToken = oneLineTokens.lastOption match {
147 | case Some(last) => last
148 | case None =>
149 | tokens.zipWithIndex.filter { case (token, index) =>
150 | token.getLine == lineNum + 1
151 | }.last
152 | }
153 |
154 | if (colNum < firstToken._1.getCharPositionInLine) {
155 | return TokenPos(firstToken._2 - 1, TokenPosType.NEXT, 0)
156 | }
157 |
158 | if (colNum > lastToken._1.getCharPositionInLine + lastToken._1.getText.size) {
159 | return TokenPos(lastToken._2, TokenPosType.NEXT, 0)
160 | }
161 |
162 | if (colNum > lastToken._1.getCharPositionInLine
163 | && colNum <= lastToken._1.getCharPositionInLine + lastToken._1.getText.size
164 | && !TokenTypeWrapper.MAP.contains(lastToken._1.getType)
165 |
166 | ) {
167 | return TokenPos(lastToken._2, TokenPosType.CURRENT, colNum - lastToken._1.getCharPositionInLine)
168 | }
169 | oneLineTokens.map { case (token, index) =>
170 | val start = token.getCharPositionInLine
171 | val end = token.getCharPositionInLine + token.getText.size
172 | //紧邻一个token的后面,没有空格,一般情况下是当做前一个token的一部分,用户还没写完,但是如果
173 | //这个token是 [(,).]等,则不算
174 | if (colNum == end && (1 <= token.getType)
175 | && (
176 | TokenTypeWrapper.MAP.contains(token.getType)
177 | )) {
178 | TokenPos(index, TokenPosType.NEXT, 0)
179 | } else if (start < colNum && colNum <= end) {
180 | // in token
181 | TokenPos(index, TokenPosType.CURRENT, colNum - start)
182 | } else if (colNum <= start) {
183 | TokenPos(index - 1, TokenPosType.NEXT, 0)
184 | } else {
185 | TokenPos(-2, -2, -2)
186 | }
187 |
188 |
189 | }.filterNot(_.pos == -2).head
190 | }
191 |
192 | def isInWhereContext(tokens: List[Token], tokenPos: Int): Boolean = {
193 | if (tokenPos < 1) return false
194 | var wherePos = -1
195 | (1 until tokenPos).foreach { index =>
196 | if (tokens(index).getType == DSLSQLLexer.WHERE || tokens(index).getType == DSLSQLLexer.OPTIONS) {
197 | wherePos = index
198 | }
199 | }
200 | if (wherePos != -1) {
201 | if (wherePos == tokenPos || wherePos == tokenPos - 1) return true
202 | val noEnd = (wherePos until tokenPos).filter(index =>
203 | tokens(index).getType != DSLSQLLexer.AS && tokens(index).getType != DSLSQLLexer.PARTITIONBY).isEmpty
204 | if (noEnd) return true
205 |
206 | }
207 |
208 | return false
209 | }
210 |
211 | def isWhereKey(tokens: List[Token], tokenPos: Int): Boolean = {
212 | LexerUtils.isInWhereContext(tokens, tokenPos) && (tokens(tokenPos).getText == "and" || tokens(tokenPos - 1).getText == "and")
213 |
214 | }
215 | }
216 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/LoadSuggester.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import streaming.core.datasource.{DataSourceRegistry, MLSQLSourceInfo}
5 | import streaming.dsl.parser.DSLSQLLexer
6 | import tech.mlsql.autosuggest.dsl.{Food, MLSQLTokenTypeWrapper, TokenMatcher}
7 | import tech.mlsql.autosuggest.{AutoSuggestContext, SpecialTableConst, TokenPos, TokenPosType}
8 |
9 | import scala.collection.mutable
10 |
11 | /**
12 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | *
14 | *
15 | */
16 | class LoadSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val _tokenPos: TokenPos) extends StatementSuggester
17 | with SuggesterRegister {
18 |
19 | private val subInstances = new mutable.HashMap[String, StatementSuggester]()
20 |
21 | register(classOf[LoadPathSuggester])
22 | register(classOf[LoadFormatSuggester])
23 | register(classOf[LoadOptionsSuggester])
24 | register(classOf[LoadPathQuoteSuggester])
25 |
26 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = {
27 | val instance = clzz.getConstructor(classOf[LoadSuggester]).newInstance(this).asInstanceOf[StatementSuggester]
28 | subInstances.put(instance.name, instance)
29 | this
30 | }
31 |
32 | override def isMatch(): Boolean = {
33 | _tokens.headOption.map(_.getType) match {
34 | case Some(DSLSQLLexer.LOAD) => true
35 | case _ => false
36 | }
37 | }
38 |
39 | private def keywordSuggest: List[SuggestItem] = {
40 | _tokenPos match {
41 | case TokenPos(pos, TokenPosType.NEXT, offsetInToken) =>
42 | var items = List[SuggestItem]()
43 | val temp = TokenMatcher(_tokens, pos).back.
44 | eat(Food(None, DSLSQLLexer.BACKQUOTED_IDENTIFIER)).
45 | eat(Food(None, MLSQLTokenTypeWrapper.DOT)).
46 | build
47 | if (temp.isSuccess) {
48 | items = List(SuggestItem("where", SpecialTableConst.KEY_WORD_TABLE, Map()), SuggestItem("as", SpecialTableConst.KEY_WORD_TABLE, Map()))
49 | }
50 | items
51 |
52 | case _ => List()
53 | }
54 |
55 | }
56 |
57 | override def suggest(): List[SuggestItem] = {
58 | keywordSuggest ++ defaultSuggest(subInstances.toMap)
59 | }
60 |
61 |
62 | override def name: String = "load"
63 | }
64 |
65 | class LoadFormatSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils {
66 | override def isMatch(): Boolean = {
67 |
68 | (tokenPos.pos, tokenPos.currentOrNext) match {
69 | case (0, TokenPosType.NEXT) => true
70 | case (1, TokenPosType.CURRENT) => true
71 | case (_, _) => false
72 | }
73 |
74 | }
75 |
76 | override def suggest(): List[SuggestItem] = {
77 | // datasource type suggest
78 | val sources = (DataSourceRegistry.allSourceNames.toSet.toSeq ++ Seq(
79 | "parquet", "csv", "jsonStr", "csvStr", "json", "text", "orc", "kafka", "kafka8", "kafka9", "crawlersql", "image",
80 | "script", "hive", "xml", "mlsqlAPI", "mlsqlConf"
81 | )).toList
82 | LexerUtils.filterPrefixIfNeeded(
83 | sources.map(SuggestItem(_, SpecialTableConst.DATA_SOURCE_TABLE,
84 | Map("desc" -> "DataSource"))),
85 | tokens, tokenPos)
86 |
87 | }
88 |
89 |
90 | override def tokens: List[Token] = loadSuggester._tokens
91 |
92 | override def tokenPos: TokenPos = loadSuggester._tokenPos
93 |
94 | override def name: String = "format"
95 | }
96 |
97 | class LoadOptionsSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils {
98 | override def isMatch(): Boolean = {
99 | backAndFirstIs(DSLSQLLexer.OPTIONS) || backAndFirstIs(DSLSQLLexer.WHERE)
100 | }
101 |
102 | override def suggest(): List[SuggestItem] = {
103 | val source = tokens(1)
104 | val datasources = DataSourceRegistry.fetch(source.getText, Map[String, String]()) match {
105 | case Some(ds) => ds.asInstanceOf[MLSQLSourceInfo].
106 | explainParams(loadSuggester.context.session).collect().
107 | map(row => (row.getString(0), row.getString(1))).
108 | toList
109 | case None => List()
110 | }
111 | LexerUtils.filterPrefixIfNeeded(datasources.map(tuple =>
112 | SuggestItem(tuple._1, SpecialTableConst.OPTION_TABLE, Map("desc" -> tuple._2))),
113 | tokens, tokenPos)
114 |
115 | }
116 |
117 | override def name: String = "options"
118 |
119 | override def tokens: List[Token] = loadSuggester._tokens
120 |
121 | override def tokenPos: TokenPos = loadSuggester._tokenPos
122 | }
123 |
124 | class LoadPathQuoteSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils {
125 | override def name: String = "pathQuote"
126 |
127 | override def isMatch(): Boolean = {
128 | val temp = TokenMatcher(tokens, tokenPos.pos).back.
129 | eat(Food(None, MLSQLTokenTypeWrapper.DOT)).
130 | eat(Food(None, DSLSQLLexer.IDENTIFIER)).
131 | eat(Food(None, DSLSQLLexer.LOAD)).build
132 | temp.isSuccess
133 | }
134 |
135 | override def suggest(): List[SuggestItem] = {
136 | LexerUtils.filterPrefixIfNeeded(List(SuggestItem("``", SpecialTableConst.OTHER_TABLE, Map("desc" -> "path or table"))),
137 | tokens, tokenPos)
138 | }
139 |
140 | override def tokens: List[Token] = loadSuggester._tokens
141 |
142 | override def tokenPos: TokenPos = loadSuggester._tokenPos
143 | }
144 |
145 | //Here you can implement Hive table / HDFS Path auto suggestion
146 | class LoadPathSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils {
147 | override def isMatch(): Boolean = {
148 | false
149 | }
150 |
151 | override def suggest(): List[SuggestItem] = {
152 | List()
153 | }
154 |
155 | override def name: String = "path"
156 |
157 |
158 | override def tokens: List[Token] = loadSuggester._tokens
159 |
160 | override def tokenPos: TokenPos = loadSuggester._tokenPos
161 | }
162 |
163 |
164 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/MLSQLStatementSplitter.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import streaming.dsl.parser.DSLSQLLexer
5 |
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | /**
9 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
10 | */
11 | class MLSQLStatementSplitter extends StatementSplitter {
12 | override def split(_tokens: List[Token]): List[List[Token]] = {
13 | val _statements = ArrayBuffer[List[Token]]()
14 | val tokens = _tokens.zipWithIndex
15 | var start = 0
16 | var end = 0
17 | tokens.foreach { case (token, index) =>
18 | // statement end
19 | if (token.getType == DSLSQLLexer.T__1) {
20 | end = index
21 | _statements.append(tokens.filter(p => p._2 >= start && p._2 <= end).map(_._1))
22 | start = index + 1
23 | }
24 |
25 | }
26 | // clean the last statement without ender
27 | val theLeft = tokens.filter(p => p._2 >= start && p._2 <= tokens.size).map(_._1).toList
28 | if (theLeft.size > 0) {
29 | _statements.append(theLeft)
30 | }
31 | _statements.toList
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/MatchAndExtractor.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import tech.mlsql.autosuggest.dsl.TokenMatcher
4 |
5 | /**
6 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | trait MatchAndExtractor[T] {
9 | def matcher(start: Int): TokenMatcher
10 |
11 | def extractor(start: Int, end: Int): List[T]
12 |
13 | def iterate(start: Int, end: Int, limit: Int = 100): List[T]
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/PreProcessStatement.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 |
5 | /**
6 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | trait PreProcessStatement {
9 | def process(statement: List[Token]): Unit
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/RegisterSuggester.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import streaming.dsl.parser.DSLSQLLexer
5 | import tech.mlsql.autosuggest.{AutoSuggestContext, TokenPos}
6 |
7 | import scala.collection.mutable
8 |
9 | /**
10 | * 30/6/2020 WilliamZhu(allwefantasy@gmail.com)
11 | */
12 | class RegisterSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val _tokenPos: TokenPos) extends StatementSuggester
13 | with SuggesterRegister {
14 | private val subInstances = new mutable.HashMap[String, StatementSuggester]()
15 |
16 |
17 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = {
18 | val instance = clzz.getConstructor(classOf[LoadSuggester]).newInstance(this).asInstanceOf[StatementSuggester]
19 | subInstances.put(instance.name, instance)
20 | this
21 | }
22 |
23 | override def isMatch(): Boolean = {
24 | _tokens.headOption.map(_.getType) match {
25 | case Some(DSLSQLLexer.REGISTER) => true
26 | case _ => false
27 | }
28 | }
29 |
30 |
31 | override def suggest(): List[SuggestItem] = {
32 | List()
33 | }
34 |
35 |
36 | override def name: String = "register"
37 | }
38 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/SelectStatementUtils.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
5 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper}
6 | import tech.mlsql.autosuggest._
7 | import tech.mlsql.common.utils.log.Logging
8 |
9 | /**
10 | * 8/6/2020 WilliamZhu(allwefantasy@gmail.com)
11 | */
12 | trait SelectStatementUtils extends Logging {
13 | def selectSuggester: SelectSuggester
14 |
15 | def tokenPos: TokenPos
16 |
17 | def tokens: List[Token]
18 |
19 |
20 | def levelFromTokenPos = {
21 | var targetLevel = 0
22 | selectSuggester.sqlAST.visitDown(0) { case (ast, _level) =>
23 | if (tokenPos.pos >= ast.start && tokenPos.pos < ast.stop) targetLevel = _level
24 | }
25 | targetLevel
26 | }
27 |
28 | def getASTFromTokenPos: Option[SingleStatementAST] = {
29 | var targetAst: Option[SingleStatementAST] = None
30 | selectSuggester.sqlAST.visitUp(0) { case (ast, level) =>
31 | if (targetAst == None && (ast.start <= tokenPos.pos && tokenPos.pos < ast.stop)) {
32 | targetAst = Option(ast)
33 | }
34 | }
35 | targetAst
36 | }
37 |
38 | def table_info = {
39 | selectSuggester.table_info.get(levelFromTokenPos)
40 | }
41 |
42 |
43 | def tableSuggest(): List[SuggestItem] = {
44 | val tempStart = tokenPos.currentOrNext match {
45 | case TokenPosType.CURRENT =>
46 | tokenPos.pos - 1
47 | case TokenPosType.NEXT =>
48 | tokenPos.pos
49 | }
50 |
51 | val temp = TokenMatcher(tokens, tempStart).back.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build
52 |
53 | if (selectSuggester.context.isInDebugMode) {
54 | logInfo(s"tableSuggest, table_info:\n")
55 | logInfo("========tableSuggest start=======")
56 | if (table_info.isDefined) {
57 | table_info.get.foreach { case (key, metaTable) =>
58 | logInfo(key.toString + s"=>\n key: ${metaTable.key} \n columns: ${metaTable.columns.map(_.name).mkString(",")}")
59 | }
60 | }
61 |
62 | logInfo(s"Final: suggest ${!temp.isSuccess}")
63 |
64 | logInfo("========tableSuggest end =======")
65 | }
66 |
67 | if (temp.isSuccess) return List()
68 |
69 | table_info match {
70 | case Some(tb) => tb.map { case (key, value) =>
71 | (key.aliasName.getOrElse(key.metaTableKey.table), value)
72 | }.map { case (name, table) =>
73 | SuggestItem(name, table, Map())
74 | }.toList
75 | case None =>
76 | val tokenPrefix = LexerUtils.tableTokenPrefix(tokens, tokenPos)
77 | val owner = AutoSuggestContext.context().reqParams.getOrElse("owner", "")
78 | val extraParam = Map("searchPrefix" -> tokenPrefix, "owner" -> owner)
79 | selectSuggester.context.metaProvider.list(extraParam).map { item =>
80 | SuggestItem(item.key.table, item, Map())
81 | }
82 | }
83 | }
84 |
85 | def attributeSuggest(): List[SuggestItem] = {
86 | val tempStart = tokenPos.currentOrNext match {
87 | case TokenPosType.CURRENT =>
88 | tokenPos.pos - 1
89 | case TokenPosType.NEXT =>
90 | tokenPos.pos
91 | }
92 | if (selectSuggester.context.isInDebugMode) {
93 | logInfo(s"attributeSuggest:\n")
94 | logInfo("========attributeSuggest start=======")
95 | }
96 |
97 | def allOutput = {
98 | /**
99 | * 优先推荐别名
100 | */
101 | val res = table_info.get.filter { item => item._2.key == SpecialTableConst.subQueryAliasTable && item._1.aliasName.isDefined }.flatMap { table =>
102 | if (selectSuggester.context.isInDebugMode) {
103 | val columns = table._2.columns.map { item => s"${item.name} ${item}" }.mkString("\n")
104 | logInfo(s"TARGET table: ${table._1} \n columns: \n[${columns}] ")
105 | }
106 | table._2.columns.map(column => SuggestItem(column.name, table._2, Map()))
107 |
108 | }.toList
109 |
110 | if (res.isEmpty) {
111 | if (selectSuggester.context.isInDebugMode) {
112 | val tables = table_info.get.map { case (key, table) =>
113 | val columns = table.columns.map { item => s"${item.name} ${item}" }.mkString("\n")
114 | s"${key}:\n ${columns}"
115 | }.mkString("\n")
116 | logInfo(s"ALL tables: \n ${tables}")
117 | }
118 | table_info.get.flatMap { case (_, metaTable) =>
119 | metaTable.columns.map(column => SuggestItem(column.name, metaTable, Map())).toList
120 | }.toList
121 | } else res
122 |
123 |
124 | }
125 |
126 | val temp = TokenMatcher(tokens, tempStart).back.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build
127 | if (selectSuggester.context.isInDebugMode) {
128 | logInfo(s"Try to match attribute db prefix: Status(${temp.isSuccess})")
129 | }
130 | val res = if (temp.isSuccess) {
131 | val table = temp.getMatchTokens.head.getText
132 | table_info.get.filter { case (key, value) =>
133 | (key.aliasName.isDefined && key.aliasName.get == table) || key.metaTableKey.table == table
134 | }.headOption match {
135 | case Some(table) =>
136 | if (selectSuggester.context.isInDebugMode) {
137 | logInfo(s"table[${table._1}] found, return ${table._2.key} columns.")
138 | }
139 | table._2.columns.map(column => SuggestItem(column.name, table._2, Map())).toList
140 | case None =>
141 | if (selectSuggester.context.isInDebugMode) {
142 | logInfo(s"No table found, so return all table[${table_info.get.map { case (_, metaTable) => metaTable.key.toString }}] columns.")
143 | }
144 | allOutput
145 | }
146 | } else allOutput
147 |
148 | if (selectSuggester.context.isInDebugMode) {
149 | logInfo("========attributeSuggest end=======")
150 | }
151 | res
152 |
153 | }
154 |
155 | def functionSuggest(): List[SuggestItem] = {
156 | if (selectSuggester.context.isInDebugMode) {
157 | logInfo(s"functionSuggest:\n")
158 | logInfo("========functionSuggest start=======")
159 | }
160 |
161 | def allOutput = {
162 | MLSQLSQLFunction.funcMetaProvider.list(Map()).map(item => SuggestItem(item.key.table, item, Map()))
163 | }
164 |
165 | val tempStart = tokenPos.currentOrNext match {
166 | case TokenPosType.CURRENT =>
167 | tokenPos.pos - 1
168 | case TokenPosType.NEXT =>
169 | tokenPos.pos
170 | }
171 |
172 | // 如果匹配上了,说明是字段,那么就不应该提示函数了
173 | val temp = TokenMatcher(tokens, tempStart).back.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build
174 | val res = if (temp.isSuccess) {
175 | List()
176 | } else allOutput
177 |
178 | if (selectSuggester.context.isInDebugMode) {
179 | logInfo(s"functions: ${allOutput.map(_.name).mkString(",")}")
180 | logInfo("========functionSuggest end=======")
181 | }
182 | res
183 |
184 | }
185 | }
186 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/SelectSuggester.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
5 | import streaming.dsl.parser.DSLSQLLexer
6 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper}
7 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableColumn, MetaTableKey}
8 | import tech.mlsql.autosuggest.{AutoSuggestContext, SpecialTableConst, TokenPos}
9 |
10 | import scala.collection.mutable
11 |
12 | /**
13 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com)
14 | */
15 | class SelectSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val tokenPos: TokenPos) extends StatementSuggester with SuggesterRegister {
16 |
17 | private val subInstances = new mutable.HashMap[String, StatementSuggester]()
18 | register(classOf[ProjectSuggester])
19 | register(classOf[FromSuggester])
20 | register(classOf[FilterSuggester])
21 | register(classOf[JoinSuggester])
22 | register(classOf[JoinOnSuggester])
23 | register(classOf[OrderSuggester])
24 |
25 | override def name: String = "select"
26 |
27 | private lazy val newTokens = LexerUtils.toRawSQLTokens(context, _tokens)
28 | private lazy val TABLE_INFO = mutable.HashMap[Int, mutable.HashMap[MetaTableKeyWrapper, MetaTable]]()
29 | private lazy val selectTree: SingleStatementAST = buildTree()
30 |
31 | def sqlAST = selectTree
32 |
33 | def tokens = newTokens
34 |
35 | def table_info = TABLE_INFO
36 |
37 | override def isMatch(): Boolean = {
38 | _tokens.headOption.map(_.getType) match {
39 | case Some(DSLSQLLexer.SELECT) => true
40 | case _ => false
41 | }
42 | }
43 |
44 | private def buildTree() = {
45 | val root = SingleStatementAST.build(this, newTokens)
46 | import scala.collection.mutable
47 |
48 | root.visitUp(level = 0) { case (ast: SingleStatementAST, level: Int) =>
49 | if (!TABLE_INFO.contains(level)) {
50 | TABLE_INFO.put(level, new mutable.HashMap[MetaTableKeyWrapper, MetaTable]())
51 | }
52 |
53 | if (level != 0 && !TABLE_INFO.contains(level - 1)) {
54 | TABLE_INFO.put(level - 1, new mutable.HashMap[MetaTableKeyWrapper, MetaTable]())
55 | }
56 |
57 |
58 | ast.tables(newTokens).foreach { item =>
59 | if (item.aliasName.isEmpty || item.metaTableKey != MetaTableKey(None, None, null)) {
60 | context.metaProvider.search(item.metaTableKey) match {
61 | case Some(res) =>
62 | TABLE_INFO(level) += (item -> res)
63 | case None =>
64 | }
65 | }
66 | }
67 |
68 | val nameOpt = ast.name(newTokens)
69 | if (nameOpt.isDefined) {
70 |
71 | val metaTableKey = MetaTableKey(None, None, null)
72 | val metaTableKeyWrapper = MetaTableKeyWrapper(metaTableKey, nameOpt)
73 | val metaColumns = ast.output(newTokens).map { attr =>
74 | MetaTableColumn(attr, null, true, Map())
75 | }
76 | TABLE_INFO(level - 1) += (metaTableKeyWrapper -> MetaTable(metaTableKey, metaColumns))
77 | }
78 |
79 | }
80 |
81 | if (context.isInDebugMode) {
82 | logInfo(s"SQL[${newTokens.map(_.getText).mkString(" ")}]")
83 | logInfo(s"STRUCTURE: \n")
84 | TABLE_INFO.foreach { item =>
85 | logInfo(s"Level:${item._1}")
86 | item._2.foreach { table =>
87 | logInfo(s"${table._1} => ${table._2}")
88 | }
89 | }
90 |
91 | }
92 |
93 | root
94 | }
95 |
96 | override def suggest(): List[SuggestItem] = {
97 | var instance: StatementSuggester = null
98 | subInstances.foreach { _instance =>
99 | if (instance == null && _instance._2.isMatch()) {
100 | instance = _instance._2
101 | }
102 | }
103 | if (instance == null) List()
104 | else instance.suggest()
105 |
106 | }
107 |
108 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = {
109 | val instance = clzz.getConstructor(classOf[SelectSuggester]).newInstance(this).asInstanceOf[StatementSuggester]
110 | subInstances.put(instance.name, instance)
111 | this
112 | }
113 | }
114 |
115 |
116 | class ProjectSuggester(_selectSuggester: SelectSuggester) extends StatementSuggester with SelectStatementUtils with SuggesterRegister {
117 |
118 | def tokens = _selectSuggester.tokens
119 |
120 | def tokenPos = _selectSuggester.tokenPos
121 |
122 | def selectSuggester = _selectSuggester
123 |
124 | def backAndFirstIs(t: Int, keywords: List[Int] = TokenMatcher.SQL_SPLITTER_KEY_WORDS): Boolean = {
125 | // 能找得到所在的子查询(也可以是最外层)
126 | val ast = getASTFromTokenPos
127 | if (ast.isEmpty) return false
128 |
129 | // 从光标位置去找第一个核心词
130 | val temp = TokenMatcher(tokens, tokenPos.pos).back.orIndex(keywords.map(Food(None, _)).toArray)
131 | if (temp == -1) return false
132 | //第一个核心词必须是是定的词,并且在子查询里
133 | if (tokens(temp).getType == t && temp >= ast.get.start && temp < ast.get.stop) return true
134 | return false
135 | }
136 |
137 |
138 | override def name: String = "project"
139 |
140 | override def isMatch(): Boolean = {
141 | val temp = backAndFirstIs(SqlBaseLexer.SELECT)
142 | if (selectSuggester.context.isInDebugMode) {
143 | logInfo(s"${name} is matched")
144 | }
145 | temp
146 | }
147 |
148 | override def suggest(): List[SuggestItem] = {
149 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ attributeSuggest() ++ functionSuggest(), tokens, tokenPos)
150 | }
151 |
152 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = ???
153 | }
154 |
155 | class FilterSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) {
156 |
157 |
158 | override def name: String = "filter"
159 |
160 | override def isMatch(): Boolean = {
161 | backAndFirstIs(SqlBaseLexer.WHERE)
162 |
163 | }
164 |
165 | override def suggest(): List[SuggestItem] = {
166 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ attributeSuggest() ++ functionSuggest(), tokens, tokenPos)
167 | }
168 |
169 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = ???
170 | }
171 |
172 | class JoinOnSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) {
173 | override def name: String = "join_on"
174 |
175 | override def isMatch(): Boolean = {
176 | backAndFirstIs(SqlBaseLexer.ON)
177 | }
178 |
179 | override def suggest(): List[SuggestItem] = {
180 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ attributeSuggest() ++ functionSuggest(), tokens, tokenPos)
181 | }
182 | }
183 |
184 | class JoinSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) {
185 | override def name: String = "join"
186 |
187 | override def isMatch(): Boolean = {
188 | backAndFirstIs(SqlBaseLexer.JOIN)
189 | }
190 |
191 | override def suggest(): List[SuggestItem] = {
192 | LexerUtils.filterPrefixIfNeeded(tableSuggest(), tokens, tokenPos)
193 | }
194 | }
195 |
196 | class FromSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) {
197 | override def name: String = "from"
198 |
199 | override def isMatch(): Boolean = {
200 | backAndFirstIs(SqlBaseLexer.FROM)
201 | }
202 |
203 | override def suggest(): List[SuggestItem] = {
204 |
205 | val tokenPrefix = LexerUtils.tableTokenPrefix(tokens, tokenPos)
206 | val owner = AutoSuggestContext.context().reqParams.getOrElse("owner", "")
207 | val extraParam = Map("searchPrefix" -> tokenPrefix, "owner" -> owner)
208 |
209 | val allTables = _selectSuggester.context.metaProvider.list(extraParam).map { item =>
210 | val prefix = (item.key.prefix, item.key.db) match {
211 | case (Some(prefix), Some(db)) => prefix
212 | case (Some(prefix), None) => prefix
213 | case (None, Some(SpecialTableConst.TEMP_TABLE_DB_KEY)) => "temp table"
214 | case (None, Some(db)) => db
215 | }
216 | SuggestItem(item.key.table, item, Map("desc" -> prefix))
217 | }
218 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ allTables, tokens, tokenPos)
219 | }
220 | }
221 |
222 | class OrderSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) {
223 | override def name: String = "order"
224 |
225 | override def isMatch(): Boolean = {
226 | backAndFirstIs(SqlBaseLexer.ORDER)
227 | }
228 |
229 | override def suggest(): List[SuggestItem] = {
230 | LexerUtils.filterPrefixIfNeeded(attributeSuggest() ++ functionSuggest(), tokens, tokenPos)
231 | }
232 | }
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/StatementSplitter.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 |
5 | /**
6 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
7 | */
8 | trait StatementSplitter {
9 | def split(_tokens: List[Token]): List[List[Token]]
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/StatementSuggester.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import tech.mlsql.autosuggest.meta.MetaTable
4 | import tech.mlsql.common.utils.log.Logging
5 |
6 | /**
7 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com)
8 | */
9 | trait StatementSuggester extends Logging{
10 | def name: String
11 |
12 | def isMatch(): Boolean
13 |
14 | def suggest(): List[SuggestItem]
15 |
16 | def defaultSuggest(subInstances: Map[String, StatementSuggester]): List[SuggestItem] = {
17 | var instance: StatementSuggester = null
18 | subInstances.foreach { _instance =>
19 | if (instance == null && _instance._2.isMatch()) {
20 | instance = _instance._2
21 | }
22 | }
23 | if (instance == null) List()
24 | else instance.suggest()
25 |
26 | }
27 | }
28 |
29 | case class SuggestItem(name: String, metaTable: MetaTable, extra: Map[String, String])
30 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/StatementUtils.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import streaming.dsl.parser.DSLSQLLexer
5 | import tech.mlsql.autosuggest.TokenPos
6 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher}
7 |
8 | /**
9 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com)
10 | */
11 | trait StatementUtils {
12 |
13 | def tokens: List[Token]
14 |
15 | def tokenPos: TokenPos
16 |
17 | def SPLIT_KEY_WORDS = {
18 | List(DSLSQLLexer.OPTIONS, DSLSQLLexer.WHERE, DSLSQLLexer.AS)
19 | }
20 |
21 | def backAndFirstIs(t: Int, keywords: List[Int] = SPLIT_KEY_WORDS): Boolean = {
22 |
23 |
24 | // 从光标位置去找第一个核心词
25 | val temp = TokenMatcher(tokens, tokenPos.pos).back.orIndex(keywords.map(Food(None, _)).toArray)
26 | if (temp == -1) return false
27 | //第一个核心词必须是指定的词
28 | if (tokens(temp).getType == t) return true
29 | return false
30 | }
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/SuggesterRegister.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | /**
4 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | trait SuggesterRegister {
7 | def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/TableExtractor.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
5 | import tech.mlsql.autosuggest.AutoSuggestContext
6 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper}
7 | import tech.mlsql.autosuggest.meta.MetaTableKey
8 |
9 | import scala.collection.mutable.ArrayBuffer
10 |
11 | /**
12 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class TableExtractor(autoSuggestContext: AutoSuggestContext, ast: SingleStatementAST, tokens: List[Token]) extends MatchAndExtractor[MetaTableKeyWrapper] {
15 | override def matcher(start: Int): TokenMatcher = {
16 | val temp = TokenMatcher(tokens, start).
17 | eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, TokenTypeWrapper.DOT)).optional.
18 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).
19 | eat(Food(None, SqlBaseLexer.AS)).optional.
20 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).optional.
21 | build
22 | temp
23 | }
24 |
25 | override def extractor(start: Int, end: Int): List[MetaTableKeyWrapper] = {
26 | val dbTableTokens = tokens.slice(start, end)
27 | val dbTable = dbTableTokens.length match {
28 | case 2 =>
29 | val List(tableToken, aliasToken) = dbTableTokens
30 | if(aliasToken.getText.toLowerCase() == "as"){
31 | MetaTableKeyWrapper(MetaTableKey(None, None, tableToken.getText), Option(aliasToken.getText))
32 | }else {
33 | MetaTableKeyWrapper(MetaTableKey(None, None, tableToken.getText), None)
34 | }
35 |
36 | case 3 =>
37 | val List(dbToken, _, tableToken) = dbTableTokens
38 | MetaTableKeyWrapper(MetaTableKey(None, Option(dbToken.getText), tableToken.getText), None)
39 | case 4 =>
40 | val List(dbToken, _, tableToken, aliasToken) = dbTableTokens
41 | MetaTableKeyWrapper(MetaTableKey(None, Option(dbToken.getText), tableToken.getText), Option(aliasToken.getText))
42 | case 5 =>
43 | val List(dbToken, _, tableToken, _, aliasToken) = dbTableTokens
44 | MetaTableKeyWrapper(MetaTableKey(None, Option(dbToken.getText), tableToken.getText), Option(aliasToken.getText))
45 | case _ => MetaTableKeyWrapper(MetaTableKey(None, None, dbTableTokens.head.getText), None)
46 | }
47 |
48 | List(dbTable)
49 | }
50 |
51 | override def iterate(start: Int, end: Int, limit: Int = 100): List[MetaTableKeyWrapper] = {
52 | val tables = ArrayBuffer[MetaTableKeyWrapper]()
53 | var matchRes = matcher(start)
54 | var whileLimit = limit
55 | while (matchRes.isSuccess && whileLimit > 0) {
56 | tables ++= extractor(matchRes.start, matchRes.get)
57 | whileLimit -= 1
58 | val temp = TokenMatcher(tokens, matchRes.get).eat(Food(None, SqlBaseLexer.T__2)).build
59 | if (temp.isSuccess) {
60 | matchRes = matcher(temp.get)
61 | } else whileLimit = 0
62 | }
63 |
64 | tables.toList
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/TemplateSuggester.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import streaming.dsl.parser.DSLSQLLexer
5 | import tech.mlsql.autosuggest.{AutoSuggestContext, TokenPos}
6 |
7 | import scala.collection.mutable
8 |
9 | /**
10 | * 30/6/2020 WilliamZhu(allwefantasy@gmail.com)
11 | */
12 | class TemplateSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val _tokenPos: TokenPos) extends StatementSuggester
13 | with SuggesterRegister {
14 | private val subInstances = new mutable.HashMap[String, StatementSuggester]()
15 |
16 |
17 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = {
18 | val instance = clzz.getConstructor(classOf[LoadSuggester]).newInstance(this).asInstanceOf[StatementSuggester]
19 | subInstances.put(instance.name, instance)
20 | this
21 | }
22 |
23 | override def isMatch(): Boolean = {
24 | _tokens.headOption.map(_.getType) match {
25 | case Some(DSLSQLLexer.REGISTER) => true
26 | case _ => false
27 | }
28 | }
29 |
30 |
31 | override def suggest(): List[SuggestItem] = {
32 | List()
33 | }
34 |
35 |
36 | override def name: String = "template"
37 | }
38 |
39 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/statement/single_statement.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.statement
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
5 | import tech.mlsql.autosuggest.AttributeExtractor
6 | import tech.mlsql.autosuggest.ast.NoneToken
7 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper}
8 | import tech.mlsql.autosuggest.meta.MetaTableKey
9 |
10 | import scala.collection.mutable
11 | import scala.collection.mutable.ArrayBuffer
12 |
13 | case class MetaTableKeyWrapper(metaTableKey: MetaTableKey, aliasName: Option[String])
14 |
15 | /**
16 | * the atom query statement is only contains:
17 | * select,from,groupby,where,join limit
18 | * Notice that we do not make sure the sql is right
19 | *
20 | */
21 | class SingleStatementAST(val selectSuggester: SelectSuggester, var start: Int, var stop: Int, var parent: SingleStatementAST) {
22 | val children = ArrayBuffer[SingleStatementAST]()
23 |
24 | def isLeaf = {
25 | children.length == 0
26 | }
27 |
28 | def name(tokens: List[Token]): Option[String] = {
29 | if (parent == null) None
30 | else Option(tokens.slice(start, stop).last.getText)
31 | }
32 |
33 |
34 | // private def isInSubquery(holes: List[(Int, Int)],) = {
35 | //
36 | // }
37 |
38 | def tables(tokens: List[Token]) = {
39 |
40 | // replace token
41 | val range = children.map { ast =>
42 | (ast.start, ast.stop)
43 | }
44 |
45 | def inRange(index:Int) = {
46 | range.filter { item =>
47 | item._1 <= index && index <= item._2
48 | }.headOption.isDefined
49 | }
50 |
51 | val tokensWithoutSubQuery = tokens.zipWithIndex.map { case (token,index) =>
52 | if (inRange(index)) new NoneToken(token)
53 | else token
54 | }
55 |
56 |
57 | // collect table first
58 | // T__3 == .
59 | // extract from `from`
60 | val fromTables = new TableExtractor(selectSuggester.context, this, tokensWithoutSubQuery)
61 | val fromStart = TokenMatcher(tokensWithoutSubQuery.slice(0, stop), start).asStart(Food(None, SqlBaseLexer.FROM), 1).start
62 | val tempTables = fromTables.iterate(fromStart, tokensWithoutSubQuery.size)
63 |
64 | // extract from `join`
65 | val joinTables = new TableExtractor(selectSuggester.context, ast = this, tokensWithoutSubQuery)
66 | val joinStart = TokenMatcher(tokensWithoutSubQuery.slice(0, stop), start).asStart(Food(None, SqlBaseLexer.JOIN), offset = 1).start
67 | val tempJoinTables = joinTables.iterate(joinStart, tokens.size)
68 |
69 | // extract subquery name
70 | val subqueryTables = children.map(_.name(tokens).get).map { name =>
71 | MetaTableKeyWrapper(MetaTableKey(None, None, null), Option(name))
72 | }.toList
73 |
74 | tempTables ++ tempJoinTables ++ subqueryTables
75 | }
76 |
77 | def level = {
78 | var count = 0
79 | var temp = this.parent
80 | while (temp != null) {
81 | temp = temp.parent
82 | count += 1
83 | }
84 | count
85 | }
86 |
87 |
88 | def output(tokens: List[Token]): List[String] = {
89 | val selectStart = TokenMatcher(tokens.slice(0, stop), start).asStart(Food(None, SqlBaseLexer.SELECT), 1).start
90 | val extractor = new AttributeExtractor(selectSuggester.context, this, tokens)
91 | extractor.iterate(selectStart, tokens.size)
92 | }
93 |
94 | def visitDown(level: Int)(rule: PartialFunction[(SingleStatementAST, Int), Unit]): Unit = {
95 | rule.apply((this, level))
96 | this.children.map(_.visitDown(level + 1)(rule))
97 | }
98 |
99 | def visitUp(level: Int)(rule: PartialFunction[(SingleStatementAST, Int), Unit]): Unit = {
100 | this.children.map(_.visitUp(level + 1)(rule))
101 | rule.apply((this, level))
102 | }
103 |
104 | def fastEquals(other: SingleStatementAST): Boolean = {
105 | this.eq(other) || this == other
106 | }
107 |
108 | def printAsStr(_tokens: List[Token], _level: Int): String = {
109 | val tokens = _tokens.slice(start, stop)
110 | val stringBuilder = new mutable.StringBuilder()
111 | var count = 1
112 | stringBuilder.append(tokens.map { item =>
113 | count += 1
114 | val suffix = if (count % 20 == 0) "\n" else ""
115 | item.getText + suffix
116 | }.mkString(" "))
117 | stringBuilder.append("\n")
118 | stringBuilder.append("\n")
119 | stringBuilder.append("\n")
120 | children.zipWithIndex.foreach { case (item, index) => stringBuilder.append("=" * (_level + 1) + ">" + item.printAsStr(_tokens, _level + 1)) }
121 | stringBuilder.toString()
122 | }
123 | }
124 |
125 |
126 | object SingleStatementAST {
127 |
128 | def matchTableAlias(tokens: List[Token], start: Int) = {
129 | tokens(start)
130 | }
131 |
132 | def build(selectSuggester: SelectSuggester, tokens: List[Token]) = {
133 | _build(selectSuggester, tokens, 0, tokens.size, false)
134 | }
135 |
136 | def _build(selectSuggester: SelectSuggester, tokens: List[Token], start: Int, stop: Int, isSub: Boolean = false): SingleStatementAST = {
137 | val ROOT = new SingleStatementAST(selectSuggester, start, stop, null)
138 | // context start: ( select
139 | // context end: )
140 |
141 | var bracketStart = 0
142 | var jumpIndex = -1
143 | for (index <- (start until stop) if index >= jumpIndex) {
144 | val token = tokens(index)
145 | if (token.getType == TokenTypeWrapper.LEFT_BRACKET && index < stop - 1 && tokens(index + 1).getType == SqlBaseLexer.SELECT) {
146 | // println(s"enter: ${tokens.slice(index, index + 5).map(_.getText).mkString(" ")}")
147 | val item = SingleStatementAST._build(selectSuggester, tokens, index + 1, stop, true)
148 | jumpIndex = item.stop
149 | ROOT.children += item
150 | item.parent = ROOT
151 |
152 | } else {
153 | if (isSub) {
154 | if (token.getType == TokenTypeWrapper.LEFT_BRACKET) {
155 | bracketStart += 1
156 | }
157 | if (token.getType == TokenTypeWrapper.RIGHT_BRACKET && bracketStart != 0) {
158 | bracketStart -= 1
159 | }
160 | else if (token.getType == TokenTypeWrapper.RIGHT_BRACKET && bracketStart == 0) {
161 | // check the alias
162 | val matcher = TokenMatcher(tokens, index + 1).eat(Food(None, SqlBaseLexer.AS)).optional.eat(Food(None, SqlBaseLexer.IDENTIFIER)).build
163 | val stepSize = if (matcher.isSuccess) matcher.get else index
164 | ROOT.start = start
165 | ROOT.stop = stepSize
166 | // println(s"out: ${tokens.slice(stepSize - 5, stepSize).map(_.getText).mkString(" ")}")
167 | return ROOT
168 | }
169 |
170 | } else {
171 | // do nothing
172 | }
173 | }
174 |
175 | }
176 | ROOT
177 | }
178 | }
179 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/autosuggest/utils/SchemaUtils.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.autosuggest.utils
2 |
3 | import org.apache.spark.sql.types.StructType
4 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableColumn, MetaTableKey}
5 |
6 | /**
7 | * 15/6/2020 WilliamZhu(allwefantasy@gmail.com)
8 | */
9 | object SchemaUtils {
10 | def toMetaTable(table: MetaTableKey, st: StructType) = {
11 | val columns = st.fields.map { item =>
12 | MetaTableColumn(item.name, item.dataType.typeName, item.nullable, Map())
13 | }.toList
14 | MetaTable(table, columns)
15 | }
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Set everything to be logged to the console
18 | log4j.rootCategory=INFO, console,file
19 | log4j.appender.console=org.apache.log4j.ConsoleAppender
20 | log4j.appender.console.target=System.err
21 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
22 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %X{owner} %p %c{1}: %m%n
23 | log4j.appender.file=org.apache.log4j.RollingFileAppender
24 | log4j.appender.file.File=./logs/mlsql_engine.log
25 | log4j.appender.file.rollingPolicy=org.apache.log4j.rolling.TimeBasedRollingPolicy
26 | log4j.appender.file.rollingPolicy.fileNamePattern=./logs/mlsql_engine.%d.gz
27 | log4j.appender.file.layout=org.apache.log4j.PatternLayout
28 | log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %X{owner} %p %c{1}: %m%n
29 | log4j.appender.file.MaxBackupIndex=5
30 | # Set the default spark-shell log level to WARN. When running the spark-shell, the
31 | # log level for this class is used to overwrite the root logger's log level, so that
32 | # the user can have different defaults for the shell and regular Spark apps.
33 | log4j.logger.org.apache.spark=WARN
34 | #log4j.logger.org.apache.spark=WARN
35 | # Settings to quiet third party logs that are too verbose
36 | log4j.logger.org.spark_project.jetty=WARN
37 | log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
38 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
39 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
40 | log4j.logger.org.apache.parquet=ERROR
41 | log4j.logger.org.apache.spark.ContextCleaner=ERROR
42 | log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=ERROR
43 | log4j.logger.parquet=ERROR
44 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
45 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
46 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/AutoSuggestContextTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.scalatest.BeforeAndAfterEach
5 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey}
6 | import tech.mlsql.autosuggest.statement.{LexerUtils, SuggestItem}
7 | import tech.mlsql.autosuggest.{DataType, SpecialTableConst, TokenPos, TokenPosType}
8 |
9 | import scala.collection.JavaConverters._
10 |
11 | /**
12 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class AutoSuggestContextTest extends BaseTest with BeforeAndAfterEach {
15 | override def afterEach(): Unit = {
16 | // context.statements.clear()
17 | }
18 |
19 | test("parse") {
20 | val wow = context.lexer.tokenizeNonDefaultChannel(
21 | """
22 | | -- yes
23 | | load hive.`` as -- jack
24 | | table1;
25 | | select * from table1 as table2;
26 | |""".stripMargin).tokens.asScala.toList
27 | context.build(wow)
28 |
29 | assert(context.statements.size == 2)
30 |
31 | }
32 | test("parse partial") {
33 | val wow = context.lexer.tokenizeNonDefaultChannel(
34 | """
35 | | -- yes
36 | | load hive.`` as -- jack
37 | | table1;
38 | | select * from table1
39 | |""".stripMargin).tokens.asScala.toList
40 | context.build(wow)
41 | printStatements(context.statements)
42 | assert(context.statements.size == 2)
43 | }
44 |
45 | def printStatements(items: List[List[Token]]) = {
46 | items.foreach { item =>
47 | println(item.map(_.getText).mkString(" "))
48 | println()
49 | }
50 | }
51 |
52 | test("relative pos convert") {
53 | val wow = context.lexer.tokenizeNonDefaultChannel(
54 | """
55 | | -- yes
56 | | load hive.`` as -- jack
57 | | table1;
58 | | select * from table1
59 | |""".stripMargin).tokens.asScala.toList
60 | context.build(wow)
61 |
62 | assert(context.statements.size == 2)
63 | // select * f[cursor]rom table1
64 | val tokenPos = LexerUtils.toTokenPos(wow, 5, 11)
65 | assert(tokenPos == TokenPos(9, TokenPosType.CURRENT, 1))
66 | assert(context.toRelativePos(tokenPos)._1 == TokenPos(2, TokenPosType.CURRENT, 1))
67 | }
68 |
69 | test("keyword") {
70 | val wow = context.lexer.tokenizeNonDefaultChannel(
71 | """
72 | | -- yes
73 | | loa
74 | |""".stripMargin).tokens.asScala.toList
75 | context.build(wow)
76 | val tokenPos = LexerUtils.toTokenPos(wow, 3, 4)
77 | assert(tokenPos == TokenPos(0, TokenPosType.CURRENT, 3))
78 | assert(context.suggest(3, 4) == List(SuggestItem("load", SpecialTableConst.KEY_WORD_TABLE, Map())))
79 | }
80 |
81 | test("spark sql") {
82 | val wow = context.rawSQLLexer.tokenizeNonDefaultChannel(
83 | """
84 | |SELECT CAST(25.65 AS int) from jack;
85 | |""".stripMargin).tokens.asScala.toList
86 |
87 | wow.foreach(item => println(s"${item.getText} ${item.getType}"))
88 | }
89 |
90 | test("load/select 4/10 select ke[cursor] from") {
91 | val wow =
92 | """
93 | | -- yes
94 | | load hive.`jack.db` as table1;
95 | | select ke from (select keywords,search_num,c from table1) table2
96 | |""".stripMargin
97 | val items = context.buildFromString(wow).suggest(4, 10)
98 | assert(items.map(_.name) == List("keywords"))
99 | }
100 |
101 | test("load/select 4/22 select from (select [cursor]keywords") {
102 | context.setUserDefinedMetaProvider(new MetaProvider {
103 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
104 | val key = MetaTableKey(None, None, "table1")
105 | val value = Option(MetaTable(
106 | key, List(
107 | MetaTableColumn("keywords", DataType.STRING, true, Map()),
108 | MetaTableColumn("search_num", DataType.STRING, true, Map()),
109 | MetaTableColumn("c", DataType.STRING, true, Map()),
110 | MetaTableColumn("d", DataType.STRING, true, Map())
111 | )
112 | ))
113 | value
114 | }
115 |
116 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List()
117 | })
118 | val wow = context.lexer.tokenizeNonDefaultChannel(
119 | """
120 | | -- yes
121 | | load hive.`jack.db` as table1;
122 | | select from (select keywords,search_num,c from table1) table2
123 | |""".stripMargin).tokens.asScala.toList
124 | val items = context.build(wow).suggest(4, 8)
125 | // items.foreach(println(_))
126 | assert(items.map(_.name) == List("table2", "keywords", "search_num", "c"))
127 |
128 | }
129 |
130 | test("load/select table with star") {
131 | context.setUserDefinedMetaProvider(new MetaProvider {
132 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
133 | if (key.prefix == Option("hive")) {
134 | Option(MetaTable(key, List(
135 | MetaTableColumn("a", DataType.STRING, true, Map()),
136 | MetaTableColumn("b", DataType.STRING, true, Map()),
137 | MetaTableColumn("c", DataType.STRING, true, Map()),
138 | MetaTableColumn("d", DataType.STRING, true, Map())
139 | )))
140 | } else None
141 | }
142 |
143 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ???
144 | })
145 | val wow = context.lexer.tokenizeNonDefaultChannel(
146 | """
147 | | -- yes
148 | | load hive.`db.table1` as table2;
149 | | select * from table2 as table3;
150 | | select from table3
151 | |""".stripMargin).tokens.asScala.toList
152 | val items = context.build(wow).suggest(5, 8)
153 | println(items)
154 |
155 | }
156 |
157 | test("load/select table with star and func") {
158 | context.setDebugMode(true)
159 | context.setUserDefinedMetaProvider(new MetaProvider {
160 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
161 | if (key.prefix == Option("hive")) {
162 | Option(MetaTable(key, List(
163 | MetaTableColumn("a", DataType.STRING, true, Map()),
164 | MetaTableColumn("b", DataType.STRING, true, Map()),
165 | MetaTableColumn("c", DataType.STRING, true, Map()),
166 | MetaTableColumn("d", DataType.STRING, true, Map())
167 | )))
168 | } else None
169 | }
170 |
171 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ???
172 | })
173 | val sql =
174 | """
175 | | -- yes
176 | | load hive.`db.table1` as table2;
177 | | select * from table2 as table3;
178 | | select sum() from table3
179 | |""".stripMargin
180 | val items = context.buildFromString(sql).suggest(5, 12)
181 | println(items)
182 |
183 | }
184 | test("table alias with temp table") {
185 | val sql =
186 | """
187 | |select a,b,c from table1 as table1;
188 | |select aa,bb,cc from table2 as table2;
189 | |select from table1 t1 left join table2 t2 on t1.a = t2.
190 | |""".stripMargin
191 |
192 | val items = context.buildFromString(sql).suggest(4, 58)
193 | assert(items.map(_.name) == List("aa", "bb", "cc"))
194 |
195 | }
196 | }
197 |
198 |
199 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/BaseTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import org.antlr.v4.runtime.Token
4 | import org.apache.spark.sql.SparkSession
5 | import org.apache.spark.sql.catalyst.parser.{SqlBaseLexer, SqlBaseParser}
6 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
7 | import streaming.dsl.parser.{DSLSQLLexer, DSLSQLParser}
8 | import tech.mlsql.autosuggest.{AutoSuggestContext, MLSQLSQLFunction}
9 |
10 | import scala.collection.JavaConverters._
11 |
12 | /**
13 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com)
14 | */
15 | class BaseTest extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
16 | val lexerAndParserfactory = new ReflectionLexerAndParserFactory(classOf[DSLSQLLexer], classOf[DSLSQLParser]);
17 | val loadLexer = new LexerWrapper(lexerAndParserfactory, new DefaultToCharStream)
18 |
19 | val lexerAndParserfactory2 = new ReflectionLexerAndParserFactory(classOf[SqlBaseLexer], classOf[SqlBaseParser]);
20 | val rawSQLloadLexer = new LexerWrapper(lexerAndParserfactory2, new RawSQLToCharStream)
21 |
22 | var context: AutoSuggestContext = _
23 | var tokens: List[Token] = _
24 | var sparkSession: SparkSession = _
25 |
26 | override def beforeAll(): Unit = {
27 | sparkSession = SparkSession.builder().appName("local").master("local[*]").getOrCreate()
28 |
29 | }
30 |
31 | override def afterAll(): Unit = {
32 | context.session.close()
33 | }
34 |
35 |
36 | override def beforeEach(): Unit = {
37 | MLSQLSQLFunction.funcMetaProvider.clear
38 | context = new AutoSuggestContext(sparkSession, loadLexer, rawSQLloadLexer)
39 | context.setDebugMode(true)
40 | var tr = loadLexer.tokenizeNonDefaultChannel(
41 | """
42 | | -- yes
43 | | load hive.`` as -- jack
44 | | table1;
45 | |""".stripMargin)
46 | tokens = tr.tokens.asScala.toList
47 | }
48 |
49 | def getMLSQLTokens(sql: String) = {
50 | context.lexer.tokenizeNonDefaultChannel(sql).tokens.asScala.toList
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/BaseTestWithoutSparkSession.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
4 | import tech.mlsql.autosuggest.app.AutoSuggestController
5 | import tech.mlsql.autosuggest.{AutoSuggestContext, MLSQLSQLFunction}
6 |
7 | /**
8 | * 11/6/2020 WilliamZhu(allwefantasy@gmail.com)
9 | */
10 | class BaseTestWithoutSparkSession extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
11 |
12 | var context: AutoSuggestContext = _
13 |
14 |
15 | override def beforeAll(): Unit = {
16 |
17 | }
18 |
19 | override def afterAll(): Unit = {
20 | context.session.close()
21 | }
22 |
23 |
24 | override def beforeEach(): Unit = {
25 | MLSQLSQLFunction.funcMetaProvider.clear
26 | context = new AutoSuggestContext(null, AutoSuggestController.mlsqlLexer, AutoSuggestController.sqlLexer)
27 | context.setDebugMode(true)
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/LexerUtilsTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import tech.mlsql.autosuggest.statement.LexerUtils
4 | import tech.mlsql.autosuggest.{TokenPos, TokenPosType}
5 |
6 | /**
7 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com)
8 | */
9 | class LexerUtilsTest extends BaseTest {
10 |
11 | test(" load [cursor]hive.`` as -- jack") {
12 | assert(LexerUtils.toTokenPos(tokens, 3, 6) == TokenPos(0, TokenPosType.NEXT, 0))
13 |
14 | }
15 | test(" load h[cursor]ive.`` as -- jack") {
16 | assert(LexerUtils.toTokenPos(tokens.toList, 3, 7) == TokenPos(1, TokenPosType.CURRENT, 1))
17 | }
18 |
19 | test("[cursor] load hive.`` as -- jack") {
20 | assert(LexerUtils.toTokenPos(tokens.toList, 3, 0) == TokenPos(-1, TokenPosType.NEXT, 0))
21 | }
22 | test(" load hive.`` as -- jack [cursor]") {
23 | assert(LexerUtils.toTokenPos(tokens.toList, 3, 23) == TokenPos(4, TokenPosType.NEXT, 0))
24 | }
25 |
26 | test("select sum([cursor]) as t from table") {
27 | context.buildFromString("select sum() as t from table")
28 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 11) == TokenPos(2, TokenPosType.NEXT, 0))
29 | }
30 |
31 | test("select from (select table2.abc as abc from table1 left join table2 on table1.column1 == table2.[cursor]) t1") {
32 | context.buildFromString("select from (select table2.abc as abc from table1 left join table2 on table1.column1 == table2.) t1")
33 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 96) == TokenPos(21, TokenPosType.NEXT, 0))
34 | }
35 |
36 | test("select sum(abc[cursor]) as t from table") {
37 | context.buildFromString("select sum(abc) as t from table")
38 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 14) == TokenPos(3, TokenPosType.CURRENT, 3))
39 | }
40 |
41 | test("load csv.") {
42 | context.buildFromString("load csv.")
43 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 9) == TokenPos(2, TokenPosType.NEXT, 0))
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/LoadSuggesterTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import tech.mlsql.autosuggest.statement.LoadSuggester
4 | import tech.mlsql.autosuggest.{TokenPos, TokenPosType}
5 |
6 | import scala.collection.JavaConverters._
7 |
8 | /**
9 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com)
10 | */
11 | class LoadSuggesterTest extends BaseTest {
12 | test("load hiv[cursor]") {
13 | val wow = context.lexer.tokenizeNonDefaultChannel(
14 | """
15 | | -- yes
16 | | load hiv
17 | |""".stripMargin).tokens.asScala.toList
18 | val loadSuggester = new LoadSuggester(context, wow, TokenPos(1, TokenPosType.CURRENT, 3)).suggest()
19 | assert(loadSuggester.map(_.name) == List("hive"))
20 | }
21 |
22 | test("load [cursor]") {
23 | val wow = context.lexer.tokenizeNonDefaultChannel(
24 | """
25 | | -- yes
26 | | load
27 | |""".stripMargin).tokens.asScala.toList
28 | val loadSuggester = new LoadSuggester(context, wow, TokenPos(0, TokenPosType.NEXT, 0)).suggest()
29 | println(loadSuggester)
30 | assert(loadSuggester.size > 1)
31 | }
32 |
33 | test("load csv.`` where [cursor]") {
34 | val wow = context.lexer.tokenizeNonDefaultChannel(
35 | """
36 | | -- yes
37 | | load csv.`` where
38 | |""".stripMargin).tokens.asScala.toList
39 | val result = new LoadSuggester(context, wow, TokenPos(4, TokenPosType.NEXT, 0)).suggest()
40 | println(result)
41 |
42 | }
43 |
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/MatchTokenTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer
4 | import streaming.dsl.parser.DSLSQLLexer
5 | import tech.mlsql.autosuggest.dsl.{Food, MLSQLTokenTypeWrapper, TokenMatcher}
6 | import tech.mlsql.autosuggest.statement.LexerUtils
7 |
8 | import scala.collection.JavaConverters._
9 |
10 | /**
11 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com)
12 | */
13 | class MatchTokenTest extends BaseTest {
14 | test("orIndex back") {
15 | val wow = context.lexer.tokenizeNonDefaultChannel(
16 | """
17 | |select a.k from jack.drugs_bad_case_di as a
18 | |""".stripMargin).tokens.asScala.toList
19 |
20 | val tokens = LexerUtils.toRawSQLTokens(context, wow)
21 | val temp = TokenMatcher(tokens, 6).back.orIndex(Array(Food(None, SqlBaseLexer.FROM), Food(None, SqlBaseLexer.SELECT)))
22 | assert(temp == 4)
23 | }
24 |
25 | test("orIndex forward") {
26 | val wow = context.lexer.tokenizeNonDefaultChannel(
27 | """
28 | |select a.k from jack.drugs_bad_case_di as a
29 | |""".stripMargin).tokens.asScala.toList
30 |
31 | val tokens = LexerUtils.toRawSQLTokens(context, wow)
32 | val temp = TokenMatcher(tokens, 0).forward.orIndex(Array(Food(None, SqlBaseLexer.FROM), Food(None, SqlBaseLexer.SELECT)))
33 | assert(temp == 0)
34 | }
35 |
36 | test("forward out of bound success") {
37 | val wow = context.lexer.tokenizeNonDefaultChannel(
38 | """
39 | |load csv.
40 | |""".stripMargin).tokens.asScala.toList
41 |
42 | val temp = TokenMatcher(wow, 0)
43 | .forward
44 | .eat(Food(None, DSLSQLLexer.LOAD))
45 | .eat(Food(None, DSLSQLLexer.IDENTIFIER))
46 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT)).build
47 | assert(temp.isSuccess)
48 | }
49 | test("forward out of bound fail") {
50 | val wow = context.lexer.tokenizeNonDefaultChannel(
51 | """
52 | |load csv
53 | |""".stripMargin).tokens.asScala.toList
54 |
55 | val temp = TokenMatcher(wow, 0)
56 | .forward
57 | .eat(Food(None, DSLSQLLexer.LOAD))
58 | .eat(Food(None, DSLSQLLexer.IDENTIFIER))
59 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT)).build
60 | assert(!temp.isSuccess)
61 | }
62 |
63 | test("back out of bound success") {
64 | val wow = context.lexer.tokenizeNonDefaultChannel(
65 | """
66 | |load csv.
67 | |""".stripMargin).tokens.asScala.toList
68 |
69 | val temp = TokenMatcher(wow, 2)
70 | .back
71 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT))
72 | .eat(Food(None, DSLSQLLexer.IDENTIFIER))
73 | .eat(Food(None, DSLSQLLexer.LOAD)).build
74 | assert(temp.isSuccess)
75 | }
76 | test("back out of bound fail") {
77 | val wow = context.lexer.tokenizeNonDefaultChannel(
78 | """
79 | |csv.
80 | |""".stripMargin).tokens.asScala.toList
81 |
82 | val temp = TokenMatcher(wow, 1)
83 | .back
84 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT))
85 | .eat(Food(None, DSLSQLLexer.IDENTIFIER))
86 | .eat(Food(None, DSLSQLLexer.LOAD)).build
87 | assert(!temp.isSuccess)
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/SelectSuggesterTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey}
4 | import tech.mlsql.autosuggest.statement.SelectSuggester
5 | import tech.mlsql.autosuggest.{DataType, MLSQLSQLFunction, TokenPos, TokenPosType}
6 |
7 | import scala.collection.JavaConverters._
8 |
9 | /**
10 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com)
11 | */
12 | class SelectSuggesterTest extends BaseTest {
13 |
14 | def buildMetaProvider = {
15 | context.setUserDefinedMetaProvider(new MetaProvider {
16 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
17 | Option(MetaTable(key, List(
18 | MetaTableColumn("no_result_type", null, true, Map()),
19 | MetaTableColumn("keywords", null, true, Map()),
20 | MetaTableColumn("search_num", null, true, Map()),
21 | MetaTableColumn("hp_stat_date", null, true, Map()),
22 | MetaTableColumn("action_dt", null, true, Map()),
23 | MetaTableColumn("action_type", null, true, Map()),
24 | MetaTableColumn("av", null, true, Map())
25 | )))
26 |
27 | }
28 |
29 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List()
30 | })
31 |
32 |
33 | }
34 |
35 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
36 | """
37 | |select no_result_type, keywords, search_num, rank
38 | |from(
39 | | select no_result_type, keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank
40 | | from(
41 | | select no_result_type, keywords, sum(search_num) AS search_num
42 | | from jack.drugs_bad_case_di,jack.abc jack
43 | | where hp_stat_date >= date_sub(current_date,30)
44 | | and action_dt >= date_sub(current_date,30)
45 | | and action_type = 'search'
46 | | and length(keywords) > 1
47 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9))
48 | | --and no_result_type = 'indication'
49 | | group by no_result_type, keywords
50 | | )a
51 | |)b
52 | |where rank <=
53 | |""".stripMargin).tokens.asScala.toList
54 |
55 | test("select") {
56 |
57 | context.setUserDefinedMetaProvider(new MetaProvider {
58 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = None
59 |
60 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List()
61 | })
62 |
63 |
64 | val suggester = new SelectSuggester(context, wow, TokenPos(0, TokenPosType.NEXT, 0))
65 | suggester.suggest()
66 |
67 |
68 | }
69 |
70 | test("project: complex attribute suggest") {
71 |
72 | buildMetaProvider
73 |
74 | lazy val wow2 = context.lexer.tokenizeNonDefaultChannel(
75 | """
76 | |select key no_result_type, keywords, search_num, rank
77 | |from(
78 | | select keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank
79 | | from(
80 | | select *,no_result_type, keywords, sum(search_num) AS search_num
81 | | from jack.drugs_bad_case_di,jack.abc jack
82 | | where hp_stat_date >= date_sub(current_date,30)
83 | | and action_dt >= date_sub(current_date,30)
84 | | and action_type = 'search'
85 | | and length(keywords) > 1
86 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9))
87 | | --and no_result_type = 'indication'
88 | | group by no_result_type, keywords
89 | | )a
90 | |)b
91 | |where rank <=
92 | |""".stripMargin).tokens.asScala.toList
93 |
94 | val suggester = new SelectSuggester(context, wow2, TokenPos(1, TokenPosType.CURRENT, 2))
95 | assert(suggester.suggest().map(_.name) == List("keywords"))
96 |
97 | }
98 |
99 | test("project: second level select ") {
100 |
101 | buildMetaProvider
102 |
103 | lazy val wow2 = context.lexer.tokenizeNonDefaultChannel(
104 | """
105 | |select key no_result_type, keywords, search_num, rank
106 | |from(
107 | | select sea keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank
108 | | from(
109 | | select *,no_result_type, keywords, sum(search_num) AS search_num
110 | | from jack.drugs_bad_case_di,jack.abc jack
111 | | where hp_stat_date >= date_sub(current_date,30)
112 | | and action_dt >= date_sub(current_date,30)
113 | | and action_type = 'search'
114 | | and length(keywords) > 1
115 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9))
116 | | --and no_result_type = 'indication'
117 | | group by no_result_type, keywords
118 | | )a
119 | |)b
120 | |where rank <=
121 | |""".stripMargin).tokens.asScala.toList
122 |
123 | // wow2.zipWithIndex.foreach{case (token,index)=>
124 | // println(s"${index} $token")}
125 | val suggester = new SelectSuggester(context, wow2, TokenPos(12, TokenPosType.CURRENT, 3))
126 | assert(suggester.suggest().distinct.map(_.name) == List("search_num"))
127 |
128 | }
129 |
130 | test("project: single query with alias table name") {
131 |
132 | buildMetaProvider
133 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
134 | """
135 | |select a.k from jack.drugs_bad_case_di as a
136 | |""".stripMargin).tokens.asScala.toList
137 |
138 | val suggester = new SelectSuggester(context, wow, TokenPos(3, TokenPosType.CURRENT, 1))
139 |
140 | assert(suggester.suggest().map(_.name) == List(("keywords")))
141 | }
142 |
143 | test("project: table or attribute") {
144 |
145 | buildMetaProvider
146 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
147 | """
148 | |select from jack.drugs_bad_case_di as a
149 | |""".stripMargin).tokens.asScala.toList
150 |
151 | val suggester = new SelectSuggester(context, wow, TokenPos(0, TokenPosType.NEXT, 0))
152 |
153 | assert(suggester.suggest().map(_.name) == List(("a"),
154 | ("no_result_type"),
155 | ("keywords"),
156 | ("search_num"),
157 | ("hp_stat_date"),
158 | ("action_dt"),
159 | ("action_type"),
160 | ("av")))
161 | }
162 |
163 | test("project: complex table attribute ") {
164 |
165 | buildMetaProvider
166 |
167 | lazy val wow2 = context.lexer.tokenizeNonDefaultChannel(
168 | """
169 | |select no_result_type, keywords, search_num, rank
170 | |from(
171 | | select keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank
172 | | from(
173 | | select *,no_result_type, keywords, sum(search_num) AS search_num
174 | | from jack.drugs_bad_case_di,jack.abc jack
175 | | where hp_stat_date >= date_sub(current_date,30)
176 | | and action_dt >= date_sub(current_date,30)
177 | | and action_type = 'search'
178 | | and length(keywords) > 1
179 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9))
180 | | --and no_result_type = 'indication'
181 | | group by no_result_type, keywords
182 | | )a
183 | |)b
184 | |where rank <=
185 | |""".stripMargin).tokens.asScala.toList
186 |
187 | val suggester = new SelectSuggester(context, wow2, TokenPos(0, TokenPosType.NEXT, 0))
188 | println(suggester.sqlAST.printAsStr(suggester.tokens, 0))
189 | suggester.table_info.foreach { case (level, item) =>
190 | println(level + ":")
191 | println(item.map(_._1).toList)
192 | }
193 | assert(suggester.suggest().map(_.name) == List(("b"), ("keywords"), ("search_num"), ("rank")))
194 |
195 | }
196 |
197 | test("table layer") {
198 | buildMetaProvider
199 | val sql =
200 | """
201 | |select from (select no_result_type from db1.table1) b;
202 | |""".stripMargin
203 | val tokens = getMLSQLTokens(sql)
204 |
205 | val suggester = new SelectSuggester(context, tokens, TokenPos(0, TokenPosType.NEXT, 0))
206 | println("=======")
207 | println(suggester.suggest())
208 | assert(suggester.suggest().head.name=="b")
209 | }
210 |
211 |
212 |
213 | test("project: function suggester") {
214 |
215 | val func = MLSQLSQLFunction.apply("split").
216 | funcParam.
217 | param("str", DataType.STRING).
218 | param("splitter", DataType.STRING).
219 | func.
220 | returnParam(DataType.ARRAY, true, Map()).
221 | build
222 | MLSQLSQLFunction.funcMetaProvider.register(func)
223 |
224 | val tableKey = MetaTableKey(None, Option("jack"), "drugs_bad_case_di")
225 |
226 | val metas = Map(tableKey ->
227 | Option(MetaTable(tableKey, List(
228 | MetaTableColumn("no_result_type", null, true, Map()),
229 | MetaTableColumn("keywords", null, true, Map()),
230 | MetaTableColumn("search_num", null, true, Map()),
231 | MetaTableColumn("hp_stat_date", null, true, Map()),
232 | MetaTableColumn("action_dt", null, true, Map()),
233 | MetaTableColumn("action_type", null, true, Map()),
234 | MetaTableColumn("av", null, true, Map())
235 | )))
236 |
237 | )
238 | context.setUserDefinedMetaProvider(new MetaProvider {
239 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
240 | metas(key)
241 | }
242 |
243 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List()
244 | })
245 |
246 |
247 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
248 | """
249 | |select spl from jack.drugs_bad_case_di as a
250 | |""".stripMargin).tokens.asScala.toList
251 |
252 | val suggester = new SelectSuggester(context, wow, TokenPos(1, TokenPosType.CURRENT, 3))
253 | println(suggester.suggest())
254 | assert(suggester.suggest().map(_.name) == List(("split")))
255 | }
256 |
257 |
258 |
259 |
260 | }
261 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/TablePreprocessorTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey}
4 | import tech.mlsql.autosuggest.preprocess.TablePreprocessor
5 | import tech.mlsql.autosuggest.{DataType, SpecialTableConst}
6 |
7 | import scala.collection.JavaConverters._
8 |
9 | /**
10 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com)
11 | */
12 | class TablePreprocessorTest extends BaseTest {
13 |
14 | test("load/select table") {
15 | context.setUserDefinedMetaProvider(new MetaProvider {
16 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = {
17 | if (key.prefix == Option("hive")) {
18 | Option(MetaTable(key, List(
19 | MetaTableColumn("a", DataType.STRING, true, Map()),
20 | MetaTableColumn("b", DataType.STRING, true, Map()),
21 | MetaTableColumn("c", DataType.STRING, true, Map()),
22 | MetaTableColumn("d", DataType.STRING, true, Map())
23 | )))
24 | } else None
25 | }
26 |
27 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ???
28 | })
29 | val wow = context.lexer.tokenizeNonDefaultChannel(
30 | """
31 | | -- yes
32 | | load hive.`db.table1` as table2;
33 | | select a,b,c from table2 as table3;
34 | |""".stripMargin).tokens.asScala.toList
35 | context.build(wow)
36 | val processor = new TablePreprocessor(context)
37 | context.statements.foreach(processor.process(_))
38 |
39 | val targetTable = context.tempTableProvider.search(SpecialTableConst.tempTable("table2").key).get
40 | assert(targetTable.key == MetaTableKey(Some("hive"), Some("db"), "table1"))
41 |
42 | val targetTable3 = context.tempTableProvider.search(SpecialTableConst.tempTable("table3").key).get
43 | assert(targetTable3 == MetaTable(MetaTableKey(None, Some(SpecialTableConst.TEMP_TABLE_DB_KEY), "table3"),
44 | List(MetaTableColumn("a", null, true, Map()), MetaTableColumn("b", null, true, Map()), MetaTableColumn("c", null, true, Map()))))
45 | }
46 |
47 | test("load/select table with star") {
48 | context.setUserDefinedMetaProvider(new MetaProvider {
49 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = {
50 | if (key.prefix == Option("hive")) {
51 | Option(MetaTable(key, List(
52 | MetaTableColumn("a", DataType.STRING, true, Map()),
53 | MetaTableColumn("b", DataType.STRING, true, Map()),
54 | MetaTableColumn("c", DataType.STRING, true, Map()),
55 | MetaTableColumn("d", DataType.STRING, true, Map())
56 | )))
57 | } else None
58 | }
59 |
60 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ???
61 | })
62 | val wow = context.lexer.tokenizeNonDefaultChannel(
63 | """
64 | | -- yes
65 | | load hive.`db.table1` as table2;
66 | | select * from table2 as table3;
67 | |""".stripMargin).tokens.asScala.toList
68 | context.build(wow)
69 | val processor = new TablePreprocessor(context)
70 | context.statements.foreach(processor.process(_))
71 |
72 | val targetTable = context.tempTableProvider.search(SpecialTableConst.tempTable("table2").key).get
73 | assert(targetTable.key == MetaTableKey(Some("hive"), Some("db"), "table1"))
74 |
75 | val targetTable3 = context.tempTableProvider.search(SpecialTableConst.tempTable("table3").key).get
76 | assert(targetTable3 == MetaTable(MetaTableKey(None, Some(SpecialTableConst.TEMP_TABLE_DB_KEY), "table3"),
77 | List(MetaTableColumn("a", null, true, Map()),
78 | MetaTableColumn("b", null, true, Map()),
79 | MetaTableColumn("c", null, true, Map()), MetaTableColumn("d", null, true, Map()))))
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/src/test/java/com/intigua/antlr4/autosuggest/TableStructureTest.scala:
--------------------------------------------------------------------------------
1 | package com.intigua.antlr4.autosuggest
2 |
3 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey}
4 | import tech.mlsql.autosuggest.statement.{MetaTableKeyWrapper, SelectSuggester}
5 | import tech.mlsql.autosuggest.{TokenPos, TokenPosType}
6 | import tech.mlsql.common.utils.log.Logging
7 | import scala.collection.JavaConverters._
8 |
9 | import scala.collection.mutable.ArrayBuffer
10 |
11 | /**
12 | * 22/6/2020 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class TableStructureTest extends BaseTest with Logging {
15 |
16 |
17 | test("s1") {
18 | buildMetaProvider
19 | val sql =
20 | """
21 | |select from (select no_result_type from db1.table1) b;
22 | |""".stripMargin
23 | val tokens = getMLSQLTokens(sql)
24 |
25 | val suggester = new SelectSuggester(context, tokens, TokenPos(0, TokenPosType.NEXT, 0))
26 | println(suggester.sqlAST)
27 | }
28 |
29 | test("s2") {
30 | buildMetaProvider
31 | val sql =
32 | """
33 | |select from (select no_result_type from (select no_result_type from db1.table1) b left join db2.table2) c;
34 | |""".stripMargin
35 | val tokens = getMLSQLTokens(sql)
36 |
37 | val suggester = new SelectSuggester(context, tokens, TokenPos(0, TokenPosType.NEXT, 0))
38 | printAST(suggester)
39 | }
40 |
41 | def printAST(suggester: SelectSuggester) = {
42 | suggester.sqlAST
43 | logInfo(s"SQL[${suggester.tokens.map(_.getText).mkString(" ")}]")
44 | logInfo(s"STRUCTURE: \n")
45 | suggester.table_info.foreach { item =>
46 | logInfo(s"Level:${item._1}")
47 | item._2.foreach { table =>
48 | logInfo(s"${table._1} => ${table._2.copy(columns = List())}")
49 | }
50 | }
51 | }
52 |
53 |
54 | def buildMetaProvider = {
55 | context.setUserDefinedMetaProvider(new MetaProvider {
56 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = {
57 | Option(MetaTable(key, List(
58 | MetaTableColumn("no_result_type", null, true, Map()),
59 | MetaTableColumn("keywords", null, true, Map()),
60 | MetaTableColumn("search_num", null, true, Map()),
61 | MetaTableColumn("hp_stat_date", null, true, Map()),
62 | MetaTableColumn("action_dt", null, true, Map()),
63 | MetaTableColumn("action_type", null, true, Map()),
64 | MetaTableColumn("av", null, true, Map())
65 | )))
66 |
67 | }
68 |
69 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List()
70 | })
71 |
72 |
73 | }
74 | test("single select build") {
75 |
76 | buildMetaProvider
77 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
78 | """
79 | |select a.k from jack.drugs_bad_case_di as a
80 | |""".stripMargin).tokens.asScala.toList
81 |
82 | val suggester = new SelectSuggester(context, wow, TokenPos(3, TokenPosType.CURRENT, 1))
83 | val root = suggester.sqlAST
84 | root.visitDown(0) { case (ast, level) =>
85 | println(s"${ast.name(suggester.tokens)} ${ast.output(suggester.tokens)}")
86 | }
87 |
88 | assert(suggester.suggest().map(_.name) == List("keywords"))
89 | }
90 |
91 | test("subquery build") {
92 | buildMetaProvider
93 |
94 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
95 | """
96 | |select a.k from (select * from jack.drugs_bad_case_di ) a;
97 | |""".stripMargin).tokens.asScala.toList
98 |
99 | val suggester = new SelectSuggester(context, wow, TokenPos(3, TokenPosType.CURRENT, 1))
100 | val root = suggester.sqlAST
101 | root.visitDown(0) { case (ast, level) =>
102 | println(s"${ast.name(suggester.tokens)} ${ast.output(suggester.tokens)}")
103 | }
104 |
105 | assert(suggester.suggest().map(_.name) == List("keywords"))
106 | }
107 |
108 | test("subquery build without prefix") {
109 | buildMetaProvider
110 |
111 | lazy val wow = context.lexer.tokenizeNonDefaultChannel(
112 | """
113 | |select k from (select * from jack.drugs_bad_case_di ) a;
114 | |""".stripMargin).tokens.asScala.toList
115 |
116 | val suggester = new SelectSuggester(context, wow, TokenPos(1, TokenPosType.CURRENT, 1))
117 | val root = suggester.sqlAST
118 | val buffer = ArrayBuffer[String]()
119 | root.visitDown(0) { case (ast, level) =>
120 |
121 | buffer += suggester._tokens.slice(ast.start, ast.stop).map(_.getText).mkString(" ")
122 |
123 | }
124 | assert(buffer(0) == "select k from ( select * from jack . drugs_bad_case_di ) a ;")
125 | assert(buffer(1) == "select * from jack . drugs_bad_case_di ) a")
126 |
127 | suggester.table_info.map {
128 | case (level, table) =>
129 | if (level == 0) {
130 | assert(table.map(_._1).toList == List(MetaTableKeyWrapper(MetaTableKey(None, None, null), Some("a"))))
131 | }
132 | if (level == 1) {
133 | val tables = table.map(_._1).toList
134 | assert(tables == List(MetaTableKeyWrapper(MetaTableKey(None, Some("jack"), "drugs_bad_case_di"), None)))
135 |
136 | }
137 | }
138 |
139 | }
140 |
141 | }
142 |
--------------------------------------------------------------------------------