├── .gitignore ├── README.md ├── config ├── application.example.yml ├── application.yml └── logging.yml ├── dev ├── build-package.sh └── start.sh ├── pom.xml └── src ├── main ├── java │ ├── com │ │ └── intigua │ │ │ └── antlr4 │ │ │ └── autosuggest │ │ │ ├── AutoSuggester.java │ │ │ ├── CasePreference.java │ │ │ ├── DefaultToCharStream.scala │ │ │ ├── LexerAndParserFactory.java │ │ │ ├── LexerFactory.java │ │ │ ├── LexerWrapper.java │ │ │ ├── ParserFactory.java │ │ │ ├── ParserWrapper.java │ │ │ ├── ReflectionLexerAndParserFactory.java │ │ │ ├── ToCharStream.java │ │ │ ├── TokenSuggester.java │ │ │ └── TransitionWrapper.java │ └── tech │ │ └── mlsql │ │ └── autosuggest │ │ ├── AttributeExtractor.scala │ │ ├── AutoSuggestContext.scala │ │ ├── AutoSuggester.scala │ │ ├── FuncReg.scala │ │ ├── FunctionUtils.scala │ │ ├── POrCLiterals.scala │ │ ├── SpecialTableConst.scala │ │ ├── TokenPos.scala │ │ ├── app │ │ ├── Constants.scala │ │ ├── MLSQLAutoSuggestApp.scala │ │ ├── MysqlType.java │ │ ├── RDSchema.scala │ │ ├── RawDBTypeToJavaType.scala │ │ ├── SchemaRegistry.scala │ │ └── Standalone.scala │ │ ├── ast │ │ ├── NoneToken.scala │ │ └── TableTree.scala │ │ ├── dsl │ │ └── TokenMatcher.scala │ │ ├── funcs │ │ ├── Count.scala │ │ └── Splitter.scala │ │ ├── meta │ │ ├── LayeredMetaProvider.scala │ │ ├── MLSQLEngineMetaProvider.scala │ │ ├── MemoryMetaProvider.scala │ │ ├── MetaProvider.scala │ │ ├── RestMetaProvider.scala │ │ ├── StatementTempTableProvider.scala │ │ └── meta_protocal.scala │ │ ├── preprocess │ │ └── TablePreprocessor.scala │ │ ├── statement │ │ ├── LexerUtils.scala │ │ ├── LoadSuggester.scala │ │ ├── MLSQLStatementSplitter.scala │ │ ├── MatchAndExtractor.scala │ │ ├── PreProcessStatement.scala │ │ ├── RegisterSuggester.scala │ │ ├── SelectStatementUtils.scala │ │ ├── SelectSuggester.scala │ │ ├── StatementSplitter.scala │ │ ├── StatementSuggester.scala │ │ ├── StatementUtils.scala │ │ ├── SuggesterRegister.scala │ │ ├── TableExtractor.scala │ │ ├── TemplateSuggester.scala │ │ └── single_statement.scala │ │ └── utils │ │ └── SchemaUtils.scala └── resources │ └── log4j.properties └── test └── java └── com └── intigua └── antlr4 └── autosuggest ├── AutoSuggestContextTest.scala ├── BaseTest.scala ├── BaseTestWithoutSparkSession.scala ├── LexerUtilsTest.scala ├── LoadSuggesterTest.scala ├── MatchTokenTest.scala ├── SelectSuggesterTest.scala ├── TablePreprocessorTest.scala └── TableStructureTest.scala /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .idea/ 3 | spark-binlog.iml 4 | release.sh 5 | build 6 | logs 7 | dependency-reduced-pom.xml 8 | sql-code-intelligence.iml 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SQL Code Intelligence 2 | SQL Code Intelligence 是一个代码补全后端引擎。既可以作为MLSQL语法补全, 3 | 也可以为标准Spark SQL(Select语句)做补全。` 4 | 5 | ## 当前状态 6 | 【积极开发中,还未发布稳定版本】 7 | 8 | 9 | ## 发行方式 10 | 11 | ### maven 依赖 12 | 13 | ```xml 14 | 15 | tech.mlsql 16 | sql-code-intelligence 17 | 0.1.0 18 | 19 | ``` 20 | 21 | 使用该依赖,用户可以很好的将功能嵌入到自己的web应用中,比如使用spring之类的web框架。 22 | 23 | ### 预编译包 24 | 25 | 26 | 下载地址:[sql-code-intelligence](http://download.mlsql.tech/sql-code-intelligence/). 27 | 28 | 下载后 `tar xvf sql-code-intelligence-0.1.0.tar` 解压,执行如下指令即可运行: 29 | 30 | ``` 31 | ./start.sh 32 | ``` 33 | 34 | 用户也可以直接使用Java命令启动: 35 | 36 | ``` 37 | java -cp .:sql-code-intelligence-0.1.0.jar tech.mlsql.autosuggest.app.Standalone -config ./config/application.yml 38 | ``` 39 | application.yml 可以参考mlsql-autosuggest/config的示例。默认端口是9004. 40 | 41 | 42 | 43 | ## 目标 44 | 【SQL Code Intelligence】目标分成两个,第一个是标准SQL补全: 45 | 46 | 1. SQL关键字补全 47 | 2. 表/字段属性/函数补全 48 | 3. 可二次开发自定义对接任何Schema Provider 49 | 50 | 第二个是MLSQL语法补全(依托于标准SQL的补全功能之上): 51 | 52 | 1. 支持各种数据源提示 53 | 2. 支持临时表提示(临时表字段补全等等) 54 | 3. 支持各种ET组件参数提示以及名称提示 55 | 56 | 对于表和字段补,函数补全,相比其他一些SQL代码提示工具,该插件可根据当前已有的信息精确推断。 57 | 58 | ## 效果展示 59 | 【点击下面动图可看到标准SQL补全视频】 60 | [![](http://docs.mlsql.tech/upload_images/sql-code.gif)](https://www.bilibili.com/video/BV1xk4y1z7tV) 61 | 62 | ### 标准的SQL语法提示 63 | 64 | ```sql 65 | select no_result_type, keywords, search_num, rank 66 | from( 67 | select [鼠标位置] row_number() over (PARTITION BY no_result_type order by search_num desc) as rank 68 | from( 69 | select jack1.*,no_result_type, keywords, sum(search_num) AS search_num 70 | from jack.drugs_bad_case_di as jack1,jack.abc jack2 71 | where hp_stat_date >= date_sub(current_date,30) 72 | and action_dt >= date_sub(current_date,30) 73 | and action_type = 'search' 74 | and length(keywords) > 1 75 | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9)) 76 | --and no_result_type = 'indication' 77 | group by no_result_type, keywords 78 | )a 79 | )b 80 | where rank <= 81 | ``` 82 | 83 | 鼠标在第三行第十列,此时系统会自动提示: 84 | 1. a [表名] 85 | 2. jack1展开的所有列 86 | 3. no_result_type 87 | 4. keywords 88 | 5. search_num 89 | 90 | 91 | ### 多行MLSQL的提示 92 | 93 | ```sql 94 | load hive.`db.table1` as table2; 95 | select * from table2 as table3; 96 | select [鼠标位置] from table3 97 | ``` 98 | 99 | 假设db.table1 表的字段为a,b,c,d 100 | 其中鼠标在低3行第七列,在此位置,会提示: 101 | 102 | 1. table3 103 | 2. a 104 | 3. b 105 | 4. c 106 | 5. d 107 | 108 | 可以看到,系统具有非常强的跨语句能力,会自动展开*,并且推测出每个表的schema信息从而进行补全。 109 | 110 | ## MLSQL 数据源/ET组件参数提示 111 | 112 | ```sql 113 | select spl from jack.drugs_bad_case_di as a; 114 | load csv.`/tmp.csv` where [鼠标位置] 115 | ``` 116 | 117 | 通常加载csv我们需要设定下csv是不是包含header, 分割符是什么。不过一般我们需要去查文档才能知道这些参数。 现在,【SQL Code Intelligence】会给出提示: 118 | 119 | ```json 120 | [ 121 | {"name":"codec","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 122 | {"name":"dateFormat","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 123 | {"name":"delimiter","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 124 | {"name":"emptyValue","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 125 | {"name":"escape","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 126 | {"name":"header","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 127 | {"name":"inferSchema","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}, 128 | {"name":"quote","metaTable":{"key":{"table":"__OPTION__"},"columns":[]},"extra":{}}] 129 | ``` 130 | 131 | 132 | 133 | ## 用户指南 134 | 135 | ### 部署 136 | 137 | 如果作为MLSQL插件运行, 请参考部署文档 [MLSQL部署](http://docs.mlsql.tech/zh/installation/) 138 | 该插件作为MLSQ默认插件,所以开箱即用 139 | 140 | 如果用户希望作为独立应用运行, 下载:[sql-code-intelligence](http://download.mlsql.tech/sql-code-intelligence/). 141 | 下载后 `tar xvf sql-code-intelligence-0.1.0.tar` 解压,执行如下指令即可运行: 142 | 143 | ``` 144 | java -cp .:sql-code-intelligence-0.1.0.jar tech.mlsql.autosuggest.app.Standalone -config ./config/application.yml 145 | ``` 146 | application.yml 可以参考mlsql-autosuggest/config的示例。默认端口是9004. 147 | 148 | ### 编辑器接入 (只要三分钟) 149 | 150 | 写一个http请求类: 151 | 152 | ![](http://docs.mlsql.tech/upload_images/WX20200508-115001.png) 153 | 154 | 在编辑器设置autocompleter回调: 155 | 156 | ![](http://docs.mlsql.tech/upload_images/WechatIMG1001.jpeg) 157 | 158 | Done。 这里我使用的是reactjs, 编辑器是aceditor. 159 | 160 | ### Schema信息 161 | 162 | 【SQL Code Intelligence】 需要基础表的schema信息,目前用户有三种可选方式: 163 | 164 | 1. 主动注册schema信息 (适合体验和调试) 165 | 2. 提供符合规范的Rest接口,系统会自动调用该接口获取schema信息 (推荐,对本项目无需任何修改) 166 | 3. 扩展【SQL Code Intelligence】的 MetaProvider,使得系统可以获取shcema信息。 (启动本项目时需要注册该类) 167 | 168 | 最简单的是方式1. 通过http接口注册表信息,我下面是使用scala代码完成,用户也可以使用POSTMan之类的工具完成注册。 169 | 170 | ```scala 171 | def registerTable(port: Int = 9003) = { 172 | val time = System.currentTimeMillis() 173 | val response = Request.Post(s"http://127.0.0.1:${port}/run/script").bodyForm( 174 | Form.form().add("executeMode", "registerTable").add("schema", 175 | """ 176 | |CREATE TABLE emps( 177 | | empid INT NOT NULL, 178 | | deptno INT NOT NULL, 179 | | locationid INT NOT NULL, 180 | | empname VARCHAR(20) NOT NULL, 181 | | salary DECIMAL (18, 2), 182 | | PRIMARY KEY (empid), 183 | | FOREIGN KEY (deptno) REFERENCES depts(deptno), 184 | | FOREIGN KEY (locationid) REFERENCES locations(locationid) 185 | |); 186 | |""".stripMargin).add("db", "db1").add("table", "emps"). 187 | add("isDebug", "true").build() 188 | ).execute().returnContent().asString() 189 | println(response) 190 | } 191 | ``` 192 | 193 | 创建表的语句类型支持三种:db,hive,json。 分别对应MySQL语法,Hive语法,Spark SQL schema json格式。默认是MySQL的语法。 194 | 195 | 接着就系统就能够提示了: 196 | 197 | ```scala 198 | def testSuggest(port: Int = 9003) = { 199 | val time = System.currentTimeMillis() 200 | val response = Request.Post(s"http://127.0.0.1:${port}/run/script").bodyForm( 201 | Form.form().add("executeMode", "autoSuggest").add("sql", 202 | """ 203 | |select emp from db1.emps as a; 204 | |-- load csv.`/tmp.csv` where 205 | |""".stripMargin).add("lineNum", "2").add("columnNum", "10"). 206 | add("isDebug", "true").build() 207 | ).execute().returnContent().asString() 208 | println(response) 209 | } 210 | ``` 211 | 212 | 第二种在请求参数里传递searchUrl和listUrl,要求接口的输入输出需要符合`tech.mlsql.autosuggest.meta.RestMetaProvider` 213 | 中的定义。 214 | 215 | 216 | 第三种模式是用户实现一个自定义的MetaProvider,就可以充分利用自己的schema系统 217 | 218 | ```scala 219 | trait MetaProvider { 220 | def search(key: MetaTableKey): Option[MetaTable] 221 | 222 | def list: List[MetaTable] 223 | } 224 | ``` 225 | 226 | 使用时,在AutoSuggestContext设置下使其生效: 227 | 228 | ``` 229 | context.setUserDefinedMetaProvider(有的实现类的实例) 230 | ``` 231 | 232 | MetaTableKey 的定义如下: 233 | 234 | ```scala 235 | case class MetaTableKey(prefix: Option[String], db: Option[String], table: String) 236 | ``` 237 | 238 | prefix是方便定义数据源的。比如同样一个表,可能是hive表,也可能是mysql表。如果你只有一个数仓,不访问其他数据源,那么设置为None就好。对于下面的句子: 239 | 240 | ```sql 241 | load hive.`db.table1` as table2; 242 | ``` 243 | 【SQL Code Intelligence】 会发送如下的MetaTableKey给你的MetaProvider.search方法: 244 | 245 | ```scala 246 | MetaTableKey(Option(hive),Option("db"),Option("table2")) 247 | ``` 248 | 249 | 如果是一个普通的SQL语句,而非MLSQL 语句,比如: 250 | 251 | ```sql 252 | select * from db.table1 253 | ``` 254 | 255 | 则发送给MetaProvider.search方法的MetaTableKey是这个样子的: 256 | 257 | ```scala 258 | MetaTableKey(None,Option("db"),Option("table2")) 259 | ``` 260 | 261 | ### 接口使用 262 | 访问接口: http://127.0.0.1:9003/run/script?executeMode=autoSuggest 263 | 264 | 参数1: sql SQL脚本 265 | 参数2: lineNum 光标所在的行号 从1开始计数 266 | 参数3: columnNum 光标所在的列号,从1开始计数 267 | 268 | 比如我用Scala写一个client: 269 | 270 | ``` 271 | object Test { 272 | def main(args: Array[String]): Unit = { 273 | val time = System.currentTimeMillis() 274 | val response = Request.Post("http://127.0.0.1:9003/run/script").bodyForm( 275 | Form.form().add("executeMode", "autoSuggest").add("sql", 276 | """ 277 | |select spl from jack.drugs_bad_case_di as a 278 | |""".stripMargin).add("lineNum", "2").add("columnNum", "10").build() 279 | ).execute().returnContent().asString() 280 | println(System.currentTimeMillis() - time) 281 | println(response) 282 | } 283 | 284 | } 285 | ``` 286 | 287 | 最后结果如下: 288 | 289 | ```json 290 | [{"name":"split", 291 | "metaTable":{"key":{"db":"__FUNC__","table":"split"}, 292 | "columns":[ 293 | {"name":null,"dataType":"array","isNull":true,"extra":{"zhDoc":"\nsplit函数。用于切割字符串,返回字符串数组\n"}},{"name":"str","dataType":"string","isNull":false,"extra":{"zhDoc":"待切割字符"}}, 294 | {"name":"pattern","dataType":"string","isNull":false,"extra":{"zhDoc":"分隔符"}}]}, 295 | "extra":{}}] 296 | ``` 297 | 可以知道提示了split,并且这是一个函数,函数的参数以及返回值都有定义。 298 | 299 | ### 编程使用 300 | 301 | 创建AutoSuggestContext即可,然后用buildFromString处理字符串,使用suggest方法 302 | 进行推荐。 303 | 304 | ```scala 305 | 306 | val sql = params("sql") 307 | val lineNum = params("lineNum").toInt 308 | val columnNum = params("columnNum").toInt 309 | 310 | val sparkSession = SparkSession.builder().appName("local").master("local[*]").getOrCreate() 311 | val context = new AutoSuggestContext(sparkSession, 312 | AutoSuggestController.mlsqlLexer, 313 | AutoSuggestController.sqlLexer) 314 | 315 | JSONTool.toJsonStr(context.buildFromString(sql).suggest(lineNum,columnNum)) 316 | ``` 317 | 318 | sparkSession也可以设置为null,但是会缺失一些功能,比如数据源提示等等。 319 | 320 | 321 | ## 开发者指南 322 | 323 | ### 解析流程 324 | 【SQL Code Intelligence】复用了MLSQL/Spark SQL的lexer,重写了parser部分。因为代码提示有其自身特点,就是句法在书写过程中,大部分情况下都是错误的,无法使用严格的parser来进行解析。 325 | 326 | 使用两个Lexer的原因是因为,MLSQL Lexer主要用来解析整个MLSQL脚本,Spark SQL Lexer主要用来解决标准SQL中的select语句。但是因为该项目高度可扩展,用户也可以自行扩展到其他标准SQL的语句中。 327 | 328 | 以select语句里的代码提示为例,整个解析流程为: 329 | 330 | 1. 使用MLSQL Lexer 将脚本切分成多个statement 331 | 2. 每个statement 会使用不同的Suggester进行下一步解析 332 | 3. 使用SelectSuggester 对select statement进行解析 333 | 4. 首先对select语句构建一个非常粗粒度的AST,节点为每个子查询,同时构建一个表结构层级缓存信息TABLE_INFO 334 | 5. 将光标位置转化为全局TokenPos 335 | 6. 将全局TokenPos转化select语句相对TokenPos 336 | 7. 根据TokenPos遍历Select AST树,定位到简单子语句 337 | 8. 使用project/where/groupby/on/having子suggester进行匹配,匹配的suggester最后完成提示逻辑 338 | 339 | 在AST树种,每个子语句都可以是不完整的。由上面流程可知,我们会以statement为粗粒度工作context,然后对于复杂的select语句,最后我们会进一步细化到每个子查询为工作context。这样为我们编码带来了非常大的便利。 340 | 341 | 342 | ### 快速参与贡献该项目 343 | 【SQL Code Intelligence】 需要大量函数的定义,方便在用户使用时给予提示。下面是我实现的`split` 函数的代码: 344 | 345 | ```scala 346 | class Splitter extends FuncReg { 347 | 348 | override def register = { 349 | val func = MLSQLSQLFunction.apply("split"). 350 | funcParam. 351 | param("str", DataType.STRING, false, Map("zhDoc" -> "待切割字符")). 352 | param("pattern", DataType.STRING, false, Map("zhDoc" -> "分隔符")). 353 | func. 354 | returnParam(DataType.ARRAY, true, Map( 355 | "zhDoc" -> 356 | """ 357 | |split函数。用于切割字符串,返回字符串数组 358 | |""".stripMargin 359 | )). 360 | build 361 | func 362 | } 363 | 364 | } 365 | ``` 366 | 367 | 用户只要用FunctionBuilder去构建函数签名即可。这样用户在使用该函数的时候就能得到非常详尽的使用说明和参数说明。同时,我们也可以通过该函数签名获取嵌套函数处理后的字段的类型信息。 368 | 369 | 用户只要按上面的方式添加更多函数到tech.mlsql.autosuggest.funcs包下即可。系统会自动扫描该包里的实现并且注册。 370 | 371 | ### TokenMatcher工具类 372 | 373 | 在【SQL Code Intelligence】中,最主要的工作是做token匹配。我们提供了TokenMatcher来完成token的匹配。TokenMatcher支持前向和后向匹配。如下token序列: 374 | 375 | ``` 376 | select a , b , c from jack 377 | ``` 378 | 379 | 假设我想以token index 3(b) 为起始点,前向匹配一个逗号,identify 可以使用如下语法: 380 | 381 | ```scala 382 | val tokenMatcher = TokenMatcher(tokens,4).forward.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build 383 | ``` 384 | 385 | 接着你可以调用 tokenMatcher.isSuccess来判断是否匹配成功,可以调用tokenMatcher.get 获取匹配得到匹配成功后的index,通过tokenMatcher.getMatchTokens 获取匹配成功的token集合。 386 | 387 | 注意,TokenMatcher起始位置是包含的,也就是他会将起始位置的token也加入到匹配token里去。所以在上面的例子中,start 是4而不是3. 更多例子可以查看源码。 388 | 389 | ### 子查询层级结构 390 | 391 | 对于语句: 392 | 393 | ```sql 394 | select no_result_type, keywords, search_num, rank 395 | from( 396 | select keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank 397 | from( 398 | select *,no_result_type, keywords, sum(search_num) AS search_num 399 | from jack.drugs_bad_case_di,jack.abc jack 400 | where hp_stat_date >= date_sub(current_date,30) 401 | and action_dt >= date_sub(current_date,30) 402 | and action_type = 'search' 403 | and length(keywords) > 1 404 | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9)) 405 | --and no_result_type = 'indication' 406 | group by no_result_type, keywords 407 | )a 408 | )b 409 | where rank <= 410 | ``` 411 | 412 | 形成的AST结构树如下: 413 | 414 | ```sql 415 | select no_result_type , keywords , search_num , rank from ( select keywords , search_num , row_number ( ) over 416 | ( PARTITION BY no_result_type order by search_num desc ) as rank from ( select * , no_result_type , keywords , 417 | sum ( search_num ) AS search_num from jack . drugs_bad_case_di , jack . abc jack where hp_stat_date >= date_sub ( 418 | current_date , 30 ) and action_dt >= date_sub ( current_date , 30 ) and action_type = 'search' and length ( 419 | keywords ) > 1 and ( split ( av , '\\.' ) [ 0 ] >= 11 OR ( split 420 | ( av , '\\.' ) [ 0 ] = 10 AND split ( av , '\\.' ) [ 1 ] 421 | = 9 ) ) group by no_result_type , keywords ) a ) b where rank <= 422 | 423 | 424 | =>select keywords , search_num , row_number ( ) over ( PARTITION BY no_result_type order by search_num desc ) as 425 | rank from ( select * , no_result_type , keywords , sum ( search_num ) AS search_num from jack . drugs_bad_case_di 426 | , jack . abc jack where hp_stat_date >= date_sub ( current_date , 30 ) and action_dt >= date_sub ( current_date 427 | , 30 ) and action_type = 'search' and length ( keywords ) > 1 and ( split ( av , 428 | '\\.' ) [ 0 ] >= 11 OR ( split ( av , '\\.' ) [ 0 ] = 10 429 | AND split ( av , '\\.' ) [ 1 ] = 9 ) ) group by no_result_type , keywords ) 430 | a ) b 431 | 432 | 433 | ==>select * , no_result_type , keywords , sum ( search_num ) AS search_num from jack . drugs_bad_case_di , jack 434 | . abc jack where hp_stat_date >= date_sub ( current_date , 30 ) and action_dt >= date_sub ( current_date , 30 435 | ) and action_type = 'search' and length ( keywords ) > 1 and ( split ( av , '\\.' ) 436 | [ 0 ] >= 11 OR ( split ( av , '\\.' ) [ 0 ] = 10 AND split 437 | ( av , '\\.' ) [ 1 ] = 9 ) ) group by no_result_type , keywords ) a 438 | ``` 439 | 440 | 我们可以看到一共嵌套了两层,每层都有一个子查询。 441 | 442 | 对此形成的TABLE_INFO结构如下: 443 | 444 | ``` 445 | 2: 446 | List( 447 | MetaTableKeyWrapper(MetaTableKey(None,Some(jack),drugs_bad_case_di),None), 448 | MetaTableKeyWrapper(MetaTableKey(None,None,null),Some(a)), 449 | MetaTableKeyWrapper(MetaTableKey(None,Some(jack),abc),Some(jack))) 450 | 1: 451 | List(MetaTableKeyWrapper(MetaTableKey(None,None,null),Some(b))) 452 | 0: 453 | List() 454 | ``` 455 | 456 | 0层级为最外层语句;1层级为第一个子查询;2层级为第二个子查询,他包含了子查询别名以及该子查询里所有的实体表信息。 457 | 458 | 上面只是为了显示,实际上还包含了所有列的信息。这意味着,如果我要补全0层记得 project,那我只需要获取1层级的信息,可以补全b表名称或者b表对应的字段。同理类推。 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | -------------------------------------------------------------------------------- /config/application.example.yml: -------------------------------------------------------------------------------- 1 | #mode 2 | mode: 3 | development 4 | #mode=production 5 | 6 | path: 7 | conf: /Users/allwefantasy/CSDNWorkSpace/streamingpro-spark-2.4.x/external/mlsql-autosuggest/config/ 8 | 9 | ###############datasource config################## 10 | #mysql,mongodb,redis等数据源配置方式 11 | development: 12 | datasources: 13 | mysql: 14 | host: 127.0.0.1 15 | port: 3306 16 | database: mlsql_console 17 | username: xxx 18 | password: xxxx 19 | initialSize: 8 20 | disable: true 21 | removeAbandoned: true 22 | testWhileIdle: true 23 | removeAbandonedTimeout: 30 24 | maxWait: 100 25 | filters: stat,log4j 26 | mongodb: 27 | disable: true 28 | redis: 29 | disable: true 30 | test: 31 | datasources: 32 | mysql: 33 | host: 127.0.0.1 34 | port: 3306 35 | database: wow 36 | username: root 37 | password: mlsql 38 | disable: true 39 | 40 | production: 41 | datasources: 42 | mysql: 43 | host: 127.0.0.1 44 | port: 3306 45 | database: wow 46 | username: root 47 | password: mlsql 48 | disable: false 49 | 50 | ###############application config################## 51 | #'model' for relational database like MySQL 52 | #'document' for NoSQL database model configuration, MongoDB 53 | auth_secret: "mlsql" 54 | application: 55 | controller: tech.mlsql.autosuggest.app 56 | model: tech.mlsql.model 57 | test: test.com.example 58 | static: 59 | enable: false 60 | template: 61 | engine: 62 | enable: false 63 | 64 | serviceframework: 65 | template: 66 | loader: 67 | classpath: 68 | enable: true 69 | static: 70 | loader: 71 | classpath: 72 | enable: true 73 | dir: "streamingpro/assets" 74 | ###############http config################## 75 | http: 76 | port: 9004 77 | disable: false 78 | host: 0.0.0.0 79 | server: 80 | idleTimeout: 6000000 81 | client: 82 | accept: 83 | timeout: 43200000 84 | 85 | #thrift: 86 | # disable: true 87 | # services: 88 | # net_csdn_controller_thrift_impl_CBayesianQueryServiceImpl: 89 | # port: 9001 90 | # min_threads: 100 91 | # max_threads: 1000 92 | # 93 | # servers: 94 | # spam_bayes: ["127.0.0.1:9001"] 95 | 96 | 97 | 98 | ###############validator config################## 99 | #如果需要添加验证器,只要配置好类全名即可 100 | #替换验证器实现,则替换相应的类名即可 101 | #warning: 自定义验证器实现需要线程安全 102 | 103 | validator: 104 | format: net.csdn.validate.impl.Format 105 | numericality: net.csdn.validate.impl.Numericality 106 | presence: net.csdn.validate.impl.Presence 107 | uniqueness: net.csdn.validate.impl.Uniqueness 108 | length: net.csdn.validate.impl.Length 109 | associated: net.csdn.validate.impl.Associated 110 | 111 | mongo_validator: 112 | format: net.csdn.mongo.validate.impl.Format 113 | numericality: net.csdn.mongo.validate.impl.Numericality 114 | presence: net.csdn.mongo.validate.impl.Presence 115 | uniqueness: net.csdn.mongo.validate.impl.Uniqueness 116 | length: net.csdn.mongo.validate.impl.Length 117 | associated: net.csdn.mongo.validate.impl.Associated 118 | 119 | ################ 数据库类型映射 #################### 120 | type_mapping: net.csdn.jpa.type.impl.MysqlType 121 | 122 | qps: 123 | /say/hello: 10 124 | 125 | qpslimit: 126 | enable: true 127 | dubbo: 128 | disable: true 129 | server: true -------------------------------------------------------------------------------- /config/application.yml: -------------------------------------------------------------------------------- 1 | #mode 2 | mode: 3 | development 4 | #mode=production 5 | 6 | path: 7 | conf: /Users/allwefantasy/CSDNWorkSpace/streamingpro-spark-2.4.x/external/mlsql-autosuggest/config/ 8 | 9 | ###############datasource config################## 10 | #mysql,mongodb,redis等数据源配置方式 11 | development: 12 | datasources: 13 | mysql: 14 | host: 127.0.0.1 15 | port: 3306 16 | database: mlsql_console 17 | username: xxx 18 | password: xxxx 19 | initialSize: 8 20 | disable: true 21 | removeAbandoned: true 22 | testWhileIdle: true 23 | removeAbandonedTimeout: 30 24 | maxWait: 100 25 | filters: stat,log4j 26 | mongodb: 27 | disable: true 28 | redis: 29 | disable: true 30 | test: 31 | datasources: 32 | mysql: 33 | host: 127.0.0.1 34 | port: 3306 35 | database: wow 36 | username: root 37 | password: mlsql 38 | disable: true 39 | 40 | production: 41 | datasources: 42 | mysql: 43 | host: 127.0.0.1 44 | port: 3306 45 | database: wow 46 | username: root 47 | password: mlsql 48 | disable: false 49 | 50 | ###############application config################## 51 | #'model' for relational database like MySQL 52 | #'document' for NoSQL database model configuration, MongoDB 53 | auth_secret: "mlsql" 54 | application: 55 | controller: tech.mlsql.autosuggest.app 56 | model: tech.mlsql.model 57 | test: test.com.example 58 | static: 59 | enable: false 60 | template: 61 | engine: 62 | enable: false 63 | 64 | serviceframework: 65 | template: 66 | loader: 67 | classpath: 68 | enable: true 69 | static: 70 | loader: 71 | classpath: 72 | enable: true 73 | dir: "streamingpro/assets" 74 | ###############http config################## 75 | http: 76 | port: 9004 77 | disable: false 78 | host: 0.0.0.0 79 | server: 80 | idleTimeout: 6000000 81 | client: 82 | accept: 83 | timeout: 43200000 84 | 85 | #thrift: 86 | # disable: true 87 | # services: 88 | # net_csdn_controller_thrift_impl_CBayesianQueryServiceImpl: 89 | # port: 9001 90 | # min_threads: 100 91 | # max_threads: 1000 92 | # 93 | # servers: 94 | # spam_bayes: ["127.0.0.1:9001"] 95 | 96 | 97 | 98 | ###############validator config################## 99 | #如果需要添加验证器,只要配置好类全名即可 100 | #替换验证器实现,则替换相应的类名即可 101 | #warning: 自定义验证器实现需要线程安全 102 | 103 | validator: 104 | format: net.csdn.validate.impl.Format 105 | numericality: net.csdn.validate.impl.Numericality 106 | presence: net.csdn.validate.impl.Presence 107 | uniqueness: net.csdn.validate.impl.Uniqueness 108 | length: net.csdn.validate.impl.Length 109 | associated: net.csdn.validate.impl.Associated 110 | 111 | mongo_validator: 112 | format: net.csdn.mongo.validate.impl.Format 113 | numericality: net.csdn.mongo.validate.impl.Numericality 114 | presence: net.csdn.mongo.validate.impl.Presence 115 | uniqueness: net.csdn.mongo.validate.impl.Uniqueness 116 | length: net.csdn.mongo.validate.impl.Length 117 | associated: net.csdn.mongo.validate.impl.Associated 118 | 119 | ################ 数据库类型映射 #################### 120 | type_mapping: net.csdn.jpa.type.impl.MysqlType 121 | 122 | qps: 123 | /say/hello: 10 124 | 125 | qpslimit: 126 | enable: true 127 | dubbo: 128 | disable: true 129 | server: true -------------------------------------------------------------------------------- /config/logging.yml: -------------------------------------------------------------------------------- 1 | rootLogger: INFO,console 2 | 3 | 4 | appender: 5 | console: 6 | type: console 7 | threshold: INFO 8 | layout: 9 | type: consolePattern 10 | conversionPattern: "[%d{ISO8601}][%-5p][%-25c] %m%n" 11 | 12 | file: 13 | type: dailyRollingFile 14 | file: ${path.logs}/${cluster.name}.log 15 | datePattern: "'.'yyyy-MM-dd" 16 | layout: 17 | type: pattern 18 | conversionPattern: "[%d{ISO8601}][%-5p][%-25c] %m%n" -------------------------------------------------------------------------------- /dev/build-package.sh: -------------------------------------------------------------------------------- 1 | mvn package -DskipTests -Pshade 2 | rm -rf build 3 | mkdir -p build/sql-code-intelligence-0.1.0 4 | cp target/sql-code-intelligence-0.1.0.jar build/sql-code-intelligence-0.1.0 5 | cp -r config build/sql-code-intelligence-0.1.0 6 | cp dev/start.sh build/sql-code-intelligence-0.1.0 7 | cd build 8 | tar cvf sql-code-intelligence-0.1.0.tar sql-code-intelligence-0.1.0 9 | scp sql-code-intelligence-0.1.0.tar mlsql2:/data/mlsql/releases/sql-code-intelligence/ -------------------------------------------------------------------------------- /dev/start.sh: -------------------------------------------------------------------------------- 1 | java -cp .:sql-code-intelligence-0.1.0.jar tech.mlsql.autosuggest.app.Standalone -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | tech.mlsql 8 | sql-code-intelligence 9 | 0.1.0 10 | SQL Code Intelligence 11 | https://github.com/allwefantasy/sql-code-intelligence 12 | 13 | SQL code autocomplete engine 14 | 15 | 16 | 17 | Apache 2.0 License 18 | http://www.apache.org/licenses/LICENSE-2.0.html 19 | repo 20 | 21 | 22 | 23 | 24 | allwefantasy 25 | ZhuHaiLin 26 | allwefantasy@gmail.com 27 | 28 | 29 | 30 | 31 | scm:git:git@github.com:allwefantasy/sql-code-intelligence.git 32 | 33 | 34 | scm:git:git@github.com:allwefantasy/sql-code-intelligence.git 35 | 36 | https://github.com/allwefantasy/sql-code-intelligence 37 | 38 | 39 | https://github.com/allwefantasy/sql-code-intelligence/issues 40 | 41 | 42 | 43 | UTF-8 44 | 2.11.8 45 | 2.11 46 | 2.11.0-M3 47 | 48 | 2.4.3 49 | 2.4 50 | 51 | 16.0 52 | 4.5.3 53 | 0.3.5 54 | 1.6.0 55 | 56 | provided 57 | 1.6.0 58 | 59 | 60 | 61 | tech.mlsql 62 | common-utils_${scala.binary.version} 63 | ${common-utils-version} 64 | 65 | 66 | org.python 67 | jython-standalone 68 | 69 | 70 | org.fusesource.leveldbjni 71 | leveldbjni-all 72 | 73 | 74 | 75 | 76 | tech.mlsql 77 | streamingpro-dsl-${spark.bigversion}_${scala.binary.version} 78 | ${mlsql.version} 79 | 80 | 81 | 82 | tech.mlsql 83 | streamingpro-spark-${spark.bigversion}.0-adaptor_${scala.binary.version} 84 | ${mlsql.version} 85 | 86 | 87 | tech.mlsql 88 | spark-adhoc-kafka_2.11 89 | 90 | 91 | tech.mlsql 92 | mysql-binlog_2.11 93 | 94 | 95 | 96 | 97 | 98 | tech.mlsql 99 | streamingpro-common-${spark.bigversion}_${scala.binary.version} 100 | ${mlsql.version} 101 | 102 | 103 | 104 | org.pegdown 105 | pegdown 106 | ${pegdown.version} 107 | test 108 | 109 | 110 | 111 | tech.mlsql 112 | streamingpro-core-${spark.bigversion}_${scala.binary.version} 113 | 1.6.0-SNAPSHOT 114 | 115 | 116 | cn.edu.hfut.dmic.webcollector 117 | WebCollector 118 | 119 | 120 | tech.mlsql 121 | pyjava-2.4_2.11 122 | 123 | 124 | tech.mlsql 125 | mlsql-scheduler 126 | 127 | 128 | org.apache.spark 129 | spark-streaming-kafka-0-8_2.11 130 | 131 | 132 | 133 | 134 | 135 | com.alibaba 136 | druid 137 | 1.1.16 138 | 139 | 140 | org.scalactic 141 | scalactic_${scala.binary.version} 142 | 3.0.0 143 | test 144 | 145 | 146 | 147 | org.scalatest 148 | scalatest_${scala.binary.version} 149 | 3.0.0 150 | test 151 | 152 | 153 | 154 | org.apache.spark 155 | spark-sql_${scala.binary.version} 156 | ${spark.version} 157 | 158 | 159 | org.apache.parquet 160 | parquet-hadoop 161 | 162 | 163 | org.apache.parquet 164 | parquet-column 165 | 166 | 167 | org.apache.arrow 168 | arrow-vector 169 | 170 | 171 | com.github.luben 172 | zstd-jni 173 | 174 | 175 | org.apache.orc 176 | orc-core 177 | 178 | 179 | org.apache.hadoop 180 | hadoop-client 181 | 182 | 183 | 184 | 185 | 186 | org.apache.spark 187 | spark-catalyst_${scala.binary.version} 188 | ${spark.version} 189 | tests 190 | test 191 | 192 | 193 | 194 | org.apache.spark 195 | spark-core_${scala.binary.version} 196 | ${spark.version} 197 | tests 198 | test 199 | 200 | 201 | 202 | org.apache.spark 203 | spark-sql_${scala.binary.version} 204 | ${spark.version} 205 | tests 206 | test 207 | 208 | 209 | 210 | 211 | shade 212 | 213 | 214 | 215 | org.apache.maven.plugins 216 | maven-shade-plugin 217 | 2.4.3 218 | 219 | 220 | 221 | *:* 222 | 223 | META-INF/*.SF 224 | META-INF/*.DSA 225 | META-INF/*.RSA 226 | 227 | 228 | 229 | 230 | 231 | 232 | package 233 | 234 | shade 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | disable-java8-doclint 244 | 245 | [1.8,) 246 | 247 | 248 | -Xdoclint:none 249 | none 250 | 251 | 252 | 253 | release-sign-artifacts 254 | 255 | 256 | performRelease 257 | true 258 | 259 | 260 | 261 | 262 | 263 | org.apache.maven.plugins 264 | maven-gpg-plugin 265 | 1.1 266 | 267 | 268 | sign-artifacts 269 | verify 270 | 271 | sign 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | src/main/resources 285 | 286 | 287 | 288 | 289 | org.apache.maven.plugins 290 | maven-surefire-plugin 291 | 3.0.0-M1 292 | 293 | 1 294 | true 295 | -Xmx4024m 296 | 297 | **/*.java 298 | **/*.scala 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | org.scala-tools 307 | maven-scala-plugin 308 | 2.15.2 309 | 310 | 311 | 312 | -g:vars 313 | 314 | 315 | true 316 | 317 | 318 | 319 | compile 320 | 321 | compile 322 | 323 | compile 324 | 325 | 326 | testCompile 327 | 328 | testCompile 329 | 330 | test 331 | 332 | 333 | process-resources 334 | 335 | compile 336 | 337 | 338 | 339 | 340 | 341 | 342 | org.apache.maven.plugins 343 | maven-compiler-plugin 344 | 2.3.2 345 | 346 | 347 | -g 348 | true 349 | 1.8 350 | 1.8 351 | 352 | 353 | 354 | 355 | 356 | 357 | maven-source-plugin 358 | 2.1 359 | 360 | true 361 | 362 | 363 | 364 | compile 365 | 366 | jar 367 | 368 | 369 | 370 | 371 | 372 | org.apache.maven.plugins 373 | maven-javadoc-plugin 374 | 375 | 376 | attach-javadocs 377 | 378 | jar 379 | 380 | 381 | 382 | 383 | 384 | org.sonatype.plugins 385 | nexus-staging-maven-plugin 386 | 1.6.7 387 | true 388 | 389 | sonatype-nexus-staging 390 | https://oss.sonatype.org/ 391 | true 392 | 393 | 394 | 395 | 396 | org.scalatest 397 | scalatest-maven-plugin 398 | 2.0.0 399 | 400 | streaming.core.NotToRunTag 401 | ${project.build.directory}/surefire-reports 402 | . 403 | WDF TestSuite.txt 404 | ${project.build.directory}/html/scalatest 405 | false 406 | 407 | 408 | 409 | test 410 | 411 | test 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | sonatype-nexus-snapshots 421 | https://oss.sonatype.org/content/repositories/snapshots 422 | 423 | 424 | sonatype-nexus-staging 425 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 426 | 427 | 428 | 429 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/AutoSuggester.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import com.intigua.antlr4.autosuggest.LexerWrapper.TokenizationResult; 4 | import org.antlr.v4.runtime.Token; 5 | import org.antlr.v4.runtime.atn.ATNState; 6 | import org.antlr.v4.runtime.atn.AtomTransition; 7 | import org.antlr.v4.runtime.atn.SetTransition; 8 | import org.antlr.v4.runtime.atn.Transition; 9 | import org.antlr.v4.runtime.misc.Interval; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | 13 | import java.util.*; 14 | 15 | /** 16 | * Suggests completions for given text, using a given ANTLR4 grammar. 17 | */ 18 | public class AutoSuggester { 19 | private static final Logger logger = LoggerFactory.getLogger(AutoSuggester.class); 20 | 21 | private final ParserWrapper parserWrapper; 22 | private final LexerWrapper lexerWrapper; 23 | private final String input; 24 | private final Set collectedSuggestions = new HashSet<>(); 25 | 26 | private List inputTokens; 27 | private String untokenizedText = ""; 28 | private String indent = ""; 29 | private CasePreference casePreference = CasePreference.BOTH; 30 | 31 | private Map parserStateToTokenListIndexWhereLastVisited = new HashMap<>(); 32 | 33 | public AutoSuggester(LexerAndParserFactory lexerAndParserFactory, String input) { 34 | this.lexerWrapper = new LexerWrapper(lexerAndParserFactory, new DefaultToCharStream()); 35 | this.parserWrapper = new ParserWrapper(lexerAndParserFactory, lexerWrapper.getVocabulary()); 36 | this.input = input; 37 | } 38 | 39 | public void setCasePreference(CasePreference casePreference) { 40 | this.casePreference = casePreference; 41 | } 42 | 43 | public Collection suggestCompletions() { 44 | tokenizeInput(); 45 | runParserAtnAndCollectSuggestions(); 46 | return collectedSuggestions; 47 | } 48 | 49 | private void tokenizeInput() { 50 | TokenizationResult tokenizationResult = lexerWrapper.tokenizeNonDefaultChannel(this.input); 51 | this.inputTokens = tokenizationResult.tokens; 52 | this.untokenizedText = tokenizationResult.untokenizedText; 53 | if (logger.isDebugEnabled()) { 54 | logger.debug("TOKENS FOUND IN FIRST PASS:"); 55 | for (Token token : this.inputTokens) { 56 | logger.debug(token.toString()); 57 | } 58 | } 59 | } 60 | 61 | private void runParserAtnAndCollectSuggestions() { 62 | ATNState initialState = this.parserWrapper.getAtnState(0); 63 | logger.debug("Parser initial state: " + initialState); 64 | parseAndCollectTokenSuggestions(initialState, 0); 65 | } 66 | 67 | /** 68 | * Recursive through the parser ATN to process all tokens. When successful (out of tokens) - collect completion 69 | * suggestions. 70 | */ 71 | private void parseAndCollectTokenSuggestions(ATNState parserState, int tokenListIndex) { 72 | indent = indent + " "; 73 | if (didVisitParserStateOnThisTokenIndex(parserState, tokenListIndex)) { 74 | logger.debug(indent + "State " + parserState + " had already been visited while processing token " 75 | + tokenListIndex + ", backtracking to avoid infinite loop."); 76 | return; 77 | } 78 | Integer previousTokenListIndexForThisState = setParserStateLastVisitedOnThisTokenIndex(parserState, tokenListIndex); 79 | try { 80 | if (logger.isDebugEnabled()) { 81 | logger.debug(indent + "State: " + parserWrapper.toString(parserState)); 82 | logger.debug(indent + "State available transitions: " + parserWrapper.transitionsStr(parserState)); 83 | } 84 | 85 | if (!haveMoreTokens(tokenListIndex)) { // stop condition for recursion 86 | suggestNextTokensForParserState(parserState); 87 | return; 88 | } 89 | for (Transition trans : parserState.getTransitions()) { 90 | if (trans.isEpsilon()) { 91 | handleEpsilonTransition(trans, tokenListIndex); 92 | } else if (trans instanceof AtomTransition) { 93 | handleAtomicTransition((AtomTransition) trans, tokenListIndex); 94 | } else { 95 | handleSetTransition((SetTransition) trans, tokenListIndex); 96 | } 97 | } 98 | } finally { 99 | indent = indent.substring(2); 100 | setParserStateLastVisitedOnThisTokenIndex(parserState, previousTokenListIndexForThisState); 101 | } 102 | } 103 | 104 | private boolean didVisitParserStateOnThisTokenIndex(ATNState parserState, Integer currentTokenListIndex) { 105 | Integer lastVisitedThisStateAtTokenListIndex = parserStateToTokenListIndexWhereLastVisited.get(parserState); 106 | return currentTokenListIndex.equals(lastVisitedThisStateAtTokenListIndex); 107 | } 108 | 109 | private Integer setParserStateLastVisitedOnThisTokenIndex(ATNState parserState, Integer tokenListIndex) { 110 | if (tokenListIndex == null) { 111 | return parserStateToTokenListIndexWhereLastVisited.remove(parserState); 112 | } else { 113 | return parserStateToTokenListIndexWhereLastVisited.put(parserState, tokenListIndex); 114 | } 115 | } 116 | 117 | private boolean haveMoreTokens(int tokenListIndex) { 118 | return tokenListIndex < inputTokens.size(); 119 | } 120 | 121 | private void handleEpsilonTransition(Transition trans, int tokenListIndex) { 122 | // Epsilon transitions don't consume a token, so don't move the index 123 | parseAndCollectTokenSuggestions(trans.target, tokenListIndex); 124 | } 125 | 126 | private void handleAtomicTransition(AtomTransition trans, int tokenListIndex) { 127 | Token nextToken = inputTokens.get(tokenListIndex); 128 | int nextTokenType = inputTokens.get(tokenListIndex).getType(); 129 | boolean nextTokenMatchesTransition = (trans.label == nextTokenType); 130 | if (nextTokenMatchesTransition) { 131 | logger.debug(indent + "Token " + nextToken + " following transition: " + parserWrapper.toString(trans)); 132 | parseAndCollectTokenSuggestions(trans.target, tokenListIndex + 1); 133 | } else { 134 | logger.debug(indent + "Token " + nextToken + " NOT following transition: " + parserWrapper.toString(trans)); 135 | } 136 | } 137 | 138 | private void handleSetTransition(SetTransition trans, int tokenListIndex) { 139 | Token nextToken = inputTokens.get(tokenListIndex); 140 | int nextTokenType = nextToken.getType(); 141 | for (int transitionTokenType : trans.label().toList()) { 142 | boolean nextTokenMatchesTransition = (transitionTokenType == nextTokenType); 143 | if (nextTokenMatchesTransition) { 144 | logger.debug(indent + "Token " + nextToken + " following transition: " + parserWrapper.toString(trans) + " to " + transitionTokenType); 145 | parseAndCollectTokenSuggestions(trans.target, tokenListIndex + 1); 146 | } else { 147 | logger.debug(indent + "Token " + nextToken + " NOT following transition: " + parserWrapper.toString(trans) + " to " + transitionTokenType); 148 | } 149 | } 150 | } 151 | 152 | private void suggestNextTokensForParserState(ATNState parserState) { 153 | Set transitionLabels = new HashSet<>(); 154 | fillParserTransitionLabels(parserState, transitionLabels, new HashSet<>()); 155 | TokenSuggester tokenSuggester = new TokenSuggester(this.untokenizedText, lexerWrapper, this.casePreference); 156 | Collection suggestions = tokenSuggester.suggest(transitionLabels); 157 | parseSuggestionsAndAddValidOnes(parserState, suggestions); 158 | logger.debug(indent + "WILL SUGGEST TOKENS FOR STATE: " + parserState); 159 | } 160 | 161 | private void fillParserTransitionLabels(ATNState parserState, Collection result, Set visitedTransitions) { 162 | for (Transition trans : parserState.getTransitions()) { 163 | TransitionWrapper transWrapper = new TransitionWrapper(parserState, trans); 164 | if (visitedTransitions.contains(transWrapper)) { 165 | logger.debug(indent + "Not following visited " + transWrapper); 166 | continue; 167 | } 168 | if (trans.isEpsilon()) { 169 | try { 170 | visitedTransitions.add(transWrapper); 171 | fillParserTransitionLabels(trans.target, result, visitedTransitions); 172 | } finally { 173 | visitedTransitions.remove(transWrapper); 174 | } 175 | } else if (trans instanceof AtomTransition) { 176 | int label = ((AtomTransition) trans).label; 177 | if (label >= 1) { // EOF would be -1 178 | result.add(label); 179 | } 180 | } else if (trans instanceof SetTransition) { 181 | for (Interval interval : ((SetTransition) trans).label().getIntervals()) { 182 | for (int i = interval.a; i <= interval.b; ++i) { 183 | result.add(i); 184 | } 185 | } 186 | } 187 | } 188 | } 189 | 190 | private void parseSuggestionsAndAddValidOnes(ATNState parserState, Collection suggestions) { 191 | for (String suggestion : suggestions) { 192 | logger.debug("CHECKING suggestion: " + suggestion); 193 | Token addedToken = getAddedToken(suggestion); 194 | if (isParseableWithAddedToken(parserState, addedToken, new HashSet())) { 195 | collectedSuggestions.add(suggestion); 196 | } else { 197 | logger.debug("DROPPING non-parseable suggestion: " + suggestion); 198 | } 199 | } 200 | } 201 | 202 | private Token getAddedToken(String suggestedCompletion) { 203 | String completedText = this.input + suggestedCompletion; 204 | List completedTextTokens = this.lexerWrapper.tokenizeNonDefaultChannel(completedText).tokens; 205 | if (completedTextTokens.size() <= inputTokens.size()) { 206 | return null; // Completion didn't yield whole token, could be just a token fragment 207 | } 208 | logger.debug("TOKENS IN COMPLETED TEXT: " + completedTextTokens); 209 | Token newToken = completedTextTokens.get(completedTextTokens.size() - 1); 210 | return newToken; 211 | } 212 | 213 | private boolean isParseableWithAddedToken(ATNState parserState, Token newToken, Set visitedTransitions) { 214 | if (newToken == null) { 215 | return false; 216 | } 217 | for (Transition parserTransition : parserState.getTransitions()) { 218 | if (parserTransition.isEpsilon()) { // Recurse through any epsilon transitionsStr 219 | TransitionWrapper transWrapper = new TransitionWrapper(parserState, parserTransition); 220 | if (visitedTransitions.contains(transWrapper)) { 221 | continue; 222 | } 223 | visitedTransitions.add(transWrapper); 224 | try { 225 | if (isParseableWithAddedToken(parserTransition.target, newToken, visitedTransitions)) { 226 | return true; 227 | } 228 | } finally { 229 | visitedTransitions.remove(transWrapper); 230 | } 231 | } else if (parserTransition instanceof AtomTransition) { 232 | AtomTransition parserAtomTransition = (AtomTransition) parserTransition; 233 | int transitionTokenType = parserAtomTransition.label; 234 | if (transitionTokenType == newToken.getType()) { 235 | return true; 236 | } 237 | } else if (parserTransition instanceof SetTransition) { 238 | SetTransition parserSetTransition = (SetTransition) parserTransition; 239 | for (int transitionTokenType : parserSetTransition.label().toList()) { 240 | if (transitionTokenType == newToken.getType()) { 241 | return true; 242 | } 243 | } 244 | } else { 245 | throw new IllegalStateException("Unexpected: " + parserWrapper.toString(parserTransition)); 246 | } 247 | } 248 | return false; 249 | } 250 | 251 | 252 | } 253 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/CasePreference.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | 4 | public enum CasePreference { 5 | /** 6 | * Suggest both uppercase and lowercase completions 7 | */ 8 | BOTH, 9 | 10 | /** 11 | * In case both uppercase and lowercase are supported by the grammar, suggest only lowercase alternative 12 | */ 13 | LOWER, 14 | 15 | /** 16 | * In case both uppercase and lowercase are supported by the grammar, suggest only uppercase alternative 17 | */ 18 | UPPER 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/DefaultToCharStream.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import java.io.StringReader 4 | 5 | import org.antlr.v4.runtime.{CharStream, CharStreams} 6 | import tech.mlsql.autosuggest 7 | 8 | /** 9 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class DefaultToCharStream extends ToCharStream { 12 | override def toCharStream(text: String): CharStream = { 13 | CharStreams.fromReader(new StringReader(text)) 14 | 15 | } 16 | } 17 | 18 | class RawSQLToCharStream extends ToCharStream { 19 | override def toCharStream(text: String): CharStream = { 20 | new autosuggest.UpperCaseCharStream(CharStreams.fromString(text)) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/LexerAndParserFactory.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | public interface LexerAndParserFactory extends LexerFactory, ParserFactory { 4 | } 5 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/LexerFactory.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.CharStream; 4 | import org.antlr.v4.runtime.Lexer; 5 | 6 | public interface LexerFactory { 7 | 8 | Lexer createLexer(CharStream input); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/LexerWrapper.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.*; 4 | import org.antlr.v4.runtime.atn.ATNState; 5 | import org.antlr.v4.runtime.misc.ParseCancellationException; 6 | 7 | import java.util.List; 8 | import java.util.stream.Collectors; 9 | 10 | public class LexerWrapper { 11 | private final LexerFactory lexerFactory; 12 | private Lexer cachedLexer; 13 | private ToCharStream toCharStream; 14 | 15 | public static class TokenizationResult { 16 | public List tokens; 17 | public String untokenizedText = ""; 18 | } 19 | 20 | public LexerWrapper(LexerFactory lexerFactory, ToCharStream toCharStream) { 21 | super(); 22 | this.lexerFactory = lexerFactory; 23 | this.toCharStream = toCharStream; 24 | } 25 | 26 | public TokenizationResult tokenizeNonDefaultChannel(String input) { 27 | TokenizationResult result = this.tokenize(input); 28 | result.tokens = result.tokens.stream().filter(t -> t.getChannel() == 0).collect(Collectors.toList()); 29 | return result; 30 | } 31 | 32 | public String[] getRuleNames() { 33 | return getCachedLexer().getRuleNames(); 34 | } 35 | 36 | public ATNState findStateByRuleNumber(int ruleNumber) { 37 | return getCachedLexer().getATN().ruleToStartState[ruleNumber]; 38 | } 39 | 40 | public Vocabulary getVocabulary() { 41 | return getCachedLexer().getVocabulary(); 42 | } 43 | 44 | private Lexer getCachedLexer() { 45 | if (cachedLexer == null) { 46 | cachedLexer = createLexer(""); 47 | } 48 | return cachedLexer; 49 | } 50 | 51 | private TokenizationResult tokenize(String input) { 52 | Lexer lexer = this.createLexer(input); 53 | lexer.removeErrorListeners(); 54 | final TokenizationResult result = new TokenizationResult(); 55 | ANTLRErrorListener newErrorListener = new BaseErrorListener() { 56 | @Override 57 | public void syntaxError(Recognizer recognizer, Object offendingSymbol, int line, 58 | int charPositionInLine, String msg, RecognitionException e) throws ParseCancellationException { 59 | result.untokenizedText = input.substring(charPositionInLine); // intended side effect 60 | } 61 | }; 62 | lexer.addErrorListener(newErrorListener); 63 | result.tokens = lexer.getAllTokens(); 64 | return result; 65 | } 66 | 67 | private Lexer createLexer(CharStream input) { 68 | return this.lexerFactory.createLexer(input); 69 | } 70 | 71 | private Lexer createLexer(String lexerInput) { 72 | return this.createLexer(toCharStream.toCharStream(lexerInput)); 73 | } 74 | 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/ParserFactory.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.Parser; 4 | import org.antlr.v4.runtime.TokenStream; 5 | 6 | public interface ParserFactory { 7 | 8 | Parser createParser(TokenStream tokenStream); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/ParserWrapper.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.Parser; 4 | import org.antlr.v4.runtime.Vocabulary; 5 | import org.antlr.v4.runtime.atn.ATN; 6 | import org.antlr.v4.runtime.atn.ATNState; 7 | import org.antlr.v4.runtime.atn.AtomTransition; 8 | import org.antlr.v4.runtime.atn.Transition; 9 | import org.apache.commons.lang3.StringUtils; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | 13 | import java.util.Arrays; 14 | import java.util.List; 15 | import java.util.stream.Collectors; 16 | import java.util.stream.Stream; 17 | 18 | class ParserWrapper { 19 | private static final Logger logger = LoggerFactory.getLogger(ParserWrapper.class); 20 | private final Vocabulary lexerVocabulary; 21 | 22 | private final ATN parserAtn; 23 | private final String[] parserRuleNames; 24 | 25 | public ParserWrapper(ParserFactory parserFactory, Vocabulary lexerVocabulary) { 26 | this.lexerVocabulary = lexerVocabulary; 27 | 28 | Parser parserForAtnOnly = parserFactory.createParser(null); 29 | this.parserAtn = parserForAtnOnly.getATN(); 30 | this.parserRuleNames = parserForAtnOnly.getRuleNames(); 31 | logger.debug("Parser rule names: " + StringUtils.join(parserForAtnOnly.getRuleNames(), ", ")); 32 | } 33 | 34 | public String toString(ATNState parserState) { 35 | String ruleName = this.parserRuleNames[parserState.ruleIndex]; 36 | return "*" + ruleName + "* " + parserState.getClass().getSimpleName() + " " + parserState; 37 | } 38 | 39 | public String toString(Transition t) { 40 | String nameOrLabel = t.getClass().getSimpleName(); 41 | if (t instanceof AtomTransition) { 42 | nameOrLabel += ' ' + this.lexerVocabulary.getDisplayName(((AtomTransition) t).label); 43 | } 44 | return nameOrLabel + " -> " + toString(t.target); 45 | } 46 | 47 | public String transitionsStr(ATNState state) { 48 | Stream transitionsStream = Arrays.asList(state.getTransitions()).stream(); 49 | List transitionStrings = transitionsStream.map(this::toString).collect(Collectors.toList()); 50 | return StringUtils.join(transitionStrings, ", "); 51 | } 52 | 53 | public ATNState getAtnState(int stateNumber) { 54 | return parserAtn.states.get(stateNumber); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/ReflectionLexerAndParserFactory.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.CharStream; 4 | import org.antlr.v4.runtime.Lexer; 5 | import org.antlr.v4.runtime.Parser; 6 | import org.antlr.v4.runtime.TokenStream; 7 | 8 | import java.lang.reflect.Constructor; 9 | import java.lang.reflect.InvocationTargetException; 10 | 11 | public class ReflectionLexerAndParserFactory implements LexerAndParserFactory { 12 | 13 | private final Constructor lexerCtr; 14 | private final Constructor parserCtr; 15 | 16 | public ReflectionLexerAndParserFactory(Class lexerClass, Class parserClass) { 17 | lexerCtr = getConstructor(lexerClass, CharStream.class); 18 | parserCtr = getConstructor(parserClass, TokenStream.class); 19 | } 20 | 21 | @Override 22 | public Lexer createLexer(CharStream input) { 23 | return create(lexerCtr, input); 24 | } 25 | 26 | @Override 27 | public Parser createParser(TokenStream tokenStream) { 28 | return create(parserCtr, tokenStream); 29 | } 30 | 31 | private static Constructor getConstructor(Class givenClass, Class argClass) { 32 | try { 33 | return givenClass.getConstructor(argClass); 34 | } catch (NoSuchMethodException | SecurityException e) { 35 | throw new IllegalArgumentException( 36 | givenClass.getSimpleName() + " must have constructor from " + argClass.getSimpleName() + "."); 37 | } 38 | } 39 | 40 | private T create(Constructor contructor, Object arg) { 41 | try { 42 | return contructor.newInstance(arg); 43 | } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { 44 | throw new IllegalArgumentException(e); 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/ToCharStream.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.CharStream; 4 | 5 | /** 6 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | public interface ToCharStream { 9 | public CharStream toCharStream(String text); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/TokenSuggester.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.atn.ATNState; 4 | import org.antlr.v4.runtime.atn.AtomTransition; 5 | import org.antlr.v4.runtime.atn.SetTransition; 6 | import org.antlr.v4.runtime.atn.Transition; 7 | import org.slf4j.Logger; 8 | import org.slf4j.LoggerFactory; 9 | 10 | import java.util.*; 11 | import java.util.stream.Collectors; 12 | 13 | /** 14 | * Given an ATN state and the lexer ATN, suggests auto-completion texts. 15 | */ 16 | class TokenSuggester { 17 | private static final Logger logger = LoggerFactory.getLogger(TokenSuggester.class); 18 | 19 | private final LexerWrapper lexerWrapper; 20 | private final CasePreference casePreference; 21 | 22 | private final Set suggestions = new TreeSet(); 23 | private final List visitedLexerStates = new ArrayList<>(); 24 | private String origPartialToken; 25 | 26 | public TokenSuggester(LexerWrapper lexerWrapper, String input) { 27 | this(input, lexerWrapper, CasePreference.BOTH); 28 | } 29 | 30 | public TokenSuggester(String origPartialToken, LexerWrapper lexerWrapper, CasePreference casePreference) { 31 | this.origPartialToken = origPartialToken; 32 | this.lexerWrapper = lexerWrapper; 33 | this.casePreference = casePreference; 34 | } 35 | 36 | public Collection suggest(Collection nextParserTransitionLabels) { 37 | logTokensUsedForSuggestion(nextParserTransitionLabels); 38 | for (int nextParserTransitionLabel : nextParserTransitionLabels) { 39 | int nextTokenRuleNumber = nextParserTransitionLabel - 1; // Count from 0 not from 1 40 | ATNState lexerState = this.lexerWrapper.findStateByRuleNumber(nextTokenRuleNumber); 41 | suggest("", lexerState, origPartialToken); 42 | } 43 | return suggestions; 44 | // return suggestions.stream().filter(s -> this.lexerWrapper.isValidSuggestion(input, s)).collect(Collectors.toList()); 45 | } 46 | 47 | private void logTokensUsedForSuggestion(Collection ruleIndices) { 48 | if (!logger.isDebugEnabled()) { 49 | return; 50 | } 51 | String ruleNames = ruleIndices.stream().map(r -> lexerWrapper.getRuleNames()[r-1]).collect(Collectors.joining(" ")); 52 | logger.debug("Suggesting tokens for lexer rules: " + ruleNames, " "); 53 | } 54 | 55 | 56 | private void suggest(String tokenSoFar, ATNState lexerState, String remainingText) { 57 | logger.debug( 58 | "SUGGEST: tokenSoFar=" + tokenSoFar + " remainingText=" + remainingText + " lexerState=" + toString(lexerState)); 59 | if (visitedLexerStates.contains(lexerState.stateNumber)) { 60 | return; // avoid infinite loop and stack overflow 61 | } 62 | visitedLexerStates.add(lexerState.stateNumber); 63 | try { 64 | Transition[] transitions = lexerState.getTransitions(); 65 | boolean tokenNotEmpty = tokenSoFar.length() > 0; 66 | boolean noMoreCharactersInToken = (transitions.length == 0); 67 | if (tokenNotEmpty && noMoreCharactersInToken) { 68 | addSuggestedToken(tokenSoFar); 69 | return; 70 | } 71 | for (Transition trans : transitions) { 72 | suggestViaLexerTransition(tokenSoFar, remainingText, trans); 73 | } 74 | } finally { 75 | visitedLexerStates.remove(visitedLexerStates.size() - 1); 76 | } 77 | } 78 | 79 | private String toString(ATNState lexerState) { 80 | String ruleName = this.lexerWrapper.getRuleNames()[lexerState.ruleIndex]; 81 | return ruleName + " " + lexerState.getClass().getSimpleName() + " " + lexerState; 82 | } 83 | 84 | private void suggestViaLexerTransition(String tokenSoFar, String remainingText, Transition trans) { 85 | if (trans.isEpsilon()) { 86 | suggest(tokenSoFar, trans.target, remainingText); 87 | } else if (trans instanceof AtomTransition) { 88 | String newTokenChar = getAddedTextFor((AtomTransition) trans); 89 | if (remainingText.isEmpty() || remainingText.startsWith(newTokenChar)) { 90 | logger.debug("LEXER TOKEN: " + newTokenChar + " remaining=" + remainingText); 91 | suggestViaNonEpsilonLexerTransition(tokenSoFar, remainingText, newTokenChar, trans.target); 92 | } else { 93 | logger.debug("NONMATCHING LEXER TOKEN: " + newTokenChar + " remaining=" + remainingText); 94 | } 95 | } else if (trans instanceof SetTransition) { 96 | List symbols = ((SetTransition) trans).label().toList(); 97 | for (Integer symbol : symbols) { 98 | char[] charArr = Character.toChars(symbol); 99 | String charStr = new String(charArr); 100 | boolean shouldIgnoreCase = shouldIgnoreThisCase(charArr[0], symbols); // TODO: check for non-BMP 101 | if (!shouldIgnoreCase && (remainingText.isEmpty() || remainingText.startsWith(charStr))) { 102 | suggestViaNonEpsilonLexerTransition(tokenSoFar, remainingText, charStr, trans.target); 103 | } 104 | } 105 | } 106 | } 107 | 108 | private void suggestViaNonEpsilonLexerTransition(String tokenSoFar, String remainingText, 109 | String newTokenChar, ATNState targetState) { 110 | String newRemainingText = (remainingText.length() > 0) ? remainingText.substring(1) : remainingText; 111 | suggest(tokenSoFar + newTokenChar, targetState, newRemainingText); 112 | } 113 | 114 | private void addSuggestedToken(String tokenToAdd) { 115 | String justTheCompletionPart = chopOffCommonStart(tokenToAdd, this.origPartialToken); 116 | suggestions.add(justTheCompletionPart); 117 | } 118 | 119 | private String chopOffCommonStart(String a, String b) { 120 | int charsToChopOff = Math.min(b.length(), a.length()); 121 | return a.substring(charsToChopOff); 122 | } 123 | 124 | private String getAddedTextFor(AtomTransition transition) { 125 | return new String(Character.toChars(transition.label)); 126 | } 127 | 128 | private boolean shouldIgnoreThisCase(char transChar, List allTransChars) { 129 | if (this.casePreference == null) { 130 | return false; 131 | } 132 | switch(this.casePreference) { 133 | case BOTH: 134 | return false; 135 | case LOWER: 136 | return Character.isUpperCase(transChar) && allTransChars.contains((int) Character.toLowerCase(transChar)); 137 | case UPPER: 138 | return Character.isLowerCase(transChar) && allTransChars.contains((int) Character.toUpperCase(transChar)); 139 | default: 140 | return false; 141 | } 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/main/java/com/intigua/antlr4/autosuggest/TransitionWrapper.java: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest; 2 | 3 | import org.antlr.v4.runtime.atn.ATNState; 4 | import org.antlr.v4.runtime.atn.Transition; 5 | 6 | public class TransitionWrapper { 7 | private final ATNState source; 8 | private final Transition transition; 9 | 10 | public TransitionWrapper(ATNState source, Transition transition) { 11 | super(); 12 | this.source = source; 13 | this.transition = transition; 14 | } 15 | 16 | @Override 17 | public int hashCode() { 18 | final int prime = 31; 19 | int result = 1; 20 | result = prime * result + ((source == null) ? 0 : source.hashCode()); 21 | result = prime * result + ((transition == null) ? 0 : transition.hashCode()); 22 | return result; 23 | } 24 | 25 | @Override 26 | public boolean equals(Object obj) { 27 | if (this == obj) 28 | return true; 29 | if (obj == null) 30 | return false; 31 | if (getClass() != obj.getClass()) 32 | return false; 33 | TransitionWrapper other = (TransitionWrapper) obj; 34 | if (source == null) { 35 | if (other.source != null) 36 | return false; 37 | } else if (!source.equals(other.source)) 38 | return false; 39 | if (transition == null) { 40 | if (other.transition != null) 41 | return false; 42 | } else if (!transition.equals(other.transition)) 43 | return false; 44 | return true; 45 | } 46 | 47 | @Override 48 | public String toString() { 49 | return transition.getClass().getSimpleName() + " from " + source + " to " + transition.target; 50 | } 51 | 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/AttributeExtractor.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 5 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher} 6 | import tech.mlsql.autosuggest.meta.MetaTableKey 7 | import tech.mlsql.autosuggest.statement.{MatchAndExtractor, MetaTableKeyWrapper, SingleStatementAST} 8 | 9 | import scala.collection.mutable.ArrayBuffer 10 | 11 | /** 12 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class AttributeExtractor(autoSuggestContext: AutoSuggestContext, ast: SingleStatementAST, tokens: List[Token]) extends MatchAndExtractor[String] { 15 | 16 | override def matcher(start: Int): TokenMatcher = { 17 | return asterriskM(start) 18 | } 19 | 20 | private def attributeM(start: Int): TokenMatcher = { 21 | val temp = TokenMatcher(tokens, start). 22 | eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, SqlBaseLexer.T__3)).optional. 23 | eat(Food(None, SqlBaseLexer.IDENTIFIER)). 24 | eat(Food(None, SqlBaseLexer.AS)).optional. 25 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).optional. 26 | build 27 | temp.isSuccess match { 28 | case true => temp 29 | case false => 30 | TokenMatcher(tokens, start). 31 | eatOneAny. 32 | eat(Food(None, SqlBaseLexer.AS)). 33 | eat(Food(None, SqlBaseLexer.IDENTIFIER)). 34 | build 35 | } 36 | } 37 | 38 | private def funcitonM(start: Int): TokenMatcher = { 39 | // deal with something like: sum(a[1],fun(b)) as a 40 | val temp = TokenMatcher(tokens, start).eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, SqlBaseLexer.T__0)).build 41 | // function match 42 | if (temp.isSuccess) { 43 | // try to find AS 44 | // we need to take care of situation like this: cast(a as int) as b 45 | // In future, we should get first function and get the return type so we can get the b type. 46 | val index = TokenMatcher(tokens, start).index(Array(Food(None, SqlBaseLexer.T__1), Food(None, SqlBaseLexer.AS), Food(None, SqlBaseLexer.IDENTIFIER))) 47 | if (index != -1) { 48 | //index + 1 to skip ) 49 | val aliasName = TokenMatcher(tokens, index + 1).eat(Food(None, SqlBaseLexer.AS)). 50 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).build 51 | if (aliasName.isSuccess) { 52 | return TokenMatcher.resultMatcher(tokens, start, aliasName.get) 53 | } 54 | 55 | } 56 | // if no AS, do nothing 57 | null 58 | } 59 | return attributeM(start) 60 | } 61 | 62 | private def asterriskM(start: Int): TokenMatcher = { 63 | val temp = TokenMatcher(tokens, start). 64 | eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, SqlBaseLexer.T__3)).optional. 65 | eat(Food(None, SqlBaseLexer.ASTERISK)).build 66 | if (temp.isSuccess) { 67 | return TokenMatcher.resultMatcher(tokens, start, temp.get) 68 | } 69 | return funcitonM(start) 70 | } 71 | 72 | override def extractor(start: Int, end: Int): List[String] = { 73 | 74 | 75 | val attrTokens = tokens.slice(start, end) 76 | val token = attrTokens.last 77 | if (token.getType == SqlBaseLexer.ASTERISK) { 78 | return attrTokens match { 79 | case List(tableName, _, _) => 80 | //expand output 81 | ast.selectSuggester.table_info(ast.level). 82 | get(MetaTableKeyWrapper(MetaTableKey(None, None, null), Option(tableName.getText))).orElse{ 83 | ast.selectSuggester.table_info(ast.level). 84 | get(MetaTableKeyWrapper(MetaTableKey(None, None, tableName.getText), None)) 85 | } match { 86 | case Some(table) => 87 | //如果是临时表,那么需要进一步展开 88 | val columns = if (table.key.db == Option(SpecialTableConst.TEMP_TABLE_DB_KEY)) { 89 | autoSuggestContext.metaProvider.search(table.key) match { 90 | case Some(item) => item.columns 91 | case None => List() 92 | } 93 | } else table.columns 94 | columns.map(_.name).toList 95 | case None => List() 96 | } 97 | case List(starAttr) => 98 | val table = ast.tables(tokens).head 99 | ast.selectSuggester.table_info(ast.level). 100 | get(table) match { 101 | case Some(table) => 102 | //如果是临时表,那么需要进一步展开 103 | val columns = if (table.key.db == Option(SpecialTableConst.TEMP_TABLE_DB_KEY)) { 104 | autoSuggestContext.metaProvider.search(table.key) match { 105 | case Some(item) => item.columns 106 | case None => List() 107 | } 108 | } else table.columns 109 | columns.map(_.name).toList 110 | case None => List() 111 | } 112 | } 113 | } 114 | List(token.getText) 115 | } 116 | 117 | override def iterate(start: Int, end: Int, limit: Int): List[String] = { 118 | val attributes = ArrayBuffer[String]() 119 | var matchRes = matcher(start) 120 | var whileLimit = 1000 121 | while (matchRes.isSuccess && whileLimit > 0) { 122 | attributes ++= extractor(matchRes.start, matchRes.get) 123 | whileLimit -= 1 124 | val temp = TokenMatcher(tokens, matchRes.get).eat(Food(None, SqlBaseLexer.T__2)).build 125 | if (temp.isSuccess) { 126 | matchRes = matcher(temp.get) 127 | } else whileLimit = 0 128 | } 129 | attributes.toList 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/AutoSuggestContext.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | import com.intigua.antlr4.autosuggest.LexerWrapper 4 | import org.antlr.v4.runtime.misc.Interval 5 | import org.antlr.v4.runtime.{CharStream, CodePointCharStream, IntStream, Token} 6 | import org.apache.spark.sql.SparkSession 7 | import tech.mlsql.autosuggest.meta._ 8 | import tech.mlsql.autosuggest.preprocess.TablePreprocessor 9 | import tech.mlsql.autosuggest.statement._ 10 | import tech.mlsql.common.utils.log.Logging 11 | import tech.mlsql.common.utils.reflect.ClassPath 12 | import tech.mlsql.common.utils.serder.json.JSONTool 13 | 14 | import scala.collection.JavaConverters._ 15 | import scala.collection.mutable.ArrayBuffer 16 | 17 | object AutoSuggestContext { 18 | private[this] val autoSuggestContext: ThreadLocal[AutoSuggestContext] = new ThreadLocal[AutoSuggestContext] 19 | 20 | def context(): AutoSuggestContext = autoSuggestContext.get 21 | def setContext(ec: AutoSuggestContext): Unit = autoSuggestContext.set(ec) 22 | 23 | val memoryMetaProvider = new MemoryMetaProvider() 24 | var isInit = false 25 | 26 | def init: Unit = { 27 | val funcRegs = ClassPath.from(classOf[AutoSuggestContext].getClassLoader).getTopLevelClasses(POrCLiterals.FUNCTION_PACKAGE).iterator() 28 | while (funcRegs.hasNext) { 29 | val wow = funcRegs.next() 30 | val funcMetaTable = wow.load().newInstance().asInstanceOf[FuncReg].register 31 | MLSQLSQLFunction.funcMetaProvider.register(funcMetaTable) 32 | } 33 | isInit = true 34 | } 35 | 36 | if (!isInit) { 37 | init 38 | } 39 | 40 | } 41 | 42 | /** 43 | * 每个请求都需要实例化一个 44 | */ 45 | class AutoSuggestContext(val session: SparkSession, 46 | val lexer: LexerWrapper, 47 | val rawSQLLexer: LexerWrapper, val options: Map[String, String] = Map()) extends Logging { 48 | 49 | AutoSuggestContext.setContext(this) 50 | private var _debugMode = false 51 | 52 | private var _rawTokens: List[Token] = List() 53 | private var _statements = List[List[Token]]() 54 | private val _tempTableProvider: StatementTempTableProvider = new StatementTempTableProvider() 55 | private var _rawLineNum = 0 56 | private var _rawColumnNum = 0 57 | private var userDefinedProvider: MetaProvider = new MetaProvider { 58 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = None 59 | 60 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List() 61 | } 62 | private var _metaProvider: MetaProvider = new LayeredMetaProvider(tempTableProvider, userDefinedProvider) 63 | 64 | private val _statementProcessors = ArrayBuffer[PreProcessStatement]() 65 | addStatementProcessor(new TablePreprocessor(this)) 66 | 67 | private var _statementSplitter: StatementSplitter = new MLSQLStatementSplitter() 68 | 69 | def setDebugMode(isDebug: Boolean) = { 70 | this._debugMode = isDebug 71 | } 72 | 73 | def isSchemaInferEnabled = { 74 | !options.getOrElse("schemaInferUrl","").isEmpty && session != null 75 | } 76 | 77 | def isInDebugMode = _debugMode 78 | 79 | def metaProvider = _metaProvider 80 | 81 | def statements = { 82 | _statements 83 | } 84 | 85 | def reqParams = { 86 | JSONTool.parseJson[Map[String,String]](AutoSuggestContext.context().options("params")) 87 | } 88 | 89 | def rawTokens = _rawTokens 90 | 91 | def tempTableProvider = { 92 | _tempTableProvider 93 | } 94 | 95 | def setStatementSplitter(_statementSplitter: StatementSplitter) = { 96 | this._statementSplitter = _statementSplitter 97 | this 98 | } 99 | 100 | def setUserDefinedMetaProvider(_metaProvider: MetaProvider) = { 101 | userDefinedProvider = _metaProvider 102 | this._metaProvider = new LayeredMetaProvider(tempTableProvider, userDefinedProvider) 103 | this 104 | } 105 | 106 | def setRootMetaProvider(_metaProvider: MetaProvider) = { 107 | this._metaProvider = _metaProvider 108 | this 109 | } 110 | 111 | def addStatementProcessor(item: PreProcessStatement) = { 112 | this._statementProcessors += item 113 | this 114 | } 115 | 116 | def buildFromString(str: String): AutoSuggestContext = { 117 | build(lexer.tokenizeNonDefaultChannel(str).tokens.asScala.toList) 118 | this 119 | } 120 | 121 | def build(_tokens: List[Token]): AutoSuggestContext = { 122 | _rawTokens = _tokens 123 | _statements = _statementSplitter.split(rawTokens) 124 | // preprocess 125 | _statementProcessors.foreach { sta => 126 | _statements.foreach(sta.process(_)) 127 | } 128 | return this 129 | } 130 | 131 | def toRelativePos(tokenPos: TokenPos): (TokenPos, Int) = { 132 | var skipSize = 0 133 | var targetIndex = 0 134 | var targetPos: TokenPos = null 135 | var targetStaIndex = 0 136 | _statements.zipWithIndex.foreach { case (sta, index) => 137 | val relativePos = tokenPos.pos - skipSize 138 | if (relativePos >= 0 && relativePos < sta.size) { 139 | targetPos = tokenPos.copy(pos = tokenPos.pos - skipSize) 140 | targetStaIndex = index 141 | } 142 | skipSize += sta.size 143 | targetIndex += 1 144 | } 145 | return (targetPos, targetStaIndex) 146 | } 147 | 148 | def suggest(lineNum: Int, columnNum: Int): List[SuggestItem] = { 149 | if (isInDebugMode) { 150 | logInfo(s"lineNum:${lineNum} columnNum:${columnNum}") 151 | } 152 | _rawLineNum = lineNum 153 | _rawColumnNum = columnNum 154 | val tokenPos = LexerUtils.toTokenPos(rawTokens, lineNum, columnNum) 155 | _suggest(tokenPos) 156 | } 157 | 158 | /** 159 | * Notice that the pos in tokenPos is in whole script. 160 | * We need to convert it to the relative pos in every statement 161 | */ 162 | private[autosuggest] def _suggest(tokenPos: TokenPos): List[SuggestItem] = { 163 | assert(_rawColumnNum != 0 || _rawColumnNum != 0, "lineNum and columnNum should be set") 164 | if (isInDebugMode) { 165 | logInfo("Global Pos::" + tokenPos.str + s"::${rawTokens(tokenPos.pos)}") 166 | } 167 | if (tokenPos.pos == -1) { 168 | return firstWords 169 | } 170 | val (relativeTokenPos, index) = toRelativePos(tokenPos) 171 | 172 | if (isInDebugMode) { 173 | logInfo(s"Relative Pos in ${index}-statement ::" + relativeTokenPos.str) 174 | logInfo(s"${index}-statement:\n${_statements(index).map(_.getText).mkString(" ")}") 175 | } 176 | val items = _statements(index).headOption.map(_.getText) match { 177 | case Some("load") => 178 | val suggester = new LoadSuggester(this, _statements(index), relativeTokenPos) 179 | suggester.suggest() 180 | case Some("select") => 181 | // we should recompute the token pos since they use spark sql lexer instead of 182 | val selectTokens = _statements(index) 183 | val startLineNum = selectTokens.head.getLine 184 | val relativeLineNum = _rawLineNum - startLineNum + 1 // startLineNum start from 1 185 | val relativeColumnNum = _rawColumnNum - selectTokens.head.getCharPositionInLine // charPos is start from 0 186 | 187 | if (isInDebugMode) { 188 | logInfo(s"select[${index}] relativeLineNum:${relativeLineNum} relativeColumnNum:${relativeColumnNum}") 189 | } 190 | val relativeTokenPos = LexerUtils.toTokenPosForSparkSQL(LexerUtils.toRawSQLTokens(this, selectTokens), relativeLineNum, relativeColumnNum) 191 | if (isInDebugMode) { 192 | logInfo(s"select[${index}] relativeTokenPos:${relativeTokenPos.str}") 193 | } 194 | val suggester = new SelectSuggester(this, selectTokens, relativeTokenPos) 195 | suggester.suggest() 196 | case Some(value) => firstWords.filter(_.name.startsWith(value)) 197 | case None => firstWords 198 | } 199 | items.distinct 200 | } 201 | 202 | private val firstWords = List("load", "select", "include", "register", "run", "train", "save", "set").map(SuggestItem(_, SpecialTableConst.KEY_WORD_TABLE, Map())).toList 203 | } 204 | 205 | class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { 206 | override def consume(): Unit = wrapped.consume 207 | 208 | override def getSourceName(): String = wrapped.getSourceName 209 | 210 | override def index(): Int = wrapped.index 211 | 212 | override def mark(): Int = wrapped.mark 213 | 214 | override def release(marker: Int): Unit = wrapped.release(marker) 215 | 216 | override def seek(where: Int): Unit = wrapped.seek(where) 217 | 218 | override def size(): Int = wrapped.size 219 | 220 | override def getText(interval: Interval): String = { 221 | // ANTLR 4.7's CodePointCharStream implementations have bugs when 222 | // getText() is called with an empty stream, or intervals where 223 | // the start > end. See 224 | // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix 225 | // that is not yet in a released ANTLR artifact. 226 | if (size() > 0 && (interval.b - interval.a >= 0)) { 227 | wrapped.getText(interval) 228 | } else { 229 | "" 230 | } 231 | } 232 | 233 | override def LA(i: Int): Int = { 234 | val la = wrapped.LA(i) 235 | if (la == 0 || la == IntStream.EOF) la 236 | else Character.toUpperCase(la) 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/AutoSuggester.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | /** 4 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | class AutoSuggester { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/FuncReg.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | import tech.mlsql.autosuggest.meta.MetaTable 4 | 5 | /** 6 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | trait FuncReg { 9 | val DOC = "doc" 10 | val COLUMN = "column" 11 | val IS_AGG = "agg" 12 | val YES = "yes" 13 | val NO = "no" 14 | val DEFAULT_VALUE = "default" 15 | 16 | def register: MetaTable 17 | } 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/FunctionUtils.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | import tech.mlsql.autosuggest.MLSQLSQLFunction.DB_KEY 4 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey} 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | 9 | object MLSQLSQLFunction { 10 | val DB_KEY = "__FUNC__" 11 | val RETURN_KEY = "__RETURN__" 12 | val funcMetaProvider = new FuncMetaProvider() 13 | 14 | def apply(name: String): MLSQLSQLFunction = new MLSQLSQLFunction(name) 15 | } 16 | 17 | class MLSQLSQLFunction(name: String) { 18 | 19 | private val _params = ArrayBuffer[MetaTableColumn]() 20 | private val _returnParam = ArrayBuffer[MetaTableColumn]() 21 | private var _tableKey: MetaTableKey = MetaTableKey(None, Option(DB_KEY), name) 22 | private val _funcDescParam = ArrayBuffer[MetaTableColumn]() 23 | 24 | def funcName(name: String) = { 25 | _tableKey = MetaTableKey(None, Option(DB_KEY), name) 26 | this 27 | } 28 | 29 | def funcParam = { 30 | new MLSQLFuncParam(this) 31 | } 32 | 33 | def desc(extra: Map[String, String]) = { 34 | assert(_funcDescParam.size == 0, "desc can only invoke once") 35 | _funcDescParam += MetaTableColumn(MLSQLSQLFunction.DB_KEY, "", false, extra) 36 | this 37 | } 38 | 39 | def returnParam(dataType: String, isNull: Boolean, extra: Map[String, String]) = { 40 | assert(_returnParam.size == 0, "returnParam can only invoke once") 41 | _returnParam += MetaTableColumn(MLSQLSQLFunction.RETURN_KEY, dataType, isNull, extra) 42 | this 43 | } 44 | 45 | def addColumn(column: MetaTableColumn) = { 46 | _params += column 47 | this 48 | } 49 | 50 | def build = { 51 | MetaTable(_tableKey, (_funcDescParam ++ _returnParam ++ _params).toList) 52 | } 53 | 54 | } 55 | 56 | class MLSQLFuncParam(_func: MLSQLSQLFunction) { 57 | def param(name: String, dataType: String) = { 58 | _func.addColumn(MetaTableColumn(name, dataType, true, Map())) 59 | this 60 | } 61 | 62 | def param(name: String, dataType: String, isNull: Boolean) = { 63 | _func.addColumn(MetaTableColumn(name, dataType, isNull, Map())) 64 | this 65 | } 66 | 67 | def param(name: String, dataType: String, isNull: Boolean, extra: Map[String, String]) = { 68 | _func.addColumn(MetaTableColumn(name, dataType, isNull, extra)) 69 | this 70 | } 71 | 72 | def func = { 73 | _func 74 | } 75 | } 76 | 77 | class FuncMetaProvider extends MetaProvider { 78 | private val funcs = scala.collection.mutable.HashMap[MetaTableKey, MetaTable]() 79 | 80 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 81 | funcs.get(key) 82 | } 83 | 84 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = { 85 | funcs.map(_._2).toList 86 | } 87 | 88 | def register(func: MetaTable) = { 89 | this.funcs.put(func.key, func) 90 | this 91 | } 92 | 93 | // for test 94 | def clear = { 95 | funcs.clear() 96 | } 97 | } 98 | 99 | 100 | object DataType { 101 | val STRING = "string" 102 | val INT = "integer" 103 | val LONG = "long" 104 | val DOUBLE = "double" 105 | val NUMBER = "number" 106 | val DATE = "date" 107 | val DATE_TIMESTAMP = "date_timestamp" 108 | val ARRAY = "array" 109 | val MAP = "map" 110 | } 111 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/POrCLiterals.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | /** 4 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | object POrCLiterals { 7 | val FUNCTION_PACKAGE = "tech.mlsql.autosuggest.funcs" 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/SpecialTableConst.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableKey} 4 | 5 | /** 6 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | object SpecialTableConst { 9 | val KEY_WORD = "__KEY__WORD__" 10 | val DATA_SOURCE_KEY = "__DATA__SOURCE__" 11 | val OPTION_KEY = "__OPTION__" 12 | val TEMP_TABLE_DB_KEY = "__TEMP_TABLE__" 13 | 14 | val OTHER_TABLE_KEY = "__OTHER__TABLE__" 15 | 16 | val TOP_LEVEL_KEY = "__TOP_LEVEL__" 17 | 18 | def KEY_WORD_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.KEY_WORD), List()) 19 | 20 | def DATA_SOURCE_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.DATA_SOURCE_KEY), List()) 21 | 22 | def OPTION_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.OPTION_KEY), List()) 23 | 24 | def OTHER_TABLE = MetaTable(MetaTableKey(None, None, SpecialTableConst.OTHER_TABLE_KEY), List()) 25 | 26 | def tempTable(name: String) = MetaTable(MetaTableKey(None, Option(TEMP_TABLE_DB_KEY), name), List()) 27 | 28 | def subQueryAliasTable = { 29 | MetaTableKey(None, None, null) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/TokenPos.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest 2 | 3 | /** 4 | * TokenPos mark the cursor position 5 | */ 6 | object TokenPosType { 7 | val END = -1 8 | val CURRENT = -2 9 | val NEXT = -3 10 | } 11 | 12 | /** 13 | * input "load" the cursor can be (situation 1) in the end of load 14 | * or (situation 2) with white space 15 | * 16 | * Looks like this: 17 | * 1. loa[cursor] => tab hit => loa[d] => pos:0 currentOrNext:TokenPosType.CURRENT offsetInToken:3 18 | * 2. load [cursor] => tab hit => load [DataSource list] => pos:0 currentOrNext:TokenPosType.NEXT offsetInToken:0 19 | * 2.1 load hi[cursor]ve.`db.table` => tab hit => load hi[ve] => pos:1 currentOrNext:TokenPosType.CURRENT offsetInToken:2 20 | * 21 | * the second situation can also show like this: 22 | * 23 | * load [cursor]hive.`db.table` 24 | * 25 | * we should still show DataSource List. This means we don't care the token after the cursor when we provide 26 | * suggestion list. 27 | * 28 | */ 29 | case class TokenPos(pos: Int, currentOrNext: Int, offsetInToken: Int = -1) { 30 | def str = { 31 | val posType = currentOrNext match { 32 | case TokenPosType.NEXT => "next" 33 | case TokenPosType.CURRENT => "current" 34 | } 35 | s"TokenPos: Index(${pos}) currentOrNext(${posType}) offsetInToken(${offsetInToken})" 36 | } 37 | } 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/app/Constants.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.app 2 | 3 | /** 4 | * 2019-07-21 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | object Constants { 7 | val DB = "db" 8 | val HIVE = "hive" 9 | val JSON = "json" 10 | val TEXT = "text" 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/app/MLSQLAutoSuggestApp.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.app 2 | 3 | import com.intigua.antlr4.autosuggest.{DefaultToCharStream, LexerWrapper, RawSQLToCharStream, ReflectionLexerAndParserFactory} 4 | import org.apache.spark.sql.catalyst.parser.{SqlBaseLexer, SqlBaseParser} 5 | import streaming.dsl.ScriptSQLExec 6 | import streaming.dsl.parser.{DSLSQLLexer, DSLSQLParser} 7 | import tech.mlsql.app.CustomController 8 | import tech.mlsql.autosuggest.meta.{MLSQLEngineMetaProvider, RestMetaProvider} 9 | import tech.mlsql.autosuggest.{AutoSuggestContext, MLSQLSQLFunction} 10 | import tech.mlsql.common.utils.serder.json.JSONTool 11 | import tech.mlsql.runtime.AppRuntimeStore 12 | import tech.mlsql.version.VersionCompatibility 13 | 14 | /** 15 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com) 16 | */ 17 | class MLSQLAutoSuggestApp extends tech.mlsql.app.App with VersionCompatibility { 18 | override def run(args: Seq[String]): Unit = { 19 | AutoSuggestContext.init 20 | AppRuntimeStore.store.registerController("autoSuggest", classOf[AutoSuggestController].getName) 21 | AppRuntimeStore.store.registerController("registerTable", classOf[RegisterTableController].getName) 22 | AppRuntimeStore.store.registerController("sqlFunctions", classOf[SQLFunctionController].getName) 23 | } 24 | 25 | override def supportedVersions: Seq[String] = { 26 | Seq("1.5.0-SNAPSHOT", "1.5.0", "1.6.0-SNAPSHOT", "1.6.0") 27 | } 28 | } 29 | 30 | object AutoSuggestController { 31 | val lexerAndParserfactory = new ReflectionLexerAndParserFactory(classOf[DSLSQLLexer], classOf[DSLSQLParser]); 32 | val mlsqlLexer = new LexerWrapper(lexerAndParserfactory, new DefaultToCharStream) 33 | 34 | val lexerAndParserfactory2 = new ReflectionLexerAndParserFactory(classOf[SqlBaseLexer], classOf[SqlBaseParser]); 35 | val sqlLexer = new LexerWrapper(lexerAndParserfactory2, new RawSQLToCharStream) 36 | 37 | def getSchemaRegistry = { 38 | 39 | new SchemaRegistry(getSession) 40 | } 41 | 42 | def getSession = { 43 | val session = if (ScriptSQLExec.context() != null) ScriptSQLExec.context().execListener.sparkSession else Standalone.sparkSession 44 | session 45 | } 46 | } 47 | 48 | class SQLFunctionController extends CustomController { 49 | override def run(params: Map[String, String]): String = { 50 | JSONTool.toJsonStr(MLSQLSQLFunction.funcMetaProvider.list(Map())) 51 | } 52 | } 53 | 54 | class RegisterTableController extends CustomController { 55 | override def run(params: Map[String, String]): String = { 56 | def hasParam(str: String) = params.contains(str) 57 | 58 | def paramOpt(name: String) = { 59 | if (hasParam(name)) Option(params(name)) else None 60 | } 61 | 62 | val prefix = paramOpt("prefix") 63 | val db = paramOpt("db") 64 | require(hasParam("table"), "table is required") 65 | require(hasParam("schema"), "schema is required") 66 | val table = params("table") 67 | 68 | val session = AutoSuggestController.getSchemaRegistry 69 | params.getOrElse("schemaType", "") match { 70 | case Constants.DB => session.createTableFromDBSQL(prefix, db, table, params("schema")) 71 | case Constants.HIVE => session.createTableFromHiveSQL(prefix, db, table, params("schema")) 72 | case Constants.JSON => session.createTableFromJson(prefix, db, table, params("schema")) 73 | case _ => session.createTableFromDBSQL(prefix, db, table, params("schema")) 74 | } 75 | JSONTool.toJsonStr(Map("success" -> true)) 76 | } 77 | } 78 | 79 | class AutoSuggestController extends CustomController { 80 | override def run(params: Map[String, String]): String = { 81 | val sql = params("sql") 82 | val lineNum = params("lineNum").toInt 83 | val columnNum = params("columnNum").toInt 84 | val isDebug = params.getOrElse("isDebug", "false").toBoolean 85 | val size = params.getOrElse("size", "30").toInt 86 | val includeTableMeta = params.getOrElse("includeTableMeta", "false").toBoolean 87 | 88 | val schemaInferUrl = params.getOrElse("schemaInferUrl", "") 89 | 90 | val enableMemoryProvider = params.getOrElse("enableMemoryProvider", "true").toBoolean 91 | val session = AutoSuggestController.getSession 92 | val context = new AutoSuggestContext(session, 93 | AutoSuggestController.mlsqlLexer, 94 | AutoSuggestController.sqlLexer,Map("schemaInferUrl"->schemaInferUrl,"params"->JSONTool.toJsonStr(params))) 95 | context.setDebugMode(isDebug) 96 | if (enableMemoryProvider) { 97 | context.setUserDefinedMetaProvider(AutoSuggestContext.memoryMetaProvider) 98 | } 99 | 100 | if(!schemaInferUrl.isEmpty){ 101 | context.setUserDefinedMetaProvider(new MLSQLEngineMetaProvider()) 102 | } 103 | 104 | val searchUrl = params.get("searchUrl") 105 | val listUrl = params.get("listUrl") 106 | (searchUrl, listUrl) match { 107 | case (Some(searchUrl), Some(listUrl)) => 108 | context.setUserDefinedMetaProvider(new RestMetaProvider(searchUrl, listUrl)) 109 | case (None, None) => 110 | } 111 | 112 | 113 | 114 | var resItems = context.buildFromString(sql).suggest(lineNum, columnNum).take(size) 115 | if (!includeTableMeta) { 116 | resItems = resItems.map { item => 117 | item.copy(metaTable = null) 118 | }.take(size) 119 | } 120 | JSONTool.toJsonStr(resItems) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/app/RDSchema.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.app 2 | 3 | import java.sql.{JDBCType, SQLException} 4 | 5 | import com.alibaba.druid.sql.SQLUtils 6 | import com.alibaba.druid.sql.ast.SQLDataType 7 | import com.alibaba.druid.sql.ast.statement.{SQLColumnDefinition, SQLCreateTableStatement} 8 | import com.alibaba.druid.sql.repository.SchemaRepository 9 | import org.apache.spark.sql.jdbc.JdbcDialects 10 | import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} 11 | import org.apache.spark.sql.types._ 12 | 13 | import scala.collection.JavaConverters._ 14 | import scala.math.min 15 | 16 | /** 17 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 18 | */ 19 | class RDSchema(dbType: String) { 20 | 21 | private val repository = new SchemaRepository(dbType) 22 | 23 | def createTable(sql: String) = { 24 | repository.console(sql) 25 | val tableName = SQLUtils.parseStatements(sql, dbType).get(0).asInstanceOf[SQLCreateTableStatement]. 26 | getTableSource.getName.getSimpleName 27 | SQLUtils.normalize(tableName) 28 | } 29 | 30 | def getTableSchema(table: String) = { 31 | val dialect = JdbcDialects.get(s"jdbc:${dbType}") 32 | 33 | 34 | def extractfieldSize = (dataType: SQLDataType) => { 35 | dataType.getArguments.asScala.map { f => 36 | try { 37 | f.toString.toInt 38 | } catch { 39 | case e: Exception => 0 40 | } 41 | 42 | }.headOption 43 | } 44 | 45 | val fields = repository.findTable(table).getStatement.asInstanceOf[SQLCreateTableStatement]. 46 | getTableElementList().asScala.filter(f => f.isInstanceOf[SQLColumnDefinition]). 47 | map { 48 | _.asInstanceOf[SQLColumnDefinition] 49 | }.map { column => 50 | 51 | val columnName = column.getName.getSimpleName 52 | val dataType = RawDBTypeToJavaType.convert(dbType, column.getDataType.getName) 53 | val isNullable = !column.containsNotNullConstaint() 54 | 55 | val fieldSize = extractfieldSize(column.getDataType) match { 56 | case Some(i) => i 57 | case None => 0 58 | } 59 | val fieldScale = 0 60 | 61 | val columnType = dialect.getCatalystType(dataType, column.getDataType.getName, fieldSize, new MetadataBuilder()). 62 | getOrElse( 63 | getCatalystType(dataType, fieldSize, fieldScale, false)) 64 | 65 | StructField(columnName, columnType, isNullable) 66 | } 67 | new StructType(fields.toArray) 68 | 69 | } 70 | 71 | private def getCatalystType( 72 | sqlType: Int, 73 | precision: Int, 74 | scale: Int, 75 | signed: Boolean): DataType = { 76 | val answer = sqlType match { 77 | // scalastyle:off 78 | case java.sql.Types.ARRAY => null 79 | case java.sql.Types.BIGINT => if (signed) { 80 | LongType 81 | } else { 82 | DecimalType(20, 0) 83 | } 84 | case java.sql.Types.BINARY => BinaryType 85 | case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks 86 | case java.sql.Types.BLOB => BinaryType 87 | case java.sql.Types.BOOLEAN => BooleanType 88 | case java.sql.Types.CHAR => StringType 89 | case java.sql.Types.CLOB => StringType 90 | case java.sql.Types.DATALINK => null 91 | case java.sql.Types.DATE => DateType 92 | case java.sql.Types.DECIMAL 93 | if precision != 0 || scale != 0 => DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) 94 | case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT 95 | case java.sql.Types.DISTINCT => null 96 | case java.sql.Types.DOUBLE => DoubleType 97 | case java.sql.Types.FLOAT => FloatType 98 | case java.sql.Types.INTEGER => if (signed) { 99 | IntegerType 100 | } else { 101 | LongType 102 | } 103 | case java.sql.Types.JAVA_OBJECT => null 104 | case java.sql.Types.LONGNVARCHAR => StringType 105 | case java.sql.Types.LONGVARBINARY => BinaryType 106 | case java.sql.Types.LONGVARCHAR => StringType 107 | case java.sql.Types.NCHAR => StringType 108 | case java.sql.Types.NCLOB => StringType 109 | case java.sql.Types.NULL => null 110 | case java.sql.Types.NUMERIC 111 | if precision != 0 || scale != 0 => DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) 112 | case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT 113 | case java.sql.Types.NVARCHAR => StringType 114 | case java.sql.Types.OTHER => null 115 | case java.sql.Types.REAL => DoubleType 116 | case java.sql.Types.REF => StringType 117 | case java.sql.Types.REF_CURSOR => null 118 | case java.sql.Types.ROWID => LongType 119 | case java.sql.Types.SMALLINT => IntegerType 120 | case java.sql.Types.SQLXML => StringType 121 | case java.sql.Types.STRUCT => StringType 122 | case java.sql.Types.TIME => TimestampType 123 | case java.sql.Types.TIME_WITH_TIMEZONE 124 | => null 125 | case java.sql.Types.TIMESTAMP => TimestampType 126 | case java.sql.Types.TIMESTAMP_WITH_TIMEZONE 127 | => null 128 | case java.sql.Types.TINYINT => IntegerType 129 | case java.sql.Types.VARBINARY => BinaryType 130 | case java.sql.Types.VARCHAR => StringType 131 | case _ => 132 | throw new SQLException("Unrecognized SQL type " + sqlType) 133 | // scalastyle:on 134 | } 135 | 136 | if (answer == null) { 137 | throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName) 138 | } 139 | answer 140 | } 141 | } 142 | 143 | 144 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/app/RawDBTypeToJavaType.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.app 2 | 3 | import com.alibaba.druid.util.JdbcConstants 4 | 5 | 6 | /** 7 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | object RawDBTypeToJavaType { 10 | def convert(dbType: String, name: String) = { 11 | dbType match { 12 | case JdbcConstants.MYSQL => MysqlType.valueOf(name.toUpperCase()).getJdbcType 13 | case _ => throw new RuntimeException(s"dbType ${dbType} is not supported yet") 14 | } 15 | 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/app/SchemaRegistry.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.app 2 | 3 | import com.alibaba.druid.util.JdbcConstants 4 | import org.apache.spark.sql.SparkSession 5 | import org.apache.spark.sql.types.{DataType, StructType} 6 | import tech.mlsql.autosuggest.AutoSuggestContext 7 | import tech.mlsql.autosuggest.meta.MetaTableKey 8 | import tech.mlsql.autosuggest.utils.SchemaUtils 9 | 10 | /** 11 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 12 | */ 13 | class SchemaRegistry(_spark: SparkSession) { 14 | val spark = _spark 15 | 16 | def createTableFromDBSQL(prefix: Option[String], db: Option[String], tableName: String, createSQL: String) = { 17 | val rd = new RDSchema(JdbcConstants.MYSQL) 18 | val tableName = rd.createTable(createSQL) 19 | val schema = rd.getTableSchema(tableName) 20 | val table = MetaTableKey(prefix, db, tableName) 21 | AutoSuggestContext.memoryMetaProvider.register(table, SchemaUtils.toMetaTable(table, schema)); 22 | } 23 | 24 | def createTableFromHiveSQL(prefix: Option[String], db: Option[String], tableName: String, createSQL: String) = { 25 | spark.sql(createSQL) 26 | val schema = spark.table(tableName).schema 27 | val table = MetaTableKey(prefix, db, tableName) 28 | AutoSuggestContext.memoryMetaProvider.register(table, SchemaUtils.toMetaTable(table, schema)); 29 | 30 | } 31 | 32 | 33 | def createTableFromJson(prefix: Option[String], db: Option[String], tableName: String, schemaJson: String) = { 34 | val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] 35 | val table = MetaTableKey(prefix, db, tableName) 36 | AutoSuggestContext.memoryMetaProvider.register(table, SchemaUtils.toMetaTable(table, schema)); 37 | } 38 | 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/app/Standalone.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.app 2 | 3 | import net.csdn.ServiceFramwork 4 | import net.csdn.annotation.rest.At 5 | import net.csdn.bootstrap.Application 6 | import net.csdn.modules.http.RestRequest.Method 7 | import net.csdn.modules.http.{ApplicationController, ViewType} 8 | import org.apache.spark.sql.SparkSession 9 | import streaming.dsl.ScriptSQLExec 10 | import tech.mlsql.autosuggest.AutoSuggestContext 11 | import tech.mlsql.common.utils.shell.command.ParamsUtil 12 | 13 | import scala.collection.JavaConverters._ 14 | 15 | /** 16 | * 11/6/2020 WilliamZhu(allwefantasy@gmail.com) 17 | */ 18 | object Standalone extends { 19 | var sparkSession: SparkSession = null 20 | 21 | def main(args: Array[String]): Unit = { 22 | AutoSuggestContext.init 23 | val params = new ParamsUtil(args) 24 | val enableMLSQL = params.getParam("enableMLSQL", "false").toBoolean 25 | if (enableMLSQL) { 26 | sparkSession = SparkSession.builder().appName("local").master("local[*]").getOrCreate() 27 | } 28 | val applicationYamlName = params.getParam("config", "application.yml") 29 | 30 | ServiceFramwork.applicaionYamlName(applicationYamlName) 31 | ServiceFramwork.scanService.setLoader(classOf[Standalone]) 32 | Application.main(args) 33 | 34 | } 35 | } 36 | 37 | class Standalone 38 | 39 | class SuggestController extends ApplicationController { 40 | @At(path = Array("/run/script"), types = Array(Method.POST)) 41 | def runScript = { 42 | 43 | if (param("executeMode", "") == "registerTable") { 44 | val respStr = new RegisterTableController().run(params().asScala.toMap) 45 | render(200, respStr, ViewType.json) 46 | 47 | } 48 | if (param("executeMode", "") == "sqlFunctions") { 49 | val respStr = new SQLFunctionController().run(params().asScala.toMap) 50 | render(200, respStr, ViewType.json) 51 | 52 | } 53 | val respStr = new AutoSuggestController().run(params.asScala.toMap) 54 | render(200, respStr, ViewType.json) 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/ast/NoneToken.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.ast 2 | 3 | import org.antlr.v4.runtime.{CharStream, Token, TokenSource} 4 | 5 | /** 6 | * 24/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | class NoneToken(token: Token) extends Token { 9 | override def getText: String = NoneToken.TEXT 10 | 11 | override def getType: Int = NoneToken.TYPE 12 | 13 | override def getLine: Int = token.getLine 14 | 15 | override def getCharPositionInLine: Int = token.getCharPositionInLine 16 | 17 | override def getChannel: Int = token.getChannel 18 | 19 | override def getTokenIndex: Int = token.getTokenIndex 20 | 21 | override def getStartIndex: Int = token.getStartIndex 22 | 23 | override def getStopIndex: Int = token.getStopIndex 24 | 25 | override def getTokenSource: TokenSource = token.getTokenSource 26 | 27 | override def getInputStream: CharStream = token.getInputStream 28 | } 29 | 30 | object NoneToken { 31 | val TEXT = "__NONE__" 32 | val TYPE = -3 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/ast/TableTree.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.ast 2 | 3 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableKey} 4 | import tech.mlsql.autosuggest.statement.MetaTableKeyWrapper 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | object TableTree { 9 | def ROOT = { 10 | TableTree(-1, MetaTableKeyWrapper(MetaTableKey(None, None, null), None), None, ArrayBuffer()) 11 | } 12 | 13 | } 14 | 15 | case class TableTree(level: Int, key: MetaTableKeyWrapper, table: Option[MetaTable], subNodes: ArrayBuffer[TableTree]) { 16 | 17 | 18 | def children = subNodes 19 | 20 | def collectByLevel(level: Int) = { 21 | val buffer = ArrayBuffer[TableTree]() 22 | visitDown(0) { case (table, _level) => 23 | if (_level == level) { 24 | buffer += table 25 | } 26 | } 27 | buffer.toList 28 | } 29 | 30 | def visitDown(level: Int)(rule: PartialFunction[(TableTree, Int), Unit]): Unit = { 31 | rule.apply((this, level)) 32 | this.children.map(_.visitDown(level + 1)(rule)) 33 | } 34 | 35 | def visitUp(level: Int)(rule: PartialFunction[(TableTree, Int), Unit]): Unit = { 36 | this.children.map(_.visitUp(level + 1)(rule)) 37 | rule.apply((this, level)) 38 | } 39 | 40 | def fastEquals(other: TableTree): Boolean = { 41 | this.eq(other) || this == other 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/dsl/TokenMatcher.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.dsl 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 5 | import streaming.dsl.parser.DSLSQLLexer 6 | 7 | import scala.collection.mutable.ArrayBuffer 8 | 9 | /** 10 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com) 11 | * 12 | */ 13 | class TokenMatcher(tokens: List[Token], val start: Int) { 14 | val foods = ArrayBuffer[FoodWrapper]() 15 | var cacheResult = -2 16 | private var direction: String = MatcherDirection.FORWARD 17 | 18 | def forward = { 19 | assert(foods.size == 0, "this function should be invoke before eat") 20 | direction = MatcherDirection.FORWARD 21 | this 22 | } 23 | 24 | def back = { 25 | assert(foods.size == 0, "this function should be invoke before eat") 26 | direction = MatcherDirection.BACK 27 | this 28 | } 29 | 30 | def eat(food: Food*) = { 31 | foods += FoodWrapper(AndOrFood(food.toList, true), false) 32 | this 33 | } 34 | 35 | def eatOneAny = { 36 | foods += FoodWrapper(AndOrFood(List(Food(None, -2)), true), false) 37 | this 38 | } 39 | 40 | /** 41 | * 42 | * 一直前进 直到遇到我们需要的,成功返回最后的index值,否则返回-1 43 | */ 44 | def orIndex(_foods: Array[Food], upperBound: Int = tokens.size) = { 45 | if (foods.size != 0) { 46 | throw new RuntimeException("eat/optional/asStart should not before index") 47 | } 48 | direction match { 49 | case MatcherDirection.FORWARD => 50 | var targetIndex = -1 51 | (start until upperBound).foreach { idx => 52 | if (targetIndex == -1) { 53 | // step by step until success 54 | var matchValue = -1 55 | _foods.zipWithIndex.foreach { case (food, _) => 56 | if (matchValue == -1 && matchToken(food, idx) != -1) { 57 | matchValue = 0 58 | } 59 | } 60 | if (matchValue != -1) { 61 | targetIndex = idx 62 | } 63 | } 64 | 65 | } 66 | targetIndex 67 | case MatcherDirection.BACK => 68 | var _start = start 69 | var targetIndex = -1 70 | while (_start >= 0) { 71 | if (targetIndex == -1) { 72 | // step by step until success 73 | var matchValue = -1 74 | _foods.zipWithIndex.foreach { case (food, _) => 75 | if (matchValue == -1 && matchToken(food, _start) != -1) { 76 | matchValue = 0 77 | } 78 | } 79 | if (matchValue != -1) { 80 | targetIndex = _start 81 | } 82 | } 83 | _start = _start - 1 84 | } 85 | targetIndex 86 | } 87 | 88 | } 89 | 90 | // find the first match 91 | def index(_foods: Array[Food], upperBound: Int = tokens.size) = { 92 | if (foods.size != 0) { 93 | throw new RuntimeException("eat/optional/asStart should not before index") 94 | } 95 | assert(direction == MatcherDirection.FORWARD, "index only support forward") 96 | var targetIndex = -1 97 | (start until upperBound).foreach { idx => 98 | if (targetIndex == -1) { 99 | // step by step until success 100 | var matchValue = 0 101 | _foods.zipWithIndex.foreach { case (food, idx2) => 102 | if (matchValue == 0 && matchToken(food, idx + idx2) == -1) { 103 | matchValue = -1 104 | } 105 | } 106 | if (matchValue != -1) { 107 | targetIndex = idx 108 | } 109 | } 110 | 111 | } 112 | targetIndex 113 | 114 | } 115 | 116 | def asStart(food: Food, offset: Int = 0) = { 117 | if (foods.size != 0) { 118 | throw new RuntimeException("eat/optional should not before asStart") 119 | } 120 | var targetIndex = -1 121 | (start until tokens.size).foreach { idx => 122 | if (targetIndex == -1) { 123 | val index = matchToken(food, idx) 124 | if (index != -1) { 125 | targetIndex = index 126 | } 127 | } 128 | 129 | } 130 | TokenMatcher(tokens, targetIndex + offset) 131 | } 132 | 133 | def optional = { 134 | foods.lastOption.foreach(_.optional = true) 135 | this 136 | } 137 | 138 | private def matchToken(food: Food, currentIndex: Int): Int = { 139 | if (currentIndex < 0) return -1 140 | if (currentIndex >= tokens.size) return -1 141 | if (food.tp == -2) { 142 | return currentIndex 143 | } 144 | food.name match { 145 | case Some(name) => if (tokens(currentIndex).getType == food.tp && tokens(currentIndex).getText == name) { 146 | currentIndex 147 | } else -1 148 | case None => 149 | if (tokens(currentIndex).getType == food.tp) { 150 | currentIndex 151 | } else -1 152 | } 153 | } 154 | 155 | private def forwardBuild: TokenMatcher = { 156 | var currentIndex = start 157 | var isFail = false 158 | 159 | 160 | foods.foreach { foodw => 161 | 162 | if (currentIndex >= tokens.size && !foodw.optional) { 163 | isFail = true 164 | } else { 165 | val stepSize = foodw.foods.count 166 | var matchValue = 0 167 | foodw.foods.foods.zipWithIndex.foreach { case (food, idx) => 168 | if (matchValue == 0 && matchToken(food, currentIndex + idx) == -1) { 169 | matchValue = -1 170 | } 171 | } 172 | if (foodw.optional) { 173 | if (matchValue != -1) { 174 | currentIndex = currentIndex + stepSize 175 | } 176 | } else { 177 | if (matchValue != -1) { 178 | currentIndex = currentIndex + stepSize 179 | 180 | } else { 181 | //mark fail 182 | isFail = true 183 | } 184 | } 185 | } 186 | } 187 | 188 | val targetIndex = if (isFail) -1 else currentIndex 189 | cacheResult = targetIndex 190 | this 191 | } 192 | 193 | private def backBuild: TokenMatcher = { 194 | var currentIndex = start 195 | var isFail = false 196 | 197 | 198 | foods.foreach { foodw => 199 | // if out of bound then mark fail 200 | if (currentIndex <= -1 && !foodw.optional) { 201 | isFail = true 202 | } else { 203 | val stepSize = foodw.foods.count 204 | var matchValue = 0 205 | foodw.foods.foods.zipWithIndex.foreach { case (food, idx) => 206 | if (matchValue == 0 && matchToken(food, currentIndex - idx) == -1) { 207 | matchValue = -1 208 | } 209 | } 210 | if (foodw.optional) { 211 | if (matchValue != -1) { 212 | currentIndex = currentIndex - stepSize 213 | } 214 | } else { 215 | if (matchValue != -1) { 216 | currentIndex = currentIndex - stepSize 217 | 218 | } else { 219 | //mark fail 220 | isFail = true 221 | } 222 | } 223 | } 224 | 225 | } 226 | 227 | if (!isFail && currentIndex == -1) { 228 | currentIndex = 0 229 | } 230 | val targetIndex = if (isFail) -1 else currentIndex 231 | cacheResult = targetIndex 232 | this 233 | } 234 | 235 | def build: TokenMatcher = { 236 | direction match { 237 | case MatcherDirection.FORWARD => 238 | forwardBuild 239 | case MatcherDirection.BACK => 240 | backBuild 241 | } 242 | } 243 | 244 | def get = { 245 | if (this.cacheResult == -2) this.build 246 | this.cacheResult 247 | } 248 | 249 | def isSuccess = { 250 | if (this.cacheResult == -2) this.build 251 | this.cacheResult != -1 252 | } 253 | 254 | def getMatchTokens = { 255 | direction match { 256 | case MatcherDirection.BACK => tokens.slice(get + 1, start + 1) 257 | case MatcherDirection.FORWARD => tokens.slice(start, get) 258 | } 259 | 260 | } 261 | } 262 | 263 | object MatcherDirection { 264 | val FORWARD = "forward" 265 | val BACK = "back" 266 | } 267 | 268 | object TokenTypeWrapper { 269 | val LEFT_BRACKET = SqlBaseLexer.T__0 //( 270 | val RIGHT_BRACKET = SqlBaseLexer.T__1 //) 271 | val COMMA = SqlBaseLexer.T__2 //, 272 | val DOT = SqlBaseLexer.T__3 //. 273 | val LEFT_SQUARE_BRACKET = SqlBaseLexer.T__7 //[ 274 | val RIGHT_SQUARE_BRACKET = SqlBaseLexer.T__8 //] 275 | val COLON = SqlBaseLexer.T__9 //: 276 | 277 | val LIST = List(LEFT_BRACKET, RIGHT_BRACKET, COMMA, DOT, LEFT_SQUARE_BRACKET, RIGHT_SQUARE_BRACKET, COLON) 278 | val MAP = LIST.map((_, 1)).toMap 279 | } 280 | 281 | object MLSQLTokenTypeWrapper { 282 | val DOT = DSLSQLLexer.T__0 283 | } 284 | 285 | object TokenMatcher { 286 | def apply(tokens: List[Token], start: Int): TokenMatcher = new TokenMatcher(tokens, start) 287 | 288 | def resultMatcher(tokens: List[Token], start: Int, stop: Int) = { 289 | val temp = new TokenMatcher(tokens, start) 290 | temp.cacheResult = stop 291 | temp 292 | } 293 | 294 | def SQL_SPLITTER_KEY_WORDS = List( 295 | SqlBaseLexer.SELECT, 296 | SqlBaseLexer.FROM, 297 | SqlBaseLexer.JOIN, 298 | SqlBaseLexer.WHERE, 299 | SqlBaseLexer.GROUP, 300 | SqlBaseLexer.ON, 301 | SqlBaseLexer.BY, 302 | SqlBaseLexer.LIMIT, 303 | SqlBaseLexer.ORDER 304 | ) 305 | } 306 | 307 | case class Food(name: Option[String], tp: Int) 308 | 309 | case class FoodWrapper(foods: AndOrFood, var optional: Boolean) 310 | 311 | case class AndOrFood(foods: List[Food], var and: Boolean) { 312 | def count = { 313 | if (and) foods.size 314 | else 1 315 | } 316 | } 317 | 318 | object TokenWalker { 319 | def apply(tokens: List[Token], start: Int): TokenWalker = new TokenWalker(tokens, start) 320 | } 321 | 322 | class TokenWalker(tokens: List[Token], start: Int) { 323 | 324 | var currentToken: Option[Token] = Option(tokens(start)) 325 | var currentIndex = start 326 | 327 | def nextSafe: TokenWalker = { 328 | val token = if ((currentIndex + 1) < tokens.size) { 329 | currentIndex += 1 330 | Option(tokens(currentIndex)) 331 | } else None 332 | currentToken = token 333 | this 334 | } 335 | 336 | def range: TokenCharRange = { 337 | if (currentToken.isEmpty) return TokenCharRange(-1, -1) 338 | val start = currentToken.get.getCharPositionInLine 339 | val end = currentToken.get.getCharPositionInLine + currentToken.get.getText.size 340 | TokenCharRange(start, end) 341 | } 342 | } 343 | 344 | case class TokenCharRange(start: Int, end: Int) 345 | 346 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/funcs/Count.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.funcs 2 | 3 | import tech.mlsql.autosuggest.{DataType, FuncReg, MLSQLSQLFunction} 4 | 5 | /** 6 | * 14/7/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | class Count extends FuncReg { 9 | override def register = { 10 | 11 | val func = MLSQLSQLFunction.apply("count").desc(Map( 12 | "zhDoc" -> 13 | """ 14 | |count,统计数目,可单独或者配合group by 使用 15 | |""".stripMargin, 16 | IS_AGG -> YES 17 | )). 18 | funcParam. 19 | param("column", DataType.STRING, false, Map("zhDoc" -> "列名或者*或者常熟",COLUMN->YES)). 20 | func. 21 | returnParam(DataType.NUMBER, true, Map( 22 | "zhDoc" -> 23 | """ 24 | |long类型数字 25 | |""".stripMargin 26 | )). 27 | build 28 | func 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/funcs/Splitter.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.funcs 2 | 3 | import tech.mlsql.autosuggest.{DataType, FuncReg, MLSQLSQLFunction} 4 | 5 | /** 6 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | class Splitter extends FuncReg { 9 | 10 | override def register = { 11 | val func = MLSQLSQLFunction.apply("split").desc(Map( 12 | "zhDoc" -> 13 | """ 14 | |split函数。用于切割字符串,返回字符串数组. 15 | |示例: split("a,b",",") == [a,b] 16 | |""".stripMargin 17 | )). 18 | funcParam. 19 | param("str", DataType.STRING, false, Map("zhDoc" -> "待切割字符", COLUMN -> YES)). 20 | param("pattern", DataType.STRING, false, Map("zhDoc" -> "分隔符", DEFAULT_VALUE -> ",")). 21 | func. 22 | returnParam(DataType.ARRAY, true, Map( 23 | "zhDoc" -> 24 | """ 25 | |返回值是数组类型 26 | |""".stripMargin 27 | )). 28 | build 29 | func 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/LayeredMetaProvider.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | /** 4 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | class LayeredMetaProvider(tempTableProvider: StatementTempTableProvider, userDefinedProvider: MetaProvider) extends MetaProvider { 7 | def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = { 8 | tempTableProvider.search(key).orElse { 9 | userDefinedProvider.search(key) 10 | } 11 | } 12 | 13 | def list(extra: Map[String, String] = Map()): List[MetaTable] = { 14 | tempTableProvider.list(extra) ++ userDefinedProvider.list(extra) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/MLSQLEngineMetaProvider.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | import java.util.UUID 4 | 5 | import org.apache.http.client.fluent.{Form, Request} 6 | import org.apache.http.util.EntityUtils 7 | import tech.mlsql.autosuggest.AutoSuggestContext 8 | import tech.mlsql.common.utils.serder.json.JSONTool 9 | 10 | 11 | /** 12 | * 26/7/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class MLSQLEngineMetaProvider() extends MetaProvider { 15 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 16 | val form = Form.form() 17 | 18 | if (key.prefix.isEmpty) return None 19 | var path = "" 20 | 21 | if (key.db.isDefined) { 22 | path += key.db.get + "." 23 | } 24 | 25 | path += key.table 26 | val tableName = UUID.randomUUID().toString.replaceAll("-", "") 27 | 28 | val sql = 29 | s""" 30 | |load ${key.prefix.get}.`${path}` where header="true" as ${tableName};!desc ${tableName}; 31 | |""".stripMargin 32 | val params = JSONTool.parseJson[Map[String, String]](AutoSuggestContext.context().options("params")) 33 | params.foreach { case (k, v) => 34 | if (k != "sql" && k!= "executeMode") { 35 | form.add(k, v) 36 | } 37 | } 38 | form.add("sql", sql) 39 | val resp = Request.Post(params("schemaInferUrl")).bodyForm(form.build()).execute().returnResponse() 40 | if (resp.getStatusLine.getStatusCode == 200) { 41 | val str = EntityUtils.toString(resp.getEntity) 42 | val columns = JSONTool.parseJson[List[TableSchemaColumn]](str) 43 | val table = MetaTable(key, columns.map { item => 44 | MetaTableColumn(item.col_name, item.data_type, true, Map()) 45 | }) 46 | Option(table) 47 | } else None 48 | } 49 | 50 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = { 51 | List() 52 | } 53 | } 54 | 55 | case class TableSchemaColumn(col_name: String, data_type: String, comment: Option[String]) 56 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/MemoryMetaProvider.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | import scala.collection.JavaConverters._ 4 | 5 | /** 6 | * 15/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | class MemoryMetaProvider extends MetaProvider { 9 | private val cache = new java.util.concurrent.ConcurrentHashMap[MetaTableKey, MetaTable]() 10 | 11 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = { 12 | if (cache.containsKey(key)) Option(cache.get(key)) else None 13 | } 14 | 15 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = { 16 | cache.values().asScala.toList 17 | } 18 | 19 | def register(key: MetaTableKey, value: MetaTable) = { 20 | cache.put(key, value) 21 | this 22 | } 23 | 24 | def unRegister(key: MetaTableKey) = { 25 | cache.remove(key) 26 | this 27 | } 28 | 29 | def clear = { 30 | cache.clear() 31 | this 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/MetaProvider.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | /** 4 | * Function should also be table and at the same time 5 | * the columns are treated as parameters. 6 | * Without MetaProvider supporting, we can not 7 | * suggest column and functions . 8 | * 9 | * If the search returns None, this means it's not a table 10 | * or the table we are searching is not exists. 11 | */ 12 | trait MetaProvider { 13 | def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] 14 | 15 | def list(extra: Map[String, String] = Map()): List[MetaTable] 16 | } 17 | 18 | 19 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/RestMetaProvider.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | import org.apache.http.Header 4 | import org.apache.http.client.fluent.{Form, Request} 5 | import org.apache.http.util.EntityUtils 6 | import tech.mlsql.common.utils.serder.json.JSONTool 7 | 8 | /** 9 | * 15/6/2020 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class RestMetaProvider(searchUrl: String, listUrl: String) extends MetaProvider { 12 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = { 13 | val form = Form.form() 14 | if (key.prefix.isDefined) { 15 | form.add("prefix", key.prefix.get) 16 | } 17 | if (key.db.isDefined) { 18 | form.add("db", key.db.get) 19 | } 20 | form.add("table", key.table) 21 | val resp = Request.Post(searchUrl).bodyForm(form.build()).execute().returnResponse() 22 | if (resp.getStatusLine.getStatusCode == 200) { 23 | val metaTable = JSONTool.parseJson[MetaTable](EntityUtils.toString(resp.getEntity)) 24 | Option(metaTable) 25 | } else None 26 | } 27 | 28 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = { 29 | val form = Form.form() 30 | extra.foreach { case (k, v) => 31 | form.add(k, v) 32 | } 33 | val resp = Request.Post(listUrl).addHeader("","").bodyForm(form.build()).execute().returnResponse() 34 | if (resp.getStatusLine.getStatusCode == 200) { 35 | val metaTables = JSONTool.parseJson[List[MetaTable]](EntityUtils.toString(resp.getEntity)) 36 | metaTables 37 | } else List() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/StatementTempTableProvider.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | /** 4 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | class StatementTempTableProvider extends MetaProvider { 7 | private val cache = scala.collection.mutable.HashMap[String, MetaTable]() 8 | 9 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = { 10 | cache.get(key.table) 11 | } 12 | 13 | def register(name: String, metaTable: MetaTable) = { 14 | cache += (name -> metaTable) 15 | this 16 | } 17 | 18 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = cache.values.toList 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/meta/meta_protocal.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.meta 2 | 3 | case class MetaTableKey(prefix: Option[String], db: Option[String], table: String) 4 | 5 | case class MetaTable(key: MetaTableKey, columns: List[MetaTableColumn]) 6 | 7 | case class MetaTableColumn(name: String, dataType: String, isNull: Boolean, extra: Map[String, String]) 8 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/preprocess/TablePreprocessor.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.preprocess 2 | 3 | import org.antlr.v4.runtime.Token 4 | import streaming.dsl.parser.DSLSQLLexer 5 | import tech.mlsql.autosuggest.SpecialTableConst.TEMP_TABLE_DB_KEY 6 | import tech.mlsql.autosuggest.dsl.{Food, MLSQLTokenTypeWrapper, TokenMatcher} 7 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableColumn, MetaTableKey} 8 | import tech.mlsql.autosuggest.statement.{PreProcessStatement, SelectSuggester} 9 | import tech.mlsql.autosuggest.{AutoSuggestContext, SpecialTableConst, TokenPos, TokenPosType} 10 | 11 | /** 12 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class TablePreprocessor(context: AutoSuggestContext) extends PreProcessStatement { 15 | 16 | def cleanStr(str: String) = { 17 | if (str.startsWith("`") || str.startsWith("\"") || (str.startsWith("'") && !str.startsWith("'''"))) 18 | str.substring(1, str.length - 1) 19 | else str 20 | } 21 | 22 | /** 23 | * 24 | * load 语句和select语句比较特殊 25 | * Load语句要获取 真实表 26 | * select 语句要获取最后的select 语句 27 | * 28 | * load语句获取真实表的时候会加一个prefix前缀,该值等于load语句里的format 29 | */ 30 | def process(statement: List[Token]): Unit = { 31 | val tempTableProvider = context.tempTableProvider 32 | val tempMatcher = TokenMatcher(statement, statement.size - 2).back.eat(Food(None, DSLSQLLexer.IDENTIFIER), Food(None, DSLSQLLexer.AS)).build 33 | 34 | if (tempMatcher.isSuccess) { 35 | val tableName = tempMatcher.getMatchTokens.last.getText 36 | val defaultTable = SpecialTableConst.tempTable(tableName) 37 | val table = statement(0).getText.toLowerCase match { 38 | case "load" => 39 | val formatMatcher = TokenMatcher(statement, 1). 40 | eat(Food(None, DSLSQLLexer.IDENTIFIER), 41 | Food(None, MLSQLTokenTypeWrapper.DOT), 42 | Food(None, DSLSQLLexer.BACKQUOTED_IDENTIFIER)).build 43 | if (formatMatcher.isSuccess) { 44 | 45 | formatMatcher.getMatchTokens.map(_.getText) match { 46 | case List(format, _, path) => 47 | cleanStr(path).split("\\.", 2) match { 48 | case Array(db, table) => 49 | // if(context.isSchemaInferEnabled){ 50 | // 51 | // } 52 | context.metaProvider.search(MetaTableKey(Option(format), Option(db), table)).getOrElse(defaultTable) 53 | case Array(table) => 54 | context.metaProvider.search(MetaTableKey(Option(format), None, table)).getOrElse(defaultTable) 55 | } 56 | } 57 | } else { 58 | defaultTable 59 | } 60 | case "select" => 61 | //statement.size - 3 是为了移除 最后的as table语句 62 | val selectSuggester = new SelectSuggester(context, statement.slice(0, statement.size - 3), TokenPos(0, TokenPosType.NEXT, -1)) 63 | val columns = selectSuggester.sqlAST.output(selectSuggester.tokens).map { name => 64 | MetaTableColumn(name, null, true, Map()) 65 | } 66 | MetaTable(MetaTableKey(None, Option(TEMP_TABLE_DB_KEY), tableName), columns) 67 | case _ => defaultTable 68 | } 69 | 70 | tempTableProvider.register(tableName, table) 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/LexerUtils.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.antlr.v4.runtime.misc.Interval 5 | import streaming.dsl.parser.DSLSQLLexer 6 | import tech.mlsql.autosuggest.dsl.{MLSQLTokenTypeWrapper, TokenTypeWrapper} 7 | import tech.mlsql.autosuggest.{AutoSuggestContext, TokenPos, TokenPosType} 8 | 9 | import scala.collection.JavaConverters._ 10 | 11 | /** 12 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | object LexerUtils { 15 | 16 | def toRawSQLTokens(autoSuggestContext: AutoSuggestContext, wow: List[Token]): List[Token] = { 17 | val originalText = toRawSQLStr(autoSuggestContext, wow) 18 | val newTokens = autoSuggestContext.rawSQLLexer.tokenizeNonDefaultChannel(originalText).tokens.asScala.toList 19 | return newTokens 20 | } 21 | 22 | def toRawSQLStr(autoSuggestContext: AutoSuggestContext, wow: List[Token]): String = { 23 | val start = wow.head.getStartIndex 24 | val stop = wow.last.getStopIndex 25 | 26 | val input = wow.head.getTokenSource.asInstanceOf[DSLSQLLexer]._input 27 | val interval = new Interval(start, stop) 28 | val originalText = input.getText(interval) 29 | originalText 30 | } 31 | 32 | def filterPrefixIfNeeded(candidates: List[SuggestItem], tokens: List[Token], tokenPos: TokenPos) = { 33 | if (tokenPos.offsetInToken != 0) { 34 | candidates.filter(s => s.name.startsWith(tokens(tokenPos.pos).getText.substring(0, tokenPos.offsetInToken))) 35 | } else candidates 36 | } 37 | 38 | def tableTokenPrefix(tokens: List[Token], tokenPos: TokenPos): String = { 39 | var temp = tokens(tokenPos.pos).getText.substring(0, tokenPos.offsetInToken) 40 | if (tokenPos.pos > 1 && tokens(tokenPos.pos - 1).getType == TokenTypeWrapper.DOT) { 41 | temp = tokens(tokenPos.pos - 2).getText + "." + temp 42 | } 43 | temp 44 | } 45 | 46 | 47 | /** 48 | * 49 | * @param tokens 50 | * @param lineNum 行号,从1开始计数 51 | * @param colNum 列号,从1开始计数 52 | * @return TokenPos 中的pos则是从0开始计数 53 | */ 54 | def toTokenPos(tokens: List[Token], lineNum: Int, colNum: Int): TokenPos = { 55 | /** 56 | * load hi[cursor]... in token 57 | * load [cursor] out token 58 | * load[cursor] in token 59 | */ 60 | 61 | if (tokens.size == 0) { 62 | return TokenPos(-1, TokenPosType.NEXT, -1) 63 | } 64 | 65 | val oneLineTokens = tokens.zipWithIndex.filter { case (token, index) => 66 | token.getLine == lineNum 67 | } 68 | 69 | val firstToken = oneLineTokens.headOption match { 70 | case Some(head) => head 71 | case None => 72 | tokens.zipWithIndex.filter { case (token, index) => 73 | token.getLine == lineNum - 1 74 | }.head 75 | } 76 | val lastToken = oneLineTokens.lastOption match { 77 | case Some(last) => last 78 | case None => 79 | tokens.zipWithIndex.filter { case (token, index) => 80 | token.getLine == lineNum + 1 81 | }.last 82 | } 83 | 84 | if (colNum < firstToken._1.getCharPositionInLine) { 85 | return TokenPos(firstToken._2 - 1, TokenPosType.NEXT, 0) 86 | } 87 | 88 | if (colNum > lastToken._1.getCharPositionInLine + lastToken._1.getText.size) { 89 | return TokenPos(lastToken._2, TokenPosType.NEXT, 0) 90 | } 91 | 92 | if (colNum > lastToken._1.getCharPositionInLine 93 | && colNum <= lastToken._1.getCharPositionInLine + lastToken._1.getText.size 94 | && 95 | (lastToken._1.getType != DSLSQLLexer.UNRECOGNIZED 96 | && lastToken._1.getType != MLSQLTokenTypeWrapper.DOT) 97 | ) { 98 | return TokenPos(lastToken._2, TokenPosType.CURRENT, colNum - lastToken._1.getCharPositionInLine) 99 | } 100 | oneLineTokens.map { case (token, index) => 101 | val start = token.getCharPositionInLine 102 | val end = token.getCharPositionInLine + token.getText.size 103 | //紧邻一个token的后面,没有空格,一般情况下是当做前一个token的一部分,用户还没写完,但是如果 104 | //这个token是 [(,).]等,则不算 105 | if (colNum == end && (1 <= token.getType) 106 | && ( 107 | token.getType == DSLSQLLexer.UNRECOGNIZED 108 | || token.getType == MLSQLTokenTypeWrapper.DOT 109 | )) { 110 | TokenPos(index, TokenPosType.NEXT, 0) 111 | } else if (start < colNum && colNum <= end) { 112 | // in token 113 | TokenPos(index, TokenPosType.CURRENT, colNum - start) 114 | } else if (colNum <= start) { 115 | TokenPos(index - 1, TokenPosType.NEXT, 0) 116 | } else { 117 | TokenPos(-2, -2, -2) 118 | } 119 | 120 | 121 | }.filterNot(_.pos == -2).head 122 | } 123 | 124 | def toTokenPosForSparkSQL(tokens: List[Token], lineNum: Int, colNum: Int): TokenPos = { 125 | /** 126 | * load hi[cursor]... in token 127 | * load [cursor] out token 128 | * load[cursor] in token 129 | */ 130 | 131 | if (tokens.size == 0) { 132 | return TokenPos(-1, TokenPosType.NEXT, -1) 133 | } 134 | 135 | val oneLineTokens = tokens.zipWithIndex.filter { case (token, index) => 136 | token.getLine == lineNum 137 | } 138 | 139 | val firstToken = oneLineTokens.headOption match { 140 | case Some(head) => head 141 | case None => 142 | tokens.zipWithIndex.filter { case (token, index) => 143 | token.getLine == lineNum - 1 144 | }.head 145 | } 146 | val lastToken = oneLineTokens.lastOption match { 147 | case Some(last) => last 148 | case None => 149 | tokens.zipWithIndex.filter { case (token, index) => 150 | token.getLine == lineNum + 1 151 | }.last 152 | } 153 | 154 | if (colNum < firstToken._1.getCharPositionInLine) { 155 | return TokenPos(firstToken._2 - 1, TokenPosType.NEXT, 0) 156 | } 157 | 158 | if (colNum > lastToken._1.getCharPositionInLine + lastToken._1.getText.size) { 159 | return TokenPos(lastToken._2, TokenPosType.NEXT, 0) 160 | } 161 | 162 | if (colNum > lastToken._1.getCharPositionInLine 163 | && colNum <= lastToken._1.getCharPositionInLine + lastToken._1.getText.size 164 | && !TokenTypeWrapper.MAP.contains(lastToken._1.getType) 165 | 166 | ) { 167 | return TokenPos(lastToken._2, TokenPosType.CURRENT, colNum - lastToken._1.getCharPositionInLine) 168 | } 169 | oneLineTokens.map { case (token, index) => 170 | val start = token.getCharPositionInLine 171 | val end = token.getCharPositionInLine + token.getText.size 172 | //紧邻一个token的后面,没有空格,一般情况下是当做前一个token的一部分,用户还没写完,但是如果 173 | //这个token是 [(,).]等,则不算 174 | if (colNum == end && (1 <= token.getType) 175 | && ( 176 | TokenTypeWrapper.MAP.contains(token.getType) 177 | )) { 178 | TokenPos(index, TokenPosType.NEXT, 0) 179 | } else if (start < colNum && colNum <= end) { 180 | // in token 181 | TokenPos(index, TokenPosType.CURRENT, colNum - start) 182 | } else if (colNum <= start) { 183 | TokenPos(index - 1, TokenPosType.NEXT, 0) 184 | } else { 185 | TokenPos(-2, -2, -2) 186 | } 187 | 188 | 189 | }.filterNot(_.pos == -2).head 190 | } 191 | 192 | def isInWhereContext(tokens: List[Token], tokenPos: Int): Boolean = { 193 | if (tokenPos < 1) return false 194 | var wherePos = -1 195 | (1 until tokenPos).foreach { index => 196 | if (tokens(index).getType == DSLSQLLexer.WHERE || tokens(index).getType == DSLSQLLexer.OPTIONS) { 197 | wherePos = index 198 | } 199 | } 200 | if (wherePos != -1) { 201 | if (wherePos == tokenPos || wherePos == tokenPos - 1) return true 202 | val noEnd = (wherePos until tokenPos).filter(index => 203 | tokens(index).getType != DSLSQLLexer.AS && tokens(index).getType != DSLSQLLexer.PARTITIONBY).isEmpty 204 | if (noEnd) return true 205 | 206 | } 207 | 208 | return false 209 | } 210 | 211 | def isWhereKey(tokens: List[Token], tokenPos: Int): Boolean = { 212 | LexerUtils.isInWhereContext(tokens, tokenPos) && (tokens(tokenPos).getText == "and" || tokens(tokenPos - 1).getText == "and") 213 | 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/LoadSuggester.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import streaming.core.datasource.{DataSourceRegistry, MLSQLSourceInfo} 5 | import streaming.dsl.parser.DSLSQLLexer 6 | import tech.mlsql.autosuggest.dsl.{Food, MLSQLTokenTypeWrapper, TokenMatcher} 7 | import tech.mlsql.autosuggest.{AutoSuggestContext, SpecialTableConst, TokenPos, TokenPosType} 8 | 9 | import scala.collection.mutable 10 | 11 | /** 12 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | * 14 | * 15 | */ 16 | class LoadSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val _tokenPos: TokenPos) extends StatementSuggester 17 | with SuggesterRegister { 18 | 19 | private val subInstances = new mutable.HashMap[String, StatementSuggester]() 20 | 21 | register(classOf[LoadPathSuggester]) 22 | register(classOf[LoadFormatSuggester]) 23 | register(classOf[LoadOptionsSuggester]) 24 | register(classOf[LoadPathQuoteSuggester]) 25 | 26 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = { 27 | val instance = clzz.getConstructor(classOf[LoadSuggester]).newInstance(this).asInstanceOf[StatementSuggester] 28 | subInstances.put(instance.name, instance) 29 | this 30 | } 31 | 32 | override def isMatch(): Boolean = { 33 | _tokens.headOption.map(_.getType) match { 34 | case Some(DSLSQLLexer.LOAD) => true 35 | case _ => false 36 | } 37 | } 38 | 39 | private def keywordSuggest: List[SuggestItem] = { 40 | _tokenPos match { 41 | case TokenPos(pos, TokenPosType.NEXT, offsetInToken) => 42 | var items = List[SuggestItem]() 43 | val temp = TokenMatcher(_tokens, pos).back. 44 | eat(Food(None, DSLSQLLexer.BACKQUOTED_IDENTIFIER)). 45 | eat(Food(None, MLSQLTokenTypeWrapper.DOT)). 46 | build 47 | if (temp.isSuccess) { 48 | items = List(SuggestItem("where", SpecialTableConst.KEY_WORD_TABLE, Map()), SuggestItem("as", SpecialTableConst.KEY_WORD_TABLE, Map())) 49 | } 50 | items 51 | 52 | case _ => List() 53 | } 54 | 55 | } 56 | 57 | override def suggest(): List[SuggestItem] = { 58 | keywordSuggest ++ defaultSuggest(subInstances.toMap) 59 | } 60 | 61 | 62 | override def name: String = "load" 63 | } 64 | 65 | class LoadFormatSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils { 66 | override def isMatch(): Boolean = { 67 | 68 | (tokenPos.pos, tokenPos.currentOrNext) match { 69 | case (0, TokenPosType.NEXT) => true 70 | case (1, TokenPosType.CURRENT) => true 71 | case (_, _) => false 72 | } 73 | 74 | } 75 | 76 | override def suggest(): List[SuggestItem] = { 77 | // datasource type suggest 78 | val sources = (DataSourceRegistry.allSourceNames.toSet.toSeq ++ Seq( 79 | "parquet", "csv", "jsonStr", "csvStr", "json", "text", "orc", "kafka", "kafka8", "kafka9", "crawlersql", "image", 80 | "script", "hive", "xml", "mlsqlAPI", "mlsqlConf" 81 | )).toList 82 | LexerUtils.filterPrefixIfNeeded( 83 | sources.map(SuggestItem(_, SpecialTableConst.DATA_SOURCE_TABLE, 84 | Map("desc" -> "DataSource"))), 85 | tokens, tokenPos) 86 | 87 | } 88 | 89 | 90 | override def tokens: List[Token] = loadSuggester._tokens 91 | 92 | override def tokenPos: TokenPos = loadSuggester._tokenPos 93 | 94 | override def name: String = "format" 95 | } 96 | 97 | class LoadOptionsSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils { 98 | override def isMatch(): Boolean = { 99 | backAndFirstIs(DSLSQLLexer.OPTIONS) || backAndFirstIs(DSLSQLLexer.WHERE) 100 | } 101 | 102 | override def suggest(): List[SuggestItem] = { 103 | val source = tokens(1) 104 | val datasources = DataSourceRegistry.fetch(source.getText, Map[String, String]()) match { 105 | case Some(ds) => ds.asInstanceOf[MLSQLSourceInfo]. 106 | explainParams(loadSuggester.context.session).collect(). 107 | map(row => (row.getString(0), row.getString(1))). 108 | toList 109 | case None => List() 110 | } 111 | LexerUtils.filterPrefixIfNeeded(datasources.map(tuple => 112 | SuggestItem(tuple._1, SpecialTableConst.OPTION_TABLE, Map("desc" -> tuple._2))), 113 | tokens, tokenPos) 114 | 115 | } 116 | 117 | override def name: String = "options" 118 | 119 | override def tokens: List[Token] = loadSuggester._tokens 120 | 121 | override def tokenPos: TokenPos = loadSuggester._tokenPos 122 | } 123 | 124 | class LoadPathQuoteSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils { 125 | override def name: String = "pathQuote" 126 | 127 | override def isMatch(): Boolean = { 128 | val temp = TokenMatcher(tokens, tokenPos.pos).back. 129 | eat(Food(None, MLSQLTokenTypeWrapper.DOT)). 130 | eat(Food(None, DSLSQLLexer.IDENTIFIER)). 131 | eat(Food(None, DSLSQLLexer.LOAD)).build 132 | temp.isSuccess 133 | } 134 | 135 | override def suggest(): List[SuggestItem] = { 136 | LexerUtils.filterPrefixIfNeeded(List(SuggestItem("``", SpecialTableConst.OTHER_TABLE, Map("desc" -> "path or table"))), 137 | tokens, tokenPos) 138 | } 139 | 140 | override def tokens: List[Token] = loadSuggester._tokens 141 | 142 | override def tokenPos: TokenPos = loadSuggester._tokenPos 143 | } 144 | 145 | //Here you can implement Hive table / HDFS Path auto suggestion 146 | class LoadPathSuggester(loadSuggester: LoadSuggester) extends StatementSuggester with StatementUtils { 147 | override def isMatch(): Boolean = { 148 | false 149 | } 150 | 151 | override def suggest(): List[SuggestItem] = { 152 | List() 153 | } 154 | 155 | override def name: String = "path" 156 | 157 | 158 | override def tokens: List[Token] = loadSuggester._tokens 159 | 160 | override def tokenPos: TokenPos = loadSuggester._tokenPos 161 | } 162 | 163 | 164 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/MLSQLStatementSplitter.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import streaming.dsl.parser.DSLSQLLexer 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | /** 9 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class MLSQLStatementSplitter extends StatementSplitter { 12 | override def split(_tokens: List[Token]): List[List[Token]] = { 13 | val _statements = ArrayBuffer[List[Token]]() 14 | val tokens = _tokens.zipWithIndex 15 | var start = 0 16 | var end = 0 17 | tokens.foreach { case (token, index) => 18 | // statement end 19 | if (token.getType == DSLSQLLexer.T__1) { 20 | end = index 21 | _statements.append(tokens.filter(p => p._2 >= start && p._2 <= end).map(_._1)) 22 | start = index + 1 23 | } 24 | 25 | } 26 | // clean the last statement without ender 27 | val theLeft = tokens.filter(p => p._2 >= start && p._2 <= tokens.size).map(_._1).toList 28 | if (theLeft.size > 0) { 29 | _statements.append(theLeft) 30 | } 31 | _statements.toList 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/MatchAndExtractor.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import tech.mlsql.autosuggest.dsl.TokenMatcher 4 | 5 | /** 6 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | trait MatchAndExtractor[T] { 9 | def matcher(start: Int): TokenMatcher 10 | 11 | def extractor(start: Int, end: Int): List[T] 12 | 13 | def iterate(start: Int, end: Int, limit: Int = 100): List[T] 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/PreProcessStatement.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | 5 | /** 6 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | trait PreProcessStatement { 9 | def process(statement: List[Token]): Unit 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/RegisterSuggester.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import streaming.dsl.parser.DSLSQLLexer 5 | import tech.mlsql.autosuggest.{AutoSuggestContext, TokenPos} 6 | 7 | import scala.collection.mutable 8 | 9 | /** 10 | * 30/6/2020 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | class RegisterSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val _tokenPos: TokenPos) extends StatementSuggester 13 | with SuggesterRegister { 14 | private val subInstances = new mutable.HashMap[String, StatementSuggester]() 15 | 16 | 17 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = { 18 | val instance = clzz.getConstructor(classOf[LoadSuggester]).newInstance(this).asInstanceOf[StatementSuggester] 19 | subInstances.put(instance.name, instance) 20 | this 21 | } 22 | 23 | override def isMatch(): Boolean = { 24 | _tokens.headOption.map(_.getType) match { 25 | case Some(DSLSQLLexer.REGISTER) => true 26 | case _ => false 27 | } 28 | } 29 | 30 | 31 | override def suggest(): List[SuggestItem] = { 32 | List() 33 | } 34 | 35 | 36 | override def name: String = "register" 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/SelectStatementUtils.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 5 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper} 6 | import tech.mlsql.autosuggest._ 7 | import tech.mlsql.common.utils.log.Logging 8 | 9 | /** 10 | * 8/6/2020 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | trait SelectStatementUtils extends Logging { 13 | def selectSuggester: SelectSuggester 14 | 15 | def tokenPos: TokenPos 16 | 17 | def tokens: List[Token] 18 | 19 | 20 | def levelFromTokenPos = { 21 | var targetLevel = 0 22 | selectSuggester.sqlAST.visitDown(0) { case (ast, _level) => 23 | if (tokenPos.pos >= ast.start && tokenPos.pos < ast.stop) targetLevel = _level 24 | } 25 | targetLevel 26 | } 27 | 28 | def getASTFromTokenPos: Option[SingleStatementAST] = { 29 | var targetAst: Option[SingleStatementAST] = None 30 | selectSuggester.sqlAST.visitUp(0) { case (ast, level) => 31 | if (targetAst == None && (ast.start <= tokenPos.pos && tokenPos.pos < ast.stop)) { 32 | targetAst = Option(ast) 33 | } 34 | } 35 | targetAst 36 | } 37 | 38 | def table_info = { 39 | selectSuggester.table_info.get(levelFromTokenPos) 40 | } 41 | 42 | 43 | def tableSuggest(): List[SuggestItem] = { 44 | val tempStart = tokenPos.currentOrNext match { 45 | case TokenPosType.CURRENT => 46 | tokenPos.pos - 1 47 | case TokenPosType.NEXT => 48 | tokenPos.pos 49 | } 50 | 51 | val temp = TokenMatcher(tokens, tempStart).back.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build 52 | 53 | if (selectSuggester.context.isInDebugMode) { 54 | logInfo(s"tableSuggest, table_info:\n") 55 | logInfo("========tableSuggest start=======") 56 | if (table_info.isDefined) { 57 | table_info.get.foreach { case (key, metaTable) => 58 | logInfo(key.toString + s"=>\n key: ${metaTable.key} \n columns: ${metaTable.columns.map(_.name).mkString(",")}") 59 | } 60 | } 61 | 62 | logInfo(s"Final: suggest ${!temp.isSuccess}") 63 | 64 | logInfo("========tableSuggest end =======") 65 | } 66 | 67 | if (temp.isSuccess) return List() 68 | 69 | table_info match { 70 | case Some(tb) => tb.map { case (key, value) => 71 | (key.aliasName.getOrElse(key.metaTableKey.table), value) 72 | }.map { case (name, table) => 73 | SuggestItem(name, table, Map()) 74 | }.toList 75 | case None => 76 | val tokenPrefix = LexerUtils.tableTokenPrefix(tokens, tokenPos) 77 | val owner = AutoSuggestContext.context().reqParams.getOrElse("owner", "") 78 | val extraParam = Map("searchPrefix" -> tokenPrefix, "owner" -> owner) 79 | selectSuggester.context.metaProvider.list(extraParam).map { item => 80 | SuggestItem(item.key.table, item, Map()) 81 | } 82 | } 83 | } 84 | 85 | def attributeSuggest(): List[SuggestItem] = { 86 | val tempStart = tokenPos.currentOrNext match { 87 | case TokenPosType.CURRENT => 88 | tokenPos.pos - 1 89 | case TokenPosType.NEXT => 90 | tokenPos.pos 91 | } 92 | if (selectSuggester.context.isInDebugMode) { 93 | logInfo(s"attributeSuggest:\n") 94 | logInfo("========attributeSuggest start=======") 95 | } 96 | 97 | def allOutput = { 98 | /** 99 | * 优先推荐别名 100 | */ 101 | val res = table_info.get.filter { item => item._2.key == SpecialTableConst.subQueryAliasTable && item._1.aliasName.isDefined }.flatMap { table => 102 | if (selectSuggester.context.isInDebugMode) { 103 | val columns = table._2.columns.map { item => s"${item.name} ${item}" }.mkString("\n") 104 | logInfo(s"TARGET table: ${table._1} \n columns: \n[${columns}] ") 105 | } 106 | table._2.columns.map(column => SuggestItem(column.name, table._2, Map())) 107 | 108 | }.toList 109 | 110 | if (res.isEmpty) { 111 | if (selectSuggester.context.isInDebugMode) { 112 | val tables = table_info.get.map { case (key, table) => 113 | val columns = table.columns.map { item => s"${item.name} ${item}" }.mkString("\n") 114 | s"${key}:\n ${columns}" 115 | }.mkString("\n") 116 | logInfo(s"ALL tables: \n ${tables}") 117 | } 118 | table_info.get.flatMap { case (_, metaTable) => 119 | metaTable.columns.map(column => SuggestItem(column.name, metaTable, Map())).toList 120 | }.toList 121 | } else res 122 | 123 | 124 | } 125 | 126 | val temp = TokenMatcher(tokens, tempStart).back.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build 127 | if (selectSuggester.context.isInDebugMode) { 128 | logInfo(s"Try to match attribute db prefix: Status(${temp.isSuccess})") 129 | } 130 | val res = if (temp.isSuccess) { 131 | val table = temp.getMatchTokens.head.getText 132 | table_info.get.filter { case (key, value) => 133 | (key.aliasName.isDefined && key.aliasName.get == table) || key.metaTableKey.table == table 134 | }.headOption match { 135 | case Some(table) => 136 | if (selectSuggester.context.isInDebugMode) { 137 | logInfo(s"table[${table._1}] found, return ${table._2.key} columns.") 138 | } 139 | table._2.columns.map(column => SuggestItem(column.name, table._2, Map())).toList 140 | case None => 141 | if (selectSuggester.context.isInDebugMode) { 142 | logInfo(s"No table found, so return all table[${table_info.get.map { case (_, metaTable) => metaTable.key.toString }}] columns.") 143 | } 144 | allOutput 145 | } 146 | } else allOutput 147 | 148 | if (selectSuggester.context.isInDebugMode) { 149 | logInfo("========attributeSuggest end=======") 150 | } 151 | res 152 | 153 | } 154 | 155 | def functionSuggest(): List[SuggestItem] = { 156 | if (selectSuggester.context.isInDebugMode) { 157 | logInfo(s"functionSuggest:\n") 158 | logInfo("========functionSuggest start=======") 159 | } 160 | 161 | def allOutput = { 162 | MLSQLSQLFunction.funcMetaProvider.list(Map()).map(item => SuggestItem(item.key.table, item, Map())) 163 | } 164 | 165 | val tempStart = tokenPos.currentOrNext match { 166 | case TokenPosType.CURRENT => 167 | tokenPos.pos - 1 168 | case TokenPosType.NEXT => 169 | tokenPos.pos 170 | } 171 | 172 | // 如果匹配上了,说明是字段,那么就不应该提示函数了 173 | val temp = TokenMatcher(tokens, tempStart).back.eat(Food(None, TokenTypeWrapper.DOT), Food(None, SqlBaseLexer.IDENTIFIER)).build 174 | val res = if (temp.isSuccess) { 175 | List() 176 | } else allOutput 177 | 178 | if (selectSuggester.context.isInDebugMode) { 179 | logInfo(s"functions: ${allOutput.map(_.name).mkString(",")}") 180 | logInfo("========functionSuggest end=======") 181 | } 182 | res 183 | 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/SelectSuggester.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 5 | import streaming.dsl.parser.DSLSQLLexer 6 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper} 7 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableColumn, MetaTableKey} 8 | import tech.mlsql.autosuggest.{AutoSuggestContext, SpecialTableConst, TokenPos} 9 | 10 | import scala.collection.mutable 11 | 12 | /** 13 | * 3/6/2020 WilliamZhu(allwefantasy@gmail.com) 14 | */ 15 | class SelectSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val tokenPos: TokenPos) extends StatementSuggester with SuggesterRegister { 16 | 17 | private val subInstances = new mutable.HashMap[String, StatementSuggester]() 18 | register(classOf[ProjectSuggester]) 19 | register(classOf[FromSuggester]) 20 | register(classOf[FilterSuggester]) 21 | register(classOf[JoinSuggester]) 22 | register(classOf[JoinOnSuggester]) 23 | register(classOf[OrderSuggester]) 24 | 25 | override def name: String = "select" 26 | 27 | private lazy val newTokens = LexerUtils.toRawSQLTokens(context, _tokens) 28 | private lazy val TABLE_INFO = mutable.HashMap[Int, mutable.HashMap[MetaTableKeyWrapper, MetaTable]]() 29 | private lazy val selectTree: SingleStatementAST = buildTree() 30 | 31 | def sqlAST = selectTree 32 | 33 | def tokens = newTokens 34 | 35 | def table_info = TABLE_INFO 36 | 37 | override def isMatch(): Boolean = { 38 | _tokens.headOption.map(_.getType) match { 39 | case Some(DSLSQLLexer.SELECT) => true 40 | case _ => false 41 | } 42 | } 43 | 44 | private def buildTree() = { 45 | val root = SingleStatementAST.build(this, newTokens) 46 | import scala.collection.mutable 47 | 48 | root.visitUp(level = 0) { case (ast: SingleStatementAST, level: Int) => 49 | if (!TABLE_INFO.contains(level)) { 50 | TABLE_INFO.put(level, new mutable.HashMap[MetaTableKeyWrapper, MetaTable]()) 51 | } 52 | 53 | if (level != 0 && !TABLE_INFO.contains(level - 1)) { 54 | TABLE_INFO.put(level - 1, new mutable.HashMap[MetaTableKeyWrapper, MetaTable]()) 55 | } 56 | 57 | 58 | ast.tables(newTokens).foreach { item => 59 | if (item.aliasName.isEmpty || item.metaTableKey != MetaTableKey(None, None, null)) { 60 | context.metaProvider.search(item.metaTableKey) match { 61 | case Some(res) => 62 | TABLE_INFO(level) += (item -> res) 63 | case None => 64 | } 65 | } 66 | } 67 | 68 | val nameOpt = ast.name(newTokens) 69 | if (nameOpt.isDefined) { 70 | 71 | val metaTableKey = MetaTableKey(None, None, null) 72 | val metaTableKeyWrapper = MetaTableKeyWrapper(metaTableKey, nameOpt) 73 | val metaColumns = ast.output(newTokens).map { attr => 74 | MetaTableColumn(attr, null, true, Map()) 75 | } 76 | TABLE_INFO(level - 1) += (metaTableKeyWrapper -> MetaTable(metaTableKey, metaColumns)) 77 | } 78 | 79 | } 80 | 81 | if (context.isInDebugMode) { 82 | logInfo(s"SQL[${newTokens.map(_.getText).mkString(" ")}]") 83 | logInfo(s"STRUCTURE: \n") 84 | TABLE_INFO.foreach { item => 85 | logInfo(s"Level:${item._1}") 86 | item._2.foreach { table => 87 | logInfo(s"${table._1} => ${table._2}") 88 | } 89 | } 90 | 91 | } 92 | 93 | root 94 | } 95 | 96 | override def suggest(): List[SuggestItem] = { 97 | var instance: StatementSuggester = null 98 | subInstances.foreach { _instance => 99 | if (instance == null && _instance._2.isMatch()) { 100 | instance = _instance._2 101 | } 102 | } 103 | if (instance == null) List() 104 | else instance.suggest() 105 | 106 | } 107 | 108 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = { 109 | val instance = clzz.getConstructor(classOf[SelectSuggester]).newInstance(this).asInstanceOf[StatementSuggester] 110 | subInstances.put(instance.name, instance) 111 | this 112 | } 113 | } 114 | 115 | 116 | class ProjectSuggester(_selectSuggester: SelectSuggester) extends StatementSuggester with SelectStatementUtils with SuggesterRegister { 117 | 118 | def tokens = _selectSuggester.tokens 119 | 120 | def tokenPos = _selectSuggester.tokenPos 121 | 122 | def selectSuggester = _selectSuggester 123 | 124 | def backAndFirstIs(t: Int, keywords: List[Int] = TokenMatcher.SQL_SPLITTER_KEY_WORDS): Boolean = { 125 | // 能找得到所在的子查询(也可以是最外层) 126 | val ast = getASTFromTokenPos 127 | if (ast.isEmpty) return false 128 | 129 | // 从光标位置去找第一个核心词 130 | val temp = TokenMatcher(tokens, tokenPos.pos).back.orIndex(keywords.map(Food(None, _)).toArray) 131 | if (temp == -1) return false 132 | //第一个核心词必须是是定的词,并且在子查询里 133 | if (tokens(temp).getType == t && temp >= ast.get.start && temp < ast.get.stop) return true 134 | return false 135 | } 136 | 137 | 138 | override def name: String = "project" 139 | 140 | override def isMatch(): Boolean = { 141 | val temp = backAndFirstIs(SqlBaseLexer.SELECT) 142 | if (selectSuggester.context.isInDebugMode) { 143 | logInfo(s"${name} is matched") 144 | } 145 | temp 146 | } 147 | 148 | override def suggest(): List[SuggestItem] = { 149 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ attributeSuggest() ++ functionSuggest(), tokens, tokenPos) 150 | } 151 | 152 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = ??? 153 | } 154 | 155 | class FilterSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) { 156 | 157 | 158 | override def name: String = "filter" 159 | 160 | override def isMatch(): Boolean = { 161 | backAndFirstIs(SqlBaseLexer.WHERE) 162 | 163 | } 164 | 165 | override def suggest(): List[SuggestItem] = { 166 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ attributeSuggest() ++ functionSuggest(), tokens, tokenPos) 167 | } 168 | 169 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = ??? 170 | } 171 | 172 | class JoinOnSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) { 173 | override def name: String = "join_on" 174 | 175 | override def isMatch(): Boolean = { 176 | backAndFirstIs(SqlBaseLexer.ON) 177 | } 178 | 179 | override def suggest(): List[SuggestItem] = { 180 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ attributeSuggest() ++ functionSuggest(), tokens, tokenPos) 181 | } 182 | } 183 | 184 | class JoinSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) { 185 | override def name: String = "join" 186 | 187 | override def isMatch(): Boolean = { 188 | backAndFirstIs(SqlBaseLexer.JOIN) 189 | } 190 | 191 | override def suggest(): List[SuggestItem] = { 192 | LexerUtils.filterPrefixIfNeeded(tableSuggest(), tokens, tokenPos) 193 | } 194 | } 195 | 196 | class FromSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) { 197 | override def name: String = "from" 198 | 199 | override def isMatch(): Boolean = { 200 | backAndFirstIs(SqlBaseLexer.FROM) 201 | } 202 | 203 | override def suggest(): List[SuggestItem] = { 204 | 205 | val tokenPrefix = LexerUtils.tableTokenPrefix(tokens, tokenPos) 206 | val owner = AutoSuggestContext.context().reqParams.getOrElse("owner", "") 207 | val extraParam = Map("searchPrefix" -> tokenPrefix, "owner" -> owner) 208 | 209 | val allTables = _selectSuggester.context.metaProvider.list(extraParam).map { item => 210 | val prefix = (item.key.prefix, item.key.db) match { 211 | case (Some(prefix), Some(db)) => prefix 212 | case (Some(prefix), None) => prefix 213 | case (None, Some(SpecialTableConst.TEMP_TABLE_DB_KEY)) => "temp table" 214 | case (None, Some(db)) => db 215 | } 216 | SuggestItem(item.key.table, item, Map("desc" -> prefix)) 217 | } 218 | LexerUtils.filterPrefixIfNeeded(tableSuggest() ++ allTables, tokens, tokenPos) 219 | } 220 | } 221 | 222 | class OrderSuggester(_selectSuggester: SelectSuggester) extends ProjectSuggester(_selectSuggester) { 223 | override def name: String = "order" 224 | 225 | override def isMatch(): Boolean = { 226 | backAndFirstIs(SqlBaseLexer.ORDER) 227 | } 228 | 229 | override def suggest(): List[SuggestItem] = { 230 | LexerUtils.filterPrefixIfNeeded(attributeSuggest() ++ functionSuggest(), tokens, tokenPos) 231 | } 232 | } 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/StatementSplitter.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | 5 | /** 6 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | trait StatementSplitter { 9 | def split(_tokens: List[Token]): List[List[Token]] 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/StatementSuggester.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import tech.mlsql.autosuggest.meta.MetaTable 4 | import tech.mlsql.common.utils.log.Logging 5 | 6 | /** 7 | * 1/6/2020 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | trait StatementSuggester extends Logging{ 10 | def name: String 11 | 12 | def isMatch(): Boolean 13 | 14 | def suggest(): List[SuggestItem] 15 | 16 | def defaultSuggest(subInstances: Map[String, StatementSuggester]): List[SuggestItem] = { 17 | var instance: StatementSuggester = null 18 | subInstances.foreach { _instance => 19 | if (instance == null && _instance._2.isMatch()) { 20 | instance = _instance._2 21 | } 22 | } 23 | if (instance == null) List() 24 | else instance.suggest() 25 | 26 | } 27 | } 28 | 29 | case class SuggestItem(name: String, metaTable: MetaTable, extra: Map[String, String]) 30 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/StatementUtils.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import streaming.dsl.parser.DSLSQLLexer 5 | import tech.mlsql.autosuggest.TokenPos 6 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher} 7 | 8 | /** 9 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | trait StatementUtils { 12 | 13 | def tokens: List[Token] 14 | 15 | def tokenPos: TokenPos 16 | 17 | def SPLIT_KEY_WORDS = { 18 | List(DSLSQLLexer.OPTIONS, DSLSQLLexer.WHERE, DSLSQLLexer.AS) 19 | } 20 | 21 | def backAndFirstIs(t: Int, keywords: List[Int] = SPLIT_KEY_WORDS): Boolean = { 22 | 23 | 24 | // 从光标位置去找第一个核心词 25 | val temp = TokenMatcher(tokens, tokenPos.pos).back.orIndex(keywords.map(Food(None, _)).toArray) 26 | if (temp == -1) return false 27 | //第一个核心词必须是指定的词 28 | if (tokens(temp).getType == t) return true 29 | return false 30 | } 31 | } 32 | 33 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/SuggesterRegister.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | /** 4 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | trait SuggesterRegister { 7 | def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/TableExtractor.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 5 | import tech.mlsql.autosuggest.AutoSuggestContext 6 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper} 7 | import tech.mlsql.autosuggest.meta.MetaTableKey 8 | 9 | import scala.collection.mutable.ArrayBuffer 10 | 11 | /** 12 | * 4/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class TableExtractor(autoSuggestContext: AutoSuggestContext, ast: SingleStatementAST, tokens: List[Token]) extends MatchAndExtractor[MetaTableKeyWrapper] { 15 | override def matcher(start: Int): TokenMatcher = { 16 | val temp = TokenMatcher(tokens, start). 17 | eat(Food(None, SqlBaseLexer.IDENTIFIER), Food(None, TokenTypeWrapper.DOT)).optional. 18 | eat(Food(None, SqlBaseLexer.IDENTIFIER)). 19 | eat(Food(None, SqlBaseLexer.AS)).optional. 20 | eat(Food(None, SqlBaseLexer.IDENTIFIER)).optional. 21 | build 22 | temp 23 | } 24 | 25 | override def extractor(start: Int, end: Int): List[MetaTableKeyWrapper] = { 26 | val dbTableTokens = tokens.slice(start, end) 27 | val dbTable = dbTableTokens.length match { 28 | case 2 => 29 | val List(tableToken, aliasToken) = dbTableTokens 30 | if(aliasToken.getText.toLowerCase() == "as"){ 31 | MetaTableKeyWrapper(MetaTableKey(None, None, tableToken.getText), Option(aliasToken.getText)) 32 | }else { 33 | MetaTableKeyWrapper(MetaTableKey(None, None, tableToken.getText), None) 34 | } 35 | 36 | case 3 => 37 | val List(dbToken, _, tableToken) = dbTableTokens 38 | MetaTableKeyWrapper(MetaTableKey(None, Option(dbToken.getText), tableToken.getText), None) 39 | case 4 => 40 | val List(dbToken, _, tableToken, aliasToken) = dbTableTokens 41 | MetaTableKeyWrapper(MetaTableKey(None, Option(dbToken.getText), tableToken.getText), Option(aliasToken.getText)) 42 | case 5 => 43 | val List(dbToken, _, tableToken, _, aliasToken) = dbTableTokens 44 | MetaTableKeyWrapper(MetaTableKey(None, Option(dbToken.getText), tableToken.getText), Option(aliasToken.getText)) 45 | case _ => MetaTableKeyWrapper(MetaTableKey(None, None, dbTableTokens.head.getText), None) 46 | } 47 | 48 | List(dbTable) 49 | } 50 | 51 | override def iterate(start: Int, end: Int, limit: Int = 100): List[MetaTableKeyWrapper] = { 52 | val tables = ArrayBuffer[MetaTableKeyWrapper]() 53 | var matchRes = matcher(start) 54 | var whileLimit = limit 55 | while (matchRes.isSuccess && whileLimit > 0) { 56 | tables ++= extractor(matchRes.start, matchRes.get) 57 | whileLimit -= 1 58 | val temp = TokenMatcher(tokens, matchRes.get).eat(Food(None, SqlBaseLexer.T__2)).build 59 | if (temp.isSuccess) { 60 | matchRes = matcher(temp.get) 61 | } else whileLimit = 0 62 | } 63 | 64 | tables.toList 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/TemplateSuggester.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import streaming.dsl.parser.DSLSQLLexer 5 | import tech.mlsql.autosuggest.{AutoSuggestContext, TokenPos} 6 | 7 | import scala.collection.mutable 8 | 9 | /** 10 | * 30/6/2020 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | class TemplateSuggester(val context: AutoSuggestContext, val _tokens: List[Token], val _tokenPos: TokenPos) extends StatementSuggester 13 | with SuggesterRegister { 14 | private val subInstances = new mutable.HashMap[String, StatementSuggester]() 15 | 16 | 17 | override def register(clzz: Class[_ <: StatementSuggester]): SuggesterRegister = { 18 | val instance = clzz.getConstructor(classOf[LoadSuggester]).newInstance(this).asInstanceOf[StatementSuggester] 19 | subInstances.put(instance.name, instance) 20 | this 21 | } 22 | 23 | override def isMatch(): Boolean = { 24 | _tokens.headOption.map(_.getType) match { 25 | case Some(DSLSQLLexer.REGISTER) => true 26 | case _ => false 27 | } 28 | } 29 | 30 | 31 | override def suggest(): List[SuggestItem] = { 32 | List() 33 | } 34 | 35 | 36 | override def name: String = "template" 37 | } 38 | 39 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/statement/single_statement.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.statement 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 5 | import tech.mlsql.autosuggest.AttributeExtractor 6 | import tech.mlsql.autosuggest.ast.NoneToken 7 | import tech.mlsql.autosuggest.dsl.{Food, TokenMatcher, TokenTypeWrapper} 8 | import tech.mlsql.autosuggest.meta.MetaTableKey 9 | 10 | import scala.collection.mutable 11 | import scala.collection.mutable.ArrayBuffer 12 | 13 | case class MetaTableKeyWrapper(metaTableKey: MetaTableKey, aliasName: Option[String]) 14 | 15 | /** 16 | * the atom query statement is only contains: 17 | * select,from,groupby,where,join limit 18 | * Notice that we do not make sure the sql is right 19 | * 20 | */ 21 | class SingleStatementAST(val selectSuggester: SelectSuggester, var start: Int, var stop: Int, var parent: SingleStatementAST) { 22 | val children = ArrayBuffer[SingleStatementAST]() 23 | 24 | def isLeaf = { 25 | children.length == 0 26 | } 27 | 28 | def name(tokens: List[Token]): Option[String] = { 29 | if (parent == null) None 30 | else Option(tokens.slice(start, stop).last.getText) 31 | } 32 | 33 | 34 | // private def isInSubquery(holes: List[(Int, Int)],) = { 35 | // 36 | // } 37 | 38 | def tables(tokens: List[Token]) = { 39 | 40 | // replace token 41 | val range = children.map { ast => 42 | (ast.start, ast.stop) 43 | } 44 | 45 | def inRange(index:Int) = { 46 | range.filter { item => 47 | item._1 <= index && index <= item._2 48 | }.headOption.isDefined 49 | } 50 | 51 | val tokensWithoutSubQuery = tokens.zipWithIndex.map { case (token,index) => 52 | if (inRange(index)) new NoneToken(token) 53 | else token 54 | } 55 | 56 | 57 | // collect table first 58 | // T__3 == . 59 | // extract from `from` 60 | val fromTables = new TableExtractor(selectSuggester.context, this, tokensWithoutSubQuery) 61 | val fromStart = TokenMatcher(tokensWithoutSubQuery.slice(0, stop), start).asStart(Food(None, SqlBaseLexer.FROM), 1).start 62 | val tempTables = fromTables.iterate(fromStart, tokensWithoutSubQuery.size) 63 | 64 | // extract from `join` 65 | val joinTables = new TableExtractor(selectSuggester.context, ast = this, tokensWithoutSubQuery) 66 | val joinStart = TokenMatcher(tokensWithoutSubQuery.slice(0, stop), start).asStart(Food(None, SqlBaseLexer.JOIN), offset = 1).start 67 | val tempJoinTables = joinTables.iterate(joinStart, tokens.size) 68 | 69 | // extract subquery name 70 | val subqueryTables = children.map(_.name(tokens).get).map { name => 71 | MetaTableKeyWrapper(MetaTableKey(None, None, null), Option(name)) 72 | }.toList 73 | 74 | tempTables ++ tempJoinTables ++ subqueryTables 75 | } 76 | 77 | def level = { 78 | var count = 0 79 | var temp = this.parent 80 | while (temp != null) { 81 | temp = temp.parent 82 | count += 1 83 | } 84 | count 85 | } 86 | 87 | 88 | def output(tokens: List[Token]): List[String] = { 89 | val selectStart = TokenMatcher(tokens.slice(0, stop), start).asStart(Food(None, SqlBaseLexer.SELECT), 1).start 90 | val extractor = new AttributeExtractor(selectSuggester.context, this, tokens) 91 | extractor.iterate(selectStart, tokens.size) 92 | } 93 | 94 | def visitDown(level: Int)(rule: PartialFunction[(SingleStatementAST, Int), Unit]): Unit = { 95 | rule.apply((this, level)) 96 | this.children.map(_.visitDown(level + 1)(rule)) 97 | } 98 | 99 | def visitUp(level: Int)(rule: PartialFunction[(SingleStatementAST, Int), Unit]): Unit = { 100 | this.children.map(_.visitUp(level + 1)(rule)) 101 | rule.apply((this, level)) 102 | } 103 | 104 | def fastEquals(other: SingleStatementAST): Boolean = { 105 | this.eq(other) || this == other 106 | } 107 | 108 | def printAsStr(_tokens: List[Token], _level: Int): String = { 109 | val tokens = _tokens.slice(start, stop) 110 | val stringBuilder = new mutable.StringBuilder() 111 | var count = 1 112 | stringBuilder.append(tokens.map { item => 113 | count += 1 114 | val suffix = if (count % 20 == 0) "\n" else "" 115 | item.getText + suffix 116 | }.mkString(" ")) 117 | stringBuilder.append("\n") 118 | stringBuilder.append("\n") 119 | stringBuilder.append("\n") 120 | children.zipWithIndex.foreach { case (item, index) => stringBuilder.append("=" * (_level + 1) + ">" + item.printAsStr(_tokens, _level + 1)) } 121 | stringBuilder.toString() 122 | } 123 | } 124 | 125 | 126 | object SingleStatementAST { 127 | 128 | def matchTableAlias(tokens: List[Token], start: Int) = { 129 | tokens(start) 130 | } 131 | 132 | def build(selectSuggester: SelectSuggester, tokens: List[Token]) = { 133 | _build(selectSuggester, tokens, 0, tokens.size, false) 134 | } 135 | 136 | def _build(selectSuggester: SelectSuggester, tokens: List[Token], start: Int, stop: Int, isSub: Boolean = false): SingleStatementAST = { 137 | val ROOT = new SingleStatementAST(selectSuggester, start, stop, null) 138 | // context start: ( select 139 | // context end: ) 140 | 141 | var bracketStart = 0 142 | var jumpIndex = -1 143 | for (index <- (start until stop) if index >= jumpIndex) { 144 | val token = tokens(index) 145 | if (token.getType == TokenTypeWrapper.LEFT_BRACKET && index < stop - 1 && tokens(index + 1).getType == SqlBaseLexer.SELECT) { 146 | // println(s"enter: ${tokens.slice(index, index + 5).map(_.getText).mkString(" ")}") 147 | val item = SingleStatementAST._build(selectSuggester, tokens, index + 1, stop, true) 148 | jumpIndex = item.stop 149 | ROOT.children += item 150 | item.parent = ROOT 151 | 152 | } else { 153 | if (isSub) { 154 | if (token.getType == TokenTypeWrapper.LEFT_BRACKET) { 155 | bracketStart += 1 156 | } 157 | if (token.getType == TokenTypeWrapper.RIGHT_BRACKET && bracketStart != 0) { 158 | bracketStart -= 1 159 | } 160 | else if (token.getType == TokenTypeWrapper.RIGHT_BRACKET && bracketStart == 0) { 161 | // check the alias 162 | val matcher = TokenMatcher(tokens, index + 1).eat(Food(None, SqlBaseLexer.AS)).optional.eat(Food(None, SqlBaseLexer.IDENTIFIER)).build 163 | val stepSize = if (matcher.isSuccess) matcher.get else index 164 | ROOT.start = start 165 | ROOT.stop = stepSize 166 | // println(s"out: ${tokens.slice(stepSize - 5, stepSize).map(_.getText).mkString(" ")}") 167 | return ROOT 168 | } 169 | 170 | } else { 171 | // do nothing 172 | } 173 | } 174 | 175 | } 176 | ROOT 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/autosuggest/utils/SchemaUtils.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.autosuggest.utils 2 | 3 | import org.apache.spark.sql.types.StructType 4 | import tech.mlsql.autosuggest.meta.{MetaTable, MetaTableColumn, MetaTableKey} 5 | 6 | /** 7 | * 15/6/2020 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | object SchemaUtils { 10 | def toMetaTable(table: MetaTableKey, st: StructType) = { 11 | val columns = st.fields.map { item => 12 | MetaTableColumn(item.name, item.dataType.typeName, item.nullable, Map()) 13 | }.toList 14 | MetaTable(table, columns) 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Set everything to be logged to the console 18 | log4j.rootCategory=INFO, console,file 19 | log4j.appender.console=org.apache.log4j.ConsoleAppender 20 | log4j.appender.console.target=System.err 21 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 22 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %X{owner} %p %c{1}: %m%n 23 | log4j.appender.file=org.apache.log4j.RollingFileAppender 24 | log4j.appender.file.File=./logs/mlsql_engine.log 25 | log4j.appender.file.rollingPolicy=org.apache.log4j.rolling.TimeBasedRollingPolicy 26 | log4j.appender.file.rollingPolicy.fileNamePattern=./logs/mlsql_engine.%d.gz 27 | log4j.appender.file.layout=org.apache.log4j.PatternLayout 28 | log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %X{owner} %p %c{1}: %m%n 29 | log4j.appender.file.MaxBackupIndex=5 30 | # Set the default spark-shell log level to WARN. When running the spark-shell, the 31 | # log level for this class is used to overwrite the root logger's log level, so that 32 | # the user can have different defaults for the shell and regular Spark apps. 33 | log4j.logger.org.apache.spark=WARN 34 | #log4j.logger.org.apache.spark=WARN 35 | # Settings to quiet third party logs that are too verbose 36 | log4j.logger.org.spark_project.jetty=WARN 37 | log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR 38 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 39 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 40 | log4j.logger.org.apache.parquet=ERROR 41 | log4j.logger.org.apache.spark.ContextCleaner=ERROR 42 | log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=ERROR 43 | log4j.logger.parquet=ERROR 44 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support 45 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL 46 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/AutoSuggestContextTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.scalatest.BeforeAndAfterEach 5 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey} 6 | import tech.mlsql.autosuggest.statement.{LexerUtils, SuggestItem} 7 | import tech.mlsql.autosuggest.{DataType, SpecialTableConst, TokenPos, TokenPosType} 8 | 9 | import scala.collection.JavaConverters._ 10 | 11 | /** 12 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class AutoSuggestContextTest extends BaseTest with BeforeAndAfterEach { 15 | override def afterEach(): Unit = { 16 | // context.statements.clear() 17 | } 18 | 19 | test("parse") { 20 | val wow = context.lexer.tokenizeNonDefaultChannel( 21 | """ 22 | | -- yes 23 | | load hive.`` as -- jack 24 | | table1; 25 | | select * from table1 as table2; 26 | |""".stripMargin).tokens.asScala.toList 27 | context.build(wow) 28 | 29 | assert(context.statements.size == 2) 30 | 31 | } 32 | test("parse partial") { 33 | val wow = context.lexer.tokenizeNonDefaultChannel( 34 | """ 35 | | -- yes 36 | | load hive.`` as -- jack 37 | | table1; 38 | | select * from table1 39 | |""".stripMargin).tokens.asScala.toList 40 | context.build(wow) 41 | printStatements(context.statements) 42 | assert(context.statements.size == 2) 43 | } 44 | 45 | def printStatements(items: List[List[Token]]) = { 46 | items.foreach { item => 47 | println(item.map(_.getText).mkString(" ")) 48 | println() 49 | } 50 | } 51 | 52 | test("relative pos convert") { 53 | val wow = context.lexer.tokenizeNonDefaultChannel( 54 | """ 55 | | -- yes 56 | | load hive.`` as -- jack 57 | | table1; 58 | | select * from table1 59 | |""".stripMargin).tokens.asScala.toList 60 | context.build(wow) 61 | 62 | assert(context.statements.size == 2) 63 | // select * f[cursor]rom table1 64 | val tokenPos = LexerUtils.toTokenPos(wow, 5, 11) 65 | assert(tokenPos == TokenPos(9, TokenPosType.CURRENT, 1)) 66 | assert(context.toRelativePos(tokenPos)._1 == TokenPos(2, TokenPosType.CURRENT, 1)) 67 | } 68 | 69 | test("keyword") { 70 | val wow = context.lexer.tokenizeNonDefaultChannel( 71 | """ 72 | | -- yes 73 | | loa 74 | |""".stripMargin).tokens.asScala.toList 75 | context.build(wow) 76 | val tokenPos = LexerUtils.toTokenPos(wow, 3, 4) 77 | assert(tokenPos == TokenPos(0, TokenPosType.CURRENT, 3)) 78 | assert(context.suggest(3, 4) == List(SuggestItem("load", SpecialTableConst.KEY_WORD_TABLE, Map()))) 79 | } 80 | 81 | test("spark sql") { 82 | val wow = context.rawSQLLexer.tokenizeNonDefaultChannel( 83 | """ 84 | |SELECT CAST(25.65 AS int) from jack; 85 | |""".stripMargin).tokens.asScala.toList 86 | 87 | wow.foreach(item => println(s"${item.getText} ${item.getType}")) 88 | } 89 | 90 | test("load/select 4/10 select ke[cursor] from") { 91 | val wow = 92 | """ 93 | | -- yes 94 | | load hive.`jack.db` as table1; 95 | | select ke from (select keywords,search_num,c from table1) table2 96 | |""".stripMargin 97 | val items = context.buildFromString(wow).suggest(4, 10) 98 | assert(items.map(_.name) == List("keywords")) 99 | } 100 | 101 | test("load/select 4/22 select from (select [cursor]keywords") { 102 | context.setUserDefinedMetaProvider(new MetaProvider { 103 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 104 | val key = MetaTableKey(None, None, "table1") 105 | val value = Option(MetaTable( 106 | key, List( 107 | MetaTableColumn("keywords", DataType.STRING, true, Map()), 108 | MetaTableColumn("search_num", DataType.STRING, true, Map()), 109 | MetaTableColumn("c", DataType.STRING, true, Map()), 110 | MetaTableColumn("d", DataType.STRING, true, Map()) 111 | ) 112 | )) 113 | value 114 | } 115 | 116 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List() 117 | }) 118 | val wow = context.lexer.tokenizeNonDefaultChannel( 119 | """ 120 | | -- yes 121 | | load hive.`jack.db` as table1; 122 | | select from (select keywords,search_num,c from table1) table2 123 | |""".stripMargin).tokens.asScala.toList 124 | val items = context.build(wow).suggest(4, 8) 125 | // items.foreach(println(_)) 126 | assert(items.map(_.name) == List("table2", "keywords", "search_num", "c")) 127 | 128 | } 129 | 130 | test("load/select table with star") { 131 | context.setUserDefinedMetaProvider(new MetaProvider { 132 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 133 | if (key.prefix == Option("hive")) { 134 | Option(MetaTable(key, List( 135 | MetaTableColumn("a", DataType.STRING, true, Map()), 136 | MetaTableColumn("b", DataType.STRING, true, Map()), 137 | MetaTableColumn("c", DataType.STRING, true, Map()), 138 | MetaTableColumn("d", DataType.STRING, true, Map()) 139 | ))) 140 | } else None 141 | } 142 | 143 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ??? 144 | }) 145 | val wow = context.lexer.tokenizeNonDefaultChannel( 146 | """ 147 | | -- yes 148 | | load hive.`db.table1` as table2; 149 | | select * from table2 as table3; 150 | | select from table3 151 | |""".stripMargin).tokens.asScala.toList 152 | val items = context.build(wow).suggest(5, 8) 153 | println(items) 154 | 155 | } 156 | 157 | test("load/select table with star and func") { 158 | context.setDebugMode(true) 159 | context.setUserDefinedMetaProvider(new MetaProvider { 160 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 161 | if (key.prefix == Option("hive")) { 162 | Option(MetaTable(key, List( 163 | MetaTableColumn("a", DataType.STRING, true, Map()), 164 | MetaTableColumn("b", DataType.STRING, true, Map()), 165 | MetaTableColumn("c", DataType.STRING, true, Map()), 166 | MetaTableColumn("d", DataType.STRING, true, Map()) 167 | ))) 168 | } else None 169 | } 170 | 171 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ??? 172 | }) 173 | val sql = 174 | """ 175 | | -- yes 176 | | load hive.`db.table1` as table2; 177 | | select * from table2 as table3; 178 | | select sum() from table3 179 | |""".stripMargin 180 | val items = context.buildFromString(sql).suggest(5, 12) 181 | println(items) 182 | 183 | } 184 | test("table alias with temp table") { 185 | val sql = 186 | """ 187 | |select a,b,c from table1 as table1; 188 | |select aa,bb,cc from table2 as table2; 189 | |select from table1 t1 left join table2 t2 on t1.a = t2. 190 | |""".stripMargin 191 | 192 | val items = context.buildFromString(sql).suggest(4, 58) 193 | assert(items.map(_.name) == List("aa", "bb", "cc")) 194 | 195 | } 196 | } 197 | 198 | 199 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/BaseTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import org.antlr.v4.runtime.Token 4 | import org.apache.spark.sql.SparkSession 5 | import org.apache.spark.sql.catalyst.parser.{SqlBaseLexer, SqlBaseParser} 6 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} 7 | import streaming.dsl.parser.{DSLSQLLexer, DSLSQLParser} 8 | import tech.mlsql.autosuggest.{AutoSuggestContext, MLSQLSQLFunction} 9 | 10 | import scala.collection.JavaConverters._ 11 | 12 | /** 13 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com) 14 | */ 15 | class BaseTest extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { 16 | val lexerAndParserfactory = new ReflectionLexerAndParserFactory(classOf[DSLSQLLexer], classOf[DSLSQLParser]); 17 | val loadLexer = new LexerWrapper(lexerAndParserfactory, new DefaultToCharStream) 18 | 19 | val lexerAndParserfactory2 = new ReflectionLexerAndParserFactory(classOf[SqlBaseLexer], classOf[SqlBaseParser]); 20 | val rawSQLloadLexer = new LexerWrapper(lexerAndParserfactory2, new RawSQLToCharStream) 21 | 22 | var context: AutoSuggestContext = _ 23 | var tokens: List[Token] = _ 24 | var sparkSession: SparkSession = _ 25 | 26 | override def beforeAll(): Unit = { 27 | sparkSession = SparkSession.builder().appName("local").master("local[*]").getOrCreate() 28 | 29 | } 30 | 31 | override def afterAll(): Unit = { 32 | context.session.close() 33 | } 34 | 35 | 36 | override def beforeEach(): Unit = { 37 | MLSQLSQLFunction.funcMetaProvider.clear 38 | context = new AutoSuggestContext(sparkSession, loadLexer, rawSQLloadLexer) 39 | context.setDebugMode(true) 40 | var tr = loadLexer.tokenizeNonDefaultChannel( 41 | """ 42 | | -- yes 43 | | load hive.`` as -- jack 44 | | table1; 45 | |""".stripMargin) 46 | tokens = tr.tokens.asScala.toList 47 | } 48 | 49 | def getMLSQLTokens(sql: String) = { 50 | context.lexer.tokenizeNonDefaultChannel(sql).tokens.asScala.toList 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/BaseTestWithoutSparkSession.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} 4 | import tech.mlsql.autosuggest.app.AutoSuggestController 5 | import tech.mlsql.autosuggest.{AutoSuggestContext, MLSQLSQLFunction} 6 | 7 | /** 8 | * 11/6/2020 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class BaseTestWithoutSparkSession extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { 11 | 12 | var context: AutoSuggestContext = _ 13 | 14 | 15 | override def beforeAll(): Unit = { 16 | 17 | } 18 | 19 | override def afterAll(): Unit = { 20 | context.session.close() 21 | } 22 | 23 | 24 | override def beforeEach(): Unit = { 25 | MLSQLSQLFunction.funcMetaProvider.clear 26 | context = new AutoSuggestContext(null, AutoSuggestController.mlsqlLexer, AutoSuggestController.sqlLexer) 27 | context.setDebugMode(true) 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/LexerUtilsTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import tech.mlsql.autosuggest.statement.LexerUtils 4 | import tech.mlsql.autosuggest.{TokenPos, TokenPosType} 5 | 6 | /** 7 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | class LexerUtilsTest extends BaseTest { 10 | 11 | test(" load [cursor]hive.`` as -- jack") { 12 | assert(LexerUtils.toTokenPos(tokens, 3, 6) == TokenPos(0, TokenPosType.NEXT, 0)) 13 | 14 | } 15 | test(" load h[cursor]ive.`` as -- jack") { 16 | assert(LexerUtils.toTokenPos(tokens.toList, 3, 7) == TokenPos(1, TokenPosType.CURRENT, 1)) 17 | } 18 | 19 | test("[cursor] load hive.`` as -- jack") { 20 | assert(LexerUtils.toTokenPos(tokens.toList, 3, 0) == TokenPos(-1, TokenPosType.NEXT, 0)) 21 | } 22 | test(" load hive.`` as -- jack [cursor]") { 23 | assert(LexerUtils.toTokenPos(tokens.toList, 3, 23) == TokenPos(4, TokenPosType.NEXT, 0)) 24 | } 25 | 26 | test("select sum([cursor]) as t from table") { 27 | context.buildFromString("select sum() as t from table") 28 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 11) == TokenPos(2, TokenPosType.NEXT, 0)) 29 | } 30 | 31 | test("select from (select table2.abc as abc from table1 left join table2 on table1.column1 == table2.[cursor]) t1") { 32 | context.buildFromString("select from (select table2.abc as abc from table1 left join table2 on table1.column1 == table2.) t1") 33 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 96) == TokenPos(21, TokenPosType.NEXT, 0)) 34 | } 35 | 36 | test("select sum(abc[cursor]) as t from table") { 37 | context.buildFromString("select sum(abc) as t from table") 38 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 14) == TokenPos(3, TokenPosType.CURRENT, 3)) 39 | } 40 | 41 | test("load csv.") { 42 | context.buildFromString("load csv.") 43 | assert(LexerUtils.toTokenPos(context.rawTokens, 1, 9) == TokenPos(2, TokenPosType.NEXT, 0)) 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/LoadSuggesterTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import tech.mlsql.autosuggest.statement.LoadSuggester 4 | import tech.mlsql.autosuggest.{TokenPos, TokenPosType} 5 | 6 | import scala.collection.JavaConverters._ 7 | 8 | /** 9 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class LoadSuggesterTest extends BaseTest { 12 | test("load hiv[cursor]") { 13 | val wow = context.lexer.tokenizeNonDefaultChannel( 14 | """ 15 | | -- yes 16 | | load hiv 17 | |""".stripMargin).tokens.asScala.toList 18 | val loadSuggester = new LoadSuggester(context, wow, TokenPos(1, TokenPosType.CURRENT, 3)).suggest() 19 | assert(loadSuggester.map(_.name) == List("hive")) 20 | } 21 | 22 | test("load [cursor]") { 23 | val wow = context.lexer.tokenizeNonDefaultChannel( 24 | """ 25 | | -- yes 26 | | load 27 | |""".stripMargin).tokens.asScala.toList 28 | val loadSuggester = new LoadSuggester(context, wow, TokenPos(0, TokenPosType.NEXT, 0)).suggest() 29 | println(loadSuggester) 30 | assert(loadSuggester.size > 1) 31 | } 32 | 33 | test("load csv.`` where [cursor]") { 34 | val wow = context.lexer.tokenizeNonDefaultChannel( 35 | """ 36 | | -- yes 37 | | load csv.`` where 38 | |""".stripMargin).tokens.asScala.toList 39 | val result = new LoadSuggester(context, wow, TokenPos(4, TokenPosType.NEXT, 0)).suggest() 40 | println(result) 41 | 42 | } 43 | 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/MatchTokenTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import org.apache.spark.sql.catalyst.parser.SqlBaseLexer 4 | import streaming.dsl.parser.DSLSQLLexer 5 | import tech.mlsql.autosuggest.dsl.{Food, MLSQLTokenTypeWrapper, TokenMatcher} 6 | import tech.mlsql.autosuggest.statement.LexerUtils 7 | 8 | import scala.collection.JavaConverters._ 9 | 10 | /** 11 | * 9/6/2020 WilliamZhu(allwefantasy@gmail.com) 12 | */ 13 | class MatchTokenTest extends BaseTest { 14 | test("orIndex back") { 15 | val wow = context.lexer.tokenizeNonDefaultChannel( 16 | """ 17 | |select a.k from jack.drugs_bad_case_di as a 18 | |""".stripMargin).tokens.asScala.toList 19 | 20 | val tokens = LexerUtils.toRawSQLTokens(context, wow) 21 | val temp = TokenMatcher(tokens, 6).back.orIndex(Array(Food(None, SqlBaseLexer.FROM), Food(None, SqlBaseLexer.SELECT))) 22 | assert(temp == 4) 23 | } 24 | 25 | test("orIndex forward") { 26 | val wow = context.lexer.tokenizeNonDefaultChannel( 27 | """ 28 | |select a.k from jack.drugs_bad_case_di as a 29 | |""".stripMargin).tokens.asScala.toList 30 | 31 | val tokens = LexerUtils.toRawSQLTokens(context, wow) 32 | val temp = TokenMatcher(tokens, 0).forward.orIndex(Array(Food(None, SqlBaseLexer.FROM), Food(None, SqlBaseLexer.SELECT))) 33 | assert(temp == 0) 34 | } 35 | 36 | test("forward out of bound success") { 37 | val wow = context.lexer.tokenizeNonDefaultChannel( 38 | """ 39 | |load csv. 40 | |""".stripMargin).tokens.asScala.toList 41 | 42 | val temp = TokenMatcher(wow, 0) 43 | .forward 44 | .eat(Food(None, DSLSQLLexer.LOAD)) 45 | .eat(Food(None, DSLSQLLexer.IDENTIFIER)) 46 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT)).build 47 | assert(temp.isSuccess) 48 | } 49 | test("forward out of bound fail") { 50 | val wow = context.lexer.tokenizeNonDefaultChannel( 51 | """ 52 | |load csv 53 | |""".stripMargin).tokens.asScala.toList 54 | 55 | val temp = TokenMatcher(wow, 0) 56 | .forward 57 | .eat(Food(None, DSLSQLLexer.LOAD)) 58 | .eat(Food(None, DSLSQLLexer.IDENTIFIER)) 59 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT)).build 60 | assert(!temp.isSuccess) 61 | } 62 | 63 | test("back out of bound success") { 64 | val wow = context.lexer.tokenizeNonDefaultChannel( 65 | """ 66 | |load csv. 67 | |""".stripMargin).tokens.asScala.toList 68 | 69 | val temp = TokenMatcher(wow, 2) 70 | .back 71 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT)) 72 | .eat(Food(None, DSLSQLLexer.IDENTIFIER)) 73 | .eat(Food(None, DSLSQLLexer.LOAD)).build 74 | assert(temp.isSuccess) 75 | } 76 | test("back out of bound fail") { 77 | val wow = context.lexer.tokenizeNonDefaultChannel( 78 | """ 79 | |csv. 80 | |""".stripMargin).tokens.asScala.toList 81 | 82 | val temp = TokenMatcher(wow, 1) 83 | .back 84 | .eat(Food(None, MLSQLTokenTypeWrapper.DOT)) 85 | .eat(Food(None, DSLSQLLexer.IDENTIFIER)) 86 | .eat(Food(None, DSLSQLLexer.LOAD)).build 87 | assert(!temp.isSuccess) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/SelectSuggesterTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey} 4 | import tech.mlsql.autosuggest.statement.SelectSuggester 5 | import tech.mlsql.autosuggest.{DataType, MLSQLSQLFunction, TokenPos, TokenPosType} 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | /** 10 | * 2/6/2020 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | class SelectSuggesterTest extends BaseTest { 13 | 14 | def buildMetaProvider = { 15 | context.setUserDefinedMetaProvider(new MetaProvider { 16 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 17 | Option(MetaTable(key, List( 18 | MetaTableColumn("no_result_type", null, true, Map()), 19 | MetaTableColumn("keywords", null, true, Map()), 20 | MetaTableColumn("search_num", null, true, Map()), 21 | MetaTableColumn("hp_stat_date", null, true, Map()), 22 | MetaTableColumn("action_dt", null, true, Map()), 23 | MetaTableColumn("action_type", null, true, Map()), 24 | MetaTableColumn("av", null, true, Map()) 25 | ))) 26 | 27 | } 28 | 29 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List() 30 | }) 31 | 32 | 33 | } 34 | 35 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 36 | """ 37 | |select no_result_type, keywords, search_num, rank 38 | |from( 39 | | select no_result_type, keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank 40 | | from( 41 | | select no_result_type, keywords, sum(search_num) AS search_num 42 | | from jack.drugs_bad_case_di,jack.abc jack 43 | | where hp_stat_date >= date_sub(current_date,30) 44 | | and action_dt >= date_sub(current_date,30) 45 | | and action_type = 'search' 46 | | and length(keywords) > 1 47 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9)) 48 | | --and no_result_type = 'indication' 49 | | group by no_result_type, keywords 50 | | )a 51 | |)b 52 | |where rank <= 53 | |""".stripMargin).tokens.asScala.toList 54 | 55 | test("select") { 56 | 57 | context.setUserDefinedMetaProvider(new MetaProvider { 58 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = None 59 | 60 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List() 61 | }) 62 | 63 | 64 | val suggester = new SelectSuggester(context, wow, TokenPos(0, TokenPosType.NEXT, 0)) 65 | suggester.suggest() 66 | 67 | 68 | } 69 | 70 | test("project: complex attribute suggest") { 71 | 72 | buildMetaProvider 73 | 74 | lazy val wow2 = context.lexer.tokenizeNonDefaultChannel( 75 | """ 76 | |select key no_result_type, keywords, search_num, rank 77 | |from( 78 | | select keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank 79 | | from( 80 | | select *,no_result_type, keywords, sum(search_num) AS search_num 81 | | from jack.drugs_bad_case_di,jack.abc jack 82 | | where hp_stat_date >= date_sub(current_date,30) 83 | | and action_dt >= date_sub(current_date,30) 84 | | and action_type = 'search' 85 | | and length(keywords) > 1 86 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9)) 87 | | --and no_result_type = 'indication' 88 | | group by no_result_type, keywords 89 | | )a 90 | |)b 91 | |where rank <= 92 | |""".stripMargin).tokens.asScala.toList 93 | 94 | val suggester = new SelectSuggester(context, wow2, TokenPos(1, TokenPosType.CURRENT, 2)) 95 | assert(suggester.suggest().map(_.name) == List("keywords")) 96 | 97 | } 98 | 99 | test("project: second level select ") { 100 | 101 | buildMetaProvider 102 | 103 | lazy val wow2 = context.lexer.tokenizeNonDefaultChannel( 104 | """ 105 | |select key no_result_type, keywords, search_num, rank 106 | |from( 107 | | select sea keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank 108 | | from( 109 | | select *,no_result_type, keywords, sum(search_num) AS search_num 110 | | from jack.drugs_bad_case_di,jack.abc jack 111 | | where hp_stat_date >= date_sub(current_date,30) 112 | | and action_dt >= date_sub(current_date,30) 113 | | and action_type = 'search' 114 | | and length(keywords) > 1 115 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9)) 116 | | --and no_result_type = 'indication' 117 | | group by no_result_type, keywords 118 | | )a 119 | |)b 120 | |where rank <= 121 | |""".stripMargin).tokens.asScala.toList 122 | 123 | // wow2.zipWithIndex.foreach{case (token,index)=> 124 | // println(s"${index} $token")} 125 | val suggester = new SelectSuggester(context, wow2, TokenPos(12, TokenPosType.CURRENT, 3)) 126 | assert(suggester.suggest().distinct.map(_.name) == List("search_num")) 127 | 128 | } 129 | 130 | test("project: single query with alias table name") { 131 | 132 | buildMetaProvider 133 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 134 | """ 135 | |select a.k from jack.drugs_bad_case_di as a 136 | |""".stripMargin).tokens.asScala.toList 137 | 138 | val suggester = new SelectSuggester(context, wow, TokenPos(3, TokenPosType.CURRENT, 1)) 139 | 140 | assert(suggester.suggest().map(_.name) == List(("keywords"))) 141 | } 142 | 143 | test("project: table or attribute") { 144 | 145 | buildMetaProvider 146 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 147 | """ 148 | |select from jack.drugs_bad_case_di as a 149 | |""".stripMargin).tokens.asScala.toList 150 | 151 | val suggester = new SelectSuggester(context, wow, TokenPos(0, TokenPosType.NEXT, 0)) 152 | 153 | assert(suggester.suggest().map(_.name) == List(("a"), 154 | ("no_result_type"), 155 | ("keywords"), 156 | ("search_num"), 157 | ("hp_stat_date"), 158 | ("action_dt"), 159 | ("action_type"), 160 | ("av"))) 161 | } 162 | 163 | test("project: complex table attribute ") { 164 | 165 | buildMetaProvider 166 | 167 | lazy val wow2 = context.lexer.tokenizeNonDefaultChannel( 168 | """ 169 | |select no_result_type, keywords, search_num, rank 170 | |from( 171 | | select keywords, search_num, row_number() over (PARTITION BY no_result_type order by search_num desc) as rank 172 | | from( 173 | | select *,no_result_type, keywords, sum(search_num) AS search_num 174 | | from jack.drugs_bad_case_di,jack.abc jack 175 | | where hp_stat_date >= date_sub(current_date,30) 176 | | and action_dt >= date_sub(current_date,30) 177 | | and action_type = 'search' 178 | | and length(keywords) > 1 179 | | and (split(av, '\\.')[0] >= 11 OR (split(av, '\\.')[0] = 10 AND split(av, '\\.')[1] = 9)) 180 | | --and no_result_type = 'indication' 181 | | group by no_result_type, keywords 182 | | )a 183 | |)b 184 | |where rank <= 185 | |""".stripMargin).tokens.asScala.toList 186 | 187 | val suggester = new SelectSuggester(context, wow2, TokenPos(0, TokenPosType.NEXT, 0)) 188 | println(suggester.sqlAST.printAsStr(suggester.tokens, 0)) 189 | suggester.table_info.foreach { case (level, item) => 190 | println(level + ":") 191 | println(item.map(_._1).toList) 192 | } 193 | assert(suggester.suggest().map(_.name) == List(("b"), ("keywords"), ("search_num"), ("rank"))) 194 | 195 | } 196 | 197 | test("table layer") { 198 | buildMetaProvider 199 | val sql = 200 | """ 201 | |select from (select no_result_type from db1.table1) b; 202 | |""".stripMargin 203 | val tokens = getMLSQLTokens(sql) 204 | 205 | val suggester = new SelectSuggester(context, tokens, TokenPos(0, TokenPosType.NEXT, 0)) 206 | println("=======") 207 | println(suggester.suggest()) 208 | assert(suggester.suggest().head.name=="b") 209 | } 210 | 211 | 212 | 213 | test("project: function suggester") { 214 | 215 | val func = MLSQLSQLFunction.apply("split"). 216 | funcParam. 217 | param("str", DataType.STRING). 218 | param("splitter", DataType.STRING). 219 | func. 220 | returnParam(DataType.ARRAY, true, Map()). 221 | build 222 | MLSQLSQLFunction.funcMetaProvider.register(func) 223 | 224 | val tableKey = MetaTableKey(None, Option("jack"), "drugs_bad_case_di") 225 | 226 | val metas = Map(tableKey -> 227 | Option(MetaTable(tableKey, List( 228 | MetaTableColumn("no_result_type", null, true, Map()), 229 | MetaTableColumn("keywords", null, true, Map()), 230 | MetaTableColumn("search_num", null, true, Map()), 231 | MetaTableColumn("hp_stat_date", null, true, Map()), 232 | MetaTableColumn("action_dt", null, true, Map()), 233 | MetaTableColumn("action_type", null, true, Map()), 234 | MetaTableColumn("av", null, true, Map()) 235 | ))) 236 | 237 | ) 238 | context.setUserDefinedMetaProvider(new MetaProvider { 239 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 240 | metas(key) 241 | } 242 | 243 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List() 244 | }) 245 | 246 | 247 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 248 | """ 249 | |select spl from jack.drugs_bad_case_di as a 250 | |""".stripMargin).tokens.asScala.toList 251 | 252 | val suggester = new SelectSuggester(context, wow, TokenPos(1, TokenPosType.CURRENT, 3)) 253 | println(suggester.suggest()) 254 | assert(suggester.suggest().map(_.name) == List(("split"))) 255 | } 256 | 257 | 258 | 259 | 260 | } 261 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/TablePreprocessorTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey} 4 | import tech.mlsql.autosuggest.preprocess.TablePreprocessor 5 | import tech.mlsql.autosuggest.{DataType, SpecialTableConst} 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | /** 10 | * 10/6/2020 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | class TablePreprocessorTest extends BaseTest { 13 | 14 | test("load/select table") { 15 | context.setUserDefinedMetaProvider(new MetaProvider { 16 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = { 17 | if (key.prefix == Option("hive")) { 18 | Option(MetaTable(key, List( 19 | MetaTableColumn("a", DataType.STRING, true, Map()), 20 | MetaTableColumn("b", DataType.STRING, true, Map()), 21 | MetaTableColumn("c", DataType.STRING, true, Map()), 22 | MetaTableColumn("d", DataType.STRING, true, Map()) 23 | ))) 24 | } else None 25 | } 26 | 27 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ??? 28 | }) 29 | val wow = context.lexer.tokenizeNonDefaultChannel( 30 | """ 31 | | -- yes 32 | | load hive.`db.table1` as table2; 33 | | select a,b,c from table2 as table3; 34 | |""".stripMargin).tokens.asScala.toList 35 | context.build(wow) 36 | val processor = new TablePreprocessor(context) 37 | context.statements.foreach(processor.process(_)) 38 | 39 | val targetTable = context.tempTableProvider.search(SpecialTableConst.tempTable("table2").key).get 40 | assert(targetTable.key == MetaTableKey(Some("hive"), Some("db"), "table1")) 41 | 42 | val targetTable3 = context.tempTableProvider.search(SpecialTableConst.tempTable("table3").key).get 43 | assert(targetTable3 == MetaTable(MetaTableKey(None, Some(SpecialTableConst.TEMP_TABLE_DB_KEY), "table3"), 44 | List(MetaTableColumn("a", null, true, Map()), MetaTableColumn("b", null, true, Map()), MetaTableColumn("c", null, true, Map())))) 45 | } 46 | 47 | test("load/select table with star") { 48 | context.setUserDefinedMetaProvider(new MetaProvider { 49 | override def search(key: MetaTableKey,extra: Map[String, String] = Map()): Option[MetaTable] = { 50 | if (key.prefix == Option("hive")) { 51 | Option(MetaTable(key, List( 52 | MetaTableColumn("a", DataType.STRING, true, Map()), 53 | MetaTableColumn("b", DataType.STRING, true, Map()), 54 | MetaTableColumn("c", DataType.STRING, true, Map()), 55 | MetaTableColumn("d", DataType.STRING, true, Map()) 56 | ))) 57 | } else None 58 | } 59 | 60 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = ??? 61 | }) 62 | val wow = context.lexer.tokenizeNonDefaultChannel( 63 | """ 64 | | -- yes 65 | | load hive.`db.table1` as table2; 66 | | select * from table2 as table3; 67 | |""".stripMargin).tokens.asScala.toList 68 | context.build(wow) 69 | val processor = new TablePreprocessor(context) 70 | context.statements.foreach(processor.process(_)) 71 | 72 | val targetTable = context.tempTableProvider.search(SpecialTableConst.tempTable("table2").key).get 73 | assert(targetTable.key == MetaTableKey(Some("hive"), Some("db"), "table1")) 74 | 75 | val targetTable3 = context.tempTableProvider.search(SpecialTableConst.tempTable("table3").key).get 76 | assert(targetTable3 == MetaTable(MetaTableKey(None, Some(SpecialTableConst.TEMP_TABLE_DB_KEY), "table3"), 77 | List(MetaTableColumn("a", null, true, Map()), 78 | MetaTableColumn("b", null, true, Map()), 79 | MetaTableColumn("c", null, true, Map()), MetaTableColumn("d", null, true, Map())))) 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/test/java/com/intigua/antlr4/autosuggest/TableStructureTest.scala: -------------------------------------------------------------------------------- 1 | package com.intigua.antlr4.autosuggest 2 | 3 | import tech.mlsql.autosuggest.meta.{MetaProvider, MetaTable, MetaTableColumn, MetaTableKey} 4 | import tech.mlsql.autosuggest.statement.{MetaTableKeyWrapper, SelectSuggester} 5 | import tech.mlsql.autosuggest.{TokenPos, TokenPosType} 6 | import tech.mlsql.common.utils.log.Logging 7 | import scala.collection.JavaConverters._ 8 | 9 | import scala.collection.mutable.ArrayBuffer 10 | 11 | /** 12 | * 22/6/2020 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class TableStructureTest extends BaseTest with Logging { 15 | 16 | 17 | test("s1") { 18 | buildMetaProvider 19 | val sql = 20 | """ 21 | |select from (select no_result_type from db1.table1) b; 22 | |""".stripMargin 23 | val tokens = getMLSQLTokens(sql) 24 | 25 | val suggester = new SelectSuggester(context, tokens, TokenPos(0, TokenPosType.NEXT, 0)) 26 | println(suggester.sqlAST) 27 | } 28 | 29 | test("s2") { 30 | buildMetaProvider 31 | val sql = 32 | """ 33 | |select from (select no_result_type from (select no_result_type from db1.table1) b left join db2.table2) c; 34 | |""".stripMargin 35 | val tokens = getMLSQLTokens(sql) 36 | 37 | val suggester = new SelectSuggester(context, tokens, TokenPos(0, TokenPosType.NEXT, 0)) 38 | printAST(suggester) 39 | } 40 | 41 | def printAST(suggester: SelectSuggester) = { 42 | suggester.sqlAST 43 | logInfo(s"SQL[${suggester.tokens.map(_.getText).mkString(" ")}]") 44 | logInfo(s"STRUCTURE: \n") 45 | suggester.table_info.foreach { item => 46 | logInfo(s"Level:${item._1}") 47 | item._2.foreach { table => 48 | logInfo(s"${table._1} => ${table._2.copy(columns = List())}") 49 | } 50 | } 51 | } 52 | 53 | 54 | def buildMetaProvider = { 55 | context.setUserDefinedMetaProvider(new MetaProvider { 56 | override def search(key: MetaTableKey, extra: Map[String, String] = Map()): Option[MetaTable] = { 57 | Option(MetaTable(key, List( 58 | MetaTableColumn("no_result_type", null, true, Map()), 59 | MetaTableColumn("keywords", null, true, Map()), 60 | MetaTableColumn("search_num", null, true, Map()), 61 | MetaTableColumn("hp_stat_date", null, true, Map()), 62 | MetaTableColumn("action_dt", null, true, Map()), 63 | MetaTableColumn("action_type", null, true, Map()), 64 | MetaTableColumn("av", null, true, Map()) 65 | ))) 66 | 67 | } 68 | 69 | override def list(extra: Map[String, String] = Map()): List[MetaTable] = List() 70 | }) 71 | 72 | 73 | } 74 | test("single select build") { 75 | 76 | buildMetaProvider 77 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 78 | """ 79 | |select a.k from jack.drugs_bad_case_di as a 80 | |""".stripMargin).tokens.asScala.toList 81 | 82 | val suggester = new SelectSuggester(context, wow, TokenPos(3, TokenPosType.CURRENT, 1)) 83 | val root = suggester.sqlAST 84 | root.visitDown(0) { case (ast, level) => 85 | println(s"${ast.name(suggester.tokens)} ${ast.output(suggester.tokens)}") 86 | } 87 | 88 | assert(suggester.suggest().map(_.name) == List("keywords")) 89 | } 90 | 91 | test("subquery build") { 92 | buildMetaProvider 93 | 94 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 95 | """ 96 | |select a.k from (select * from jack.drugs_bad_case_di ) a; 97 | |""".stripMargin).tokens.asScala.toList 98 | 99 | val suggester = new SelectSuggester(context, wow, TokenPos(3, TokenPosType.CURRENT, 1)) 100 | val root = suggester.sqlAST 101 | root.visitDown(0) { case (ast, level) => 102 | println(s"${ast.name(suggester.tokens)} ${ast.output(suggester.tokens)}") 103 | } 104 | 105 | assert(suggester.suggest().map(_.name) == List("keywords")) 106 | } 107 | 108 | test("subquery build without prefix") { 109 | buildMetaProvider 110 | 111 | lazy val wow = context.lexer.tokenizeNonDefaultChannel( 112 | """ 113 | |select k from (select * from jack.drugs_bad_case_di ) a; 114 | |""".stripMargin).tokens.asScala.toList 115 | 116 | val suggester = new SelectSuggester(context, wow, TokenPos(1, TokenPosType.CURRENT, 1)) 117 | val root = suggester.sqlAST 118 | val buffer = ArrayBuffer[String]() 119 | root.visitDown(0) { case (ast, level) => 120 | 121 | buffer += suggester._tokens.slice(ast.start, ast.stop).map(_.getText).mkString(" ") 122 | 123 | } 124 | assert(buffer(0) == "select k from ( select * from jack . drugs_bad_case_di ) a ;") 125 | assert(buffer(1) == "select * from jack . drugs_bad_case_di ) a") 126 | 127 | suggester.table_info.map { 128 | case (level, table) => 129 | if (level == 0) { 130 | assert(table.map(_._1).toList == List(MetaTableKeyWrapper(MetaTableKey(None, None, null), Some("a")))) 131 | } 132 | if (level == 1) { 133 | val tables = table.map(_._1).toList 134 | assert(tables == List(MetaTableKeyWrapper(MetaTableKey(None, Some("jack"), "drugs_bad_case_di"), None))) 135 | 136 | } 137 | } 138 | 139 | } 140 | 141 | } 142 | --------------------------------------------------------------------------------