├── .gitignore ├── README.md ├── docs └── sql-booster-api.md ├── pom.xml └── src ├── main ├── java │ └── tech │ │ └── mlsql │ │ └── sqlbooster │ │ └── JavaDoc.java └── scala │ ├── org │ └── apache │ │ └── spark │ │ └── sql │ │ └── catalyst │ │ ├── SessionUtil.scala │ │ ├── optimizer │ │ ├── PreOrderOptimize.scala │ │ ├── RewriteHelper.scala │ │ ├── RewriteTableToView.scala │ │ └── rewrite │ │ │ ├── component │ │ │ ├── AggMatcher.scala │ │ │ ├── GroupByMatcher.scala │ │ │ ├── JoinMatcher.scala │ │ │ ├── PredicateMatcher.scala │ │ │ ├── ProjectMatcher.scala │ │ │ ├── TableNonOpMatcher.scala │ │ │ ├── rewrite │ │ │ │ ├── AggRewrite.scala │ │ │ │ ├── GroupByRewrite.scala │ │ │ │ ├── JoinRewrite.scala │ │ │ │ ├── PredicateRewrite.scala │ │ │ │ ├── ProjectRewrite.scala │ │ │ │ ├── SPGJPredicateRewrite.scala │ │ │ │ └── TableOrViewRewrite.scala │ │ │ └── util │ │ │ │ └── ExpressionSemanticEquals.scala │ │ │ └── rule │ │ │ ├── RewriteMatchRule.scala │ │ │ ├── SPGJRule.scala │ │ │ ├── WithoutJoinGroupRule.scala │ │ │ └── WithoutJoinRule.scala │ │ └── sqlgenerator │ │ ├── BasicSQLDialect.scala │ │ ├── LogicalPlanSQL.scala │ │ └── SQLDialect.scala │ └── tech │ └── mlsql │ └── sqlbooster │ ├── DataLineageExtractor.scala │ ├── MaterializedViewOptimizeRewrite.scala │ ├── SchemaRegistry.scala │ ├── analysis │ └── protocals.scala │ ├── db │ ├── RDSchema.scala │ └── RawDBTypeToJavaType.scala │ └── meta │ └── ViewCatalyst.scala └── test └── scala └── org └── apache └── spark └── sql └── catalyst ├── BaseSuite.scala ├── DataLineageSuite.scala ├── NewMVSuite.scala └── RangeSuite.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .idea/ 3 | target/ 4 | 5 | .settings/ 6 | .cache 7 | .project 8 | .classpath 9 | metastore_db/ 10 | spark-warehouse/ 11 | derby.log 12 | .cache-main 13 | .cache-tests 14 | */.cache-main 15 | */.cache-tests 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sql-booster 2 | 3 | This is a library for SQL optimizing/rewriting. 4 | Current version (0.4.0) we have already supports: 5 | 6 | 1. Materialized View rewrite. 7 | 2. Data Lineage analysis 8 | 9 | This project is under active development and *** NOT READY FOR PRODUCTION ***. 10 | 11 | # Online APIs 12 | 13 | We have provided free API `http://sql-booster.mlsql.tech`. You can visit `http://sql-booster.mlsql.tech/api` to 14 | get all available APIs. We recommend you using PostMan to play with these APIs. Notice that most of the APIs only 15 | support POST method. 16 | 17 | please check: [HTTP API Tutorial](https://github.com/aistack/sql-booster/blob/master/docs/sql-booster-api.md) 18 | 19 | # Liking 20 | You can link against this library in your program at the following coordinates: 21 | 22 | ## Scala 2.11 23 | 24 | ``` 25 | groupId: tech.mlsql 26 | artifactId: sql-booster_2.11 27 | version: 0.4.0 28 | ``` 29 | ## Deployment 30 | 31 | We recommend people wrap sql-booster with springboot(or other web framework) as http service. 32 | 33 | ## View-based query rewriting usage 34 | 35 | In order to do view-based query rewriting, you should register schema of your concrete tables and views manually. 36 | Notice that we only need following information to make sql-booster work: 37 | 38 | 1. table create statement 39 | 2. view create statement 40 | 3. view schema (infer from view create statement automatically) 41 | 42 | 43 | sql-booster supports three kinds of create statement: 44 | 45 | 1. MySQL/Oracle 46 | 2. Hive 47 | 3. [SimpleSchema](https://github.com/allwefantasy/simple-schema) 48 | 4. Spark StructType json 49 | 50 | 51 | Steps: 52 | 53 | 1. Initial sql-booster, do only one time. 54 | 55 | ```scala 56 | ViewCatalyst.createViewCatalyst() 57 | val schemaReg = new SchemaRegistry(spark) 58 | ``` 59 | 60 | 2. register tables: 61 | 62 | ```scala 63 | schemaReg.createTableFromDBSQL( 64 | """ 65 | |CREATE TABLE depts( 66 | | deptno INT NOT NULL, 67 | | deptname VARCHAR(20), 68 | | PRIMARY KEY (deptno) 69 | |); 70 | """.stripMargin) 71 | 72 | schemaReg.createTableFromDBSQL( 73 | """ 74 | |CREATE TABLE locations( 75 | | locationid INT NOT NULL, 76 | | state CHAR(2), 77 | | PRIMARY KEY (locationid) 78 | |); 79 | """.stripMargin) 80 | 81 | schemaReg.createTableFromDBSQL( 82 | """ 83 | |CREATE TABLE emps( 84 | | empid INT NOT NULL, 85 | | deptno INT NOT NULL, 86 | | locationid INT NOT NULL, 87 | | empname VARCHAR(20) NOT NULL, 88 | | salary DECIMAL (18, 2), 89 | | PRIMARY KEY (empid), 90 | | FOREIGN KEY (deptno) REFERENCES depts(deptno), 91 | | FOREIGN KEY (locationid) REFERENCES locations(locationid) 92 | |); 93 | """.stripMargin) 94 | 95 | schemaReg.createTableFromHiveSQL("src", 96 | """ 97 | |CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive 98 | """.stripMargin) 99 | 100 | 101 | schemaReg.createTableFromSimpleSchema("table1","""st(field(a,string),field(b,string))""") 102 | 103 | schemaReg.createTableFromJson("table2", 104 | """ 105 | |{"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}}]} 106 | """.stripMargin) 107 | ``` 108 | 109 | 110 | 3. register MV: 111 | 112 | ```scala 113 | schemaReg.createMV("emps_mv", 114 | """ 115 | |SELECT empid 116 | |FROM emps 117 | |JOIN depts ON depts.deptno = emps.deptno 118 | """.stripMargin) 119 | 120 | ``` 121 | 122 | 4. Using MaterializedViewOptimizeRewrite to execute rewrite: 123 | 124 | 125 | ```scala 126 | val rewrite3 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 127 | """ 128 | |select * from (SELECT e.empid 129 | |FROM emps e 130 | |JOIN depts d 131 | |ON e.deptno = d.deptno 132 | |where e.empid=1) as a where a.empid=2 133 | """.stripMargin)) 134 | ``` 135 | 136 | 5. Generate rewrite SQL 137 | 138 | ```scala 139 | assert(schemaReg.genSQL(rewrite3) 140 | == "SELECT a.`empid` FROM (SELECT `empid` FROM emps_mv WHERE `empid` = CAST(1 AS BIGINT)) a WHERE a.`empid` = CAST(2 AS BIGINT)") 141 | ``` 142 | 143 | 144 | 145 | 146 | ## Data Lineage Usage 147 | 148 | Given a SQL, sql-booster can help you analysis: 149 | 150 | 1. tables and their corresponding columns which this sql dependents includes the columns used in where,select,join condition. 151 | 2. every output column of this sql is composed by which columns in the original tables 152 | 153 | NOTICE: sql-booster needs you to register table schema firstly like described in **View-based query rewriting usage**. 154 | 155 | Here is the example code: 156 | 157 | ```scala 158 | val result = DataLineageExtractor.execute(schemaReg.toLogicalPlan( 159 | """ 160 | |select * from (SELECT e.empid 161 | |FROM emps e 162 | |JOIN depts d 163 | |ON e.deptno = d.deptno 164 | |where e.empid=1) as a where a.empid=2 165 | """.stripMargin)) 166 | println(JSONTool.pretty(result)) 167 | ``` 168 | 169 | then output is like this: 170 | 171 | ```json 172 | { 173 | "outputMapToSourceTable":[{ 174 | "name":"empid", 175 | "sources":[{ 176 | "tableName":"emps", 177 | "columns":["empid"], 178 | "locates":[["PROJECT","FILTER"]] 179 | },{ 180 | "tableName":"depts", 181 | "columns":[], 182 | "locates":[] 183 | }] 184 | }], 185 | "dependences":[{ 186 | "tableName":"emps", 187 | "columns":["empid","deptno"], 188 | "locates":[["PROJECT","FILTER"],["JOIN"]] 189 | },{ 190 | "tableName":"depts", 191 | "columns":["deptno"], 192 | "locates":[["JOIN"]] 193 | }] 194 | } 195 | ``` 196 | 197 | this means the new table only have one column named empid, and it depends empid in table emps. 198 | the new table depends emps and depts, and empid,deptno are required. 199 | 200 | Also, sql-booster tell you the column appears in which part of the sql. There are 4 parts: 201 | 202 | ``` 203 | FILTER 204 | GROUP_BY 205 | JOIN 206 | PROJECT 207 | ``` 208 | 209 | As to the example, depts.depno exists in `Join`, and emps.empid exists in PROJECT and FILTER(where condition). 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /docs/sql-booster-api.md: -------------------------------------------------------------------------------- 1 | ## How to use sql-booster.mlsql.tech 2 | 3 | This tutorial will show you how to use APIs in sql-booster.mlsql.tech. 4 | 5 | ## APIs 6 | 7 | APIs help information: 8 | 9 | ``` 10 | GET/POST http://sql-booster.mlsql.tech/api 11 | ``` 12 | 13 | Table Register: 14 | 15 | ``` 16 | POST http://sql-booster.mlsql.tech/api_v1/table/register 17 | ``` 18 | 19 | View Register: 20 | 21 | ``` 22 | POST http://sql-booster.mlsql.tech/api_v1/view/register 23 | ``` 24 | 25 | Data Lineage Analysis: 26 | 27 | ``` 28 | POST http://sql-booster.mlsql.tech/api_v1/dataLineage 29 | ``` 30 | 31 | View Based SQL Rewriting: 32 | 33 | ``` 34 | POST http://sql-booster.mlsql.tech/api_v1/mv/rewrite 35 | ``` 36 | 37 | ## Summary 38 | 39 | The design and behavior of sql-booster allows you to analyse or rewrite SQL without real table exits. 40 | The only thing you should do is register table/view create statement before using functions like Lineage Analysis 41 | and View Based SQL Rewriting. 42 | 43 | Also, you should identify who invoke the API, and the system will create a uniq session for you. This will make your tables 44 | registered will not mess up with the other's. The example bellow I will use name `allwefantasy@gmail.com` to identify my 45 | requests. 46 | 47 | We strongly recommend you using PostMan to play with this APIs since most of them only support POST method. 48 | 49 | 50 | ## Register tables 51 | 52 | We will register three tables firstly, and you can use PostMan to post follow data 53 | to `http://sql-booster.mlsql.tech/api_v1/table/register`. 54 | 55 | 56 | table depts: 57 | 58 | ``` 59 | name:allwefantasy@gmail.com 60 | tableName:depts 61 | schema:CREATE TABLE depts(↵ deptno INT NOT NULL,↵ deptname VARCHAR(20),↵ PRIMARY KEY (deptno)↵); 62 | ``` 63 | 64 | table locations: 65 | 66 | ``` 67 | name:allwefantasy@gmail.com 68 | tableName:locations 69 | schema:CREATE TABLE locations(↵ locationid INT NOT NULL,↵ state CHAR(2),↵ PRIMARY KEY (locationid)↵); 70 | ``` 71 | 72 | table emps: 73 | 74 | ``` 75 | name:allwefantasy@gmail.com 76 | tableName:emps 77 | schema:CREATE TABLE emps(↵ empid INT NOT NULL,↵ deptno INT NOT NULL,↵ locationid INT NOT NULL,↵ empname VARCHAR(20) NOT NULL,↵ salary DECIMAL (18, 2),↵ PRIMARY KEY (empid),↵ FOREIGN KEY (deptno) REFERENCES depts(deptno),↵ FOREIGN KEY (locationid) REFERENCES locations(locationid)↵); 78 | ``` 79 | 80 | ## Data Lineage 81 | 82 | Visit `http://sql-booster.mlsql.tech/api_v1/dataLineage` to analyse data lineage for any SQL: 83 | 84 | ``` 85 | sql:select * from (SELECT e.empid↵FROM emps e↵JOIN depts d↵ON e.deptno = d.deptno↵where e.empid=1) as a where a.empid=2 86 | name:allwefantasy@gmail.com 87 | ``` 88 | 89 | The response looks like this: 90 | 91 | ```json 92 | { 93 | "outputMapToSourceTable": [ 94 | { 95 | "name": "empid", 96 | "sources": [ 97 | { 98 | "tableName": "emps", 99 | "columns": [ 100 | "empid" 101 | ], 102 | "locates": [ 103 | [ 104 | "PROJECT", 105 | "FILTER" 106 | ] 107 | ] 108 | }, 109 | { 110 | "tableName": "depts", 111 | "columns": [], 112 | "locates": [] 113 | } 114 | ] 115 | } 116 | ], 117 | "dependences": [ 118 | { 119 | "tableName": "emps", 120 | "columns": [ 121 | "empid", 122 | "deptno" 123 | ], 124 | "locates": [ 125 | [ 126 | "PROJECT", 127 | "FILTER" 128 | ], 129 | [ 130 | "JOIN" 131 | ] 132 | ] 133 | }, 134 | { 135 | "tableName": "depts", 136 | "columns": [ 137 | "deptno" 138 | ], 139 | "locates": [ 140 | [ 141 | "JOIN" 142 | ] 143 | ] 144 | } 145 | ] 146 | } 147 | ``` 148 | 149 | ## View Based SQL Rewrite 150 | 151 | Register View with API `http://sql-booster.mlsql.tech/api_v1/view/register` 152 | 153 | ``` 154 | viewName:emps_mv 155 | name:allwefantasy@gmail.com 156 | sql:SELECT empid↵FROM emps↵JOIN depts ON depts.deptno = emps.deptno 157 | ``` 158 | 159 | Sending a SQL to `http://sql-booster.mlsql.tech/api_v1/mv/rewrite`: 160 | 161 | ``` 162 | name:allwefantasy@gmail.com 163 | sql:select * from (SELECT e.empid↵FROM emps e↵JOIN depts d↵ON e.deptno = d.deptno↵where e.empid=1) as a where a.empid=2 164 | ``` 165 | 166 | The response looks like follow: 167 | 168 | ```sql 169 | SELECT a.`empid` 170 | FROM ( 171 | SELECT `empid` 172 | FROM emps_mv 173 | WHERE `empid` = CAST(1 AS BIGINT) 174 | ) a 175 | WHERE a.`empid` = CAST(2 AS BIGINT) 176 | ``` 177 | 178 | Notice that we have replaced emps,depts by the view emps_mv we have created before. 179 | 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | tech.mlsql 8 | sql-booster 9 | 0.4.0 10 | SQL Booster 11 | https://github.com/aistack/sql-booster 12 | 13 | A library for optimizing/rewriting/auditing SQL and easy to plugin new strategies 14 | 15 | 16 | 17 | Apache 2.0 License 18 | http://www.apache.org/licenses/LICENSE-2.0.html 19 | repo 20 | 21 | 22 | 23 | 24 | allwefantasy 25 | ZhuHaiLin 26 | allwefantasy@gmail.com 27 | 28 | 29 | jackylk 30 | Jacky Li 31 | jacky.likun@qq.com 32 | 33 | 34 | 35 | 36 | scm:git:git@github.com:aistack/sql-booster.git 37 | 38 | 39 | scm:git:git@github.com:aistack/sql-booster.git 40 | 41 | https://github.com/aistack/sql-booster 42 | 43 | 44 | https://github.com/aistack/sql-booster/issues 45 | 46 | 47 | 48 | UTF-8 49 | 2.11.8 50 | 2.11 51 | 2.11.0-M3 52 | 53 | 2.4.3 54 | 2.4 55 | 1.2.0 56 | 57 | 16.0 58 | 4.5.3 59 | 60 | 2.0.0 61 | provided 62 | 2.6.5 63 | 64 | 65 | 66 | 67 | com.alibaba 68 | druid 69 | 1.1.16 70 | 71 | 72 | 73 | tech.mlsql 74 | simple-schema_${scala.binary.version} 75 | 0.2.0 76 | 77 | 78 | org.scalactic 79 | scalactic_${scala.binary.version} 80 | 3.0.0 81 | test 82 | 83 | 84 | org.scalatest 85 | scalatest_${scala.binary.version} 86 | 3.0.0 87 | test 88 | 89 | 90 | 91 | org.apache.spark 92 | spark-core_${scala.binary.version} 93 | ${spark.version} 94 | ${scope} 95 | 96 | 97 | org.apache.spark 98 | spark-sql_${scala.binary.version} 99 | ${spark.version} 100 | ${scope} 101 | 102 | 103 | 104 | org.apache.spark 105 | spark-mllib_${scala.binary.version} 106 | ${spark.version} 107 | ${scope} 108 | 109 | 110 | 111 | org.apache.spark 112 | spark-catalyst_${scala.binary.version} 113 | ${spark.version} 114 | tests 115 | test 116 | 117 | 118 | 119 | org.apache.spark 120 | spark-core_${scala.binary.version} 121 | ${spark.version} 122 | tests 123 | test 124 | 125 | 126 | 127 | org.apache.spark 128 | spark-sql_${scala.binary.version} 129 | ${spark.version} 130 | tests 131 | test 132 | 133 | 134 | 135 | org.pegdown 136 | pegdown 137 | 1.6.0 138 | test 139 | 140 | 141 | 142 | mysql 143 | mysql-connector-java 144 | 8.0.16 145 | 146 | 147 | 148 | 149 | org.apache.spark 150 | spark-hive_${scala.binary.version} 151 | ${spark.version} 152 | ${scope} 153 | 154 | 155 | org.spark-project.hive 156 | hive-exec 157 | 1.2.1.spark2 158 | 159 | 160 | org.apache.spark 161 | spark-core_2.10 162 | 163 | 164 | org.apache.spark 165 | spark-network-common_2.10 166 | 167 | 168 | org.apache.spark 169 | spark-network-shuffle_2.10 170 | 171 | 172 | org.json4s 173 | json4s-jackson_2.10 174 | 175 | 176 | org.json4s 177 | json4s-core_2.10 178 | 179 | 180 | org.json4s 181 | json4s-module-scala_2.10 182 | 183 | 184 | org.json4s 185 | json4s-ast_2.10 186 | 187 | 188 | com.twitter 189 | chill_2.10 190 | 191 | 192 | com.typesafe.akka 193 | akka-actor_2.10.10 194 | 195 | 196 | com.typesafe.akka 197 | akka-remote_2.10 198 | 199 | 200 | com.typesafe.akka 201 | akka-slf4j_2.10 202 | 203 | 204 | ${scope} 205 | 206 | 207 | org.spark-project.hive 208 | hive-metastore 209 | 1.2.1.spark2 210 | 211 | 212 | org.apache.spark 213 | spark-core_2.10 214 | 215 | 216 | org.apache.spark 217 | spark-network-common_2.10 218 | 219 | 220 | org.apache.spark 221 | spark-network-shuffle_2.10 222 | 223 | 224 | org.json4s 225 | json4s-jackson_2.10 226 | 227 | 228 | org.json4s 229 | json4s-core_2.10 230 | 231 | 232 | org.json4s 233 | json4s-module-scala_2.10 234 | 235 | 236 | org.json4s 237 | json4s-ast_2.10 238 | 239 | 240 | com.twitter 241 | chill_2.10 242 | 243 | 244 | com.typesafe.akka 245 | akka-actor_2.10.10 246 | 247 | 248 | com.typesafe.akka 249 | akka-remote_2.10 250 | 251 | 252 | com.typesafe.akka 253 | akka-slf4j_2.10 254 | 255 | 256 | ${scope} 257 | 258 | 259 | 260 | net.liftweb 261 | lift-json_${scala.binary.version} 262 | 2.6.2 263 | test 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | spark-2.4.3 272 | 273 | 2.4.3 274 | 2.4 275 | 276 | 277 | 278 | spark-2.3.2 279 | 280 | 2.3.2 281 | 2.3 282 | 283 | 284 | 285 | disable-java8-doclint 286 | 287 | [1.8,) 288 | 289 | 290 | -Xdoclint:none 291 | none 292 | 293 | 294 | 295 | release-sign-artifacts 296 | 297 | 298 | performRelease 299 | true 300 | 301 | 302 | 303 | 304 | 305 | org.apache.maven.plugins 306 | maven-gpg-plugin 307 | 1.1 308 | 309 | 310 | sign-artifacts 311 | verify 312 | 313 | sign 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | src/main/resources 327 | 328 | 329 | 330 | 331 | org.apache.maven.plugins 332 | maven-surefire-plugin 333 | 3.0.0-M1 334 | 335 | 1 336 | true 337 | -Xmx4024m 338 | 339 | **/*.java 340 | **/*.scala 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | org.scala-tools 349 | maven-scala-plugin 350 | 2.15.2 351 | 352 | 353 | 354 | -g:vars 355 | 356 | 357 | true 358 | 359 | 360 | 361 | compile 362 | 363 | compile 364 | 365 | compile 366 | 367 | 368 | testCompile 369 | 370 | testCompile 371 | 372 | test 373 | 374 | 375 | process-resources 376 | 377 | compile 378 | 379 | 380 | 381 | 382 | 383 | 384 | org.apache.maven.plugins 385 | maven-compiler-plugin 386 | 2.3.2 387 | 388 | 389 | -g 390 | true 391 | 1.8 392 | 1.8 393 | 394 | 395 | 396 | 397 | 398 | 399 | maven-source-plugin 400 | 2.1 401 | 402 | true 403 | 404 | 405 | 406 | compile 407 | 408 | jar 409 | 410 | 411 | 412 | 413 | 414 | org.apache.maven.plugins 415 | maven-javadoc-plugin 416 | 417 | 418 | attach-javadocs 419 | 420 | jar 421 | 422 | 423 | 424 | 425 | 426 | org.sonatype.plugins 427 | nexus-staging-maven-plugin 428 | 1.6.7 429 | true 430 | 431 | sonatype-nexus-staging 432 | https://oss.sonatype.org/ 433 | true 434 | 435 | 436 | 437 | 438 | org.scalatest 439 | scalatest-maven-plugin 440 | 2.0.0 441 | 442 | streaming.core.NotToRunTag 443 | ${project.build.directory}/surefire-reports 444 | . 445 | WDF TestSuite.txt 446 | ${project.build.directory}/html/scalatest 447 | false 448 | 449 | 450 | 451 | test 452 | 453 | test 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | sonatype-nexus-snapshots 463 | https://oss.sonatype.org/content/repositories/snapshots 464 | 465 | 466 | sonatype-nexus-staging 467 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 468 | 469 | 470 | 471 | 472 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/sqlbooster/JavaDoc.java: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster; 2 | 3 | /** 4 | * 2019-07-19 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | public class JavaDoc { 7 | } 8 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/SessionUtil.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst 2 | 3 | import org.apache.spark.sql.SparkSession 4 | 5 | /** 6 | * 2019-07-19 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | object SessionUtil { 9 | def cloneSession(session: SparkSession) = { 10 | session.cloneSession() 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/PreOrderOptimize.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer 2 | 3 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 4 | import org.apache.spark.sql.catalyst.rules.RuleExecutor 5 | 6 | object PreOptimizeRewrite extends RuleExecutor[LogicalPlan] { 7 | val batches = 8 | Batch("Before join rewrite", FixedPoint(100), 9 | EliminateOuterJoin, PushPredicateThroughJoin) :: Nil 10 | } 11 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteHelper.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer 2 | 3 | import org.apache.spark.sql.catalyst.catalog.HiveTableRelation 4 | import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression 5 | import org.apache.spark.sql.catalyst.expressions.{Alias, And, AttributeReference, EqualNullSafe, EqualTo, Exists, ExprId, Expression, ListQuery, NamedLambdaVariable, PredicateHelper, ScalarSubquery} 6 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{ProcessedComponent, RewriteContext} 7 | import org.apache.spark.sql.catalyst.plans.logical._ 8 | import org.apache.spark.sql.execution.LogicalRDD 9 | import org.apache.spark.sql.execution.datasources.LogicalRelation 10 | import tech.mlsql.sqlbooster.meta.TableHolder 11 | 12 | import scala.collection.mutable.ArrayBuffer 13 | 14 | /** 15 | * 2019-07-12 WilliamZhu(allwefantasy@gmail.com) 16 | */ 17 | trait RewriteHelper extends PredicateHelper { 18 | 19 | /** 20 | * Since attribute references are given globally unique ids during analysis, 21 | * we must normalize them to check if two different queries are identical. 22 | */ 23 | protected def normalizeExprIds(plan: LogicalPlan) = { 24 | plan transformAllExpressions { 25 | case s: ScalarSubquery => 26 | s.copy(exprId = ExprId(0)) 27 | case e: Exists => 28 | e.copy(exprId = ExprId(0)) 29 | case l: ListQuery => 30 | l.copy(exprId = ExprId(0)) 31 | case a: AttributeReference => 32 | AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) 33 | case a: Alias => 34 | Alias(a.child, a.name)(exprId = ExprId(0)) 35 | case ae: AggregateExpression => 36 | ae.copy(resultId = ExprId(0)) 37 | case lv: NamedLambdaVariable => 38 | lv.copy(exprId = ExprId(0), value = null) 39 | } 40 | } 41 | 42 | /** 43 | * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be 44 | * equivalent: 45 | * 1. (a = b), (b = a); 46 | * 2. (a <=> b), (b <=> a). 47 | */ 48 | private def rewriteEqual(condition: Expression): Expression = condition match { 49 | case eq@EqualTo(l: Expression, r: Expression) => 50 | Seq(l, r).sortBy(hashCode).reduce(EqualTo) 51 | case eq@EqualNullSafe(l: Expression, r: Expression) => 52 | Seq(l, r).sortBy(hashCode).reduce(EqualNullSafe) 53 | case _ => condition // Don't reorder. 54 | } 55 | 56 | def hashCode(_ar: Expression): Int = { 57 | // See http://stackoverflow.com/questions/113511/hash-code-implementation 58 | _ar match { 59 | case ar@AttributeReference(_, _, _, _) => 60 | var h = 17 61 | h = h * 37 + ar.name.hashCode() 62 | h = h * 37 + ar.dataType.hashCode() 63 | h = h * 37 + ar.nullable.hashCode() 64 | h = h * 37 + ar.metadata.hashCode() 65 | h = h * 37 + ar.exprId.hashCode() 66 | h 67 | case _ => _ar.hashCode() 68 | } 69 | 70 | } 71 | 72 | /** 73 | * Normalizes plans: 74 | * - Filter the filter conditions that appear in a plan. For instance, 75 | * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) 76 | * etc., will all now be equivalent. 77 | * - Sample the seed will replaced by 0L. 78 | * - Join conditions will be resorted by hashCode. 79 | * 80 | * we use new hash function to avoid `ar.qualifier` from alias affect the final order. 81 | * 82 | */ 83 | protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { 84 | 85 | 86 | plan transform { 87 | case Filter(condition: Expression, child: LogicalPlan) => 88 | Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(hashCode) 89 | .reduce(And), child) 90 | case sample: Sample => 91 | sample.copy(seed = 0L) 92 | case Join(left, right, joinType, condition) if condition.isDefined => 93 | val newCondition = 94 | splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(hashCode) 95 | .reduce(And) 96 | Join(left, right, joinType, Some(newCondition)) 97 | } 98 | } 99 | 100 | /** Consider symmetry for joins when comparing plans. */ 101 | def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { 102 | (plan1, plan2) match { 103 | case (j1: Join, j2: Join) => 104 | (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || 105 | (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) 106 | case (p1: Project, p2: Project) => 107 | p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) 108 | case _ => 109 | plan1 == plan2 110 | } 111 | } 112 | 113 | def extractTablesFromPlan(plan: LogicalPlan) = { 114 | extractTableHolderFromPlan(plan).map { holder => 115 | if (holder.db != null) holder.db + "." + holder.table 116 | else holder.table 117 | }.filterNot(f => f == null) 118 | } 119 | 120 | def extractTableHolderFromPlan(plan: LogicalPlan) = { 121 | var tables = Set[TableHolder]() 122 | plan transformDown { 123 | case a@SubqueryAlias(_, LogicalRelation(_, _, _, _)) => 124 | tables += TableHolder(null, a.name.unquotedString, a.output, a) 125 | a 126 | case a@SubqueryAlias(_, LogicalRDD(_, _, _, _, _)) => 127 | tables += TableHolder(null, a.name.unquotedString, a.output, a) 128 | a 129 | case a@SubqueryAlias(_, m@HiveTableRelation(tableMeta, _, _)) => 130 | tables += TableHolder(null, a.name.unquotedString, a.output, a) 131 | a 132 | case m@HiveTableRelation(tableMeta, _, _) => 133 | tables += TableHolder(tableMeta.database, tableMeta.identifier.table, m.output, m) 134 | m 135 | case m@LogicalRelation(_, output, catalogTable, _) => 136 | val tableIdentifier = catalogTable.map(_.identifier) 137 | val database = tableIdentifier.map(_.database).flatten.getOrElse(null) 138 | val table = tableIdentifier.map(_.table).getOrElse(null) 139 | tables += TableHolder(database, table, output, m) 140 | m 141 | } 142 | tables.toList 143 | } 144 | 145 | /** Fails the test if the join order in the two plans do not match */ 146 | protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) = { 147 | val normalized1 = normalizePlan(normalizeExprIds(plan1)) 148 | val normalized2 = normalizePlan(normalizeExprIds(plan2)) 149 | sameJoinPlan(normalized1, normalized2) 150 | } 151 | 152 | /** Fails the test if the two plans do not match */ 153 | protected def comparePlans( 154 | plan1: LogicalPlan, 155 | plan2: LogicalPlan) = { 156 | 157 | val normalized1 = normalizePlan(normalizeExprIds(plan1)) 158 | val normalized2 = normalizePlan(normalizeExprIds(plan2)) 159 | normalized1 == normalized2 160 | } 161 | 162 | /** Fails the test if the two expressions do not match */ 163 | protected def compareExpressions(e1: Expression, e2: Expression) = { 164 | comparePlans(Filter(e1, OneRowRelation()), Filter(e2, OneRowRelation())) 165 | } 166 | 167 | def mergeConjunctiveExpressions(e: Seq[Expression]) = { 168 | e.reduce { (a, b) => 169 | And(a, b) 170 | } 171 | } 172 | 173 | def extractTheSameExpressions(view: Seq[Expression], query: Seq[Expression]) = { 174 | val viewLeft = ArrayBuffer[Expression](view: _*) 175 | val queryLeft = ArrayBuffer[Expression](query: _*) 176 | val common = ArrayBuffer[Expression]() 177 | query.foreach { itemInQuery => 178 | view.foreach { itemInView => 179 | if (itemInView.semanticEquals(itemInQuery)) { 180 | common += itemInQuery 181 | viewLeft -= itemInView 182 | queryLeft -= itemInQuery 183 | } 184 | } 185 | } 186 | (viewLeft, queryLeft, common) 187 | } 188 | 189 | 190 | def extractAttributeReference(expr: Expression) = { 191 | val columns = ArrayBuffer[AttributeReference]() 192 | expr transformDown { 193 | case a@AttributeReference(name, dataType, _, _) => 194 | columns += a 195 | a 196 | } 197 | columns 198 | } 199 | 200 | def extractAttributeReferenceFromFirstLevel(exprs: Seq[Expression]) = { 201 | exprs.map { expr => 202 | expr match { 203 | case a@AttributeReference(name, dataType, _, _) => Option(a) 204 | case _ => None 205 | } 206 | }.filter(_.isDefined).map(_.get) 207 | } 208 | 209 | /** 210 | * Sometimes we compare two tables with column name and dataType 211 | */ 212 | def attributeReferenceEqual(a: AttributeReference, b: AttributeReference) = { 213 | a.name == b.name && a.dataType == b.dataType 214 | } 215 | 216 | def isJoinExists(plan: LogicalPlan) = { 217 | var _isJoinExists = false 218 | plan transformDown { 219 | case a@Join(_, _, _, _) => 220 | _isJoinExists = true 221 | a 222 | } 223 | _isJoinExists 224 | } 225 | 226 | def isAggExistsExists(plan: LogicalPlan) = { 227 | var _isAggExistsExists = false 228 | plan transformDown { 229 | case a@Aggregate(_, _, _) => 230 | _isAggExistsExists = true 231 | a 232 | } 233 | _isAggExistsExists 234 | } 235 | 236 | def generateRewriteContext(plan: LogicalPlan, rewriteContext: RewriteContext) = { 237 | var queryConjunctivePredicates: Seq[Expression] = Seq() 238 | var viewConjunctivePredicates: Seq[Expression] = Seq() 239 | 240 | var queryProjectList: Seq[Expression] = Seq() 241 | var viewProjectList: Seq[Expression] = Seq() 242 | 243 | var queryGroupingExpressions: Seq[Expression] = Seq() 244 | var viewGroupingExpressions: Seq[Expression] = Seq() 245 | 246 | var queryAggregateExpressions: Seq[Expression] = Seq() 247 | var viewAggregateExpressions: Seq[Expression] = Seq() 248 | 249 | val viewJoins = ArrayBuffer[Join]() 250 | val queryJoins = ArrayBuffer[Join]() 251 | 252 | val queryNormalizePlan = normalizePlan(plan) 253 | val viewNormalizePlan = normalizePlan(rewriteContext.viewLogicalPlan.get.viewCreateLogicalPlan) 254 | //collect all predicates 255 | viewNormalizePlan transformDown { 256 | case a@Filter(condition, _) => 257 | viewConjunctivePredicates ++= splitConjunctivePredicates(condition) 258 | a 259 | } 260 | 261 | queryNormalizePlan transformDown { 262 | case a@Filter(condition, _) => 263 | queryConjunctivePredicates ++= splitConjunctivePredicates(condition) 264 | a 265 | } 266 | 267 | // check projectList and where condition 268 | normalizePlan(plan) match { 269 | case Project(projectList, Filter(condition, _)) => 270 | queryConjunctivePredicates = splitConjunctivePredicates(condition) 271 | queryProjectList = projectList 272 | case Project(projectList, _) => 273 | queryProjectList = projectList 274 | 275 | case Aggregate(groupingExpressions, aggregateExpressions, Filter(condition, _)) => { 276 | queryConjunctivePredicates = splitConjunctivePredicates(condition) 277 | queryGroupingExpressions = groupingExpressions 278 | queryAggregateExpressions = aggregateExpressions 279 | } 280 | case Aggregate(groupingExpressions, aggregateExpressions, _) => 281 | queryGroupingExpressions = groupingExpressions 282 | queryAggregateExpressions = aggregateExpressions 283 | 284 | 285 | } 286 | 287 | normalizePlan(rewriteContext.viewLogicalPlan.get().viewCreateLogicalPlan) match { 288 | case Project(projectList, Filter(condition, _)) => 289 | viewConjunctivePredicates = splitConjunctivePredicates(condition) 290 | viewProjectList = projectList 291 | case Project(projectList, _) => 292 | viewProjectList = projectList 293 | 294 | case Aggregate(groupingExpressions, aggregateExpressions, Filter(condition, _)) => 295 | viewConjunctivePredicates = splitConjunctivePredicates(condition) 296 | viewGroupingExpressions = groupingExpressions 297 | viewAggregateExpressions = aggregateExpressions 298 | 299 | case Aggregate(groupingExpressions, aggregateExpressions, _) => 300 | viewGroupingExpressions = groupingExpressions 301 | viewAggregateExpressions = aggregateExpressions 302 | } 303 | 304 | if (isJoinExists(plan)) { 305 | // get the first level join 306 | viewJoins += extractFirstLevelJoin(viewNormalizePlan) 307 | queryJoins += extractFirstLevelJoin(queryNormalizePlan) 308 | } 309 | 310 | 311 | rewriteContext.processedComponent.set(ProcessedComponent( 312 | queryConjunctivePredicates, 313 | viewConjunctivePredicates, 314 | queryProjectList, 315 | viewProjectList, 316 | queryGroupingExpressions, 317 | viewGroupingExpressions, 318 | queryAggregateExpressions, 319 | viewAggregateExpressions, 320 | viewJoins, 321 | queryJoins 322 | )) 323 | 324 | } 325 | 326 | def extractFirstLevelJoin(plan: LogicalPlan) = { 327 | plan match { 328 | case p@Project(_, join@Join(_, _, _, _)) => join 329 | case p@Project(_, Filter(_, join@Join(_, _, _, _))) => join 330 | case p@Aggregate(_, _, Filter(_, join@Join(_, _, _, _))) => join 331 | case p@Aggregate(_, _, join@Join(_, _, _, _)) => join 332 | } 333 | } 334 | 335 | 336 | } 337 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteTableToView.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer 2 | 3 | import java.util.concurrent.atomic.AtomicReference 4 | 5 | import org.apache.spark.sql.catalyst.expressions._ 6 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule._ 7 | import org.apache.spark.sql.catalyst.plans.logical._ 8 | import org.apache.spark.sql.catalyst.rules.Rule 9 | import org.apache.spark.sql.execution.LogicalRDD 10 | import org.apache.spark.sql.execution.datasources.LogicalRelation 11 | 12 | import scala.collection.mutable.ArrayBuffer 13 | 14 | /** 15 | * References: 16 | * [GL01] Jonathan Goldstein and Per-åke Larson. 17 | * Optimizing queries using materialized views: A practical, scalable solution. In Proc. ACM SIGMOD Conf., 2001. 18 | * 19 | * The Rule should be used on the resolved analyzer, for 20 | * example: 21 | * {{{ 22 | * object OptimizeRewrite extends RuleExecutor[LogicalPlan] { 23 | * val batches = 24 | * Batch("User Rewriter", Once, 25 | * RewriteTableToViews) :: Nil 26 | * } 27 | * 28 | * ViewCatalyst.createViewCatalyst() 29 | * ViewCatalyst.meta.registerFromLogicalPlan("viewTable1", viewTable1.logicalPlan, createViewTable1.logicalPlan) 30 | * 31 | * val analyzed = spark.sql(""" select * from at where a="jack" and b="wow" """).queryExecution.analyzed 32 | * val mvRewrite = OptimizeRewrite.execute(analyzed) 33 | * 34 | * // do other stuff to mvRewrite 35 | * }}} 36 | * 37 | * 38 | */ 39 | object RewriteTableToViews extends Rule[LogicalPlan] with PredicateHelper { 40 | val batches = ArrayBuffer[RewriteMatchRule]( 41 | WithoutJoinGroupRule.apply, 42 | WithoutJoinRule.apply, 43 | SPGJRule.apply 44 | ) 45 | 46 | def apply(plan: LogicalPlan): LogicalPlan = { 47 | var lastPlan = plan 48 | var shouldStop = false 49 | var count = 100 50 | val rewriteContext = new RewriteContext(new AtomicReference[ViewLogicalPlan](), new AtomicReference[ProcessedComponent]()) 51 | while (!shouldStop && count > 0) { 52 | count -= 1 53 | var currentPlan = if (isSPJG(plan)) { 54 | rewrite(plan, rewriteContext) 55 | } else { 56 | plan.transformUp { 57 | case a if isSPJG(a) => 58 | rewrite(a, rewriteContext) 59 | } 60 | } 61 | if (currentPlan != lastPlan) { 62 | //fix all attributeRef in finalPlan 63 | currentPlan = currentPlan transformAllExpressions { 64 | case ar@AttributeReference(_, _, _, _) => 65 | val qualifier = ar.qualifier 66 | rewriteContext.replacedARMapping.getOrElse(ar.withQualifier(Seq()), ar).withQualifier(qualifier) 67 | } 68 | } else { 69 | shouldStop = true 70 | } 71 | 72 | lastPlan = currentPlan 73 | } 74 | lastPlan 75 | } 76 | 77 | private def rewrite(plan: LogicalPlan, rewriteContext: RewriteContext) = { 78 | // this plan is SPJG, but the first step is check whether we can rewrite it 79 | var rewritePlan = plan 80 | batches.foreach { rewriter => 81 | rewritePlan = rewriter.rewrite(rewritePlan, rewriteContext) 82 | } 83 | 84 | rewritePlan match { 85 | case RewritedLogicalPlan(_, true) => 86 | logInfo(s"=====try to rewrite but fail ======:\n\n${plan} ") 87 | plan 88 | case RewritedLogicalPlan(inner, false) => 89 | logInfo(s"=====try to rewrite and success ======:\n\n${plan} \n\n ${inner}") 90 | inner 91 | case _ => 92 | logInfo(s"=====try to rewrite but fail ======:\n\n${plan} ") 93 | rewritePlan 94 | } 95 | } 96 | 97 | /** 98 | * check the plan is whether a basic sql pattern 99 | * only contains select(filter)/agg/project/join/group. 100 | * 101 | * @param plan 102 | * @return 103 | */ 104 | private def isSPJG(plan: LogicalPlan): Boolean = { 105 | println(plan) 106 | var isMatch = true 107 | plan transformDown { 108 | case a@SubqueryAlias(_, Project(_, _)) => 109 | isMatch = false 110 | a 111 | case a@Union(_) => 112 | isMatch = false 113 | a 114 | } 115 | 116 | if (!isMatch) { 117 | return false 118 | } 119 | 120 | plan match { 121 | case p@Project(_, Join(_, _, _, _)) => true 122 | case p@Project(_, Filter(_, Join(_, _, _, _))) => true 123 | case p@Aggregate(_, _, Filter(_, Join(_, _, _, _))) => true 124 | case p@Aggregate(_, _, Filter(_, _)) => true 125 | case p@Project(_, Filter(_, _)) => true 126 | case p@Aggregate(_, _, Join(_, _, _, _)) => true 127 | case p@Aggregate(_, _, SubqueryAlias(_, LogicalRDD(_, _, _, _, _))) => true 128 | case p@Aggregate(_, _, SubqueryAlias(_, LogicalRelation(_, _, _, _))) => true 129 | case p@Project(_, SubqueryAlias(_, LogicalRDD(_, _, _, _, _))) => true 130 | case p@Project(_, SubqueryAlias(_, LogicalRelation(_, _, _, _))) => true 131 | case _ => false 132 | } 133 | } 134 | } 135 | 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/AggMatcher.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component 2 | 3 | import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Sum} 4 | import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Divide, Expression, Literal} 5 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule._ 6 | import org.apache.spark.sql.types.IntegerType 7 | 8 | import scala.collection.mutable.ArrayBuffer 9 | 10 | /** 11 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 12 | */ 13 | class AggMatcher(rewriteContext: RewriteContext 14 | ) extends ExpressionMatcher { 15 | /** 16 | * when the view/query both has count(*), and the group by condition in query isSubset(not equal) of view, 17 | * then we should replace query count(*) with SUM(view_count(*)). for example: 18 | * 19 | * view: select count(*) as a from table1 group by m; 20 | * query: view1 => select count(*) as a from table1 group by m,c 21 | * 22 | * target: select sum(a) from view1 group by c 23 | * 24 | * Another situation we should take care is AVG: 25 | * 26 | * view: select count(*) as a from table1 group by m; 27 | * query: view1 => select avg(k) from table1 group by m,c 28 | * 29 | * target: select sum(k)/a from view1 group by c 30 | * 31 | * 32 | */ 33 | override def compare: CompensationExpressions = { 34 | 35 | val query = rewriteContext.processedComponent.get().queryAggregateExpressions 36 | val view = rewriteContext.processedComponent.get().viewAggregateExpressions 37 | 38 | // let's take care the first situation, if there are count(*) in query, then 39 | // count(*) should also be in view and we should replace it with sum(count_view) 40 | val queryCountStar = getCountStartList(query) 41 | val viewCountStar = getCountStartList(view) 42 | 43 | if (queryCountStar.size > 0 && viewCountStar == 0) return RewriteFail.AGG_NUMBER_UNMATCH(this) 44 | 45 | val viewProjectOrAggList = rewriteContext.viewLogicalPlan.get().tableLogicalPlan.output 46 | 47 | 48 | /** 49 | * let's take care the third situation, any agg filed both in view/query, we should replace it with new field in view 50 | * 51 | * query: 52 | * 53 | * SELECT deptno, COUNT(*) AS c, SUM(salary) AS s 54 | * FROM emps 55 | * GROUP BY deptno 56 | * 57 | * view: 58 | * 59 | * SELECT empid, deptno, COUNT(*) AS c, SUM(salary) AS s 60 | * FROM emps 61 | * GROUP BY empid, deptno 62 | * 63 | * target: 64 | * 65 | * SELECT deptno, SUM(c), SUM(s) 66 | * FROM mv 67 | * GROUP BY deptno 68 | * 69 | * here we should convert SUM(salary) to SUM(s) or s 70 | */ 71 | 72 | 73 | val exactlySame = query.filterNot { item => 74 | item match { 75 | case a@Alias(agg@AggregateExpression(Average(ar@_), _, _, _), name) => false 76 | case a@Alias(agg@AggregateExpression(Count(_), _, _, _), name) => false 77 | case _ => true 78 | } 79 | } 80 | 81 | 82 | val success = exactlySame.map { item => 83 | if (view.filter { f => 84 | cleanAlias(f).semanticEquals(cleanAlias(item)) 85 | }.size > 0) 1 else 0 86 | }.sum == exactlySame.size 87 | 88 | if (!success) return RewriteFail.AGG_COLUMNS_UNMATCH(this) 89 | 90 | var queryReplaceAgg = query 91 | 92 | queryReplaceAgg = queryReplaceAgg.map { item => 93 | item transformUp { 94 | case a@Alias(agg@AggregateExpression(Average(ar@_), _, _, _), name) => a 95 | case a@Alias(agg@AggregateExpression(Count(_), _, _, _), name) => a 96 | case a@Alias(agg@AggregateExpression(_, _, _, _), name) => 97 | val (vItem, index) = view.zipWithIndex.filter { case (vItem, index) => 98 | cleanAlias(vItem).semanticEquals(cleanAlias(a)) 99 | }.head 100 | val newVItem = vItem transformDown { 101 | case a@AttributeReference(_, _, _, _) => viewProjectOrAggList(index) 102 | } 103 | Alias(cleanAlias(newVItem), name)() 104 | } 105 | } 106 | 107 | 108 | var queryReplaceCountStar = queryReplaceAgg 109 | 110 | if (queryCountStar.size > 0) { 111 | val replaceItem = viewCountStar.head 112 | val arInViewTable = extractAttributeReferenceFromFirstLevel(viewProjectOrAggList).filter { ar => 113 | ar.name == replaceItem.asInstanceOf[Alias].name 114 | }.head 115 | 116 | queryReplaceCountStar = queryReplaceCountStar map { expr => 117 | expr transformDown { 118 | case Alias(agg@AggregateExpression(Count(Seq(Literal(1, IntegerType))), _, _, _), name) => 119 | Alias(agg.copy(aggregateFunction = Sum(arInViewTable)), name)() 120 | } 121 | } 122 | } 123 | 124 | // let's take care the second situation, if there are AVG(k) in query,then count(*) 125 | // should also be in view and we should replace it with sum(k)/view_count(*) 126 | 127 | val queryAvg = getAvgList(query) 128 | 129 | if (queryAvg.size > 0 && viewCountStar == 0) return RewriteFail.AGG_VIEW_MISSING_COUNTING_STAR(this) 130 | 131 | var queryReplaceAvg = queryReplaceCountStar 132 | 133 | if (queryAvg.size > 0) { 134 | val replaceItem = viewCountStar.head 135 | val arInViewTable = extractAttributeReferenceFromFirstLevel(viewProjectOrAggList).filter { ar => 136 | ar.name == replaceItem.asInstanceOf[Alias].name 137 | }.head 138 | 139 | queryReplaceAvg = queryReplaceAvg.map { expr => 140 | val newExpr = expr transformDown { 141 | case a@Alias(agg@AggregateExpression(Average(ar@_), _, _, _), name) => 142 | // and ar should be also in viewProjectOrAggList 143 | val sum = agg.copy(aggregateFunction = Sum(ar)) 144 | Alias(Divide(sum, arInViewTable), name)() 145 | } 146 | newExpr 147 | } 148 | } 149 | 150 | 151 | CompensationExpressions(true, queryReplaceAvg) 152 | 153 | } 154 | 155 | 156 | private def getCountStartList(items: Seq[Expression]) = { 157 | val queryCountStar = ArrayBuffer[Expression]() 158 | items.zipWithIndex.foreach { case (expr, index) => 159 | expr transformDown { 160 | case a@Alias(AggregateExpression(Count(Seq(Literal(1, IntegerType))), _, _, _), name) => 161 | queryCountStar += a 162 | a 163 | } 164 | 165 | } 166 | queryCountStar 167 | } 168 | 169 | private def getAvgList(items: Seq[Expression]) = { 170 | val avgList = ArrayBuffer[Expression]() 171 | items.foreach { expr => 172 | expr transformDown { 173 | case a@Alias(AggregateExpression(Average(ar@_), _, _, _), name) => 174 | avgList += a 175 | a 176 | } 177 | } 178 | avgList 179 | } 180 | 181 | private def cleanAlias(expr: Expression) = { 182 | expr match { 183 | case Alias(child, _) => child 184 | case _ => expr 185 | } 186 | } 187 | } 188 | 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/GroupByMatcher.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component 2 | 3 | import org.apache.spark.sql.catalyst.expressions.Expression 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule._ 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | /** 9 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class GroupByMatcher(rewriteContext: RewriteContext) extends ExpressionMatcher { 12 | override def compare: CompensationExpressions = { 13 | /** 14 | * Query: 15 | * 16 | * SELECT deptno 17 | * FROM emps 18 | * WHERE deptno > 10 19 | * GROUP BY deptno 20 | * 21 | * View: 22 | * 23 | * SELECT empid, deptno 24 | * FROM emps 25 | * WHERE deptno > 5 26 | * GROUP BY empid, deptno 27 | * 28 | * Target: 29 | * 30 | * SELECT deptno 31 | * FROM mv 32 | * WHERE deptno > 10 33 | * GROUP BY deptno 34 | * 35 | * then query isSubSet of view . Please take care of the order in group by. 36 | */ 37 | val query = rewriteContext.processedComponent.get().queryGroupingExpressions 38 | val view = rewriteContext.processedComponent.get().viewGroupingExpressions 39 | val viewAggregateExpressions = rewriteContext.processedComponent.get().viewAggregateExpressions 40 | 41 | if (query.size > view.size) return RewriteFail.GROUP_BY_SIZE_UNMATCH(this) 42 | if (!isSubSetOf(query, view)) return RewriteFail.GROUP_BY_SIZE_UNMATCH(this) 43 | 44 | // again make sure the columns in queryLeft is also in view project/agg 45 | 46 | val viewAttrs = extractAttributeReferenceFromFirstLevel(viewAggregateExpressions) 47 | 48 | val compensationCondAllInViewProjectList = isSubSetOf(query.flatMap(extractAttributeReference), viewAttrs) 49 | 50 | if (!compensationCondAllInViewProjectList) return RewriteFail.GROUP_BY_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG(this) 51 | 52 | CompensationExpressions(true, query) 53 | 54 | } 55 | 56 | private def extractTheSameExpressionsOrder(view: Seq[Expression], query: Seq[Expression]) = { 57 | val viewLeft = ArrayBuffer[Expression](view: _*) 58 | val queryLeft = ArrayBuffer[Expression](query: _*) 59 | val common = ArrayBuffer[Expression]() 60 | 61 | (0 until view.size).foreach { index => 62 | if (view(index).semanticEquals(query(index))) { 63 | common += view(index) 64 | viewLeft -= view(index) 65 | queryLeft -= query(index) 66 | } 67 | } 68 | 69 | (viewLeft, queryLeft, common) 70 | } 71 | } 72 | 73 | 74 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/JoinMatcher.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component 2 | 3 | import org.apache.spark.sql.catalyst.expressions.AttributeReference 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule._ 5 | import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, SubqueryAlias} 6 | 7 | /** 8 | * 2019-07-16 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class JoinMatcher(rewriteContext: RewriteContext 11 | ) extends ExpressionMatcher { 12 | override def compare: CompensationExpressions = { 13 | 14 | 15 | val viewJoin = rewriteContext.processedComponent.get().viewJoins.head 16 | val queryJoin = rewriteContext.processedComponent.get().queryJoins.head 17 | // since the prediate condition will be pushed down into Join filter, 18 | // but we have compare them in Predicate Matcher/Rewrite step, so when compare Join, 19 | // we should clean the filter from Join 20 | if (!sameJoinPlan(cleanJoinFilter(viewJoin), cleanJoinFilter(queryJoin))) return RewriteFail.JOIN_UNMATCH(this) 21 | CompensationExpressions(true, Seq()) 22 | } 23 | 24 | def cleanJoinFilter(join: Join) = { 25 | val newPlan = join transformUp { 26 | case a@Filter(_, child) => 27 | child 28 | case SubqueryAlias(_, a@SubqueryAlias(_, _)) => 29 | a 30 | case a@Join(_, _, _, condition) => 31 | if (condition.isDefined) { 32 | val newConditions = condition.get transformUp { 33 | case a@AttributeReference(name, dataType, nullable, metadata) => 34 | AttributeReference(name, dataType, nullable, metadata)(a.exprId, Seq()) 35 | } 36 | a.copy(condition = Option(newConditions)) 37 | 38 | } else a 39 | 40 | } 41 | newPlan 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/PredicateMatcher.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component 2 | 3 | import org.apache.spark.sql.catalyst.expressions.{Cast, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal} 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule._ 5 | import org.apache.spark.sql.types._ 6 | 7 | import scala.collection.mutable.ArrayBuffer 8 | 9 | 10 | /** 11 | * Here we compare where conditions 12 | * 13 | * 1. Equal view subSetOf query 14 | * 2. NoEqual (greater/less) , we first convert them to RangeCondition, 15 | * so we can define the range contains(range in query should be narrow then range in view) 16 | * between RangeCondition, and the 17 | * final check view subSetOf query 18 | * 19 | * 3. others. Using expression semanticEqual, and make sure view subSetOf query 20 | * 21 | * 22 | */ 23 | class PredicateMatcher(rewriteContext: RewriteContext) extends ExpressionMatcher { 24 | 25 | override def compare: CompensationExpressions = { 26 | 27 | val compensationCond = ArrayBuffer[Expression]() 28 | 29 | if (rewriteContext.processedComponent.get().viewConjunctivePredicates.size > rewriteContext.processedComponent.get().queryConjunctivePredicates.size) return RewriteFail.PREDICATE_UNMATCH(this) 30 | 31 | // equal expression compare 32 | val viewEqual = extractEqualConditions(rewriteContext.processedComponent.get().viewConjunctivePredicates) 33 | val queryEqual = extractEqualConditions(rewriteContext.processedComponent.get().queryConjunctivePredicates) 34 | 35 | // if viewEqual are not subset of queryEqual, then it will not match. 36 | if (!isSubSetOf(viewEqual, queryEqual)) return RewriteFail.PREDICATE_EQUALS_UNMATCH(this) 37 | compensationCond ++= subset[Expression](queryEqual, viewEqual) 38 | 39 | // less/greater expressions compare 40 | 41 | // make sure all less/greater expression with the same presentation 42 | // for example if exits a < 3 && a>=1 then we should change to RangeCondition(a,1,3) 43 | // or b < 3 then RangeCondition(b,None,3) 44 | val viewRange = extractRangeConditions(rewriteContext.processedComponent.get().viewConjunctivePredicates).map(RangeFilter.convertRangeCon) 45 | val queryRange = extractRangeConditions(rewriteContext.processedComponent.get().queryConjunctivePredicates).map(RangeFilter.convertRangeCon) 46 | 47 | // combine something like 48 | // RangeCondition(a,1,None),RangeCondition(a,None,3) into RangeCondition(a,1,3) 49 | 50 | 51 | val viewRangeCondition = RangeFilter.combineAndMergeRangeCondition(viewRange).toSeq 52 | val queryRangeCondtion = RangeFilter.combineAndMergeRangeCondition(queryRange).toSeq 53 | 54 | //again make sure viewRangeCondition.size is small queryRangeCondtion.size 55 | if (viewRangeCondition.size > queryRangeCondtion.size) return RewriteFail.PREDICATE_RANGE_UNMATCH(this) 56 | 57 | //all view rangeCondition should a SubRangeCondition of query 58 | val isRangeMatch = viewRangeCondition.map { viewRC => 59 | if (queryRangeCondtion.filter(queryRC => queryRC.isSubRange(viewRC)).size >= 1) 1 else 0 60 | }.sum == viewRangeCondition.size 61 | 62 | if (!isRangeMatch) return RewriteFail.PREDICATE_RANGE_UNMATCH(this) 63 | 64 | compensationCond ++= queryRangeCondtion.flatMap(_.toExpression) 65 | 66 | // other conditions compare 67 | val viewResidual = extractResidualConditions(rewriteContext.processedComponent.get().viewConjunctivePredicates) 68 | val queryResidual = extractResidualConditions(rewriteContext.processedComponent.get().queryConjunctivePredicates) 69 | if (!isSubSetOf(viewResidual, queryResidual)) return RewriteFail.PREDICATE_EXACLTY_SAME_UNMATCH(this) 70 | compensationCond ++= subset[Expression](queryResidual, viewResidual) 71 | 72 | // make sure all attributeReference in compensationCond is also in output of view 73 | // we get all columns without applied any function in projectList of viewCreateLogicalPlan 74 | val viewAttrs = extractAttributeReferenceFromFirstLevel(rewriteContext.viewLogicalPlan.get().viewCreateLogicalPlan.output) 75 | 76 | val compensationCondAllInViewProjectList = isSubSetOf(compensationCond.flatMap(extractAttributeReference), viewAttrs) 77 | 78 | if (!compensationCondAllInViewProjectList) return RewriteFail.PREDICATE_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG(this) 79 | 80 | // return the compensation expressions 81 | CompensationExpressions(true, compensationCond) 82 | } 83 | 84 | 85 | def extractEqualConditions(conjunctivePredicates: Seq[Expression]) = { 86 | conjunctivePredicates.filter(RangeFilter.equalCon) 87 | } 88 | 89 | def extractRangeConditions(conjunctivePredicates: Seq[Expression]) = { 90 | conjunctivePredicates.filter(RangeFilter.rangeCon) 91 | } 92 | 93 | def extractResidualConditions(conjunctivePredicates: Seq[Expression]) = { 94 | conjunctivePredicates.filterNot(RangeFilter.equalCon).filterNot(RangeFilter.rangeCon) 95 | } 96 | } 97 | 98 | case class RangeCondition(key: Expression, lowerBound: Option[Literal], upperBound: Option[Literal], 99 | includeLowerBound: Boolean, 100 | includeUpperBound: Boolean) { 101 | 102 | def toExpression: Seq[Expression] = { 103 | (lowerBound, upperBound) match { 104 | case (None, None) => Seq() 105 | case (Some(l), None) => if (includeLowerBound) 106 | Seq(GreaterThanOrEqual(key, l)) else Seq(GreaterThan(key, Cast(l, key.dataType))) 107 | case (None, Some(l)) => if (includeUpperBound) 108 | Seq(LessThanOrEqual(key, l)) else Seq(LessThan(key, Cast(l, key.dataType))) 109 | case (Some(a), Some(b)) => 110 | val aSeq = if (includeLowerBound) 111 | Seq(GreaterThanOrEqual(key, Cast(a, key.dataType))) else Seq(GreaterThan(key, Cast(a, key.dataType))) 112 | val bSeq = if (includeUpperBound) 113 | Seq(LessThanOrEqual(key, Cast(b, key.dataType))) else Seq(LessThan(key, Cast(b, key.dataType))) 114 | aSeq ++ bSeq 115 | } 116 | } 117 | 118 | def isSubRange(other: RangeCondition) = { 119 | this.key.semanticEquals(other.key) && 120 | greaterThenOrEqual(this.lowerBound, other.lowerBound, true) && 121 | greaterThenOrEqual(other.upperBound, this.upperBound, false) 122 | } 123 | 124 | def greaterThenOrEqual(lit1: Option[Literal], lit2: Option[Literal], isLowerBound: Boolean) = { 125 | (lit1, lit2) match { 126 | case (None, None) => true 127 | case (Some(l), None) => if (isLowerBound) true else false 128 | case (None, Some(l)) => if (isLowerBound) false else true 129 | case (Some(a), Some(b)) => 130 | a.dataType match { 131 | 132 | case ShortType | IntegerType | LongType | FloatType | DoubleType => a.value.toString.toDouble >= b.value.toString.toDouble 133 | case StringType => a.value.toString >= b.value.toString 134 | case _ => throw new RuntimeException("not support type") 135 | } 136 | } 137 | } 138 | 139 | def +(other: RangeCondition) = { 140 | assert(this.key.semanticEquals(other.key)) 141 | 142 | 143 | val _lowerBound = if (greaterThenOrEqual(this.lowerBound, other.lowerBound, true)) 144 | (this.lowerBound, this.includeLowerBound) else (other.lowerBound, other.includeLowerBound) 145 | 146 | val _upperBound = if (greaterThenOrEqual(this.upperBound, other.upperBound, false)) 147 | (other.upperBound, other.includeUpperBound) else (this.upperBound, this.includeUpperBound) 148 | RangeCondition(key, _lowerBound._1, _upperBound._1, _lowerBound._2, _upperBound._2) 149 | } 150 | 151 | 152 | } 153 | 154 | object RangeFilter { 155 | val equalCon = (f: Expression) => { 156 | f.isInstanceOf[EqualNullSafe] || f.isInstanceOf[EqualTo] 157 | } 158 | 159 | val convertRangeCon = (f: Expression) => { 160 | f match { 161 | case GreaterThan(a, Cast(v@Literal(_, _), _, _)) => RangeCondition(a, Option(v), None, false, false) 162 | case GreaterThan(a, v@Literal(_, _)) => RangeCondition(a, Option(v), None, false, false) 163 | case GreaterThan(v@Literal(_, _), a) => RangeCondition(a, None, Option(v), false, false) 164 | case GreaterThan(Cast(v@Literal(_, _), _, _), a) => RangeCondition(a, None, Option(v), false, false) 165 | case GreaterThanOrEqual(a, v@Literal(_, _)) => RangeCondition(a, Option(v), None, true, false) 166 | case GreaterThanOrEqual(a, Cast(v@Literal(_, _), _, _)) => RangeCondition(a, Option(v), None, true, false) 167 | case GreaterThanOrEqual(v@Literal(_, _), a) => RangeCondition(a, None, Option(v), false, true) 168 | case GreaterThanOrEqual(Cast(v@Literal(_, _), _, _), a) => RangeCondition(a, None, Option(v), false, true) 169 | case LessThan(a, Cast(v@Literal(_, _), _, _)) => RangeCondition(a, None, Option(v), false, false) 170 | case LessThan(a, v@Literal(_, _)) => RangeCondition(a, None, Option(v), false, false) 171 | case LessThan(Cast(v@Literal(_, _), _, _), a) => RangeCondition(a, Option(v), None, false, true) 172 | case LessThan(v@Literal(_, _), a) => RangeCondition(a, Option(v), None, false, true) 173 | case LessThanOrEqual(a, Cast(v@Literal(_, _), _, _)) => RangeCondition(a, None, Option(v), false, true) 174 | case LessThanOrEqual(a, v@Literal(_, _)) => RangeCondition(a, None, Option(v), false, true) 175 | case LessThanOrEqual(a, Cast(v@Literal(_, _), _, _)) => RangeCondition(a, None, Option(v), false, true) 176 | case LessThanOrEqual(v@Literal(_, _), a) => RangeCondition(a, Option(v), None, true, false) 177 | } 178 | } 179 | 180 | val rangeCon = (f: Expression) => { 181 | f match { 182 | case GreaterThan(_, Literal(_, _)) | GreaterThan(Literal(_, _), _) => true 183 | case GreaterThan(_, Cast(Literal(_, _), _, _)) | GreaterThan(Cast(Literal(_, _), _, _), _) => true 184 | case GreaterThanOrEqual(_, Literal(_, _)) | GreaterThanOrEqual(Literal(_, _), _) => true 185 | case GreaterThanOrEqual(_, Cast(Literal(_, _), _, _)) | GreaterThanOrEqual(Cast(Literal(_, _), _, _), _) => true 186 | case LessThan(_, Literal(_, _)) | LessThan(Literal(_, _), _) => true 187 | case LessThan(_, Cast(Literal(_, _), _, _)) | LessThan(Cast(Literal(_, _), _, _), _) => true 188 | case LessThanOrEqual(_, Literal(_, _)) | LessThanOrEqual(Literal(_, _), _) => true 189 | case LessThanOrEqual(_, Cast(Literal(_, _), _, _)) | LessThanOrEqual(Cast(Literal(_, _), _, _), _) => true 190 | case _ => false 191 | } 192 | } 193 | 194 | def combineAndMergeRangeCondition(items: Seq[RangeCondition]) = { 195 | items.groupBy(f => f.key).map { f => 196 | val first = f._2.head.copy(lowerBound = None, upperBound = None) 197 | f._2.foldLeft(first) { (result, item) => 198 | result + item 199 | } 200 | } 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/ProjectMatcher.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component 2 | 3 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component.util.{ExpressionIntersectResp, ExpressionSemanticEquals} 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule._ 5 | 6 | /** 7 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | class ProjectMatcher(rewriteContext: RewriteContext) extends ExpressionMatcher { 10 | /** 11 | * 12 | * @param query the project expression list in query 13 | * @param view the project expression list in view 14 | * @return 15 | * 16 | * We should make sure all query project list isSubSet of view project list. 17 | * 18 | * 19 | */ 20 | override def compare: CompensationExpressions = { 21 | 22 | val query = rewriteContext.processedComponent.get().queryProjectList 23 | val view = rewriteContext.processedComponent.get().viewProjectList 24 | val ExpressionIntersectResp(queryLeft, viewLeft, _) = ExpressionSemanticEquals.process(query, view) 25 | // for now, we must make sure the queryLeft's columns(not alias) all in viewLeft.columns(not alias) 26 | val queryColumns = queryLeft.flatMap(extractAttributeReference) 27 | val viewColumns = viewLeft.flatMap(extractAttributeReference) 28 | 29 | val ExpressionIntersectResp(queryColumnsLeft, viewColumnsLeft, _) = ExpressionSemanticEquals.process(queryColumns, viewColumns) 30 | if (queryColumnsLeft.size > 0) return RewriteFail.PROJECT_UNMATCH(this) 31 | CompensationExpressions(true, Seq()) 32 | } 33 | 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/TableNonOpMatcher.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component 2 | 3 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{CompensationExpressions, ExpressionMatcher, RewriteContext} 4 | 5 | /** 6 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 7 | */ 8 | class TableNonOpMatcher(rewriteContext: RewriteContext) extends ExpressionMatcher { 9 | override def compare: CompensationExpressions = CompensationExpressions(true, Seq()) 10 | } 11 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/AggRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext, RewritedLeafLogicalPlan} 5 | import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} 6 | 7 | /** 8 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class AggRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 11 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 12 | val projectOrAggList = rewriteContext.viewLogicalPlan.get().tableLogicalPlan.output 13 | 14 | val newExpressions = _compensationExpressions.compensation.map { expr => 15 | expr transformDown { 16 | case a@AttributeReference(name, dt, _, _) => 17 | val newAr = extractAttributeReferenceFromFirstLevel(projectOrAggList).filter(f => attributeReferenceEqual(a, f)).head 18 | rewriteContext.replacedARMapping += (a.withQualifier(Seq()) -> newAr) 19 | newAr 20 | } 21 | }.map(_.asInstanceOf[NamedExpression]) 22 | val newPlan = plan transformDown { 23 | case Aggregate(groupingExpressions, _, child) => 24 | RewritedLeafLogicalPlan(Aggregate(groupingExpressions, newExpressions, child)) 25 | } 26 | _back(newPlan) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/GroupByRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.expressions.AttributeReference 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext, RewritedLeafLogicalPlan} 5 | import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} 6 | 7 | /** 8 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class GroupByRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 11 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 12 | 13 | val projectOrAggList = rewriteContext.viewLogicalPlan.get().tableLogicalPlan.output 14 | 15 | val newExpressions = _compensationExpressions.compensation.map { expr => 16 | expr transformDown { 17 | case a@AttributeReference(name, dt, _, _) => 18 | extractAttributeReferenceFromFirstLevel(projectOrAggList).filter(f => attributeReferenceEqual(a, f)).head 19 | } 20 | } 21 | 22 | 23 | val newPlan = plan transformDown { 24 | case Aggregate(_, aggregateExpressions, child) => 25 | RewritedLeafLogicalPlan(Aggregate(newExpressions, aggregateExpressions, child)) 26 | } 27 | _back(newPlan) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/JoinRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext} 4 | import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} 5 | 6 | /** 7 | * 2019-07-16 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | class JoinRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 10 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 11 | 12 | plan transformUp { 13 | case Join(_, _, _, _) => rewriteContext.viewLogicalPlan.get().tableLogicalPlan match { 14 | case Project(_, child) => child 15 | case _ => rewriteContext.viewLogicalPlan.get().tableLogicalPlan 16 | } 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/PredicateRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.expressions.AttributeReference 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext, RewritedLeafLogicalPlan, ViewLogicalPlan} 5 | import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} 6 | 7 | /** 8 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class PredicateRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 11 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 12 | 13 | val projectOrAggList = rewriteContext.viewLogicalPlan.get().tableLogicalPlan.output 14 | 15 | val newExpressions = _compensationExpressions.compensation.map { expr => 16 | expr transformDown { 17 | case a@AttributeReference(name, dt, _, _) => 18 | extractAttributeReferenceFromFirstLevel(projectOrAggList).filter(f => attributeReferenceEqual(a, f)).head 19 | } 20 | } 21 | 22 | 23 | val newPlan = plan transformDown { 24 | case a@Filter(condition, child) => 25 | if (newExpressions.isEmpty) { 26 | RewritedLeafLogicalPlan(child) 27 | } else { 28 | RewritedLeafLogicalPlan(Filter(mergeConjunctiveExpressions(newExpressions), child)) 29 | } 30 | 31 | } 32 | _back(newPlan) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/ProjectRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext, RewritedLogicalPlan} 5 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} 6 | 7 | /** 8 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class ProjectRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 11 | 12 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 13 | 14 | val projectOrAggList = rewriteContext.viewLogicalPlan.get().tableLogicalPlan.output 15 | 16 | def rewriteProject(plan: LogicalPlan): LogicalPlan = { 17 | plan match { 18 | case Project(projectList, child) => 19 | val newProjectList = projectList.map { expr => 20 | expr transformDown { 21 | case a@AttributeReference(name, dt, _, _) => 22 | val newAr = extractAttributeReferenceFromFirstLevel(projectOrAggList).filter(f => attributeReferenceEqual(a, f)).head 23 | rewriteContext.replacedARMapping += (a.withQualifier(Seq()) -> newAr) 24 | newAr 25 | } 26 | }.map(_.asInstanceOf[NamedExpression]) 27 | Project(newProjectList, child) 28 | case RewritedLogicalPlan(inner, _) => rewriteProject(inner) 29 | case _ => plan 30 | } 31 | } 32 | 33 | val newPlan = rewriteProject(plan) 34 | _back(RewritedLogicalPlan(newPlan, false)) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/SPGJPredicateRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.expressions.AttributeReference 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext, RewritedLeafLogicalPlan} 5 | import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} 6 | 7 | /** 8 | * 2019-07-16 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class SPGJPredicateRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 11 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 12 | 13 | val projectOrAggList = rewriteContext.viewLogicalPlan.get().tableLogicalPlan.output 14 | 15 | val newExpressions = _compensationExpressions.compensation.map { expr => 16 | expr transformDown { 17 | case a@AttributeReference(name, dt, _, _) => 18 | extractAttributeReferenceFromFirstLevel(projectOrAggList).filter(f => attributeReferenceEqual(a, f)).head 19 | } 20 | } 21 | 22 | //clean filter and then add new filter before Join 23 | var newPlan = plan transformDown { 24 | case a@Filter(condition, child) => 25 | child 26 | } 27 | 28 | var lastJoin: Join = null 29 | newPlan = plan transformUp { 30 | case a@Join(_, _, _, _) => 31 | lastJoin = a 32 | a 33 | } 34 | 35 | newPlan = plan transformDown { 36 | case a@Join(_, _, _, _) => 37 | if (a == lastJoin) { 38 | RewritedLeafLogicalPlan(Filter(mergeConjunctiveExpressions(newExpressions), a)) 39 | } else a 40 | } 41 | 42 | 43 | _back(newPlan) 44 | 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/rewrite/TableOrViewRewrite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite 2 | 3 | import org.apache.spark.sql.catalyst.catalog.HiveTableRelation 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.{LogicalPlanRewrite, RewriteContext, RewritedLeafLogicalPlan} 5 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, SubqueryAlias} 6 | import org.apache.spark.sql.execution.datasources.LogicalRelation 7 | 8 | /** 9 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class TableOrViewRewrite(rewriteContext: RewriteContext) extends LogicalPlanRewrite { 12 | override def rewrite(plan: LogicalPlan): LogicalPlan = { 13 | val finalTable = rewriteContext.viewLogicalPlan.get().tableLogicalPlan match { 14 | case Project(_, child) => child 15 | case _ => rewriteContext.viewLogicalPlan.get().tableLogicalPlan 16 | } 17 | val newPlan = plan transformDown { 18 | case SubqueryAlias(_, _) => 19 | RewritedLeafLogicalPlan(finalTable) 20 | case HiveTableRelation(_, _, _) => 21 | RewritedLeafLogicalPlan(finalTable) 22 | case LogicalRelation(_, output, catalogTable, _) => 23 | RewritedLeafLogicalPlan(finalTable) 24 | } 25 | 26 | _back(newPlan) 27 | 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/component/util/ExpressionSemanticEquals.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.component.util 2 | 3 | import org.apache.spark.sql.catalyst.expressions.Expression 4 | import org.apache.spark.sql.catalyst.optimizer.RewriteHelper 5 | 6 | /** 7 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | object ExpressionSemanticEquals extends RewriteHelper { 10 | def process(query: Seq[Expression], view: Seq[Expression]) = { 11 | val (viewLeft, queryLeft, common) = extractTheSameExpressions(view, query) 12 | ExpressionIntersectResp(queryLeft, viewLeft, common) 13 | } 14 | } 15 | 16 | case class ExpressionIntersectResp( 17 | queryLeft: Seq[Expression], 18 | viewLeft: Seq[Expression], 19 | common: Seq[Expression] 20 | ) 21 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/rule/RewriteMatchRule.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.rule 2 | 3 | import java.util.concurrent.atomic.AtomicReference 4 | 5 | import org.apache.spark.internal.Logging 6 | import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} 7 | import org.apache.spark.sql.catalyst.optimizer.RewriteHelper 8 | import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} 9 | 10 | import scala.collection.mutable 11 | 12 | 13 | /** 14 | * This is entry point of Plan rewrite. 15 | * Every Rewrite Rule contains a LogicalPlanRewritePipeline which is composed by a bunch of PipelineItemExecutor. 16 | * 17 | * PipelineItemExecutor contains: 18 | * 19 | * 1. A ExpressionMatcher, check where we can rewrite some part of SQL, and if we can, how to compensate expressions. 20 | * 2. A LogicalPlanRewrite, do the logical plan rewrite and return a new plan. 21 | * 22 | * For example: 23 | * 24 | * [[WithoutJoinGroupRule]] is a RewriteMatchRule, it is designed for the SQL like `select * from a where m='yes'` which 25 | * without agg,groupby and join. 26 | * 27 | * WithoutJoinGroupRule have three items in PipelineItemExecutor: 28 | * 29 | * 1. Project Matcher/Rewriter 30 | * 2. Predicate Matcher/Rewriter 31 | * 3. Table(View) Matcher/Rewriter 32 | */ 33 | trait RewriteMatchRule extends RewriteHelper { 34 | def fetchView(plan: LogicalPlan, rewriteContext: RewriteContext): Seq[ViewLogicalPlan] 35 | 36 | def rewrite(plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan 37 | 38 | def buildPipeline[_](rewriteContext: RewriteContext, items: Seq[MatchOrRewrite]) = { 39 | val pipeline = mutable.ArrayBuffer[PipelineItemExecutor]() 40 | items.grouped(2).foreach { items => 41 | pipeline += PipelineItemExecutor(items(0).asInstanceOf[ExpressionMatcher], items(1).asInstanceOf[LogicalPlanRewrite]) 42 | } 43 | pipeline 44 | } 45 | 46 | 47 | } 48 | 49 | trait MatchOrRewrite { 50 | def rewrite(plan: LogicalPlan): LogicalPlan = { 51 | plan 52 | } 53 | 54 | def compare: CompensationExpressions = { 55 | CompensationExpressions(false, Seq()) 56 | } 57 | } 58 | 59 | trait LogicalPlanRewrite extends MatchOrRewrite with RewriteHelper { 60 | protected var _compensationExpressions: CompensationExpressions = null 61 | 62 | def compensationExpressions(ce: CompensationExpressions) = { 63 | _compensationExpressions = ce 64 | this 65 | } 66 | 67 | def _back(newPlan: LogicalPlan) = { 68 | newPlan transformDown { 69 | case RewritedLeafLogicalPlan(inner) => inner 70 | } 71 | } 72 | 73 | def rewrite(plan: LogicalPlan): LogicalPlan 74 | } 75 | 76 | trait ExpressionMatcher extends MatchOrRewrite with ExpressionMatcherHelper { 77 | var rewriteFail: Option[RewriteFail] = None 78 | 79 | def compare: CompensationExpressions 80 | } 81 | 82 | object RewriteFail { 83 | val DEFAULT = CompensationExpressions(false, Seq()) 84 | 85 | def apply(msg: String): RewriteFail = RewriteFail(msg, DEFAULT) 86 | 87 | def msg(value: String, matcher: ExpressionMatcher) = { 88 | matcher.rewriteFail = Option(apply(value)) 89 | DEFAULT 90 | } 91 | 92 | def GROUP_BY_SIZE_UNMATCH(matcher: ExpressionMatcher) = { 93 | msg("GROUP_BY_SIZE_UNMATCH", matcher) 94 | } 95 | 96 | def GROUP_BY_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG(matcher: ExpressionMatcher) = { 97 | msg("GROUP_BY_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG", matcher) 98 | } 99 | 100 | def AGG_NUMBER_UNMATCH(matcher: ExpressionMatcher) = { 101 | msg("AGG_UNMATCH", matcher) 102 | } 103 | 104 | def AGG_COLUMNS_UNMATCH(matcher: ExpressionMatcher) = { 105 | msg("AGG_COLUMNS_UNMATCH", matcher) 106 | } 107 | 108 | def AGG_VIEW_MISSING_COUNTING_STAR(matcher: ExpressionMatcher) = { 109 | msg("AGG_VIEW_MISSING_COUNTING_STAR", matcher) 110 | } 111 | 112 | def JOIN_UNMATCH(matcher: ExpressionMatcher) = { 113 | msg("JOIN_UNMATCH", matcher) 114 | } 115 | 116 | def PREDICATE_UNMATCH(matcher: ExpressionMatcher) = { 117 | msg("PREDICATE_UNMATCH", matcher) 118 | } 119 | 120 | def PREDICATE_EQUALS_UNMATCH(matcher: ExpressionMatcher) = { 121 | msg("PREDICATE_EQUALS_UNMATCH", matcher) 122 | } 123 | 124 | def PREDICATE_RANGE_UNMATCH(matcher: ExpressionMatcher) = { 125 | msg("PREDICATE_RANGE_UNMATCH", matcher) 126 | } 127 | 128 | def PREDICATE_EXACLTY_SAME_UNMATCH(matcher: ExpressionMatcher) = { 129 | msg("PREDICATE_EXACLTY_SAME_UNMATCH", matcher) 130 | } 131 | 132 | def PREDICATE_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG(matcher: ExpressionMatcher) = { 133 | msg("PREDICATE_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG", matcher) 134 | } 135 | 136 | def PROJECT_UNMATCH(matcher: ExpressionMatcher) = { 137 | msg("PREDICATE_COLUMNS_NOT_IN_VIEW_PROJECT_OR_AGG", matcher) 138 | } 139 | } 140 | 141 | 142 | case class RewriteFail(val msg: String, val ce: CompensationExpressions) 143 | 144 | 145 | trait ExpressionMatcherHelper extends MatchOrRewrite with RewriteHelper { 146 | def isSubSetOf(e1: Seq[Expression], e2: Seq[Expression]) = { 147 | e1.map { item1 => 148 | e2.map { item2 => 149 | if (item2.semanticEquals(item1)) 1 else 0 150 | }.sum 151 | }.sum == e1.size 152 | } 153 | 154 | def isSubSetOfWithOrder(e1: Seq[Expression], e2: Seq[Expression]) = { 155 | val zipCount = Math.min(e1.size, e2.size) 156 | (0 until zipCount).map { index => 157 | if (e1(index).semanticEquals(e2(index))) 158 | 0 159 | else 1 160 | }.sum == 0 161 | } 162 | 163 | def subset[T](e1: Seq[T], e2: Seq[T]) = { 164 | assert(e1.size >= e2.size) 165 | if (e1.size == 0) Seq[Expression]() 166 | e1.slice(e2.size, e1.size) 167 | } 168 | } 169 | 170 | case class CompensationExpressions(isRewriteSuccess: Boolean, compensation: Seq[Expression]) 171 | 172 | class LogicalPlanRewritePipeline(pipeline: Seq[PipelineItemExecutor]) extends Logging { 173 | def rewrite(plan: LogicalPlan): LogicalPlan = { 174 | 175 | var planRewrite: RewritedLogicalPlan = RewritedLogicalPlan(plan, false) 176 | 177 | (0 until pipeline.size).foreach { index => 178 | 179 | if (!planRewrite.stopPipeline) { 180 | pipeline(index).execute(planRewrite) match { 181 | case a@RewritedLogicalPlan(_, true) => 182 | logInfo(s"Pipeline item [${pipeline(index)}] fails. ") 183 | planRewrite = a 184 | case a@RewritedLogicalPlan(_, false) => 185 | planRewrite = a 186 | } 187 | } 188 | } 189 | planRewrite 190 | } 191 | } 192 | 193 | object LogicalPlanRewritePipeline { 194 | def apply(pipeline: Seq[PipelineItemExecutor]): LogicalPlanRewritePipeline = new LogicalPlanRewritePipeline(pipeline) 195 | } 196 | 197 | case class PipelineItemExecutor(matcher: ExpressionMatcher, reWriter: LogicalPlanRewrite) extends Logging { 198 | def execute(plan: LogicalPlan) = { 199 | val compsation = matcher.compare 200 | compsation match { 201 | case CompensationExpressions(true, _) => 202 | reWriter.compensationExpressions(compsation) 203 | reWriter.rewrite(plan) 204 | case CompensationExpressions(false, _) => 205 | logInfo(s"=====Rewrite fail:${matcher.rewriteFail.map(_.msg).getOrElse("NONE")}=====") 206 | println(s"=====Rewrite fail:${matcher.rewriteFail.map(_.msg).getOrElse("NONE")}=====") 207 | RewritedLogicalPlan(plan, stopPipeline = true) 208 | } 209 | } 210 | } 211 | 212 | case class RewritedLogicalPlan(inner: LogicalPlan, val stopPipeline: Boolean = false) extends LogicalPlan { 213 | override def output: Seq[Attribute] = inner.output 214 | 215 | override def children: Seq[LogicalPlan] = Seq(inner) 216 | 217 | } 218 | 219 | case class RewritedLeafLogicalPlan(inner: LogicalPlan) extends LogicalPlan { 220 | override def output: Seq[Attribute] = Seq() 221 | 222 | override def children: Seq[LogicalPlan] = Seq() 223 | } 224 | 225 | case class ViewLogicalPlan(tableLogicalPlan: LogicalPlan, viewCreateLogicalPlan: LogicalPlan) 226 | 227 | case class RewriteContext(viewLogicalPlan: AtomicReference[ViewLogicalPlan], processedComponent: AtomicReference[ProcessedComponent], 228 | replacedARMapping: mutable.HashMap[AttributeReference, AttributeReference] = 229 | mutable.HashMap[AttributeReference, AttributeReference]()) 230 | 231 | case class ProcessedComponent( 232 | queryConjunctivePredicates: Seq[Expression] = Seq(), 233 | viewConjunctivePredicates: Seq[Expression] = Seq(), 234 | queryProjectList: Seq[Expression] = Seq(), 235 | viewProjectList: Seq[Expression] = Seq(), 236 | queryGroupingExpressions: Seq[Expression] = Seq(), 237 | viewGroupingExpressions: Seq[Expression] = Seq(), 238 | queryAggregateExpressions: Seq[Expression] = Seq(), 239 | viewAggregateExpressions: Seq[Expression] = Seq(), 240 | viewJoins: Seq[Join] = Seq(), 241 | queryJoins: Seq[Join] = Seq() 242 | ) 243 | 244 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/rule/SPGJRule.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.rule 2 | 3 | import org.apache.spark.sql.catalyst.expressions.Expression 4 | import org.apache.spark.sql.catalyst.optimizer.PreOptimizeRewrite 5 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component._ 6 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite._ 7 | import org.apache.spark.sql.catalyst.plans.logical._ 8 | import org.apache.spark.sql.execution.LogicalRDD 9 | import org.apache.spark.sql.execution.datasources.LogicalRelation 10 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 11 | 12 | import scala.collection.mutable.ArrayBuffer 13 | 14 | /** 15 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 16 | */ 17 | 18 | object SPGJRule { 19 | def apply: SPGJRule = new SPGJRule() 20 | } 21 | 22 | class SPGJRule extends RewriteMatchRule { 23 | 24 | /** 25 | * 26 | * @param plan 27 | * @return 28 | */ 29 | override def fetchView(plan: LogicalPlan, rewriteContext: RewriteContext): Seq[ViewLogicalPlan] = { 30 | require(plan.resolved, "LogicalPlan must be resolved.") 31 | 32 | if (!isJoinExists(plan)) return Seq() 33 | 34 | // get all tables in join and the first table 35 | val tables = extractTablesFromPlan(plan) 36 | if (tables.size == 0) return Seq() 37 | 38 | var mainTableLogicalPlan: LogicalPlan = null 39 | 40 | plan transformUp { 41 | case a@Join(_, _, _, _) => 42 | a.left transformUp { 43 | case a@SubqueryAlias(_, child@LogicalRelation(_, _, _, _)) => 44 | mainTableLogicalPlan = a 45 | a 46 | case a@SubqueryAlias(_, child@LogicalRDD(_, _, _, _, _)) => 47 | mainTableLogicalPlan = a 48 | a 49 | } 50 | a 51 | } 52 | 53 | val mainTable = extractTablesFromPlan(mainTableLogicalPlan).head 54 | 55 | val viewPlan = ViewCatalyst.meta.getCandidateViewsByTable(mainTable) match { 56 | case Some(viewNames) => 57 | viewNames.filter { viewName => 58 | ViewCatalyst.meta.getViewCreateLogicalPlan(viewName) match { 59 | case Some(viewLogicalPlan) => 60 | extractTablesFromPlan(viewLogicalPlan).toSet == tables.toSet 61 | case None => false 62 | } 63 | }.map { targetViewName => 64 | ViewLogicalPlan( 65 | ViewCatalyst.meta.getViewLogicalPlan(targetViewName).get, 66 | ViewCatalyst.meta.getViewCreateLogicalPlan(targetViewName).get) 67 | }.toSeq 68 | case None => Seq() 69 | 70 | 71 | } 72 | viewPlan 73 | } 74 | 75 | override def rewrite(_plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan = { 76 | val plan = PreOptimizeRewrite.execute(_plan) 77 | var targetViewPlanOption = fetchView(plan, rewriteContext) 78 | if (targetViewPlanOption.isEmpty) return plan 79 | 80 | targetViewPlanOption = targetViewPlanOption.map(f => 81 | f.copy(viewCreateLogicalPlan = PreOptimizeRewrite.execute(f.viewCreateLogicalPlan))) 82 | 83 | var shouldBreak = false 84 | var finalPlan = RewritedLogicalPlan(plan, true) 85 | 86 | targetViewPlanOption.foreach { targetViewPlan => 87 | if (!shouldBreak) { 88 | rewriteContext.viewLogicalPlan.set(targetViewPlan) 89 | val res = _rewrite(plan, rewriteContext) 90 | res match { 91 | case a@RewritedLogicalPlan(_, true) => 92 | finalPlan = a 93 | case a@RewritedLogicalPlan(_, false) => 94 | finalPlan = a 95 | shouldBreak = true 96 | } 97 | } 98 | } 99 | finalPlan 100 | } 101 | 102 | def _rewrite(plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan = { 103 | 104 | generateRewriteContext(plan, rewriteContext) 105 | 106 | val pipeline = buildPipeline(rewriteContext: RewriteContext, Seq( 107 | new PredicateMatcher(rewriteContext), 108 | new SPGJPredicateRewrite(rewriteContext), 109 | new GroupByMatcher(rewriteContext), 110 | new GroupByRewrite(rewriteContext), 111 | new AggMatcher(rewriteContext), 112 | new AggRewrite(rewriteContext), 113 | new JoinMatcher(rewriteContext), 114 | new JoinRewrite(rewriteContext), 115 | new ProjectMatcher(rewriteContext), 116 | new ProjectRewrite(rewriteContext) 117 | 118 | )) 119 | 120 | /** 121 | * When we are rewriting plan, any step fails, we should return the original plan. 122 | * So we should check the mark in RewritedLogicalPlan is final success or fail. 123 | */ 124 | LogicalPlanRewritePipeline(pipeline).rewrite(plan) 125 | } 126 | 127 | 128 | } 129 | 130 | 131 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/rule/WithoutJoinGroupRule.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.rule 2 | 3 | import org.apache.spark.sql.catalyst.expressions.Expression 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component._ 5 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite._ 6 | import org.apache.spark.sql.catalyst.plans.logical._ 7 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 8 | 9 | /** 10 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | object WithoutJoinGroupRule { 13 | def apply: WithoutJoinGroupRule = new WithoutJoinGroupRule() 14 | } 15 | 16 | 17 | class WithoutJoinGroupRule extends RewriteMatchRule { 18 | override def fetchView(plan: LogicalPlan, rewriteContext: RewriteContext): Seq[ViewLogicalPlan] = { 19 | require(plan.resolved, "LogicalPlan must be resolved.") 20 | 21 | if (isAggExistsExists(plan) || isJoinExists(plan)) return Seq() 22 | 23 | val tables = extractTablesFromPlan(plan) 24 | if (tables.size == 0) return Seq() 25 | val table = tables.head 26 | val viewPlan = ViewCatalyst.meta.getCandidateViewsByTable(table) match { 27 | case Some(viewNames) => 28 | viewNames.filter { viewName => 29 | ViewCatalyst.meta.getViewCreateLogicalPlan(viewName) match { 30 | case Some(viewLogicalPlan) => 31 | extractTablesFromPlan(viewLogicalPlan).toSet == Set(table) 32 | case None => false 33 | } 34 | }.map { targetViewName => 35 | ViewLogicalPlan( 36 | ViewCatalyst.meta.getViewLogicalPlan(targetViewName).get, 37 | ViewCatalyst.meta.getViewCreateLogicalPlan(targetViewName).get) 38 | }.toSeq 39 | case None => Seq() 40 | 41 | 42 | } 43 | 44 | 45 | viewPlan 46 | } 47 | 48 | 49 | /** 50 | * query: select * from a,b where a.name=b.name and a.name2="jack" and b.jack="wow"; 51 | * view: a_view= select * from a,b where a.name=b.name and a.name2="jack" ; 52 | * target: select * from a_view where b.jack="wow" 53 | * 54 | * step 0: tables equal check 55 | * step 1: View equivalence classes: 56 | * query: PE:{a.name,b.name} 57 | * NPE: {a.name2="jack"} {b.jack="wow"} 58 | * view: PE: {a.name,b.name},NPE: {a.name2="jack"} 59 | * 60 | * step2 QPE < VPE, and QNPE < VNPE. We should normalize the PE make sure a=b equal to b=a, and 61 | * compare the NPE with Range Check, the others just check exactly 62 | * 63 | * step3: output columns check 64 | * 65 | * @param plan 66 | * @return 67 | */ 68 | override def rewrite(plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan = { 69 | val targetViewPlanOption = fetchView(plan, rewriteContext) 70 | if (targetViewPlanOption.isEmpty) return plan 71 | 72 | var shouldBreak = false 73 | var finalPlan = RewritedLogicalPlan(plan, true) 74 | 75 | targetViewPlanOption.foreach { targetViewPlan => 76 | if (!shouldBreak) { 77 | rewriteContext.viewLogicalPlan.set(targetViewPlan) 78 | val res = _rewrite(plan, rewriteContext) 79 | res match { 80 | case a@RewritedLogicalPlan(_, true) => 81 | finalPlan = a 82 | case a@RewritedLogicalPlan(_, false) => 83 | finalPlan = a 84 | shouldBreak = true 85 | } 86 | } 87 | } 88 | finalPlan 89 | } 90 | 91 | def _rewrite(plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan = { 92 | 93 | var queryConjunctivePredicates: Seq[Expression] = Seq() 94 | var viewConjunctivePredicates: Seq[Expression] = Seq() 95 | 96 | var queryProjectList: Seq[Expression] = Seq() 97 | var viewProjectList: Seq[Expression] = Seq() 98 | 99 | // check projectList and where condition 100 | normalizePlan(plan) match { 101 | case Project(projectList, Filter(condition, _)) => 102 | queryConjunctivePredicates = splitConjunctivePredicates(condition) 103 | queryProjectList = projectList 104 | case Project(projectList, _) => 105 | queryProjectList = projectList 106 | } 107 | 108 | normalizePlan(rewriteContext.viewLogicalPlan.get().viewCreateLogicalPlan) match { 109 | case Project(projectList, Filter(condition, _)) => 110 | viewConjunctivePredicates = splitConjunctivePredicates(condition) 111 | viewProjectList = projectList 112 | case Project(projectList, _) => 113 | viewProjectList = projectList 114 | } 115 | 116 | 117 | rewriteContext.processedComponent.set(ProcessedComponent( 118 | queryConjunctivePredicates, 119 | viewConjunctivePredicates, 120 | queryProjectList, 121 | viewProjectList, 122 | Seq(), 123 | Seq(), 124 | Seq(), 125 | Seq(), 126 | Seq(), 127 | Seq() 128 | )) 129 | 130 | /** 131 | * Three match/rewrite steps: 132 | * 1. Predicate 133 | * 2. Project 134 | * 3. Table(View) 135 | */ 136 | val pipeline = buildPipeline(rewriteContext: RewriteContext, Seq( 137 | new PredicateMatcher(rewriteContext), 138 | new PredicateRewrite(rewriteContext), 139 | new ProjectMatcher(rewriteContext), 140 | new ProjectRewrite(rewriteContext), 141 | new TableNonOpMatcher(rewriteContext), 142 | new TableOrViewRewrite(rewriteContext) 143 | 144 | )) 145 | 146 | /** 147 | * When we are rewriting plan, any step fails, we should return the original plan. 148 | * So we should check the mark in RewritedLogicalPlan is final success or fail. 149 | */ 150 | LogicalPlanRewritePipeline(pipeline).rewrite(plan) 151 | 152 | } 153 | 154 | 155 | } 156 | 157 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/optimizer/rewrite/rule/WithoutJoinRule.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.optimizer.rewrite.rule 2 | 3 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component.rewrite.{AggRewrite, GroupByRewrite, PredicateRewrite, TableOrViewRewrite} 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component.{AggMatcher, GroupByMatcher, PredicateMatcher, TableNonOpMatcher} 5 | import org.apache.spark.sql.catalyst.plans.logical._ 6 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 7 | 8 | /** 9 | * 2019-07-15 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class WithoutJoinRule extends RewriteMatchRule { 12 | 13 | 14 | override def fetchView(plan: LogicalPlan, rewriteContext: RewriteContext): Seq[ViewLogicalPlan] = { 15 | require(plan.resolved, "LogicalPlan must be resolved.") 16 | if (isJoinExists(plan)) return Seq() 17 | 18 | val tables = extractTablesFromPlan(plan) 19 | if (tables.size == 0) return Seq() 20 | val table = tables.head 21 | val viewPlan = ViewCatalyst.meta.getCandidateViewsByTable(table) match { 22 | case Some(viewNames) => 23 | viewNames.filter { viewName => 24 | ViewCatalyst.meta.getViewCreateLogicalPlan(viewName) match { 25 | case Some(viewLogicalPlan) => 26 | extractTablesFromPlan(viewLogicalPlan).toSet == Set(table) 27 | case None => false 28 | } 29 | }.map { targetViewName => 30 | ViewLogicalPlan( 31 | ViewCatalyst.meta.getViewLogicalPlan(targetViewName).get, 32 | ViewCatalyst.meta.getViewCreateLogicalPlan(targetViewName).get) 33 | }.toSeq 34 | case None => Seq() 35 | 36 | 37 | } 38 | viewPlan 39 | } 40 | 41 | override def rewrite(plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan = { 42 | val targetViewPlanOption = fetchView(plan, rewriteContext) 43 | if (targetViewPlanOption.isEmpty) return plan 44 | 45 | var shouldBreak = false 46 | var finalPlan = RewritedLogicalPlan(plan, true) 47 | 48 | targetViewPlanOption.foreach { targetViewPlan => 49 | if (!shouldBreak) { 50 | rewriteContext.viewLogicalPlan.set(targetViewPlan) 51 | val res = _rewrite(plan, rewriteContext) 52 | res match { 53 | case a@RewritedLogicalPlan(_, true) => 54 | finalPlan = a 55 | case a@RewritedLogicalPlan(_, false) => 56 | finalPlan = a 57 | shouldBreak = true 58 | } 59 | } 60 | } 61 | finalPlan 62 | } 63 | 64 | def _rewrite(plan: LogicalPlan, rewriteContext: RewriteContext): LogicalPlan = { 65 | 66 | generateRewriteContext(plan, rewriteContext) 67 | /** 68 | * Three match/rewrite steps: 69 | * 1. Predicate 70 | * 2. GroupBy 71 | * 3. Project 72 | * 4. Table(View) 73 | */ 74 | val pipeline = buildPipeline(rewriteContext: RewriteContext, Seq( 75 | new PredicateMatcher(rewriteContext), 76 | new PredicateRewrite(rewriteContext), 77 | new GroupByMatcher(rewriteContext), 78 | new GroupByRewrite(rewriteContext), 79 | new AggMatcher(rewriteContext), 80 | new AggRewrite(rewriteContext), 81 | new TableNonOpMatcher(rewriteContext), 82 | new TableOrViewRewrite(rewriteContext) 83 | 84 | )) 85 | 86 | /** 87 | * When we are rewriting plan, any step fails, we should return the original plan. 88 | * So we should check the mark in RewritedLogicalPlan is final success or fail. 89 | */ 90 | LogicalPlanRewritePipeline(pipeline).rewrite(plan) 91 | 92 | 93 | } 94 | } 95 | 96 | object WithoutJoinRule { 97 | def apply: WithoutJoinRule = new WithoutJoinRule() 98 | } 99 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/sqlgenerator/BasicSQLDialect.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.sqlgenerator 2 | 3 | import java.sql.Connection 4 | 5 | import org.apache.spark.sql.catalyst.dsl.plans._ 6 | import org.apache.spark.sql.execution.LogicalRDD 7 | import org.apache.spark.sql.execution.datasources.LogicalRelation 8 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 9 | 10 | /** 11 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 12 | */ 13 | class BasicSQLDialect extends SQLDialect { 14 | override def canHandle(url: String): Boolean = url.toLowerCase().startsWith("jdbc:mysql") 15 | 16 | override def quote(name: String): String = { 17 | "`" + name.replace("`", "``") + "`" 18 | } 19 | 20 | override def explainSQL(sql: String): String = s"EXPLAIN $sql" 21 | 22 | override def relation(relation: LogicalRelation): String = { 23 | val view = relation.select(relation.output: _*) 24 | ViewCatalyst.meta.getViewNameByLogicalPlan(view) match { 25 | case Some(i) => i 26 | case None => ViewCatalyst.meta.getTableNameByLogicalPlan(relation.logicalPlan) match { 27 | case Some(i) => i 28 | case None => null 29 | } 30 | } 31 | 32 | 33 | } 34 | 35 | override def relation2(relation: LogicalRDD): String = { 36 | ViewCatalyst.meta.getTableNameByLogicalPlan(relation.logicalPlan) match { 37 | case Some(i) => i 38 | case None => null 39 | } 40 | } 41 | 42 | override def maybeQuote(name: String): String = { 43 | name 44 | } 45 | 46 | override def getIndexes(conn: Connection, url: String, tableName: String): Set[String] = { 47 | Set() 48 | } 49 | 50 | override def getTableStat(conn: Connection, url: String, tableName: String): (Option[BigInt], Option[Long]) = { 51 | (None, None) 52 | } 53 | 54 | override def enableCanonicalize: Boolean = false 55 | 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/sqlgenerator/LogicalPlanSQL.scala: -------------------------------------------------------------------------------- 1 | /*- 2 | * << 3 | * Moonbox 4 | * == 5 | * Copyright (C) 2016 - 2019 EDP 6 | * == 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | * >> 19 | */ 20 | package org.apache.spark.sql.catalyst.sqlgenerator 21 | 22 | import java.util.concurrent.atomic.AtomicLong 23 | 24 | import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Last} 25 | import org.apache.spark.sql.catalyst.expressions.{Alias, And, AttributeReference, BinaryOperator, CaseWhen, Cast, CheckOverflow, Coalesce, Contains, DayOfMonth, EndsWith, EqualTo, Exists, ExprId, Expression, GetArrayStructFields, GetStructField, Hour, If, In, InSet, IsNotNull, IsNull, Like, ListQuery, Literal, MakeDecimal, Minute, Month, NamedExpression, Not, ParseToDate, RLike, RegExpExtract, RegExpReplace, ScalarSubquery, Second, SortOrder, StartsWith, StringLocate, StringPredicate, SubqueryExpression, UnscaledValue, Year} 26 | import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions} 27 | import org.apache.spark.sql.catalyst.plans.logical._ 28 | import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} 29 | import org.apache.spark.sql.execution.LogicalRDD 30 | import org.apache.spark.sql.execution.datasources.LogicalRelation 31 | import org.apache.spark.sql.types._ 32 | import org.apache.spark.unsafe.types.UTF8String 33 | 34 | import scala.collection.mutable 35 | import scala.util.control.NonFatal 36 | 37 | /** 38 | * 2019-07-13 WilliamZhu(allwefantasy@gmail.com) 39 | */ 40 | class LogicalPlanSQL(plan: LogicalPlan, dialect: SQLDialect) { 41 | require(plan.resolved, "LogicalPlan must be resolved.") 42 | 43 | import LogicalPlanSQL._ 44 | 45 | private val nextSubqueryId = new AtomicLong(0) 46 | 47 | private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" 48 | 49 | var finalLogicalPlan: LogicalPlan = finalPlan(plan) 50 | 51 | def toSQL: String = { 52 | try { 53 | //println(finalPlan.toString()) 54 | logicalPlanToSQL(finalLogicalPlan) 55 | } catch { 56 | case NonFatal(e) => 57 | throw e 58 | } 59 | } 60 | 61 | def canonicalize(plan: LogicalPlan): LogicalPlan = 62 | Canonicalizer.execute(plan) 63 | 64 | def finalPlan(_plan: LogicalPlan): LogicalPlan = { 65 | 66 | // pull up the filter out of the join and combine all where conditions 67 | val plan = _plan transformUp { 68 | case a@Join(l@Filter(lc, lchild), r@Filter(rc, rchild), joinType, condition) => 69 | Filter(And(lc, rc), Join(lchild, rchild, joinType, condition)) 70 | case a@Join(f@Filter(lc, lchild), r, joinType, condition) => Filter(lc, Join(lchild, r, joinType, condition)) 71 | case a@Join(l, r@Filter(rc, rchild), joinType, condition) => Filter(rc, Join(l, rchild, joinType, condition)) 72 | } transformDown { 73 | case Filter(con, Filter(con1, child)) => Filter(And(con, con1), child) 74 | } 75 | 76 | val realOutputNames: Seq[String] = plan.output.map(_.name) 77 | val canonicalizedPlan = if (dialect.enableCanonicalize) canonicalize(plan) else plan 78 | val canonicalizedToReal = canonicalizedPlan.output.zip(realOutputNames) 79 | val needRename = canonicalizedToReal.filter { 80 | case (attr, name) => attr.name != name 81 | }.toMap 82 | if (needRename.isEmpty) canonicalizedPlan 83 | else { 84 | val afterRenamed = canonicalizedToReal.map { 85 | case (attr, name) if needRename.contains(attr) => 86 | Alias(attr.withQualifier(Seq()), name)() 87 | case (attr, name) => 88 | attr 89 | } 90 | Project(afterRenamed, SubqueryAlias(newSubqueryName(), canonicalizedPlan)) 91 | } 92 | } 93 | 94 | def logicalPlanToSQL(logicalPlan: LogicalPlan): String = logicalPlan match { 95 | case Distinct(p: Project) => 96 | val child = logicalPlanToSQL(p.child) 97 | val expression = p.projectList.map(expressionToSQL(_)).mkString(",") 98 | dialect.projectToSQL(p, isDistinct = true, child, expression) 99 | case p: Project => 100 | val child = logicalPlanToSQL(p.child) 101 | val expression = p.projectList.map(expressionToSQL(_)).mkString(",") 102 | dialect.projectToSQL(p, isDistinct = false, child, expression) 103 | case SubqueryAlias(alias, child) => 104 | // here we can reduce too much subquery 105 | val tableName = child match { 106 | case a@LogicalRelation(_, _, _, _) => dialect.relation(a) 107 | case a@LogicalRDD(_, _, _, _, _) => dialect.relation2(a) 108 | case _ => null 109 | } 110 | if (tableName != null) { 111 | tableName 112 | } else { 113 | val childSql = logicalPlanToSQL(child) 114 | dialect.subqueryAliasToSQL(alias.identifier, childSql) 115 | } 116 | 117 | case a: Aggregate => 118 | aggregateToSQL(a) 119 | case w: Window => 120 | windowToSQL(w) 121 | case u: Union => 122 | val childrenSQL = u.children.filter { 123 | case l: LocalRelation if l.data.isEmpty => false 124 | case _ => true 125 | }.map(logicalPlanToSQL) 126 | if (childrenSQL.length > 1) s"(${childrenSQL.mkString(" UNION ALL ")})" 127 | else childrenSQL.head 128 | case r: LogicalRelation => 129 | dialect.relation(r) 130 | case r: LogicalRDD => dialect.relation2(r) 131 | case r: OneRowRelation => "__SHOULD_NOT_BE_HERE__" 132 | case r@Filter(condition, child) => 133 | val whereOrHaving = child match { 134 | case _: Aggregate => "HAVING" 135 | case _ => "WHERE" 136 | } 137 | build(logicalPlanToSQL(child), whereOrHaving, expressionToSQL(condition)) 138 | case Limit(limitExpr, child) => 139 | dialect.limitSQL(logicalPlanToSQL(child), expressionToSQL(limitExpr)) 140 | case GlobalLimit(limitExpr, child) => 141 | dialect.limitSQL(logicalPlanToSQL(child), expressionToSQL(limitExpr)) 142 | case LocalLimit(limitExpr, child) => 143 | dialect.limitSQL(logicalPlanToSQL(child), expressionToSQL(limitExpr)) 144 | case s: Sort => 145 | build( 146 | logicalPlanToSQL(s.child), 147 | if (s.global) "ORDER BY" else "SORT BY", 148 | s.order.map(expressionToSQL).mkString(", ") 149 | ) 150 | case p: Join => 151 | val left = logicalPlanToSQL(p.left) 152 | val right = logicalPlanToSQL(p.right) 153 | val condition = p.condition.map(condition => " ON " + expressionToSQL(condition)).getOrElse("") 154 | dialect.joinSQL(p, left, right, condition) 155 | } 156 | 157 | def expressionToSQL(expression: Expression): String = expression match { 158 | /*case a@Alias(array@GetArrayStructFields(child, field, _, _, _), name) => 159 | val colName = expressionToSQL(array) 160 | s"$colName AS ${dialect.quote(colName)}"*/ 161 | case toDate@ParseToDate(_, _, child) => 162 | s"${dialect.expressionToSQL(toDate)}(${expressionToSQL(child)})" 163 | case year@Year(child) => 164 | s"${dialect.expressionToSQL(year)}(${expressionToSQL(child)})" 165 | case month@Month(child) => 166 | s"${dialect.expressionToSQL(month)}(${expressionToSQL(child)})" 167 | case dayOfMonth@DayOfMonth(child) => 168 | s"${dialect.expressionToSQL(dayOfMonth)}(${expressionToSQL(child)})" 169 | case hour@Hour(child, _) => 170 | s"${dialect.expressionToSQL(hour)}(${expressionToSQL(child)})" 171 | case miniute@Minute(child, _) => 172 | s"${dialect.expressionToSQL(miniute)}(${expressionToSQL(child)})}" 173 | case second@Second(child, _) => 174 | s"${dialect.expressionToSQL(second)}(${expressionToSQL(child)})" 175 | case a@Alias(child, name) => 176 | val qualifierPrefix = a.qualifier.map(_ + ".").headOption.getOrElse("") 177 | s"${expressionToSQL(child)} AS $qualifierPrefix${dialect.quote(name)}" 178 | case GetStructField(a: AttributeReference, _, Some(name)) => 179 | dialect.quote(s"${expressionToSQL(a)}.$name") 180 | case GetArrayStructFields(child, field, _, _, _) => 181 | dialect.quote(s"${expressionToSQL(child)}.${field.name}") 182 | case a: AttributeReference => 183 | dialect.getAttributeName(a) 184 | case c@Cast(child, dataType, _) => dataType match { 185 | case _: ArrayType | _: MapType | _: StructType => expressionToSQL(child) 186 | case _ => s"CAST(${expressionToSQL(child)} AS ${dialect.dataTypeToSQL(dataType)})" 187 | // case _: DecimalType => s"CAST(${expressionToSQL(child)} AS ${dialect.dataTypeToSQL(dataType)})" 188 | // case _ => expressionToSQL(child) 189 | } 190 | case l@StringLocate(substr, str, Literal(1, IntegerType)) => 191 | s"${dialect.expressionToSQL(l)}(${expressionToSQL(substr)}, ${expressionToSQL(str)})" 192 | case r@RLike(left, right) => 193 | s"${dialect.expressionToSQL(r)}(${expressionToSQL(left)}, ${expressionToSQL(right)})" 194 | case extract@RegExpExtract(subject, regexp, Literal(1, IntegerType)) => 195 | s"${dialect.expressionToSQL(extract)}(${expressionToSQL(subject)}, ${expressionToSQL(regexp)})" 196 | case replace@RegExpReplace(subject, regexp, rep) => 197 | s"${dialect.expressionToSQL(replace)}(${expressionToSQL(subject)}, ${expressionToSQL(regexp)}, ${expressionToSQL(rep)})" 198 | case last@Last(child, _) => 199 | s"${dialect.expressionToSQL(last)}(${expressionToSQL(child)})" 200 | case If(predicate, trueValue, falseValue) => 201 | // calcite 202 | s"CASE WHEN ${expressionToSQL(predicate)} THEN ${expressionToSQL(trueValue)} ELSE ${expressionToSQL(falseValue)} END" 203 | // mysql 204 | /* 205 | * s"if(${expressionToSQL(predicate)}, ${expressionToSQL(trueValue)}, ${expressionToSQL(falseValue)})" 206 | * */ 207 | case IsNull(child) => 208 | s"${expressionToSQL(child)} IS NULL" 209 | case IsNotNull(child) => 210 | s"${expressionToSQL(child)} IS NOT NULL" 211 | case Coalesce(children) => 212 | //calcite 213 | s"coalesce(${children.map(expressionToSQL).mkString(",")})" 214 | // mysql 215 | /*children.init.foldRight(expressionToSQL(children.last)){ 216 | case (child, sql) => s"IFNULL(${expressionToSQL(child)}, $sql)" 217 | }*/ 218 | case CaseWhen(branches, elseValue) => 219 | val cases = branches.map { case (c, v) => s" WHEN ${expressionToSQL(c)} THEN ${expressionToSQL(v)}" }.mkString 220 | val elseCase = elseValue.map(" ELSE " + expressionToSQL(_)).getOrElse("") 221 | "CASE" + cases + elseCase + " END" 222 | case UnscaledValue(child) => 223 | expressionToSQL(child) 224 | case AggregateExpression(aggFunc, _, isDistinct, _) => 225 | val distinct = if (isDistinct) "DISTINCT " else "" 226 | s"${aggFunc.prettyName}($distinct${aggFunc.children.map(expressionToSQL).mkString(", ")})" 227 | case a: AggregateFunction => 228 | s"${a.prettyName}(${a.children.map(expressionToSQL).mkString(", ")})" 229 | case literal@Literal(v, t) => 230 | dialect.literalToSQL(v, t) 231 | case MakeDecimal(child, precision, scala) => 232 | s"CAST(${expressionToSQL(child)} AS DECIMAL($precision, $scala))" 233 | case Not(EqualTo(left, right)) => 234 | s"${expressionToSQL(left)} <> ${expressionToSQL(right)}" 235 | case Not(Like(left, right)) => 236 | s"${expressionToSQL(left)} NOT LIKE ${expressionToSQL(right)}" 237 | case Not(child) => 238 | s"(NOT ${expressionToSQL(child)})" 239 | case In(value, list) => 240 | val childrenSQL = (value +: list).map(expressionToSQL) 241 | val valueSQL = childrenSQL.head 242 | val listSQL = childrenSQL.tail.mkString(", ") 243 | s"($valueSQL IN ($listSQL))" 244 | case InSet(child, hset) => 245 | val valueSQL = expressionToSQL(child) 246 | val listSQL = hset.toSeq.map(s => { 247 | val literal = s match { 248 | case v: UTF8String => Literal(v, StringType) 249 | case v => Literal(v) 250 | } 251 | expressionToSQL(Literal(literal)) 252 | }).mkString(", ") 253 | s"($valueSQL IN ($listSQL))" 254 | case b: BinaryOperator => 255 | s"${expressionToSQL(b.left)} ${b.sqlOperator} ${expressionToSQL(b.right)}" 256 | case s: StringPredicate => 257 | stringPredicate(s) 258 | case c@CheckOverflow(child, _) => 259 | expressionToSQL(child) 260 | case s@SortOrder(child, direction, nullOrdering, _) => 261 | s"${expressionToSQL(child)} ${direction.sql}" 262 | case subquery: SubqueryExpression => 263 | subqueryExpressionToSQL(subquery) 264 | case e: Expression => 265 | e.sql 266 | } 267 | 268 | private def windowToSQL(w: Window): String = { 269 | build( 270 | "SELECT", 271 | (w.child.output ++ w.windowOutputSet).map(expressionToSQL).mkString(", "), 272 | if (w.child == OneRowRelation) "" else "FROM", 273 | logicalPlanToSQL(w.child) 274 | ) 275 | } 276 | 277 | private def aggregateToSQL(a: Aggregate): String = { 278 | val groupingSQL = a.groupingExpressions.map(expressionToSQL).mkString(",") 279 | val aggregateSQL = if (a.aggregateExpressions.nonEmpty) a.aggregateExpressions.map(expressionToSQL).mkString(", ") 280 | else if (a.groupingExpressions.nonEmpty) groupingSQL 281 | else throw new Exception("both aggregateExpression and groupingExpression in Aggregate are empty.") 282 | // 283 | build( 284 | "SELECT", 285 | aggregateSQL, 286 | if (a.child == OneRowRelation) "" else "FROM", 287 | logicalPlanToSQL(a.child), 288 | if (groupingSQL.isEmpty) "" else "GROUP BY", 289 | groupingSQL 290 | ) 291 | } 292 | 293 | def stringPredicate(s: StringPredicate): String = s match { 294 | case StartsWith(left, right) => 295 | s"${expressionToSQL(left)} LIKE '${expressionToSQL(right).stripPrefix("'").stripSuffix("'")}%'" 296 | case EndsWith(left, right) => 297 | s"${expressionToSQL(left)} LIKE '%${expressionToSQL(right).stripPrefix("'").stripSuffix("'")}'" 298 | case Contains(left, right) => 299 | s"${expressionToSQL(left)} LIKE '%${expressionToSQL(right).stripPrefix("'").stripSuffix("'")}%'" 300 | } 301 | 302 | def subqueryExpressionToSQL(subquery: Expression): String = subquery match { 303 | case Exists(plan, children, _) => 304 | s"EXISTS (${logicalPlanToSQL(finalPlan(plan))})" 305 | case ScalarSubquery(plan, children, _) => 306 | s"(${logicalPlanToSQL(finalPlan(plan))})" 307 | case ListQuery(plan, children, _, _) => 308 | s"IN (${logicalPlanToSQL(finalPlan(plan))})" 309 | } 310 | 311 | 312 | object Canonicalizer extends RuleExecutor[LogicalPlan] { 313 | override protected def batches: Seq[Batch] = Seq( 314 | Batch("Prepare", FixedPoint(100), 315 | CollapseProject, 316 | CombineUnions, 317 | EliminateProject, 318 | EliminateEmptyColumn 319 | ), 320 | Batch("Recover Scoping Info", Once, 321 | AddProject, 322 | AddSubqueryAlias, 323 | NormalizeAttribute 324 | ) 325 | ) 326 | } 327 | 328 | object NormalizeAttribute extends Rule[LogicalPlan] { 329 | override def apply(plan: LogicalPlan): LogicalPlan = { 330 | plan.transformUp { 331 | case l@LogicalRelation(_, output, _, _) => 332 | l.transformExpressions { 333 | case a: AttributeReference => 334 | AttributeReference( 335 | name = a.name, 336 | dataType = a.dataType, 337 | nullable = a.nullable, 338 | metadata = a.metadata)( 339 | exprId = a.exprId, 340 | qualifier = Seq()) 341 | } 342 | case l: LeafNode => l 343 | case u => 344 | val exprIdToQualifier = u.children.flatMap(_.output).map(a => (a.exprId, a.qualifier)).toMap 345 | u.transformExpressions { 346 | case a: AttributeReference => 347 | AttributeReference( 348 | name = a.name, 349 | dataType = a.dataType, 350 | nullable = a.nullable, 351 | metadata = a.metadata)( 352 | exprId = a.exprId, 353 | qualifier = exprIdToQualifier.getOrElse(a.exprId, Seq())) 354 | } 355 | } 356 | } 357 | } 358 | 359 | object NormalizedAttribute extends Rule[LogicalPlan] { 360 | 361 | private def findLogicalRelation(plan: LogicalPlan, 362 | logicalRelations: mutable.ArrayBuffer[LogicalRelation]): Unit = { 363 | plan.foreach { 364 | case l: LogicalRelation => 365 | logicalRelations.+=(l) 366 | case Filter(condition, _) => 367 | traverseExpression(condition) 368 | case Project(projectList, _) => 369 | projectList.foreach(traverseExpression) 370 | case Aggregate(groupingExpressions, aggregateExpressions, _) => 371 | groupingExpressions.foreach(traverseExpression) 372 | aggregateExpressions.foreach(traverseExpression) 373 | case Window(windowExpressions, _, _, _) => 374 | windowExpressions.foreach(traverseExpression) 375 | case _ => 376 | 377 | } 378 | 379 | def traverseExpression(expr: Expression): Unit = { 380 | expr.foreach { 381 | case ScalarSubquery(plan, _, _) => findLogicalRelation(plan, logicalRelations) 382 | case Exists(plan, _, _) => findLogicalRelation(plan, logicalRelations) 383 | case ListQuery(plan, _, _, _) => findLogicalRelation(plan, logicalRelations) 384 | case _ => 385 | } 386 | } 387 | } 388 | 389 | override def apply(plan: LogicalPlan): LogicalPlan = { 390 | val logicalRelations = new mutable.ArrayBuffer[LogicalRelation]() 391 | findLogicalRelation(plan, logicalRelations) 392 | val colNames = new mutable.HashSet[String]() 393 | val conflict = new mutable.HashMap[LogicalRelation, Seq[AttributeReference]]() 394 | val isGenerated = new mutable.HashSet[LogicalPlan]() 395 | logicalRelations.foreach { table => 396 | val (in, notIn) = table.output.partition(attr => colNames.contains(attr.name)) 397 | if (in.nonEmpty) conflict.put(table, in) 398 | colNames.++=(notIn.map(_.name)) 399 | } 400 | val renamedExprId = new mutable.HashSet[ExprId]() 401 | 402 | val plan1 = plan.transformUp { 403 | case l@LogicalRelation(relation, output, catalogTable, _) if conflict.contains(l) => 404 | val renamedOutput = output.map { attr => 405 | if (conflict(l).contains(attr)) { 406 | renamedExprId.add(attr.exprId) 407 | Alias(attr, normalizedName(attr))(exprId = attr.exprId, qualifier = Seq()) 408 | } else AttributeReference(name = attr.name, 409 | dataType = attr.dataType, nullable = attr.nullable, 410 | metadata = attr.metadata)(exprId = attr.exprId, qualifier = Seq()) 411 | } 412 | val generateProject = Project(renamedOutput, l) 413 | isGenerated.add(generateProject) 414 | SubqueryAlias(newSubqueryName(), generateProject) 415 | } 416 | plan1.transformUp { 417 | case l: LogicalRelation => l 418 | case p@Project(_, r: LogicalRelation) => 419 | if (isGenerated.contains(p)) { 420 | p 421 | } else { 422 | p.transformExpressions { 423 | case a: AttributeReference => 424 | val name = if (renamedExprId.contains(a.exprId)) normalizedName(a) else a.name 425 | AttributeReference(name, a.dataType)(exprId = a.exprId, qualifier = Seq()) 426 | case a: Alias => 427 | val name = if (renamedExprId.contains(a.exprId)) normalizedName(a) else a.name 428 | Alias(a.child, name)(exprId = a.exprId, qualifier = Seq()) 429 | } 430 | } 431 | case o => o.transformExpressions { 432 | case a: AttributeReference => 433 | val name = if (renamedExprId.contains(a.exprId)) normalizedName(a) else a.name 434 | AttributeReference(name, a.dataType)(exprId = a.exprId, qualifier = Seq()) 435 | case a: Alias => 436 | val name = if (renamedExprId.contains(a.exprId)) normalizedName(a) else a.name 437 | Alias(a.child, name)(exprId = a.exprId, qualifier = Seq()) 438 | } 439 | } 440 | } 441 | 442 | def normalizedName(n: NamedExpression): String = { 443 | "genattr" + n.exprId.id 444 | } 445 | } 446 | 447 | object EliminateProject extends Rule[LogicalPlan] { 448 | override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { 449 | case a@Aggregate(groupingExpressions, aggregateExpressions, p: Project) => 450 | a.copy(child = p.child) 451 | case p1@Project(projectList, s@Sort(_, _, p2: Project)) => 452 | Sort(s.order, s.global, Project(p1.projectList, p2.child)) 453 | } 454 | } 455 | 456 | object EliminateEmptyColumn extends Rule[LogicalPlan] { 457 | override def apply(plan: LogicalPlan): LogicalPlan = plan transform { 458 | case a: Aggregate if a.aggregateExpressions.isEmpty => 459 | a.child 460 | case p: Project if p.projectList.isEmpty => 461 | p.child 462 | case w: Window if w.windowExpressions.isEmpty => 463 | w.child 464 | } 465 | } 466 | 467 | object AddSubqueryAlias extends Rule[LogicalPlan] { 468 | override def apply(plan: LogicalPlan): LogicalPlan = { 469 | val points = new mutable.HashSet[(LogicalPlan, LogicalPlan)]() 470 | findPoint(plan, points) 471 | if (points.nonEmpty) { 472 | plan.transformDown { 473 | case a => 474 | val newChildren = a.children.map(a -> _).map { parentChild => 475 | if (points.contains(parentChild)) { 476 | //points.remove(parentChild) 477 | SubqueryAlias(newSubqueryName(), parentChild._2) 478 | } else parentChild._2 479 | } 480 | a.withNewChildren(newChildren) 481 | } 482 | } else plan 483 | } 484 | 485 | def findPoint(node: LogicalPlan, points: mutable.HashSet[(LogicalPlan, LogicalPlan)]): Boolean = { 486 | val hasSelect: Seq[Boolean] = node.children.map(findPoint(_, points)) 487 | node match { 488 | case l: LeafNode => false 489 | case p: Project => 490 | if (hasSelect.head) { 491 | points.add(p -> p.child) 492 | true 493 | } else true 494 | case p: Aggregate => 495 | if (hasSelect.head) { 496 | points.add(p -> p.child) 497 | true 498 | } else true 499 | case p: Window => 500 | if (hasSelect.head) { 501 | points.add(p -> p.child) 502 | true 503 | } else true 504 | case p: Generate => 505 | if (hasSelect.head) { 506 | points.add(p -> p.child) 507 | true 508 | } else true 509 | case j@Join(left, right, _, _) => 510 | if (hasSelect.head) { 511 | points.add(j -> left) 512 | } 513 | if (hasSelect.last) { 514 | points.add(j -> right) 515 | } 516 | false 517 | case j@Intersect(left, right, _) => 518 | if (hasSelect.head) { 519 | points.add(j -> left) 520 | } 521 | if (hasSelect.last) { 522 | points.add(j -> right) 523 | } 524 | false 525 | case u@Union(children) => 526 | hasSelect.zip(children).foreach { 527 | case (has, p) if has => points.add(u -> p) 528 | case _ => 529 | } 530 | false 531 | case a => hasSelect.head 532 | } 533 | } 534 | } 535 | 536 | object AddProject extends Rule[LogicalPlan] { 537 | private val orderCode = Map[Class[_], Int]( 538 | classOf[LogicalRelation] -> 1, 539 | classOf[Filter] -> 2, 540 | classOf[Project] -> 3, 541 | classOf[Aggregate] -> 4, 542 | classOf[Sort] -> 5, 543 | classOf[LocalLimit] -> 6, 544 | classOf[GlobalLimit] -> 7 545 | ) 546 | 547 | override def apply(plan: LogicalPlan): LogicalPlan = { 548 | val points = new mutable.HashSet[LogicalPlan]() 549 | findPoint(plan, plan, points) 550 | if (points.nonEmpty) { 551 | plan.transformDown { 552 | case a if points.contains(a) => { 553 | points.remove(a) 554 | Project(a.output, a) 555 | } 556 | } 557 | } else plan 558 | } 559 | 560 | /** 561 | * 562 | * @param node current 563 | * @param root root 564 | * @return (scope , has select) 565 | */ 566 | private def findPoint(node: LogicalPlan, root: LogicalPlan, points: mutable.HashSet[LogicalPlan]): (LogicalPlan, Boolean) = { 567 | val children = node.children.map(child => findPoint(child, root, points)) 568 | node match { 569 | // has select in scope 570 | case p: Project => 571 | (children.head._1, true) 572 | case a: Aggregate => 573 | (children.head._1, true) 574 | case w: Window => 575 | (children.head._1, true) 576 | case g: Generate => 577 | (children.head._1, true) 578 | // scope changed 579 | case j: Join => 580 | j.children.zip(children).foreach { 581 | case (start, state) => 582 | if (!start.isInstanceOf[LeafNode]) { 583 | find(start, state, points) 584 | } 585 | } 586 | if (j == root) find(j, (j, false), points) 587 | (j, false) 588 | case u: Union => 589 | u.children.zip(children).foreach { 590 | case (start, state) => 591 | find(start, state, points) 592 | /*if (!start.isInstanceOf[LeafNode]) { 593 | find(start, state, points) 594 | }*/ 595 | } 596 | if (u == root) find(u, (u, false), points) 597 | (u, false) 598 | case i: Intersect => 599 | i.children.zip(children).foreach { 600 | case (start, state) => if (!start.isInstanceOf[LeafNode]) { 601 | find(start, state, points) 602 | } 603 | } 604 | if (i == root) find(i, (i, false), points) 605 | (i, false) 606 | case s: SubqueryAlias => 607 | s.children.zip(children).foreach { 608 | case (start, state) => find(start, state, points) 609 | } 610 | (s, false) 611 | case g: GlobalLimit => 612 | g.children.zip(children).foreach { 613 | case (start, state) => find(start, state, points) 614 | } 615 | (g, false) 616 | case a => { 617 | val res = children.headOption 618 | if (res.isDefined) { 619 | if (a == root) { 620 | a.children.zip(children).foreach { 621 | case (start, state) => find(start, state, points) 622 | } 623 | } 624 | res.get 625 | } 626 | else { 627 | if (a == root) find(a, (a, false), points) 628 | (a, false) 629 | } 630 | } 631 | // 632 | } 633 | 634 | } 635 | 636 | private def find(start: LogicalPlan, state: (LogicalPlan, Boolean), points: mutable.HashSet[LogicalPlan]): Unit = { 637 | val hasSelect = state._2 638 | if (!hasSelect) { 639 | var flag = true 640 | var current = start 641 | val until = state._1 642 | while (flag) { 643 | if (current == until) { 644 | flag = false 645 | points.add(current) 646 | } else { 647 | if (orderCode(current.getClass) < orderCode(classOf[Project])) { 648 | points.add(current) 649 | flag = false 650 | } else { 651 | current = current.children.head 652 | } 653 | } 654 | } 655 | } 656 | } 657 | } 658 | 659 | } 660 | 661 | case object LogicalPlanSQL { 662 | def build(segments: String*): String = { 663 | segments.map(_.trim).filter(_.nonEmpty).mkString(" ") 664 | } 665 | } 666 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/catalyst/sqlgenerator/SQLDialect.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst.sqlgenerator 2 | 3 | /** 4 | * 2019-07-14 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | 7 | import java.sql.Connection 8 | 9 | import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} 10 | import org.apache.spark.sql.catalyst.plans.logical.{Join, OneRowRelation, Project} 11 | import org.apache.spark.sql.catalyst.util.DateTimeUtils 12 | import org.apache.spark.sql.execution.LogicalRDD 13 | import org.apache.spark.sql.execution.datasources.LogicalRelation 14 | import org.apache.spark.sql.types._ 15 | import org.apache.spark.unsafe.types.UTF8String 16 | 17 | trait SQLDialect { 18 | 19 | import LogicalPlanSQL._ 20 | import SQLDialect._ 21 | 22 | registerDialect(this) 23 | 24 | def relation(relation: LogicalRelation): String 25 | 26 | def relation2(relation: LogicalRDD): String 27 | 28 | def enableCanonicalize: Boolean 29 | 30 | def canHandle(url: String): Boolean 31 | 32 | def explainSQL(sql: String): String 33 | 34 | def quote(name: String): String 35 | 36 | def maybeQuote(name: String): String 37 | 38 | def getIndexes(conn: Connection, url: String, tableName: String): Set[String] 39 | 40 | def getTableStat(conn: Connection, url: String, tableName: String): ((Option[BigInt], Option[Long])) 41 | 42 | def projectToSQL(p: Project, isDistinct: Boolean, child: String, expression: String): String = { 43 | build( 44 | "SELECT", 45 | if (isDistinct) "DISTINCT" else "", 46 | expression, 47 | if (p.child == OneRowRelation) "" else "FROM", 48 | child) 49 | } 50 | 51 | def subqueryAliasToSQL(alias: String, child: String) = { 52 | build(s"($child) $alias") 53 | } 54 | 55 | def dataTypeToSQL(dataType: DataType): String = { 56 | dataType.sql 57 | } 58 | 59 | def literalToSQL(value: Any, dataType: DataType): String = (value, dataType) match { 60 | case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" 61 | case (v: UTF8String, StringType) => "'" + v.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" 62 | case (v: Byte, ByteType) => v + "" 63 | case (v: Boolean, BooleanType) => s"'$v'" 64 | case (v: Short, ShortType) => v + "" 65 | case (v: Long, LongType) => v + "" 66 | case (v: Float, FloatType) => v + "" 67 | case (v: Double, DoubleType) => v + "" 68 | case (v: Decimal, t: DecimalType) => v + "" 69 | case (v: Int, DateType) => s"'${DateTimeUtils.toJavaDate(v)}'" 70 | case (v: Long, TimestampType) => s"'${DateTimeUtils.toJavaTimestamp(v)}'" 71 | case _ => if (value == null) "NULL" else value.toString 72 | } 73 | 74 | def limitSQL(sql: String, limit: String): String = { 75 | s"$sql LIMIT $limit" 76 | } 77 | 78 | def joinSQL(p: Join, left: String, right: String, condition: String): String = { 79 | build( 80 | left, 81 | p.joinType.sql, 82 | "JOIN", 83 | right, 84 | condition) 85 | } 86 | 87 | def getAttributeName(e: AttributeReference): String = { 88 | val qualifierPrefix = e.qualifier.map(_ + ".").headOption.getOrElse("") 89 | s"$qualifierPrefix${quote(e.name)}" 90 | } 91 | 92 | def expressionToSQL(e: Expression): String = { 93 | e.prettyName 94 | } 95 | 96 | } 97 | 98 | object SQLDialect { 99 | private[this] var dialects = List[SQLDialect]() 100 | 101 | 102 | def registerDialect(dialect: SQLDialect): Unit = synchronized { 103 | dialects = dialect :: dialects.filterNot(_ == dialect) 104 | } 105 | 106 | def unregisterDialect(dialect: SQLDialect): Unit = synchronized { 107 | dialects = dialects.filterNot(_ == dialect) 108 | } 109 | 110 | def get(url: String): SQLDialect = { 111 | val matchingDialects = dialects.filter(_.canHandle(url)) 112 | matchingDialects.headOption match { 113 | case None => throw new NoSuchElementException(s"no suitable MbDialect from $url") 114 | case Some(d) => d 115 | } 116 | } 117 | 118 | } 119 | -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/DataLineageExtractor.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster 2 | 3 | import org.apache.spark.sql.catalyst.catalog.HiveTableRelation 4 | import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} 5 | import org.apache.spark.sql.catalyst.optimizer.RewriteHelper 6 | import org.apache.spark.sql.catalyst.optimizer.rewrite.rule.RewritedLeafLogicalPlan 7 | import org.apache.spark.sql.catalyst.plans.logical._ 8 | import org.apache.spark.sql.execution.LogicalRDD 9 | import org.apache.spark.sql.execution.datasources.LogicalRelation 10 | import tech.mlsql.sqlbooster.analysis._ 11 | 12 | import scala.collection.mutable.ArrayBuffer 13 | 14 | /** 15 | * 2019-07-19 WilliamZhu(allwefantasy@gmail.com) 16 | */ 17 | object DataLineageExtractor extends RewriteHelper { 18 | def execute(plan: LogicalPlan): DataLineage = { 19 | 20 | 21 | // collect tables 22 | val tables = extractTableHolderFromPlan(plan) 23 | 24 | def findDependencesFromColumns(columns: Seq[Expression]) = { 25 | val tempHolder = ArrayBuffer[TableAndUsedColumns]() 26 | tables.foreach { table => 27 | val tableAndUsedColumns = TableAndUsedColumns(table.table, Seq(), Seq()) 28 | val tableOutput = table.output 29 | val tempItems = columns.flatMap { atr => 30 | tableOutput.filter(f => f.semanticEquals(atr)).toSet 31 | } 32 | val tempColumns = tempItems.map(f => (f.name, Location.locate(plan, f))).groupBy(_._1).map { ar => 33 | (ar._1, ar._2.flatMap(f => f._2).toSet.toSeq) 34 | } 35 | tempHolder += tableAndUsedColumns.copy( 36 | columns = tempColumns.map(_._1).toSeq, 37 | locates = tempColumns.map(_._2).toSeq 38 | ) 39 | } 40 | tempHolder 41 | } 42 | 43 | val arBuffer = ArrayBuffer[AttributeReference]() 44 | 45 | // collect all attributeRef without original tables 46 | 47 | val newPlan = plan transformDown { 48 | case a@SubqueryAlias(_, LogicalRelation(_, _, _, _)) => 49 | RewritedLeafLogicalPlan(null) 50 | case a@SubqueryAlias(_, LogicalRDD(_, _, _, _, _)) => 51 | RewritedLeafLogicalPlan(null) 52 | case a@SubqueryAlias(_, m@HiveTableRelation(tableMeta, _, _)) => 53 | RewritedLeafLogicalPlan(null) 54 | case m@HiveTableRelation(tableMeta, _, _) => 55 | RewritedLeafLogicalPlan(null) 56 | case m@LogicalRelation(_, output, catalogTable, _) => 57 | RewritedLeafLogicalPlan(null) 58 | } 59 | 60 | newPlan.transformAllExpressions { 61 | case a@AttributeReference(_, _, _, _) => 62 | arBuffer += a 63 | a 64 | } 65 | val dependences = findDependencesFromColumns(arBuffer) 66 | 67 | 68 | val outputMapToSourceTable = plan.output.map { case columnItem => 69 | val arBuffer = ArrayBuffer[AttributeReference]() 70 | columnItem.transformDown { 71 | case a@AttributeReference(_, _, _, _) => 72 | arBuffer += a 73 | a 74 | } 75 | OutputColumnToSourceTableAndColumn(columnItem.name, findDependencesFromColumns(arBuffer)) 76 | } 77 | 78 | DataLineage(outputMapToSourceTable, dependences) 79 | } 80 | } 81 | 82 | 83 | -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/MaterializedViewOptimizeRewrite.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster 2 | 3 | import org.apache.spark.sql.catalyst.optimizer.RewriteTableToViews 4 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 5 | import org.apache.spark.sql.catalyst.rules.RuleExecutor 6 | 7 | /** 8 | * 2019-07-16 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | object MaterializedViewOptimizeRewrite extends RuleExecutor[LogicalPlan] { 11 | val batches = 12 | Batch("Materialized view rewrite", Once, 13 | RewriteTableToViews) :: Nil 14 | } -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/SchemaRegistry.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster 2 | 3 | import com.alibaba.druid.sql.SQLUtils 4 | import com.alibaba.druid.util.JdbcConstants 5 | import org.apache.spark.sql.catalyst.SessionUtil 6 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} 7 | import org.apache.spark.sql.catalyst.sqlgenerator.{BasicSQLDialect, LogicalPlanSQL} 8 | import org.apache.spark.sql.types.{DataType, StructType} 9 | import org.apache.spark.sql.{Row, SparkSession} 10 | import tech.mlsql.schema.parser.SparkSimpleSchemaParser 11 | import tech.mlsql.sqlbooster.db.RDSchema 12 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 13 | 14 | /** 15 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 16 | */ 17 | class SchemaRegistry(_spark: SparkSession) { 18 | val spark = SessionUtil.cloneSession(_spark) 19 | 20 | def createTableFromDBSQL(createSQL: String) = { 21 | val rd = new RDSchema(JdbcConstants.MYSQL) 22 | val tableName = rd.createTable(createSQL) 23 | val schema = rd.getTableSchema(tableName) 24 | val df = spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema) 25 | df.createOrReplaceTempView(tableName) 26 | ViewCatalyst.meta.registerTableFromLogicalPlan(tableName, df.queryExecution.analyzed) 27 | } 28 | 29 | def createTableFromHiveSQL(tableName: String, createSQL: String) = { 30 | spark.sql(createSQL) 31 | val lp = spark.table(tableName).queryExecution.analyzed match { 32 | case a@SubqueryAlias(name, child) => child 33 | case a@_ => a 34 | } 35 | ViewCatalyst.meta.registerTableFromLogicalPlan(tableName, lp) 36 | } 37 | 38 | def createTableFromSimpleSchema(tableName: String, schemaText: String) = { 39 | val schema = SparkSimpleSchemaParser.parse(schemaText).asInstanceOf[StructType] 40 | val df = spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema) 41 | df.createOrReplaceTempView(tableName) 42 | ViewCatalyst.meta.registerTableFromLogicalPlan(tableName, df.queryExecution.analyzed) 43 | } 44 | 45 | def createTableFromJson(tableName: String, schemaJson: String) = { 46 | val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType] 47 | val df = spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema) 48 | df.createOrReplaceTempView(tableName) 49 | ViewCatalyst.meta.registerTableFromLogicalPlan(tableName, df.queryExecution.analyzed) 50 | } 51 | 52 | def createMV(viewName: String, viewCreate: String) = { 53 | val createViewTable1 = spark.sql(viewCreate) 54 | val df = spark.createDataFrame(spark.sparkContext.emptyRDD[Row], createViewTable1.schema) 55 | df.createOrReplaceTempView(viewName) 56 | ViewCatalyst.meta.registerTableFromLogicalPlan(viewName, df.queryExecution.analyzed) 57 | ViewCatalyst.meta.registerMaterializedViewFromLogicalPlan(viewName, df.queryExecution.analyzed, createViewTable1.queryExecution.analyzed) 58 | } 59 | 60 | def toLogicalPlan(sql: String) = { 61 | val temp = spark.sql(sql).queryExecution.analyzed 62 | temp 63 | } 64 | 65 | def genSQL(lp: LogicalPlan) = { 66 | val temp = new LogicalPlanSQL(lp, new BasicSQLDialect).toSQL 67 | temp 68 | } 69 | 70 | def genPrettySQL(lp: LogicalPlan) = { 71 | SQLUtils.format(genSQL(lp), JdbcConstants.HIVE) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/analysis/protocals.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster.analysis 2 | 3 | import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} 4 | import org.apache.spark.sql.catalyst.plans.logical._ 5 | 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | case class DataLineage(outputMapToSourceTable: Seq[OutputColumnToSourceTableAndColumn], dependences: Seq[TableAndUsedColumns]) 9 | 10 | case class TableAndUsedColumns(tableName: String, columns: Seq[String], locates: Seq[Seq[String]]) 11 | 12 | case class OutputColumnToSourceTableAndColumn(name: String, sources: Seq[TableAndUsedColumns]) 13 | 14 | 15 | object Location { 16 | val FILTER = "FILTER" 17 | val GROUP_BY = "GROUP_BY" 18 | val JOIN = "JOIN" 19 | val PROJECT = "PROJECT" 20 | 21 | def locate(plan: LogicalPlan, exp: Expression) = { 22 | val locates = ArrayBuffer[String]() 23 | plan transformDown { 24 | case a@Filter(condition, child) => 25 | if (existsIn(exp, Seq(condition))) { 26 | locates += FILTER 27 | } 28 | a 29 | case a@Join(_, _, _, condition) => 30 | if (condition.isDefined && existsIn(exp, Seq(condition.get))) { 31 | locates += JOIN 32 | } 33 | a 34 | case a@Aggregate(groupingExpressions, aggregateExpressions, _) => 35 | if (existsIn(exp, groupingExpressions)) { 36 | locates += GROUP_BY 37 | } 38 | if (existsIn(exp, aggregateExpressions)) { 39 | locates += PROJECT 40 | } 41 | a 42 | case a@Project(projectList, _) => 43 | if (existsIn(exp, projectList)) { 44 | locates += PROJECT 45 | } 46 | a 47 | } 48 | locates.toSet.toSeq 49 | } 50 | 51 | def existsIn(exp: Expression, targetExpr: Seq[Expression]) = { 52 | var exists = false 53 | targetExpr.map { item => 54 | item transformDown { 55 | case ar@AttributeReference(_, _, _, _) => 56 | if (ar.semanticEquals(exp)) { 57 | exists = true 58 | } 59 | ar 60 | } 61 | } 62 | exists 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/db/RDSchema.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster.db 2 | 3 | import java.sql.{JDBCType, SQLException} 4 | 5 | import com.alibaba.druid.sql.SQLUtils 6 | import com.alibaba.druid.sql.ast.SQLDataType 7 | import com.alibaba.druid.sql.ast.statement.{SQLColumnDefinition, SQLCreateTableStatement} 8 | import com.alibaba.druid.sql.repository.SchemaRepository 9 | import org.apache.spark.sql.jdbc.JdbcDialects 10 | import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} 11 | import org.apache.spark.sql.types._ 12 | 13 | import scala.collection.JavaConverters._ 14 | import scala.math.min 15 | 16 | /** 17 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 18 | */ 19 | class RDSchema(dbType: String) { 20 | 21 | private val repository = new SchemaRepository(dbType) 22 | 23 | def createTable(sql: String) = { 24 | repository.console(sql) 25 | val tableName = SQLUtils.parseStatements(sql, dbType).get(0).asInstanceOf[SQLCreateTableStatement]. 26 | getTableSource.getName.getSimpleName 27 | SQLUtils.normalize(tableName) 28 | } 29 | 30 | def getTableSchema(table: String) = { 31 | val dialect = JdbcDialects.get(s"jdbc:${dbType}") 32 | 33 | 34 | def extractfieldSize = (dataType: SQLDataType) => { 35 | dataType.getArguments.asScala.map { f => 36 | try { 37 | f.toString.toInt 38 | } catch { 39 | case e: Exception => 0 40 | } 41 | 42 | }.headOption 43 | } 44 | 45 | val fields = repository.findTable(table).getStatement.asInstanceOf[SQLCreateTableStatement]. 46 | getTableElementList().asScala.filter(f => f.isInstanceOf[SQLColumnDefinition]). 47 | map { 48 | _.asInstanceOf[SQLColumnDefinition] 49 | }.map { column => 50 | 51 | val columnName = column.getName.getSimpleName 52 | val dataType = RawDBTypeToJavaType.convert(dbType, column.getDataType.getName) 53 | val isNullable = !column.containsNotNullConstaint() 54 | 55 | val fieldSize = extractfieldSize(column.getDataType) match { 56 | case Some(i) => i 57 | case None => 0 58 | } 59 | val fieldScale = 0 60 | 61 | val columnType = dialect.getCatalystType(dataType, column.getDataType.getName, fieldSize, new MetadataBuilder()). 62 | getOrElse( 63 | getCatalystType(dataType, fieldSize, fieldScale, false)) 64 | 65 | StructField(columnName, columnType, isNullable) 66 | } 67 | new StructType(fields.toArray) 68 | 69 | } 70 | 71 | private def getCatalystType( 72 | sqlType: Int, 73 | precision: Int, 74 | scale: Int, 75 | signed: Boolean): DataType = { 76 | val answer = sqlType match { 77 | // scalastyle:off 78 | case java.sql.Types.ARRAY => null 79 | case java.sql.Types.BIGINT => if (signed) { 80 | LongType 81 | } else { 82 | DecimalType(20, 0) 83 | } 84 | case java.sql.Types.BINARY => BinaryType 85 | case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks 86 | case java.sql.Types.BLOB => BinaryType 87 | case java.sql.Types.BOOLEAN => BooleanType 88 | case java.sql.Types.CHAR => StringType 89 | case java.sql.Types.CLOB => StringType 90 | case java.sql.Types.DATALINK => null 91 | case java.sql.Types.DATE => DateType 92 | case java.sql.Types.DECIMAL 93 | if precision != 0 || scale != 0 => DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) 94 | case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT 95 | case java.sql.Types.DISTINCT => null 96 | case java.sql.Types.DOUBLE => DoubleType 97 | case java.sql.Types.FLOAT => FloatType 98 | case java.sql.Types.INTEGER => if (signed) { 99 | IntegerType 100 | } else { 101 | LongType 102 | } 103 | case java.sql.Types.JAVA_OBJECT => null 104 | case java.sql.Types.LONGNVARCHAR => StringType 105 | case java.sql.Types.LONGVARBINARY => BinaryType 106 | case java.sql.Types.LONGVARCHAR => StringType 107 | case java.sql.Types.NCHAR => StringType 108 | case java.sql.Types.NCLOB => StringType 109 | case java.sql.Types.NULL => null 110 | case java.sql.Types.NUMERIC 111 | if precision != 0 || scale != 0 => DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) 112 | case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT 113 | case java.sql.Types.NVARCHAR => StringType 114 | case java.sql.Types.OTHER => null 115 | case java.sql.Types.REAL => DoubleType 116 | case java.sql.Types.REF => StringType 117 | case java.sql.Types.REF_CURSOR => null 118 | case java.sql.Types.ROWID => LongType 119 | case java.sql.Types.SMALLINT => IntegerType 120 | case java.sql.Types.SQLXML => StringType 121 | case java.sql.Types.STRUCT => StringType 122 | case java.sql.Types.TIME => TimestampType 123 | case java.sql.Types.TIME_WITH_TIMEZONE 124 | => null 125 | case java.sql.Types.TIMESTAMP => TimestampType 126 | case java.sql.Types.TIMESTAMP_WITH_TIMEZONE 127 | => null 128 | case java.sql.Types.TINYINT => IntegerType 129 | case java.sql.Types.VARBINARY => BinaryType 130 | case java.sql.Types.VARCHAR => StringType 131 | case _ => 132 | throw new SQLException("Unrecognized SQL type " + sqlType) 133 | // scalastyle:on 134 | } 135 | 136 | if (answer == null) { 137 | throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName) 138 | } 139 | answer 140 | } 141 | } 142 | 143 | 144 | -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/db/RawDBTypeToJavaType.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster.db 2 | 3 | import com.alibaba.druid.util.JdbcConstants 4 | import com.mysql.cj.MysqlType 5 | 6 | /** 7 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | object RawDBTypeToJavaType { 10 | def convert(dbType: String, name: String) = { 11 | dbType match { 12 | case JdbcConstants.MYSQL => MysqlType.valueOf(name.toUpperCase()).getJdbcType 13 | case _ => throw new RuntimeException(s"dbType ${dbType} is not supported yet") 14 | } 15 | 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/tech/mlsql/sqlbooster/meta/ViewCatalyst.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.sqlbooster.meta 2 | 3 | import org.apache.spark.sql.catalyst.expressions.NamedExpression 4 | import org.apache.spark.sql.catalyst.optimizer.RewriteHelper 5 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | /** 10 | * 2019-07-11 WilliamZhu(allwefantasy@gmail.com) 11 | */ 12 | 13 | trait ViewCatalyst { 14 | def registerMaterializedViewFromLogicalPlan(name: String, tableLogicalPlan: LogicalPlan, createLP: LogicalPlan): ViewCatalyst 15 | 16 | def registerTableFromLogicalPlan(name: String, tableLogicalPlan: LogicalPlan): ViewCatalyst 17 | 18 | def getCandidateViewsByTable(tableName: String): Option[Set[String]] 19 | 20 | def getViewLogicalPlan(viewName: String): Option[LogicalPlan] 21 | 22 | def getViewCreateLogicalPlan(viewName: String): Option[LogicalPlan] 23 | 24 | def getViewNameByLogicalPlan(viewLP: LogicalPlan): Option[String] 25 | 26 | def getTableNameByLogicalPlan(viewLP: LogicalPlan): Option[String] 27 | 28 | } 29 | 30 | class SimpleViewCatalyst extends ViewCatalyst with RewriteHelper { 31 | 32 | //view name -> LogicalPlan 33 | private val viewToCreateLogicalPlan = new java.util.concurrent.ConcurrentHashMap[String, LogicalPlan]() 34 | 35 | //view name -> LogicalPlan 36 | private val viewToLogicalPlan = new java.util.concurrent.ConcurrentHashMap[String, LogicalPlan]() 37 | 38 | //table -> view 39 | private val tableToViews = new java.util.concurrent.ConcurrentHashMap[String, Set[String]]() 40 | 41 | // simple meta data for LogicalPlanSQL 42 | private val logicalPlanToTableName = new java.util.concurrent.ConcurrentHashMap[LogicalPlan, String]() 43 | 44 | 45 | override def registerMaterializedViewFromLogicalPlan(name: String, tableLogicalPlan: LogicalPlan, createLP: LogicalPlan) = { 46 | 47 | def pushToTableToViews(tableName: String) = { 48 | val items = tableToViews.asScala.getOrElse(tableName, Set[String]()) 49 | tableToViews.put(tableName, items ++ Set(name)) 50 | } 51 | 52 | extractTablesFromPlan(createLP).foreach { tableName => 53 | pushToTableToViews(tableName) 54 | } 55 | 56 | viewToCreateLogicalPlan.put(name, createLP) 57 | viewToLogicalPlan.put(name, tableLogicalPlan) 58 | this 59 | 60 | } 61 | 62 | override def registerTableFromLogicalPlan(name: String, tableLogicalPlan: LogicalPlan) = { 63 | logicalPlanToTableName.put(tableLogicalPlan, name) 64 | this 65 | 66 | } 67 | 68 | 69 | override def getCandidateViewsByTable(tableName: String) = { 70 | tableToViews.asScala.get(tableName) 71 | } 72 | 73 | override def getViewLogicalPlan(viewName: String) = { 74 | viewToLogicalPlan.asScala.get(viewName) 75 | } 76 | 77 | override def getViewCreateLogicalPlan(viewName: String) = { 78 | viewToCreateLogicalPlan.asScala.get(viewName) 79 | } 80 | 81 | override def getViewNameByLogicalPlan(viewLP: LogicalPlan) = { 82 | viewToLogicalPlan.asScala.filter(f => f._2 == viewLP).map(f => f._1).headOption 83 | } 84 | 85 | override def getTableNameByLogicalPlan(viewLP: LogicalPlan) = { 86 | logicalPlanToTableName.asScala.get(viewLP) 87 | } 88 | } 89 | 90 | case class TableHolder(db: String, table: String, output: Seq[NamedExpression], lp: LogicalPlan) 91 | 92 | object ViewCatalyst { 93 | private var _meta: ViewCatalyst = null 94 | 95 | def createViewCatalyst(clzz: Option[String] = None) = { 96 | _meta = if (clzz.isDefined) Class.forName(clzz.get).newInstance().asInstanceOf[ViewCatalyst] else new SimpleViewCatalyst() 97 | } 98 | 99 | def meta = { 100 | if (_meta == null) throw new RuntimeException("ViewCatalyst is not initialed. Please invoke createViewCatalyst before call this function.") 101 | _meta 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/sql/catalyst/BaseSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst 2 | 3 | import java.io.File 4 | 5 | import org.apache.commons.io.FileUtils 6 | import org.apache.spark.sql.SparkSession 7 | import org.apache.spark.sql.catalyst.expressions.PredicateHelper 8 | import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation 9 | import org.apache.spark.sql.internal.SQLConf 10 | import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION 11 | import org.apache.spark.{DebugFilesystem, SparkConf} 12 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 13 | import tech.mlsql.sqlbooster.SchemaRegistry 14 | 15 | /** 16 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 17 | */ 18 | class BaseSuite extends FunSuite 19 | with BeforeAndAfterAll with PredicateHelper { 20 | var spark: SparkSession = _ 21 | var schemaReg: SchemaRegistry = null 22 | 23 | def init(): Unit = { 24 | FileUtils.deleteDirectory(new File("./metastore_db")) 25 | FileUtils.deleteDirectory(new File("/tmp/spark-warehouse")) 26 | spark = SparkSession.builder(). 27 | config(sparkConf). 28 | master("local[*]"). 29 | appName("base-test"). 30 | enableHiveSupport().getOrCreate() 31 | schemaReg = new SchemaRegistry(spark) 32 | } 33 | 34 | def prepareDefaultTables = { 35 | schemaReg.createTableFromDBSQL( 36 | """ 37 | |CREATE TABLE depts( 38 | | deptno INT NOT NULL, 39 | | deptname VARCHAR(20), 40 | | PRIMARY KEY (deptno) 41 | |); 42 | """.stripMargin) 43 | 44 | schemaReg.createTableFromDBSQL( 45 | """ 46 | |CREATE TABLE locations( 47 | | locationid INT NOT NULL, 48 | | state CHAR(2), 49 | | PRIMARY KEY (locationid) 50 | |); 51 | """.stripMargin) 52 | 53 | schemaReg.createTableFromDBSQL( 54 | """ 55 | |CREATE TABLE emps( 56 | | empid INT NOT NULL, 57 | | deptno INT NOT NULL, 58 | | locationid INT NOT NULL, 59 | | empname VARCHAR(20) NOT NULL, 60 | | salary DECIMAL (18, 2), 61 | | PRIMARY KEY (empid), 62 | | FOREIGN KEY (deptno) REFERENCES depts(deptno), 63 | | FOREIGN KEY (locationid) REFERENCES locations(locationid) 64 | |); 65 | """.stripMargin) 66 | 67 | schemaReg.createTableFromHiveSQL("src", 68 | """ 69 | |CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive 70 | """.stripMargin) 71 | } 72 | 73 | override def afterAll(): Unit = { 74 | //SparkSession.cleanupAnyExistingSession() 75 | spark.close() 76 | } 77 | 78 | def sparkConf = { 79 | new SparkConf() 80 | .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) 81 | .set("spark.unsafe.exceptionOnMemoryLeak", "true") 82 | .set(SQLConf.CODEGEN_FALLBACK.key, "false") 83 | // Disable ConvertToLocalRelation for better test coverage. Test cases built on 84 | // LocalRelation will exercise the optimization rules better by disabling it as 85 | // this rule may potentially block testing of other optimization rules such as 86 | // ConstantPropagation etc. 87 | .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) 88 | .set(CATALOG_IMPLEMENTATION.key, "hive") 89 | .set("spark.sql.warehouse.dir", "/tmp/spark-warehouse") 90 | } 91 | 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/sql/catalyst/DataLineageSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst 2 | 3 | import net.liftweb.{json => SJSon} 4 | import tech.mlsql.sqlbooster.DataLineageExtractor 5 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 6 | 7 | /** 8 | * 2019-07-19 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class DataLineageSuite extends BaseSuite { 11 | ViewCatalyst.createViewCatalyst() 12 | 13 | test("data lineage test") { 14 | val result = DataLineageExtractor.execute(schemaReg.toLogicalPlan( 15 | """ 16 | |select * from (SELECT e.empid 17 | |FROM emps e 18 | |JOIN depts d 19 | |ON e.deptno = d.deptno 20 | |where e.empid=1) as a where a.empid=2 21 | """.stripMargin)) 22 | println(JSONTool.pretty(result)) 23 | } 24 | 25 | override def beforeAll() = { 26 | super.init() 27 | super.prepareDefaultTables 28 | } 29 | 30 | } 31 | 32 | object JSONTool { 33 | 34 | def pretty(item: AnyRef) = { 35 | implicit val formats = SJSon.Serialization.formats(SJSon.NoTypeHints) 36 | SJSon.Serialization.writePretty(item) 37 | } 38 | 39 | def parseJson[T](str: String)(implicit m: Manifest[T]) = { 40 | implicit val formats = SJSon.DefaultFormats 41 | SJSon.parse(str).extract[T] 42 | } 43 | 44 | def toJsonStr(item: AnyRef) = { 45 | implicit val formats = SJSon.Serialization.formats(SJSon.NoTypeHints) 46 | SJSon.Serialization.write(item) 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/sql/catalyst/NewMVSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst 2 | 3 | import tech.mlsql.sqlbooster.MaterializedViewOptimizeRewrite 4 | import tech.mlsql.sqlbooster.meta.ViewCatalyst 5 | 6 | /** 7 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | class NewMVSuite extends BaseSuite { 10 | 11 | ViewCatalyst.createViewCatalyst() 12 | 13 | override def beforeAll() = { 14 | super.init() 15 | super.prepareDefaultTables 16 | // schemaReg.createTableFromSimpleSchema("table1","""st(field(a,string),field(b,string))""") 17 | // println(SparkSimpleSchemaParser.parse("""st(field(a,string),field(b,string))""").asInstanceOf[StructType].json) 18 | // schemaReg.createTableFromJson("table2", 19 | // """ 20 | // |{"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}}]} 21 | // """.stripMargin) 22 | } 23 | 24 | test("test join") { 25 | 26 | schemaReg.createMV("emps_mv", 27 | """ 28 | |SELECT empid 29 | |FROM emps 30 | |JOIN depts ON depts.deptno = emps.deptno 31 | """.stripMargin) 32 | 33 | val rewrite3 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 34 | """ 35 | |select * from (SELECT e.empid 36 | |FROM emps e 37 | |JOIN depts d 38 | |ON e.deptno = d.deptno 39 | |where e.empid=1) as a where a.empid=2 40 | """.stripMargin)) 41 | assert(schemaReg.genSQL(rewrite3) 42 | == "SELECT a.`empid` FROM (SELECT `empid` FROM emps_mv WHERE `empid` = CAST(1 AS BIGINT)) a WHERE a.`empid` = CAST(2 AS BIGINT)") 43 | 44 | 45 | val rewrite = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 46 | """ 47 | |SELECT empid 48 | |FROM emps 49 | |JOIN depts 50 | |ON depts.deptno = emps.deptno 51 | |where emps.empid=1 52 | """.stripMargin)) 53 | 54 | assert(schemaReg.genSQL(rewrite) 55 | == "SELECT `empid` FROM emps_mv WHERE `empid` = CAST(1 AS BIGINT)") 56 | 57 | 58 | val rewrite2 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 59 | """ 60 | |SELECT e.empid 61 | |FROM emps e 62 | |JOIN depts d 63 | |ON e.deptno = d.deptno 64 | |where e.empid=1 65 | """.stripMargin)) 66 | 67 | assert(schemaReg.genSQL(rewrite2) 68 | == "SELECT `empid` FROM emps_mv WHERE `empid` = CAST(1 AS BIGINT)") 69 | 70 | } 71 | test("test group ") { 72 | schemaReg.createMV("emps_mv", 73 | """ 74 | |SELECT empid, deptno 75 | |FROM emps 76 | |WHERE deptno > 5 77 | |GROUP BY empid, deptno 78 | """.stripMargin) 79 | 80 | val rewrite3 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 81 | """ 82 | |SELECT deptno 83 | |FROM emps 84 | |WHERE deptno > 10 85 | |GROUP BY deptno 86 | """.stripMargin)) 87 | 88 | assert(schemaReg.genSQL(rewrite3) 89 | == "SELECT `deptno` FROM emps_mv WHERE `deptno` > CAST(10 AS BIGINT) GROUP BY `deptno`") 90 | 91 | val rewrite4 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 92 | """ 93 | |SELECT deptno 94 | |FROM emps 95 | |WHERE deptno > 4 96 | |GROUP BY deptno 97 | """.stripMargin)) 98 | 99 | assert(schemaReg.genSQL(rewrite4) == "SELECT emps.`deptno` FROM emps WHERE emps.`deptno` > CAST(4 AS BIGINT) GROUP BY emps.`deptno`") 100 | 101 | val rewrite5 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 102 | """ 103 | |SELECT deptno 104 | |FROM emps 105 | |WHERE deptno > 5 106 | |GROUP BY deptno 107 | """.stripMargin)) 108 | 109 | assert(schemaReg.genSQL(rewrite5) == "SELECT `deptno` FROM emps_mv WHERE `deptno` > CAST(5 AS BIGINT) GROUP BY `deptno`") 110 | 111 | val rewrite6 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 112 | """ 113 | |SELECT deptno 114 | |FROM emps 115 | |WHERE deptno > 5 116 | |AND deptno <10 117 | |GROUP BY deptno 118 | """.stripMargin)) 119 | 120 | assert(schemaReg.genSQL(rewrite6) == "SELECT `deptno` FROM emps_mv " + 121 | "WHERE `deptno` > CAST(5 AS BIGINT) AND `deptno` < CAST(10 AS BIGINT) GROUP BY `deptno`") 122 | 123 | } 124 | 125 | test("test agg") { 126 | schemaReg.createMV("emps_mv", 127 | """ 128 | |SELECT empid, deptno, COUNT(*) AS c, SUM(salary) AS s 129 | |FROM emps 130 | |GROUP BY empid, deptno 131 | """.stripMargin) 132 | 133 | val rewrite1 = MaterializedViewOptimizeRewrite.execute(schemaReg.toLogicalPlan( 134 | """ 135 | |SELECT deptno, COUNT(*) AS c, SUM(salary) AS m 136 | |FROM emps 137 | |GROUP BY deptno 138 | """.stripMargin)) 139 | 140 | assert(schemaReg.genSQL(rewrite1) == 141 | "SELECT `deptno`, sum(`c`) AS `c`, sum(`s`) AS `m` FROM emps_mv GROUP BY `deptno`") 142 | } 143 | 144 | } 145 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/sql/catalyst/RangeSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.catalyst 2 | 3 | import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal} 4 | import org.apache.spark.sql.catalyst.optimizer.rewrite.component.RangeFilter 5 | import org.apache.spark.sql.types.LongType 6 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 7 | 8 | /** 9 | * 2019-07-18 WilliamZhu(allwefantasy@gmail.com) 10 | */ 11 | class RangeSuite extends FunSuite 12 | with BeforeAndAfterAll { 13 | test("range compare") { 14 | val ar = AttributeReference("a", LongType)() 15 | val items = Seq(GreaterThan(ar, Literal(10, LongType)), 16 | LessThan(ar, Literal(20, LongType))) 17 | val rangeCondition = RangeFilter.combineAndMergeRangeCondition(items.filter(RangeFilter.rangeCon).map(RangeFilter.convertRangeCon)).head 18 | assert(rangeCondition.lowerBound.get == Literal(10, LongType)) 19 | assert(rangeCondition.upperBound.get == Literal(20, LongType)) 20 | } 21 | 22 | test("range choose the bigger in lowerBound") { 23 | val ar = AttributeReference("a", LongType)() 24 | val items = Seq(GreaterThan(ar, Literal(10, LongType)), 25 | GreaterThan(ar, Literal(20, LongType))) 26 | val rangeCondition = RangeFilter.combineAndMergeRangeCondition(items.filter(RangeFilter.rangeCon).map(RangeFilter.convertRangeCon)).head 27 | assert(rangeCondition.lowerBound.get == Literal(20, LongType)) 28 | assert(rangeCondition.upperBound == None) 29 | } 30 | 31 | test("range single") { 32 | val ar = AttributeReference("a", LongType)() 33 | val items = Seq(GreaterThan(ar, Literal(10, LongType))) 34 | val rangeCondition = RangeFilter.combineAndMergeRangeCondition(items.filter(RangeFilter.rangeCon).map(RangeFilter.convertRangeCon)).head 35 | assert(rangeCondition.lowerBound.get == Literal(10, LongType)) 36 | assert(rangeCondition.upperBound == None) 37 | } 38 | 39 | test("range choose the bigger in lowerBound with include") { 40 | val ar = AttributeReference("a", LongType)() 41 | val items = Seq(GreaterThan(ar, Literal(10, LongType)), 42 | GreaterThanOrEqual(ar, Literal(20, LongType))) 43 | val rangeCondition = RangeFilter.combineAndMergeRangeCondition(items.filter(RangeFilter.rangeCon).map(RangeFilter.convertRangeCon)).head 44 | assert(rangeCondition.lowerBound.get == Literal(20, LongType)) 45 | assert(rangeCondition.includeLowerBound == true) 46 | assert(rangeCondition.upperBound == None) 47 | } 48 | 49 | test("range choose the bigger in lowerBound with include and same value") { 50 | val ar = AttributeReference("a", LongType)() 51 | val items = Seq(GreaterThan(ar, Literal(10, LongType)), 52 | GreaterThanOrEqual(ar, Literal(10, LongType))) 53 | val rangeCondition = RangeFilter.combineAndMergeRangeCondition(items.filter(RangeFilter.rangeCon).map(RangeFilter.convertRangeCon)).head 54 | assert(rangeCondition.lowerBound.get == Literal(10, LongType)) 55 | assert(rangeCondition.includeLowerBound == false) 56 | assert(rangeCondition.upperBound == None) 57 | } 58 | test("range compare with includeUpperBound and not includeLowerBound") { 59 | val ar = AttributeReference("a", LongType)() 60 | val items = Seq(GreaterThan(ar, Literal(10, LongType)), 61 | LessThanOrEqual(ar, Literal(20, LongType)), LessThanOrEqual(ar, Literal(11, LongType))) 62 | val rangeCondition = RangeFilter.combineAndMergeRangeCondition(items.filter(RangeFilter.rangeCon).map(RangeFilter.convertRangeCon)).head 63 | assert(rangeCondition.lowerBound.get == Literal(10, LongType)) 64 | assert(rangeCondition.upperBound.get == Literal(11, LongType)) 65 | assert(rangeCondition.includeLowerBound == false) 66 | assert(rangeCondition.includeUpperBound == true) 67 | } 68 | } 69 | --------------------------------------------------------------------------------