diff --git a/.github/workflows/bot.yml b/.github/workflows/bot.yml index 17688bc6d5217..8bef6fb23edce 100644 --- a/.github/workflows/bot.yml +++ b/.github/workflows/bot.yml @@ -614,6 +614,9 @@ jobs: - scalaProfile: "scala-2.13" sparkProfile: "spark4.1" sparkModules: "hudi-spark-datasource/hudi-spark4.1.x" + - scalaProfile: "scala-2.13" + sparkProfile: "spark4.2" + sparkModules: "hudi-spark-datasource/hudi-spark4.2.x" steps: - uses: actions/checkout@v5 @@ -664,6 +667,9 @@ jobs: - scalaProfile: "scala-2.13" sparkProfile: "spark4.1" sparkModules: "hudi-spark-datasource/hudi-spark4.1.x" + - scalaProfile: "scala-2.13" + sparkProfile: "spark4.2" + sparkModules: "hudi-spark-datasource/hudi-spark4.2.x" steps: - uses: actions/checkout@v5 @@ -727,6 +733,9 @@ jobs: - scalaProfile: "scala-2.13" sparkProfile: "spark4.1" sparkModules: "hudi-spark-datasource/hudi-spark4.1.x" + - scalaProfile: "scala-2.13" + sparkProfile: "spark4.2" + sparkModules: "hudi-spark-datasource/hudi-spark4.2.x" steps: - uses: actions/checkout@v5 @@ -784,6 +793,9 @@ jobs: - scalaProfile: "scala-2.13" sparkProfile: "spark4.1" sparkModules: "hudi-spark-datasource/hudi-spark4.1.x" + - scalaProfile: "scala-2.13" + sparkProfile: "spark4.2" + sparkModules: "hudi-spark-datasource/hudi-spark4.2.x" steps: - uses: actions/checkout@v5 @@ -841,6 +853,9 @@ jobs: - scalaProfile: "scala-2.13" sparkProfile: "spark4.1" sparkModules: "hudi-spark-datasource/hudi-spark4.1.x" + - scalaProfile: "scala-2.13" + sparkProfile: "spark4.2" + sparkModules: "hudi-spark-datasource/hudi-spark4.2.x" steps: - uses: actions/checkout@v5 @@ -1000,6 +1015,10 @@ jobs: flinkProfile: 'flink1.20' sparkProfile: 'spark4.1' sparkRuntime: 'spark4.1.1' + - scalaProfile: 'scala-2.13' + flinkProfile: 'flink1.20' + sparkProfile: 'spark4.2' + sparkRuntime: 'spark4.2.0-preview4' steps: - uses: actions/checkout@v5 @@ -1112,6 +1131,12 @@ jobs: flinkParquetVersion: '1.13.1' sparkProfile: 'spark4.1' sparkRuntime: 'spark4.1.1' + - scalaProfile: 'scala-2.13' + flinkProfile: 'flink1.20' + flinkAvroVersion: '1.11.4' + flinkParquetVersion: '1.13.1' + sparkProfile: 'spark4.2' + sparkRuntime: 'spark4.2.0-preview4' steps: - uses: actions/checkout@v5 - name: Set up JDK 17 diff --git a/README.md b/README.md index 38e12fc7a375b..2903582174561 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ Refer to the table below for building with different Spark and Scala versions. | `-Dspark3.5 -Dscala-2.13` | hudi-spark3.5-bundle_2.13 | For Spark 3.5.x and Scala 2.13 | | `-Dspark4.0` | hudi-spark4.0-bundle_2.13 | For Spark 4.0 and Scala 2.13 (Needs java 17) | | `-Dspark4.1` | hudi-spark4.1-bundle_2.13 | For Spark 4.1 and Scala 2.13 (Needs java 17) | +| `-Dspark4.2` | hudi-spark4.2-bundle_2.13 | For Spark 4.2 and Scala 2.13 (Needs java 17) | | `-Dspark3` | hudi-spark3-bundle_2.12 (legacy bundle name) | For Spark 3.5.x and Scala 2.12 | Please note that only Spark-related bundles, i.e., `hudi-spark-bundle`, `hudi-utilities-bundle`, diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala index 8784dc807e323..83a770478ef0a 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala @@ -60,12 +60,14 @@ private[hudi] trait SparkVersionsSupport { def isSpark3_5: Boolean = getSparkVersion.startsWith("3.5") def isSpark4_0: Boolean = getSparkVersion.startsWith("4.0") def isSpark4_1: Boolean = getSparkVersion.startsWith("4.1") + def isSpark4_2: Boolean = getSparkVersion.startsWith("4.2") def gteqSpark3_3_2: Boolean = getSparkVersion >= "3.3.2" def gteqSpark3_4: Boolean = getSparkVersion >= "3.4" def gteqSpark3_5: Boolean = getSparkVersion >= "3.5" def gteqSpark4_0: Boolean = getSparkVersion >= "4.0" def gteqSpark4_1: Boolean = getSparkVersion >= "4.1" + def gteqSpark4_2: Boolean = getSparkVersion >= "4.2" } object HoodieSparkUtils extends SparkAdapterSupport with SparkVersionsSupport with Logging { diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala index e546310b89a22..2cab797a3bcb1 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala @@ -33,7 +33,9 @@ trait SparkAdapterSupport { object SparkAdapterSupport { lazy val sparkAdapter: SparkAdapter = { - val adapterClass = if (HoodieSparkUtils.isSpark4_1) { + val adapterClass = if (HoodieSparkUtils.isSpark4_2) { + "org.apache.spark.sql.adapter.Spark4_2Adapter" + } else if (HoodieSparkUtils.isSpark4_1) { "org.apache.spark.sql.adapter.Spark4_1Adapter" } else if (HoodieSparkUtils.isSpark4_0) { "org.apache.spark.sql.adapter.Spark4_0Adapter" diff --git a/hudi-common/pom.xml b/hudi-common/pom.xml index 17288103e946e..b20ceb2990a00 100644 --- a/hudi-common/pom.xml +++ b/hudi-common/pom.xml @@ -258,9 +258,9 @@ - org.lz4 + ${lz4.groupId} lz4-java - 1.8.0 + ${lz4.version} diff --git a/hudi-spark-datasource/README.md b/hudi-spark-datasource/README.md index 2295c8919c731..1a96eafa0a018 100644 --- a/hudi-spark-datasource/README.md +++ b/hudi-spark-datasource/README.md @@ -36,6 +36,7 @@ The modules are organized in a layered architecture to maximize code reuse acros | `hudi-spark3.5.x` | Spark 3.5.x-specific adapter implementation (default). | | `hudi-spark4.0.x` | Spark 4.0.x-specific adapter implementation. | | `hudi-spark4.1.x` | Spark 4.1.x-specific adapter implementation. | +| `hudi-spark4.2.x` | Spark 4.2.x-specific adapter implementation. | | `hudi-spark` | Main Spark datasource module containing Spark Session extensions, stored procedures, SQL parser, and logical plans. | ## Spark Version Support @@ -47,6 +48,7 @@ The modules are organized in a layered architecture to maximize code reuse acros | 3.5.x (default) | `hudi-spark3.5.x` | 2.12, 2.13 | 11+ | `-Dspark3.5` | | 4.0.x | `hudi-spark4.0.x` | 2.13 | 17+ | `-Dspark4.0` | | 4.1.x | `hudi-spark4.1.x` | 2.13 | 17+ | `-Dspark4.1` | +| 4.2.x | `hudi-spark4.2.x` | 2.13 | 17+ | `-Dspark4.2` | ## Key Features diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkFilterHelper.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkFilterHelper.scala index ba0f4dd982c2d..46cc6308c6ac1 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkFilterHelper.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkFilterHelper.scala @@ -110,7 +110,7 @@ object SparkFilterHelper { Types.FloatType.get() case DoubleType => Types.DoubleType.get() - case StringType | CharType(_) | VarcharType(_) => + case StringType | (_: CharType) | (_: VarcharType) => Types.StringType.get() case DateType => Types.DateType.get() diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index d01627ae1fd56..dcf60dca48c30 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -57,7 +57,9 @@ object HoodieAnalysis extends SparkAdapterSupport { val adaptIngestionTargetLogicalRelations: RuleBuilder = session => AdaptIngestionTargetLogicalRelations(session) rules += adaptIngestionTargetLogicalRelations - val dataSourceV2ToV1FallbackClass = if (HoodieSparkUtils.isSpark4_1) { + val dataSourceV2ToV1FallbackClass = if (HoodieSparkUtils.isSpark4_2) { + "org.apache.spark.sql.hudi.analysis.HoodieSpark42DataSourceV2ToV1Fallback" + } else if (HoodieSparkUtils.isSpark4_1) { "org.apache.spark.sql.hudi.analysis.HoodieSpark41DataSourceV2ToV1Fallback" } else if (HoodieSparkUtils.isSpark4_0) { "org.apache.spark.sql.hudi.analysis.HoodieSpark40DataSourceV2ToV1Fallback" @@ -83,7 +85,10 @@ object HoodieAnalysis extends SparkAdapterSupport { // leading to all relations resolving as V2 instead of current expectation of them being resolved as V1) rules ++= Seq(dataSourceV2ToV1Fallback, resolveReferences) - if (HoodieSparkUtils.isSpark4_1) { + if (HoodieSparkUtils.isSpark4_2) { + rules += (_ => instantiateKlass( + "org.apache.spark.sql.hudi.analysis.HoodieSpark42ResolveColumnsForInsertInto")) + } else if (HoodieSparkUtils.isSpark4_1) { rules += (_ => instantiateKlass( "org.apache.spark.sql.hudi.analysis.HoodieSpark41ResolveColumnsForInsertInto")) } else if (HoodieSparkUtils.isSpark4_0) { @@ -95,7 +100,9 @@ object HoodieAnalysis extends SparkAdapterSupport { } val resolveAlterTableCommandsClass = - if (HoodieSparkUtils.isSpark4_1) { + if (HoodieSparkUtils.isSpark4_2) { + "org.apache.spark.sql.hudi.Spark42ResolveHudiAlterTableCommand" + } else if (HoodieSparkUtils.isSpark4_1) { "org.apache.spark.sql.hudi.Spark41ResolveHudiAlterTableCommand" } else if (HoodieSparkUtils.isSpark4_0) { "org.apache.spark.sql.hudi.Spark40ResolveHudiAlterTableCommand" @@ -144,7 +151,9 @@ object HoodieAnalysis extends SparkAdapterSupport { ) val nestedSchemaPruningClass = - if (HoodieSparkUtils.isSpark4_1) { + if (HoodieSparkUtils.isSpark4_2) { + "org.apache.spark.sql.execution.datasources.Spark42NestedSchemaPruning" + } else if (HoodieSparkUtils.isSpark4_1) { "org.apache.spark.sql.execution.datasources.Spark41NestedSchemaPruning" } else if (HoodieSparkUtils.isSpark4_0) { "org.apache.spark.sql.execution.datasources.Spark40NestedSchemaPruning" diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable.scala index b82b95af8662c..84902ce995b8a 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable.scala @@ -1788,7 +1788,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo """.stripMargin) spark.sql(s"insert into $sourceTable values(1, 'a1', 10, 1000)") val nonExistentTable = "hudi_test_table" - val exception = intercept[org.apache.spark.sql.AnalysisException] { + val exception = intercept[Exception] { spark.sql( s""" | MERGE INTO $nonExistentTable AS target @@ -1800,8 +1800,9 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo | WHEN NOT MATCHED THEN INSERT * """.stripMargin) } - assert(exception.getMessage.contains("TABLE_OR_VIEW_NOT_FOUND") || - exception.getMessage.contains("Table or view not found"), + val fullMsg = exception.getMessage + Option(exception.getCause).map(_.getMessage).getOrElse("") + assert(fullMsg.contains("TABLE_OR_VIEW_NOT_FOUND") || + fullMsg.contains("Table or view not found"), s"Expected TABLE_OR_VIEW_NOT_FOUND error but got: ${exception.getMessage}") } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable2.scala index 91bf94ff4363f..31974f4ac32f1 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/dml/others/TestMergeIntoTable2.scala @@ -135,7 +135,9 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { Seq(0) ) - val errorMsg = if (HoodieSparkUtils.gteqSpark4_0) + val errorMsg = if (HoodieSparkUtils.gteqSpark4_2) + "[INTERNAL_ERROR] Executed command failed. You hit a bug in Spark or the Spark plugins you use. Please, report this bug to the corresponding communities or vendors, and provide the full stack trace. SQLSTATE: XX000" + else if (HoodieSparkUtils.gteqSpark4_0) "[INTERNAL_ERROR] Eagerly executed command failed. You hit a bug in Spark or the Spark plugins you use. Please, report this bug to the corresponding communities or vendors, and provide the full stack trace. SQLSTATE: XX000" else "assertion failed: Target table's field(price) cannot be the right-value of the update clause for MOR table." diff --git a/hudi-spark-datasource/hudi-spark4.2.x/pom.xml b/hudi-spark-datasource/hudi-spark4.2.x/pom.xml new file mode 100644 index 0000000000000..c36970db41d76 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/pom.xml @@ -0,0 +1,325 @@ + + + + + hudi-spark-datasource + org.apache.hudi + 1.3.0-SNAPSHOT + + 4.0.0 + + hudi-spark4.2.x_${scala.binary.version} + 1.3.0-SNAPSHOT + + hudi-spark4.2.x_${scala.binary.version} + jar + + + ${project.parent.parent.basedir} + + + + + + src/main/resources + + + + + + net.alchim31.maven + scala-maven-plugin + ${scala-maven-plugin.version} + + + -nobootcp + + false + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + prepare-package + + copy-dependencies + + + ${project.build.directory}/lib + true + true + true + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + test-compile + + + + false + + + + org.apache.maven.plugins + maven-surefire-plugin + + ${skip.hudi-spark4.unit.tests} + + + + org.apache.rat + apache-rat-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + org.jacoco + jacoco-maven-plugin + + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + true + ../hudi-spark4.2.x/src/main/antlr4 + ../hudi-spark4.2.x/src/main/antlr4/imports + + + + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark42.version} + provided + true + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark42.version} + provided + true + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark42.version} + provided + true + + + * + * + + + + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + + org.apache.hudi + hudi-spark-client + ${project.version} + + + + org.apache.hudi + hudi-spark-common_${scala.binary.version} + ${project.version} + + + + org.apache.parquet + parquet-hadoop-bundle + + + + + + org.json4s + json4s-jackson_${scala.binary.version} + 4.0.7 + + + com.fasterxml.jackson.core + * + + + + + + + org.apache.hudi + hudi-spark4-common + ${project.version} + + + + + org.apache.hudi + hudi-tests-common + ${project.version} + test + + + + org.apache.hudi + hudi-client-common + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-spark-client + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-common + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-hadoop-common + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-spark-common_${scala.binary.version} + ${project.version} + tests + test-jar + test + + + + org.apache.parquet + parquet-hadoop-bundle + + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark42.version} + tests + test + + + + org.apache.parquet + parquet-avro + test + + + + org.apache.hadoop + hadoop-hdfs + tests + test + + + + org.mortbay.jetty + * + + + javax.servlet.jsp + * + + + javax.servlet + * + + + + + + + diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark4.2.x/src/main/antlr4/imports/SqlBase.g4 new file mode 100644 index 0000000000000..8f15fab16b0a7 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/antlr4/imports/SqlBase.g4 @@ -0,0 +1,1978 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +// The parser file is forked from spark 3.2.0's SqlBase.g4. +grammar SqlBase; + +@parser::members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enabled = false; + + /** + * When false, a literal with an exponent would be converted into + * double type rather than decimal type. + */ + public boolean legacy_exponent_literal_as_decimal_enabled = false; + + /** + * When true, the behavior of keywords follows ANSI SQL standard. + */ + public boolean SQL_standard_keyword_behavior = false; + + /** + * When false, parameter markers (? and :param) are only allowed in constant contexts. + * When true, parameter markers are allowed everywhere a literal is supported. + */ + public boolean parameter_substitution_enabled = true; +} + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement ';'* EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleMultipartIdentifier + : multipartIdentifier EOF + ; + +singleFunctionIdentifier + : functionIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +singleTableSchema + : colTypeList EOF + ; + +statement + : query #statementDefault + | ctes? dmlStatementNoWith #dmlStatement + | USE NAMESPACE? multipartIdentifier #use + | CREATE namespace (IF NOT EXISTS)? multipartIdentifier + (commentSpec | + locationSpec | + (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace + | ALTER namespace multipartIdentifier + SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties + | ALTER namespace multipartIdentifier + SET locationSpec #setNamespaceLocation + | DROP namespace (IF EXISTS)? multipartIdentifier + (RESTRICT | CASCADE)? #dropNamespace + | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showNamespaces + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier + (tableProvider | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike + | replaceTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ANALYZE TABLES ((FROM | IN) multipartIdentifier)? COMPUTE STATISTICS + (identifier)? #analyzeTables + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns + | ALTER TABLE table=multipartIdentifier + RENAME COLUMN + from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) + '(' columns=multipartIdentifierList ')' #dropTableColumns + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns + | ALTER (TABLE | VIEW) from=multipartIdentifier + RENAME TO to=multipartIdentifier #renameTable + | ALTER (TABLE | VIEW) multipartIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER (TABLE | VIEW) multipartIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE table=multipartIdentifier + (ALTER | CHANGE) COLUMN? column=multipartIdentifier + alterColumnAction? #alterTableAlterColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + CHANGE COLUMN? + colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + REPLACE COLUMNS + '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER (TABLE | VIEW) multipartIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE multipartIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER (TABLE | VIEW) multipartIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE multipartIdentifier + (partitionSpec)? SET locationSpec #setTableLocation + | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable + | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? multipartIdentifier + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES tablePropertyList))* + AS query #createView + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW + tableIdentifier ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTempViewUsing + | ALTER VIEW multipartIdentifier AS? query #alterViewQuery + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + multipartIdentifier AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain + | SHOW TABLES ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)? + LIKE pattern=STRING partitionSpec? #showTableExtended + | SHOW TBLPROPERTIES table=multipartIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW COLUMNS (FROM | IN) table=multipartIdentifier + ((FROM | IN) ns=multipartIdentifier)? #showColumns + | SHOW VIEWS ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showViews + | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions + | SHOW identifier? FUNCTIONS + (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions + | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable + | SHOW CURRENT NAMESPACE #showCurrentNamespace + | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) namespace EXTENDED? + multipartIdentifier #describeNamespace + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? + multipartIdentifier partitionSpec? describeColName? #describeRelation + | (DESC | DESCRIBE) QUERY? query #describeQuery + | COMMENT ON namespace multipartIdentifier IS + comment=(STRING | NULL) #commentNamespace + | COMMENT ON TABLE multipartIdentifier IS comment=(STRING | NULL) #commentTable + | REFRESH TABLE multipartIdentifier #refreshTable + | REFRESH FUNCTION multipartIdentifier #refreshFunction + | REFRESH (STRING | .*?) #refreshResource + | CACHE LAZY? TABLE multipartIdentifier + (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable + | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable + | CLEAR CACHE #clearCache + | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE + multipartIdentifier partitionSpec? #loadData + | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable + | MSCK REPAIR TABLE multipartIdentifier + (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable + | op=(ADD | LIST) identifier .*? #manageResource + | SET ROLE .*? #failNativeCommand + | SET TIME ZONE interval #setTimeZone + | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone + | SET TIME ZONE .*? #setTimeZone + | SET configKey EQ configValue #setQuotedConfiguration + | SET configKey (EQ .*?)? #setQuotedConfiguration + | SET .*? EQ configValue #setQuotedConfiguration + | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration + | RESET .*? #resetConfiguration + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +configKey + : quotedIdentifier + ; + +configValue + : quotedIdentifier + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS + | kw1=START kw2=TRANSACTION + | kw1=COMMIT + | kw1=ROLLBACK + | kw1=DFS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier + ; + +replaceTableHeader + : (CREATE OR)? REPLACE TABLE multipartIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +commentSpec + : COMMENT STRING + ; + +query + : ctes? queryTerm queryOrganization + ; + +insertInto + : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable + | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir + | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +namespace + : NAMESPACE + | DATABASE + | SCHEMA + ; + +describeFuncName + : qualifiedName + | STRING + | comparisonOperator + | arithmeticOperator + | predicateOperator + ; + +describeColName + : nameParts+=identifier ('.' nameParts+=identifier)* + ; + +ctes + : WITH RECURSIVE? namedQuery (COMMA namedQuery)* + ; + +namedQuery + : name=errorCapturingIdentifier (columnAliases=identifierList)? (MAX RECURSION LEVEL integerValue)? AS? LEFT_PAREN query RIGHT_PAREN + ; + +tableProvider + : USING multipartIdentifier + ; + +createTableClauses + :((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | + bucketSpec | + rowFormat | + createFileFormat | + locationSpec | + commentSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=tablePropertyValue)? + ; + +tablePropertyKey + : identifier ('.' identifier)* + | STRING + ; + +tablePropertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +dmlStatementNoWith + : insertInto queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable + | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable + | MERGE INTO target=multipartIdentifier targetAlias=tableAlias + USING (source=multipartIdentifier | + '(' sourceQuery=query')') sourceAlias=tableAlias + ON mergeCondition=booleanExpression + matchedClause* + notMatchedClause* #mergeIntoTable + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windowClause? + (LIMIT (ALL | limit=expression))? + ; + +multiInsertQueryBody + : insertInto fromStatementBody + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enabled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | fromStatement #fromStmt + | TABLE multipartIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' query ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? + ; + +fromStatement + : fromClause fromStatementBody+ + ; + +fromStatementBody + : transformClause + whereClause? + queryOrganization + | selectClause + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? + queryOrganization + ; + +querySpecification + : transformClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #transformQuerySpecification + | selectClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #regularQuerySpecification + ; + +transformClause + : (SELECT kind=TRANSFORM '(' setQuantifier? expressionSeq ')' + | kind=MAP setQuantifier? expressionSeq + | kind=REDUCE setQuantifier? expressionSeq) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + ; + +selectClause + : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq + ; + +setClause + : SET assignmentList + ; + +matchedClause + : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction + ; +notMatchedClause + : WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + ; + +matchedAction + : DELETE + | UPDATE SET ASTERISK + | UPDATE SET assignmentList + ; + +notMatchedAction + : INSERT ASTERISK + | INSERT '(' columns=multipartIdentifierList ')' + VALUES '(' expression (',' expression)* ')' + ; + +assignmentList + : assignment (',' assignment)* + ; + +assignment + : key=multipartIdentifier EQ value=expression + ; + +whereClause + : WHERE booleanExpression + ; + +havingClause + : HAVING booleanExpression + ; + +hint + : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' + ; + +fromClause + : FROM relation (',' relation)* lateralView* pivotClause? + ; + +temporalClause + : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING) + ; + +aggregationClause + : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause + (',' groupingExpressionsWithGroupingAnalytics+=groupByClause)* + | GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupByClause + : groupingAnalytics + | expression + ; + +groupingAnalytics + : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')' + | GROUPING SETS '(' groupingElement (',' groupingElement)* ')' + ; + +groupingElement + : groupingAnalytics + | groupingSet + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : LATERAL? relationPrimary joinRelation* + ; + +joinRelation + : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria? + | NATURAL joinType JOIN LATERAL? right=relationPrimary + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | LEFT? SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING identifierList + ; + +sample + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : ident=errorCapturingIdentifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier commentSpec? + ; + +relationPrimary + : multipartIdentifier temporalClause? + sample? tableAlias #tableName + | '(' query ')' sample? tableAlias #aliasedQuery + | '(' relation ')' sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction + ; + +inlineTable + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : funcName=functionName '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +multipartIdentifierList + : multipartIdentifier (',' multipartIdentifier)* + ; + +multipartIdentifier + : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)* + ; + +tableIdentifier + : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier + ; + +functionIdentifier + : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier + ; + +multipartIdentifierPropertyList + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* + ; + +multipartIdentifierProperty + : multipartIdentifier (OPTIONS options=propertyList)? + ; + +propertyList + : LEFT_PAREN property (COMMA property)* RIGHT_PAREN + ; + +property + : key=propertyKey (EQ? value=propertyValue)? + ; + +propertyKey + : identifier (DOT identifier)* + | STRING + ; + +propertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +namedExpression + : expression (AS? (name=errorCapturingIdentifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +partitionFieldList + : '(' fields+=partitionField (',' fields+=partitionField)* ')' + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + +expression + : booleanExpression + ; + +expressionSeq + : expression (',' expression)* + ; + +booleanExpression + : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists + | valueExpression predicate? #predicated + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=RLIKE pattern=valueExpression + | NOT? kind=LIKE quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')') + | NOT? kind=LIKE pattern=valueExpression (ESCAPE escapeChar=STRING)? + | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) + | IS NOT? kind=DISTINCT FROM right=valueExpression + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER) #currentLike + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | name=(CAST | TRY_CAST) '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last + | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position + | constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor + | '(' query ')' #subqueryExpression + | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' + (FILTER '(' WHERE where=booleanExpression ')')? + (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall + | identifier '->' expression #lambda + | '(' identifier (',' identifier)+ ')' '->' expression #lambda + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract + | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression + ((FOR | ',') len=valueExpression)? ')' #substring + | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression ')' #trim + | OVERLAY '(' input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +arithmeticOperator + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT + ; + +predicateOperator + : OR | AND | IN | NOT + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)? + ; + +errorCapturingMultiUnitsInterval + : body=multiUnitsInterval unitToUnitInterval? + ; + +multiUnitsInterval + : (intervalValue unit+=identifier)+ + ; + +errorCapturingUnitToUnitInterval + : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)? + ; + +unitToUnitInterval + : value=intervalValue from=identifier TO to=identifier + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE | STRING) + ; + +colPosition + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType + | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType + | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) + (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition? + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + ; + +complexColTypeList + : complexColType (',' complexColType)* + ; + +complexColType + : identifier ':'? dataType (NOT NULL)? commentSpec? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windowClause + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : name=errorCapturingIdentifier AS windowSpec + ; + +windowSpec + : name=errorCapturingIdentifier #windowRef + | '('name=errorCapturingIdentifier')' #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + +qualifiedNameList + : qualifiedName (',' qualifiedName)* + ; + +functionName + : qualifiedName + | FILTER + | LEFT + | RIGHT + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : strictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +// simpleIdentifier: like identifier but without IDENTIFIER('literal') support +// Use this for contexts where IDENTIFIER() syntax is not appropriate: +// - Named parameters (:param_name) +// - Extract field names (EXTRACT(field FROM ...)) +// - Other keyword-like or string-like uses +simpleIdentifier + : simpleStrictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier + ; + +// simpleStrictIdentifier: like strictIdentifier but without IDENTIFIER('literal') support +simpleStrictIdentifier + : IDENTIFIER #simpleUnquotedIdentifier + | quotedIdentifier #simpleQuotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #simpleUnquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #simpleUnquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral + | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral + | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +integerValue + : INTEGER_VALUE #integerVal + | parameterMarker #parameterIntegerValue + ; + +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + ; + +parameterMarker + : {parameter_substitution_enabled}? namedParameterMarker #namedParameterMarkerRule + | {parameter_substitution_enabled}? QUESTION #positionalParameterMarkerRule + ; + +namedParameterMarker + : COLON simpleIdentifier + ; + +// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. +// - Reserved keywords: +// Keywords that are reserved and can't be used as identifiers for table, view, column, +// function, alias, etc. +// - Non-reserved keywords: +// Keywords that have a special meaning only in particular contexts and can be used as +// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN +// can be used as identifiers in other places. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords. +ansiNonReserved +//--ANSI-NON-RESERVED-START + : ADD + | AFTER + | ALTER + | ANALYZE + | ANTI + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CHANGE + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DAY + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DROP + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GLOBAL + | GROUPING + | HOUR + | IF + | IGNORE + | IMPORT + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | ITEMS + | KEYS + | LAST + | LAZY + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MINUTE + | MONTH + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NULLS + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SECOND + | SEMI + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETMINUS + | SETS + | SHOW + | SKEWED + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNSET + | UPDATE + | USE + | VALUES + | VIEW + | VIEWS + | WINDOW + | YEAR + | ZONE +//--ANSI-NON-RESERVED-END + ; + +// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL. +// - Non-reserved keywords: +// Same definition as the one when `SQL_standard_keyword_behavior=true`. +// - Strict-non-reserved keywords: +// A strict version of non-reserved keywords, which can not be used as table alias. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The strict-non-reserved keywords are listed in `strictNonReserved`. +// The non-reserved keywords are listed in `nonReserved`. +// These 2 together contain all the keywords. +strictNonReserved + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LATERAL + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION + | USING + ; + +nonReserved +//--DEFAULT-NON-RESERVED-START + : ADD + | AFTER + | ALL + | ALTER + | ANALYZE + | AND + | ANY + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BOTH + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CASE + | CAST + | CHANGE + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DAY + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DROP + | ELSE + | END + | ESCAPE + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FILTER + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | HOUR + | IF + | IGNORE + | IMPORT + | IN + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LAZY + | LEADING + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MINUTE + | MONTH + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NOT + | NULL + | NULLS + | OF + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SECOND + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHOW + | SKEWED + | SOME + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | TABLE + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TIME + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNKNOWN + | UNLOCK + | UNSET + | UPDATE + | USE + | USER + | VALUES + | VIEW + | VIEWS + | WHEN + | WHERE + | WINDOW + | WITH + | YEAR + | ZONE + | SYSTEM_VERSION + | VERSION + | SYSTEM_TIME + | TIMESTAMP +//--DEFAULT-NON-RESERVED-END + ; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CHANGE: 'CHANGE'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DAY: 'DAY'; +DATA: 'DATA'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES' | 'SCHEMAS'; +DBPROPERTIES: 'DBPROPERTIES'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +HOUR: 'HOUR'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MINUTE: 'MINUTE'; +MONTH: 'MONTH'; +MSCK: 'MSCK'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +OF: 'OF'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESPECT: 'RESPECT'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SECOND: 'SECOND'; +SCHEMA: 'SCHEMA'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SKEWED: 'SKEWED'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +YEAR: 'YEAR'; +ZONE: 'ZONE'; + +SYSTEM_VERSION: 'SYSTEM_VERSION'; +VERSION: 'VERSION'; +SYSTEM_TIME: 'SYSTEM_TIME'; +TIMESTAMP: 'TIMESTAMP'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark4.2.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 new file mode 100644 index 0000000000000..ddbecfefc760d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar HoodieSqlBase; + +import SqlBase; + +singleStatement + : statement EOF + ; + +statement + : query #queryStatement + | ctes? dmlStatementNoWith #dmlStatement + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE INDEX (IF NOT EXISTS)? identifier ON TABLE? + tableIdentifier (USING indexType=identifier)? + LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN + (OPTIONS indexOptions=propertyList)? #createIndex + | DROP INDEX (IF EXISTS)? identifier ON TABLE? tableIdentifier #dropIndex + | SHOW INDEXES (FROM | IN) TABLE? tableIdentifier #showIndexes + | REFRESH INDEX identifier ON TABLE? tableIdentifier #refreshIndex + | .*? #passThrough + ; diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodieFileScanRDD.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodieFileScanRDD.scala new file mode 100644 index 0000000000000..fb4b69ab1b759 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodieFileScanRDD.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.types.StructType + +class Spark42HoodieFileScanRDD(@transient private val sparkSession: SparkSession, + read: PartitionedFile => Iterator[InternalRow], + @transient filePartitions: Seq[FilePartition], + readDataSchema: StructType, + metadataColumns: Seq[AttributeReference] = Seq.empty) + extends FileScanRDD(sparkSession, read, filePartitions, readDataSchema, metadataColumns) + with HoodieUnsafeRDD { + + override final def collect(): Array[InternalRow] = super[HoodieUnsafeRDD].collect() +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionCDCFileGroupMapping.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionCDCFileGroupMapping.scala new file mode 100644 index 0000000000000..50c2c584bd30a --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionCDCFileGroupMapping.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi + +import org.apache.hudi.common.table.cdc.HoodieCDCFileSplit + +import org.apache.spark.sql.catalyst.InternalRow + +class Spark42HoodiePartitionCDCFileGroupMapping(partitionValues: InternalRow, + fileSplits: List[HoodieCDCFileSplit]) + extends Spark42HoodiePartitionValues(partitionValues) + with HoodiePartitionCDCFileGroupMapping { + + override def getFileSplits(): List[HoodieCDCFileSplit] = { + fileSplits + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionFileSliceMapping.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionFileSliceMapping.scala new file mode 100644 index 0000000000000..f4b3838c79477 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionFileSliceMapping.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi + +import org.apache.hudi.common.model.FileSlice + +import org.apache.spark.sql.catalyst.InternalRow + +class Spark42HoodiePartitionFileSliceMapping(values: InternalRow, + slices: Map[String, FileSlice]) + extends Spark42HoodiePartitionValues(values) + with HoodiePartitionFileSliceMapping { + + override def getSlice(fileId: String): Option[FileSlice] = { + slices.get(fileId) + } + + override def getPartitionValues: InternalRow = values +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionValues.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionValues.scala new file mode 100644 index 0000000000000..831475d013640 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/Spark42HoodiePartitionValues.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{DataType, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, GeographyVal, GeometryVal, UTF8String, VariantVal} + +case class Spark42HoodiePartitionValues(values: InternalRow) extends HoodiePartitionValues { + override def numFields: Int = { + values.numFields + } + + override def setNullAt(i: Int): Unit = { + values.setNullAt(i) + } + + override def update(i: Int, value: Any): Unit = { + values.update(i, value) + } + + override def copy(): InternalRow = { + Spark42HoodiePartitionValues(values.copy()) + } + + override def isNullAt(ordinal: Int): Boolean = { + values.isNullAt(ordinal) + } + + override def getBoolean(ordinal: Int): Boolean = { + values.getBoolean(ordinal) + } + + override def getByte(ordinal: Int): Byte = { + values.getByte(ordinal) + } + + override def getShort(ordinal: Int): Short = { + values.getShort(ordinal) + } + + override def getInt(ordinal: Int): Int = { + values.getInt(ordinal) + } + + override def getLong(ordinal: Int): Long = { + values.getLong(ordinal) + } + + override def getFloat(ordinal: Int): Float = { + values.getFloat(ordinal) + } + + override def getDouble(ordinal: Int): Double = { + values.getDouble(ordinal) + } + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = { + values.getDecimal(ordinal, precision, scale) + } + + override def getUTF8String(ordinal: Int): UTF8String = { + values.getUTF8String(ordinal) + } + + override def getBinary(ordinal: Int): Array[Byte] = { + values.getBinary(ordinal) + } + + override def getGeography(ordinal: Int): GeographyVal = { + values.getGeography(ordinal) + } + + override def getGeometry(ordinal: Int): GeometryVal = { + values.getGeometry(ordinal) + } + + override def getInterval(ordinal: Int): CalendarInterval = { + values.getInterval(ordinal) + } + + override def getVariant(ordinal: Int): VariantVal = { + values.getVariant(ordinal) + } + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values.getStruct(ordinal, numFields) + } + + override def getArray(ordinal: Int): ArrayData = { + values.getArray(ordinal) + } + + override def getMap(ordinal: Int): MapData = { + values.getMap(ordinal) + } + + override def get(ordinal: Int, dataType: DataType): AnyRef = { + values.get(ordinal, dataType) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/client/model/Spark42HoodieInternalRow.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/client/model/Spark42HoodieInternalRow.scala new file mode 100644 index 0000000000000..115c2e7151f13 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/hudi/client/model/Spark42HoodieInternalRow.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.client.model + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal} + +class Spark42HoodieInternalRow( + metaFields: Array[UTF8String], + sourceRow: InternalRow, + sourceContainsMetaFields: Boolean) + extends HoodieInternalRow(metaFields, sourceRow, sourceContainsMetaFields) { + + override def getVariant(ordinal: Int): VariantVal = { + ruleOutMetaFieldsAccess(ordinal, classOf[VariantVal]) + sourceRow.getVariant(rebaseOrdinal(ordinal)) + } + + override def getGeography(ordinal: Int): GeographyVal = { + ruleOutMetaFieldsAccess(ordinal, classOf[GeographyVal]) + sourceRow.getGeography(rebaseOrdinal(ordinal)) + } + + override def getGeometry(ordinal: Int): GeometryVal = { + ruleOutMetaFieldsAccess(ordinal, classOf[GeometryVal]) + sourceRow.getGeometry(rebaseOrdinal(ordinal)) + } + + override def copy(): InternalRow = { + val copyMetaFields = metaFields.map(f => if (f != null) f.copy() else null) + new Spark42HoodieInternalRow( + copyMetaFields, + if (sourceRow == null) null else sourceRow.copy(), + sourceContainsMetaFields) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42CatalystExpressionUtils.scala new file mode 100644 index 0000000000000..2b3df6546f152 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42CatalystExpressionUtils.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, EvalMode, Exp, Expm1, Expression, FromUnixTime, FromUTCTimestamp, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUnixTimestamp, ToUTCTimestamp, Upper} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.types.{DataType, StructType} + +object HoodieSpark42CatalystExpressionUtils extends HoodieSpark4CatalystExpressionUtils with PredicateHelper { + + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + ExpressionEncoder.apply(schema).resolveAndBind() + } + + override def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] = { + DataSourceStrategy.normalizeExprs(exprs, attributes) + } + + override def extractPredicatesWithinOutputSet(condition: Expression, outputSet: AttributeSet): Option[Expression] = { + super[PredicateHelper].extractPredicatesWithinOutputSet(condition, outputSet) + } + + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = { + expr match { + case Cast(child, dataType, timeZoneId, _) => Some((child, dataType, timeZoneId)) + case _ => None + } + } + + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { + expr match { + case OrderPreservingTransformation(attrRef) => Some(attrRef) + case _ => None + } + } + + def canUpCast(fromType: DataType, toType: DataType): Boolean = + Cast.canUpCast(fromType, toType) + + override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + expr match { + case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) => + Some((castedExpr, dataType, timeZoneId, if (ansiEnabled == EvalMode.ANSI) true else false)) + case _ => None + } + + private object OrderPreservingTransformation { + def unapply(expr: Expression): Option[AttributeReference] = { + expr match { + // Date/Time Expressions + case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ParseToDate(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _, _, _) => Some(attrRef) + case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // String Expressions + case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Left API change: Improve RuntimeReplaceable + // https://issues.apache.org/jira/browse/SPARK-38240 + case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Math Expressions + // Binary + case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Unary + case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Other + case cast@Cast(OrderPreservingTransformation(attrRef), _, _, _) + if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef) + + // Identity transformation + case attrRef: AttributeReference => Some(attrRef) + // No match + case _ => None + } + } + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42CatalystPlanUtils.scala new file mode 100644 index 0000000000000..e9375efbbaf40 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42CatalystPlanUtils.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, ResolvedTable} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, ProjectionOverSchema} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} +import org.apache.spark.sql.execution.command.RepairTableCommand +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.{HoodieFormatTrait, ParquetFileFormat} +import org.apache.spark.sql.execution.streaming.runtime.SerializedOffset +import org.apache.spark.sql.types.StructType + +object HoodieSpark42CatalystPlanUtils extends BaseHoodieCatalystPlanUtils { + + def unapplyResolvedTable(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] = + plan match { + case ResolvedTable(catalog, identifier, table, _) => Some((catalog, identifier, table)) + case _ => None + } + + override def unapplyMergeIntoTable(plan: LogicalPlan): Option[(LogicalPlan, LogicalPlan, Expression)] = { + plan match { + case MergeIntoTable(targetTable, sourceTable, mergeCondition, _, _, _, _) => + Some((targetTable, sourceTable, mergeCondition)) + case _ => None + } + } + + override def maybeApplyForNewFileFormat(plan: LogicalPlan): LogicalPlan = { + plan match { + case s@ScanOperation(_, _, _, + l@LogicalRelation(fs: HadoopFsRelation, _, _, _, _)) + if fs.fileFormat.isInstanceOf[ParquetFileFormat with HoodieFormatTrait] + && !fs.fileFormat.asInstanceOf[ParquetFileFormat with HoodieFormatTrait].isProjected => + FileFormatUtilsForFileGroupReader.applyNewFileFormatChanges(s, l, fs) + case _ => plan + } + } + + override def projectOverSchema(schema: StructType, output: AttributeSet): ProjectionOverSchema = + ProjectionOverSchema(schema, output) + + override def isRepairTable(plan: LogicalPlan): Boolean = { + plan.isInstanceOf[RepairTableCommand] + } + + override def getRepairTableChildren(plan: LogicalPlan): Option[(TableIdentifier, Boolean, Boolean, String)] = { + plan match { + case rtc: RepairTableCommand => + Some((rtc.tableName, rtc.enableAddPartitions, rtc.enableDropPartitions, rtc.cmd)) + case _ => + None + } + } + + override def failAnalysisForMIT(a: Attribute, cols: String): Unit = { + a.failAnalysis( + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + messageParameters = Map( + "objectName" -> a.sql, + "proposal" -> cols)) + } + + override def failTableNotFound(tableName: String): Unit = { + throw new AnalysisException( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map("relationName" -> s"`$tableName`")) + } + + override def unapplyCreateIndex(plan: LogicalPlan): Option[(LogicalPlan, String, String, Boolean, Seq[(Seq[String], Map[String, String])], Map[String, String])] = { + plan match { + case ci@CreateIndex(table, indexName, indexType, ignoreIfExists, columns, properties) => + Some((table, indexName, indexType, ignoreIfExists, columns.map(col => (col._1.name, col._2)), properties)) + case _ => + None + } + } + + override def unapplyDropIndex(plan: LogicalPlan): Option[(LogicalPlan, String, Boolean)] = { + plan match { + case ci@DropIndex(table, indexName, ignoreIfNotExists) => + Some((table, indexName, ignoreIfNotExists)) + case _ => + None + } + } + + override def unapplyShowIndexes(plan: LogicalPlan): Option[(LogicalPlan, Seq[Attribute])] = { + plan match { + case ci@HoodieShowIndexes(table, output) => + Some((table, output)) + case _ => + None + } + } + + override def unapplyRefreshIndex(plan: LogicalPlan): Option[(LogicalPlan, String)] = { + plan match { + case ci@RefreshIndex(table, indexName) => + Some((table, indexName)) + case _ => + None + } + } + + override def unapplyInsertIntoStatement(plan: LogicalPlan): Option[(LogicalPlan, Seq[String], Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { + plan match { + case insert: InsertIntoStatement => + Some((insert.table, insert.userSpecifiedCols, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists)) + case _ => + None + } + } + + override def createProjectForByNameQuery(lr: LogicalRelation, plan: LogicalPlan): Option[LogicalPlan] = { + plan match { + case insert: InsertIntoStatement => + Some(ResolveInsertionBase.createProjectForByNameQuery(lr.catalogTable.get.qualifiedName, insert)) + case _ => + None + } + } + + override def unapplyUpdateAction(mergeAction: Any): Option[(Option[Expression], Seq[Assignment])] = { + mergeAction match { + case UpdateAction(condition, assignments, _) => Some((condition, assignments)) + case _ => None + } + } + + override def extractJsonFromSerializedOffset(offset: Any): Option[String] = { + offset match { + case SerializedOffset(json) => Some(json) + case _ => None + } + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42SchemaUtils.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42SchemaUtils.scala new file mode 100644 index 0000000000000..4d5339ade1f5f --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark42SchemaUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.jdbc.JdbcDialect +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils + +import java.sql.{Connection, ResultSet} + +/** + * Utils on schema for Spark 3.4+. + */ +object HoodieSpark42SchemaUtils extends HoodieSchemaUtils { + override def checkColumnNameDuplication(columnNames: Seq[String], + colType: String, + caseSensitiveAnalysis: Boolean): Unit = { + SchemaUtils.checkColumnNameDuplication(columnNames, caseSensitiveAnalysis) + } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + DataTypeUtils.toAttributes(struct) + } + + override def getSchema(conn: Connection, + resultSet: ResultSet, + dialect: JdbcDialect, + alwaysNullable: Boolean = false, + isTimestampNTZ: Boolean = false): StructType = { + JdbcUtils.getSchema(conn, resultSet, dialect, alwaysNullable) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark4_2Adapter.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark4_2Adapter.scala new file mode 100644 index 0000000000000..df409752c7381 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark4_2Adapter.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.adapter + +import org.apache.hudi.{HoodiePartitionCDCFileGroupMapping, HoodiePartitionFileSliceMapping, Spark42HoodieFileScanRDD, Spark42HoodiePartitionCDCFileGroupMapping, Spark42HoodiePartitionFileSliceMapping} +import org.apache.hudi.client.model.{HoodieInternalRow, Spark42HoodieInternalRow} +import org.apache.hudi.common.model.FileSlice +import org.apache.hudi.common.schema.HoodieSchema +import org.apache.hudi.common.table.cdc.HoodieCDCFileSplit + +import org.apache.hadoop.conf.Configuration +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.avro._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, ResolvedTable} +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.util.{METADATA_COL_ATTR_KEY, RebaseDateTime} +import org.apache.spark.sql.connector.catalog.{V1Table, V2TableWithV1Fallback} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.lance.SparkLanceReaderBase +import org.apache.spark.sql.execution.datasources.orc.Spark42OrcReader +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark42LegacyHoodieParquetFileFormat, Spark42ParquetReader} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.hudi.{HoodieMemoryStream, SparkAdapter} +import org.apache.spark.sql.hudi.analysis.TableValuedFunctions +import org.apache.spark.sql.hudi.blob.{BatchedBlobReaderStrategy, ScalarFunctions} +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.parser.{HoodieExtendedParserInterface, HoodieSpark4_2ExtendedSqlParser} +import org.apache.spark.sql.types.{DataType, DataTypes, Metadata, MetadataBuilder, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatchRow +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.unsafe.types.UTF8String + +import scala.jdk.CollectionConverters.MapHasAsScala + +/** + * Implementation of [[SparkAdapter]] for Spark 4.2.x branch + */ +class Spark4_2Adapter extends BaseSpark4Adapter { + + override def resolveHoodieTable(plan: LogicalPlan): Option[CatalogTable] = { + super.resolveHoodieTable(plan).orElse { + EliminateSubqueryAliases(plan) match { + // First, we need to weed out unresolved plans + case plan if !plan.resolved => None + // NOTE: When resolving Hudi table we allow [[Filter]]s and [[Project]]s be applied + // on top of it + case PhysicalOperation(_, _, DataSourceV2Relation(v2: V2TableWithV1Fallback, _, _, _, _, _)) if isHoodieTable(v2) => + Some(v2.v1Table) + case ResolvedTable(_, _, V1Table(v1Table), _) if isHoodieTable(v1Table) => + Some(v1Table) + case _ => None + } + } + } + + def isHoodieTable(v2Table: V2TableWithV1Fallback): Boolean = { + v2Table.getClass.getName.contains("HoodieInternalV2Table") + } + + override def isColumnarBatchRow(r: InternalRow): Boolean = r.isInstanceOf[ColumnarBatchRow] + + def createCatalystMetadataForMetaField: Metadata = + new MetadataBuilder() + .putBoolean(METADATA_COL_ATTR_KEY, value = true) + .build() + + override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark42CatalystExpressionUtils + + override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark42CatalystPlanUtils + + override def getSchemaUtils: HoodieSchemaUtils = HoodieSpark42SchemaUtils + + override def getSparkPartitionedFileUtils: HoodieSparkPartitionedFileUtils = HoodieSpark42PartitionedFileUtils + + override def newParseException(command: Option[String], + exception: AnalysisException, + start: Origin, + stop: Origin): ParseException = { + new ParseException(command, start, exception.getErrorClass, exception.getMessageParameters.asScala.toMap) + } + + + + override def createAvroSerializer(rootCatalystType: DataType, rootType: HoodieSchema, nullable: Boolean): HoodieAvroSerializer = + new HoodieSpark4_2AvroSerializer(rootCatalystType, rootType.toAvroSchema, nullable) + + override def createAvroDeserializer(rootType: HoodieSchema, rootCatalystType: DataType): HoodieAvroDeserializer = + new HoodieSpark4_2AvroDeserializer(rootType.toAvroSchema, rootCatalystType) + + override def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface = + new HoodieSpark4_2ExtendedSqlParser(spark, delegate) + + override def createLegacyHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = { + Some(new Spark42LegacyHoodieParquetFileFormat(appendPartitionValues)) + } + + override def createInternalRow(metaFields: Array[UTF8String], + sourceRow: InternalRow, + sourceContainsMetaFields: Boolean): HoodieInternalRow = { + new Spark42HoodieInternalRow(metaFields, sourceRow, sourceContainsMetaFields) + } + + override def createPartitionCDCFileGroupMapping(partitionValues: InternalRow, + fileSplits: List[HoodieCDCFileSplit]): HoodiePartitionCDCFileGroupMapping = { + new Spark42HoodiePartitionCDCFileGroupMapping(partitionValues, fileSplits) + } + + override def createPartitionFileSliceMapping(values: InternalRow, + slices: Map[String, FileSlice]): HoodiePartitionFileSliceMapping = { + new Spark42HoodiePartitionFileSliceMapping(values, slices) + } + + override def createHoodieFileScanRDD(sparkSession: SparkSession, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readDataSchema: StructType, + metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD = { + new Spark42HoodieFileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns) + } + + override def extractDeleteCondition(deleteFromTable: Command): Expression = { + deleteFromTable.asInstanceOf[DeleteFromTable].condition + } + + override def injectTableFunctions(extensions: SparkSessionExtensions): Unit = { + TableValuedFunctions.funcs.foreach(extensions.injectTableFunction) + } + + override def injectScalarFunctions(extensions: SparkSessionExtensions): Unit = { + ScalarFunctions.funcs.foreach(extensions.injectFunction) + } + + override def injectPlannerStrategies(extensions: SparkSessionExtensions): Unit = { + extensions.injectPlannerStrategy { session => + BatchedBlobReaderStrategy(session) + } + } + + /** + * Converts instance of [[StorageLevel]] to a corresponding string + */ + override def convertStorageLevelToString(level: StorageLevel): String = level match { + case NONE => "NONE" + case DISK_ONLY => "DISK_ONLY" + case DISK_ONLY_2 => "DISK_ONLY_2" + case DISK_ONLY_3 => "DISK_ONLY_3" + case MEMORY_ONLY => "MEMORY_ONLY" + case MEMORY_ONLY_2 => "MEMORY_ONLY_2" + case MEMORY_ONLY_SER => "MEMORY_ONLY_SER" + case MEMORY_ONLY_SER_2 => "MEMORY_ONLY_SER_2" + case MEMORY_AND_DISK => "MEMORY_AND_DISK" + case MEMORY_AND_DISK_2 => "MEMORY_AND_DISK_2" + case MEMORY_AND_DISK_SER => "MEMORY_AND_DISK_SER" + case MEMORY_AND_DISK_SER_2 => "MEMORY_AND_DISK_SER_2" + case OFF_HEAP => "OFF_HEAP" + case _ => throw new IllegalArgumentException(s"Invalid StorageLevel: $level") + } + + /** + * Get parquet file reader + * + * @param vectorized true if vectorized reading is not prohibited due to schema, reading mode, etc + * @param sqlConf the [[SQLConf]] used for the read + * @param options passed as a param to the file format + * @param hadoopConf some configs will be set for the hadoopConf + * @return parquet file reader + */ + override def createParquetFileReader(vectorized: Boolean, + sqlConf: SQLConf, + options: Map[String, String], + hadoopConf: Configuration): SparkColumnarFileReader = { + Spark42ParquetReader.build(vectorized, sqlConf, options, hadoopConf) + } + + /** + * Get ORC file reader + * + * @param vectorized true if vectorized reading is not prohibited due to schema, reading mode, etc + * @param sqlConf the [[SQLConf]] used for the read + * @param options passed as a param to the file format + * @param hadoopConf some configs will be set for the hadoopConf + * @param dataSchema the data schema of the ORC file + * @return ORC file reader + */ + override def createOrcFileReader(vectorized: Boolean, + sqlConf: SQLConf, + options: Map[String, String], + hadoopConf: Configuration, + dataSchema: StructType): SparkColumnarFileReader = { + Spark42OrcReader.build(vectorized, sqlConf, options, hadoopConf, dataSchema) + } + + override def createLanceFileReader(vectorized: Boolean, + sqlConf: SQLConf, + options: Map[String, String], + hadoopConf: Configuration): Option[SparkColumnarFileReader] = { + Some(new SparkLanceReaderBase(vectorized)) + } + + override def stopSparkContext(jssc: JavaSparkContext, exitCode: Int): Unit = { + jssc.sc.stop(exitCode) + } + + override def getDateTimeRebaseMode(): LegacyBehaviorPolicy.Value = { + SQLConf.get.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE) + } + + override def isLegacyBehaviorPolicy(value: Object): Boolean = { + value == LegacyBehaviorPolicy.LEGACY + } + + override def isTimestampNTZType(dataType: DataType): Boolean = { + dataType == DataTypes.TimestampNTZType + } + + override def getRebaseSpec(policy: String): RebaseDateTime.RebaseSpec = { + RebaseDateTime.RebaseSpec(LegacyBehaviorPolicy.withName(policy)) + } + + override def createMemoryStream[T: Encoder](id: Int, sparkSession: SparkSession): HoodieMemoryStream[T] = { + // In Spark 4.1, MemoryStream is in org.apache.spark.sql.execution.streaming.runtime package + // and takes SparkSession directly instead of SQLContext + val memoryStream = new MemoryStream[T](id, sparkSession) + new HoodieMemoryStream[T] { + override def addData(data: TraversableOnce[T]): Unit = memoryStream.addData(data) + override def toDS(): Dataset[T] = memoryStream.toDS() + } + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..9ae690f1512e8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,592 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hudi.common.schema.HoodieSchema +import org.apache.hudi.common.schema.HoodieSchema.VectorLogicalType + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 +import org.apache.spark.sql.avro.AvroDeserializer.{createDateRebaseFuncInRead, createTimestampRebaseFuncInRead, RebaseSpec} +import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData, RebaseDateTime} +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} + +import java.math.BigDecimal +import java.nio.ByteBuffer +import java.util.TimeZone + +import scala.collection.JavaConverters._ + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + * + * NOTE: This code is borrowed from Spark 3.3.0 + * This code is borrowed, so that we can better control compatibility w/in Spark minor + * branches (3.2.x, 3.1.x, etc) + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +private[sql] class AvroDeserializer(rootAvroType: Schema, + rootCatalystType: DataType, + positionalFieldMatch: Boolean, + datetimeRebaseSpec: RebaseSpec, + filters: StructFilters) { + + def this(rootAvroType: Schema, + rootCatalystType: DataType, + datetimeRebaseMode: LegacyBehaviorPolicy.Value) = { + this( + rootAvroType, + rootCatalystType, + positionalFieldMatch = false, + RebaseSpec(datetimeRebaseMode), + new NoopFilters) + } + + private lazy val decimalConversions = new DecimalConversion() + + private val dateRebaseFunc = createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro") + + private val timestampRebaseFunc = createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro") + + private val converter: Any => Option[Any] = try { + rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (_: Any) => Some(InternalRow.empty) + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val applyFilters = filters.skipRow(resultRow, _) + val writer = getRecordWriter(rootAvroType, st, Nil, Nil, applyFilters) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + val skipRow = writer(fieldUpdater, record) + if (skipRow) None else Some(resultRow) + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + Some(tmpRow.get(0, rootCatalystType)) + } + } + } catch { + case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException( + s"Cannot convert Avro type $rootAvroType to SQL type ${rootCatalystType.sql}.", ise) + } + + def deserialize(data: Any): Option[Any] = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter(avroType: Schema, + catalystType: DataType, + avroPath: Seq[String], + catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = { + val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " + + s"SQL ${toFieldStr(catalystPath)} because " + val incompatibleMsg = errorPrefix + + s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" + + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + value match { + case localDate: java.time.LocalDate => + updater.setInt(ordinal, localDate.toEpochDay.toInt) + case _ => + updater.setInt(ordinal, value.asInstanceOf[Int]) + } + + case (INT, DateType) => (updater, ordinal, value) => + val days = value match { + case localDate: java.time.LocalDate => localDate.toEpochDay.toInt + case _ => value.asInstanceOf[Int] + } + updater.setInt(ordinal, dateRebaseFunc(days)) + + case (LONG, LongType) => (updater, ordinal, value) => + val longVal = value match { + case instant: java.time.Instant => + instant.getEpochSecond * 1000000L + instant.getNano / 1000L + case l => l.asInstanceOf[Long] + } + updater.setLong(ordinal, longVal) + + case (LONG, TimestampType) => avroType.getLogicalType match { + // For backward compatibility, if the Avro type is Long and it is not logical type + // (the `null` case), the value is processed as timestamp type with millisecond precision. + case null | _: TimestampMillis => (updater, ordinal, value) => + val millis = value match { + case instant: java.time.Instant => instant.toEpochMilli + case l => l.asInstanceOf[Long] + } + val micros = DateTimeUtils.millisToMicros(millis) + updater.setLong(ordinal, timestampRebaseFunc(micros)) + case _: TimestampMicros => (updater, ordinal, value) => + val micros = value match { + case instant: java.time.Instant => + instant.getEpochSecond * 1000000L + instant.getNano / 1000L + case l => l.asInstanceOf[Long] + } + updater.setLong(ordinal, timestampRebaseFunc(micros)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.") + } + + case (LONG, TimestampNTZType) => avroType.getLogicalType match { + // To keep consistent with TimestampType, if the Avro type is Long and it is not + // logical type (the `null` case), the value is processed as TimestampNTZ + // with millisecond precision. + case null | _: LocalTimestampMillis => (updater, ordinal, value) => + val millis = value match { + case ldt: java.time.LocalDateTime => + java.time.Duration.between(java.time.Instant.EPOCH, ldt.toInstant(java.time.ZoneOffset.UTC)).toMillis + case l => l.asInstanceOf[Long] + } + val micros = DateTimeUtils.millisToMicros(millis) + updater.setLong(ordinal, micros) + case _: LocalTimestampMicros => (updater, ordinal, value) => + val micros = value match { + case ldt: java.time.LocalDateTime => + val instant = ldt.toInstant(java.time.ZoneOffset.UTC) + instant.getEpochSecond * 1000000L + instant.getNano / 1000L + case l => l.asInstanceOf[Long] + } + updater.setLong(ordinal, micros) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"Avro logical type $other cannot be converted to SQL type ${TimestampNTZType.sql}.") + } + + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + case s: GenericData.EnumSymbol => UTF8String.fromString(s.toString) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + // Do not forget to reset the position + b.rewind() + bytes + case b: Array[Byte] => b + case other => + throw new RuntimeException(errorPrefix + s"$other is not a valid avro binary.") + } + updater.set(ordinal, bytes) + + case (FIXED, _: DecimalType) => (updater, ordinal, value) => + val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] + val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) + + // Handle VECTOR logical type (FLOAT, DOUBLE, INT8) + case (FIXED, ArrayType(elementType, false)) => avroType.getLogicalType match { + case vectorLogicalType: VectorLogicalType => + val dimension = vectorLogicalType.getDimension + val vecElementType = HoodieSchema.Vector.VectorElementType.fromString(vectorLogicalType.getElementType) + val elementSize = vecElementType.getElementSize + (updater, ordinal, value) => { + val bytes = value.asInstanceOf[GenericData.Fixed].bytes() + val expectedSize = Math.multiplyExact(dimension, elementSize) + if (bytes.length != expectedSize) { + throw new IncompatibleSchemaException( + s"VECTOR byte size mismatch: expected=$expectedSize, actual=${bytes.length}") + } + elementType match { + case FloatType => + val buffer = ByteBuffer.wrap(bytes).order(VectorLogicalType.VECTOR_BYTE_ORDER) + val floats = new Array[Float](dimension) + var i = 0; while (i < dimension) { floats(i) = buffer.getFloat(); i += 1 } + updater.set(ordinal, ArrayData.toArrayData(floats)) + case DoubleType => + val buffer = ByteBuffer.wrap(bytes).order(VectorLogicalType.VECTOR_BYTE_ORDER) + val doubles = new Array[Double](dimension) + var i = 0; while (i < dimension) { doubles(i) = buffer.getDouble(); i += 1 } + updater.set(ordinal, ArrayData.toArrayData(doubles)) + case ByteType => + updater.set(ordinal, ArrayData.toArrayData(bytes.clone())) + } + } + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + + case (BYTES, _: DecimalType) => (updater, ordinal, value) => + val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) + + case (RECORD, VariantType) if avroType.getLogicalType != null + && avroType.getLogicalType.getName == HoodieSchema.VARIANT_TYPE_NAME => + // Validation & Pre-calculation with fail fast logic + val valueField = avroType.getField(HoodieSchema.Variant.VARIANT_VALUE_FIELD) + val metadataField = avroType.getField(HoodieSchema.Variant.VARIANT_METADATA_FIELD) + + if (valueField == null || metadataField == null) { + throw new IncompatibleSchemaException(incompatibleMsg + + ": Variant logical type requires 'value' and 'metadata' fields") + } + + val valueIdx = valueField.pos() + val metadataIdx = metadataField.pos() + + // Variant types are stored as records with "value" and "metadata" binary fields + // Deserialize them back to VariantVal + (updater, ordinal, value) => + val record = value.asInstanceOf[IndexedRecord] + + val valueBuffer = record.get(valueIdx).asInstanceOf[ByteBuffer] + val valueBytes = new Array[Byte](valueBuffer.remaining) + valueBuffer.get(valueBytes) + valueBuffer.rewind() + + val metadataBuffer = record.get(metadataIdx).asInstanceOf[ByteBuffer] + val metadataBytes = new Array[Byte](metadataBuffer.remaining) + metadataBuffer.get(metadataBytes) + metadataBuffer.rewind() + + val variant = new VariantVal(valueBytes, metadataBytes) + updater.set(ordinal, variant) + + case (RECORD, st: StructType) => + // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328. + // We can always return `false` from `applyFilters` for nested records. + val writeRecord = + getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = _ => false) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val avroElementPath = avroPath :+ "element" + val elementWriter = newWriter(avroType.getElementType, elementType, + avroElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iter = collection.iterator() + while (iter.hasNext) { + val element = iter.next() + if (element == null) { + if (!containsNull) { + throw new RuntimeException( + s"Array value at path ${toFieldStr(avroElementPath)} is not allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, + avroPath :+ "key", catalystPath :+ "key") + val valueWriter = newWriter(avroType.getValueType, valueType, + avroPath :+ "value", catalystPath :+ "value") + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException( + s"Map value at path ${toFieldStr(avroPath :+ "value")} is not allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + // The Avro map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, avroPath, catalystPath) + } else { + nonNullTypes.map(_.getType).toSeq match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => + newWriter(schema, field.dataType, avroPath, catalystPath :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(nonNullAvroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + } + } else { + (updater, ordinal, _) => updater.setNullAt(ordinal) + } + + case (INT, _: YearMonthIntervalType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, _: DayTimeIntervalType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + + // TODO: move the following method in Decimal object on creating Decimal from BigDecimal? + private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + Decimal(decimal.unscaledValue().longValue(), precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(decimal, precision, scale) + } + } + + private def getRecordWriter( + avroType: Schema, + catalystType: StructType, + avroPath: Seq[String], + catalystPath: Seq[String], + applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = { + + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroType, catalystType, avroPath, catalystPath, positionalFieldMatch) + + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + // no need to validateNoExtraAvroFields since extra Avro fields are ignored + + val (validFieldIndexes, fieldWriters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, ordinal, avroField) => + val baseWriter = newWriter(avroField.schema(), catalystField.dataType, + avroPath :+ avroField.name, catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + (avroField.pos(), fieldWriter) + }.toArray.unzip + + (fieldUpdater, record) => { + var i = 0 + var skipRow = false + while (i < validFieldIndexes.length && !skipRow) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + skipRow = applyFilters(i) + i += 1 + } + skipRow + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } +} + +object AvroDeserializer { + + // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroDeserializer]] implementation + // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]]. + // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch, + // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as + // w/ Spark >= 3.2.1 + // + // [1] https://github.com/apache/spark/pull/34978 + + // Specification of rebase operation including `mode` and the time zone in which it is performed + case class RebaseSpec(mode: LegacyBehaviorPolicy.Value, originTimeZone: Option[String] = None) { + // Use the default JVM time zone for backward compatibility + def timeZone: String = originTimeZone.getOrElse(TimeZone.getDefault.getID) + } + + def createDateRebaseFuncInRead(rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => days: Int => + if (days < RebaseDateTime.lastSwitchJulianDay) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def createTimestampRebaseFuncInRead(rebaseSpec: RebaseSpec, + format: String): Long => Long = rebaseSpec.mode match { + case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + if (micros < RebaseDateTime.lastSwitchJulianTs) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => micros: Long => + RebaseDateTime.rebaseJulianToGregorianMicros(TimeZone.getTimeZone(rebaseSpec.timeZone), micros) + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..ee31534a8b0fc --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,515 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hudi.common.schema.HoodieSchema +import org.apache.hudi.common.schema.HoodieSchema.VectorLogicalType + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.GenericData +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} +import org.apache.avro.util.Utf8 +import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, createTimestampRebaseFuncInWrite} +import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime} +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.types._ + +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.TimeZone + +import scala.collection.JavaConverters._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + * + * NOTE: This code is borrowed from Spark 3.3.0 + * This code is borrowed, so that we can better control compatibility w/in Spark minor + * branches (3.2.x, 3.1.x, etc) + * + * NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH THE MODIFICATION + * BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND PROPERLY ALL THE + * MODIFICATIONS. + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +private[sql] class AvroSerializer(rootCatalystType: DataType, + rootAvroType: Schema, + nullable: Boolean, + positionalFieldMatch: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging { + + def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = { + this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false, + SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE, LegacyBehaviorPolicy.CORRECTED)) + } + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val dateRebaseFunc = createDateRebaseFuncInWrite( + datetimeRebaseMode, "Avro") + + private val timestampRebaseFunc = createTimestampRebaseFuncInWrite( + datetimeRebaseMode, "Avro") + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = try { + rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType, Nil, Nil).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType, Nil, Nil) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + } catch { + case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException( + s"Cannot convert SQL type ${rootCatalystType.sql} to Avro type $rootAvroType.", ise) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private lazy val decimalConversions = new DecimalConversion() + + private def newConverter(catalystType: DataType, + avroType: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): Converter = { + val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " + + s"to Avro ${toFieldStr(avroPath)} because " + (catalystType, avroType.getType) match { + case (NullType, NULL) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (d: DecimalType, FIXED) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (d: DecimalType, BYTES) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + // Handle VECTOR logical type (FLOAT, DOUBLE, INT8) + case (ArrayType(elementType, false), FIXED) => avroType.getLogicalType match { + case vectorLogicalType: VectorLogicalType => + val dimension = vectorLogicalType.getDimension + val vecElementType = HoodieSchema.Vector.VectorElementType.fromString(vectorLogicalType.getElementType) + val bufferSize = Math.multiplyExact(dimension, vecElementType.getElementSize) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + if (arrayData.numElements() != dimension) { + throw new IncompatibleSchemaException( + s"VECTOR dimension mismatch at ${toFieldStr(catalystPath)}: " + + s"expected=$dimension, actual=${arrayData.numElements()}") + } + elementType match { + case FloatType => + val buffer = ByteBuffer.allocate(bufferSize).order(VectorLogicalType.VECTOR_BYTE_ORDER) + var i = 0; while (i < dimension) { buffer.putFloat(arrayData.getFloat(i)); i += 1 } + new Fixed(avroType, buffer.array()) + case DoubleType => + val buffer = ByteBuffer.allocate(bufferSize).order(VectorLogicalType.VECTOR_BYTE_ORDER) + var i = 0; while (i < dimension) { buffer.putDouble(arrayData.getDouble(i)); i += 1 } + new Fixed(avroType, buffer.array()) + case ByteType => + val bytes = new Array[Byte](dimension) + var i = 0; while (i < dimension) { bytes(i) = arrayData.getByte(i); i += 1 } + new Fixed(avroType, bytes) + case _ => throw new IncompatibleSchemaException(errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = $avroType)") + } + } + case _ => throw new IncompatibleSchemaException(errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = $avroType)") + } + + case (StringType, ENUM) => + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException(errorPrefix + + s""""$data" cannot be written since it's not defined in enum """ + + enumSymbols.mkString("\"", "\", \"", "\"")) + } + new EnumSymbol(avroType, data) + + case (StringType, STRING) => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + + case (BinaryType, FIXED) => + val size = avroType.getFixedSize + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else "byte"}" + + throw new IncompatibleSchemaException(errorPrefix + len2str(data.length) + + " of binary data cannot be written into FIXED type with size of " + len2str(size)) + } + new Fixed(avroType, data) + + case (BinaryType, BYTES) => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + + case (DateType, INT) => + (getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal)) + + case (TimestampType, LONG) => avroType.getLogicalType match { + // For backward compatibility, if the Avro type is Long and it is not logical type + // (the `null` case), output the timestamp value as with millisecond precision. + case null | _: TimestampMillis => (getter, ordinal) => + DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal))) + case _: TimestampMicros => (getter, ordinal) => + timestampRebaseFunc(getter.getLong(ordinal)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"SQL type ${TimestampType.sql} cannot be converted to Avro logical type $other") + } + + case (TimestampNTZType, LONG) => avroType.getLogicalType match { + // To keep consistent with TimestampType, if the Avro type is Long and it is not + // logical type (the `null` case), output the TimestampNTZ as long value + // in millisecond precision. + case null | _: LocalTimestampMillis => (getter, ordinal) => + DateTimeUtils.microsToMillis(getter.getLong(ordinal)) + case _: LocalTimestampMicros => (getter, ordinal) => + getter.getLong(ordinal) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"SQL type ${TimestampNTZType.sql} cannot be converted to Avro logical type $other") + } + + case (ArrayType(et, containsNull), ARRAY) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull), + catalystPath :+ "element", avroPath :+ "element") + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // avro writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (VariantType, RECORD) if avroType.getProp("logicalType") == HoodieSchema.VARIANT_TYPE_NAME => + // Fail fast if schema is mismatched + val valueField = avroType.getField("value") + val metadataField = avroType.getField("metadata") + + if (valueField == null || metadataField == null) { + throw new IncompatibleSchemaException(errorPrefix + + s"Avro schema with 'variant' logical type must have 'value' and 'metadata' fields. " + + s"Found: ${avroType.getFields.asScala.map(_.name()).mkString(", ")}") + } + + // Pre-calculation: Cache indices for performance + val valueIdx = valueField.pos() + val metadataIdx = metadataField.pos() + + // Variant types are stored as records with "value" and "metadata" binary fields + // This matches the schema created in SchemaConverters.toAvroType + (getter, ordinal) => + val variant = getter.getVariant(ordinal) + val record = new GenericData.Record(avroType) + + // Use positional access in serialization loop + record.put(valueIdx, ByteBuffer.wrap(variant.getValue)) + record.put(metadataIdx, ByteBuffer.wrap(variant.getMetadata)) + record + + case (st: StructType, RECORD) => + val structConverter = newStructConverter(st, avroType, catalystPath, avroPath) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + //////////////////////////////////////////////////////////////////////////////////////////// + // Following section is amended to the original (Spark's) implementation + // >>> BEGINS + //////////////////////////////////////////////////////////////////////////////////////////// + + case (st: StructType, UNION) => + val unionConverter = newUnionConverter(st, avroType, catalystPath, avroPath) + val numFields = st.length + (getter, ordinal) => unionConverter(getter.getStruct(ordinal, numFields)) + + //////////////////////////////////////////////////////////////////////////////////////////// + // <<< ENDS + //////////////////////////////////////////////////////////////////////////////////////////// + + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull), + catalystPath :+ "value", avroPath :+ "value") + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val result = new java.util.HashMap[String, Any](len) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val key = keyArray.getUTF8String(i).toString + if (valueContainsNull && valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case (_: YearMonthIntervalType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (_: DayTimeIntervalType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + + case _ => + throw new IncompatibleSchemaException(errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = $avroType)") + } + } + + private def newStructConverter(catalystStruct: StructType, + avroStruct: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): InternalRow => Record = { + + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch) + + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + avroSchemaHelper.validateNoExtraRequiredAvroFields() + + val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, _, avroField) => + val converter = newConverter(catalystField.dataType, + resolveNullableType(avroField.schema(), catalystField.nullable), + catalystPath :+ catalystField.name, avroPath :+ avroField.name) + (avroField.pos(), converter) + }.toArray.unzip + + val numFields = catalystStruct.length + row: InternalRow => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(avroIndices(i), null) + } else { + result.put(avroIndices(i), fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // Following section is amended to the original (Spark's) implementation + // >>> BEGINS + //////////////////////////////////////////////////////////////////////////////////////////// + + private def newUnionConverter(catalystStruct: StructType, + avroUnion: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): InternalRow => Any = { + if (avroUnion.getType != UNION || !canMapUnion(catalystStruct, avroUnion)) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + + s"Avro type $avroUnion.") + } + val nullable = avroUnion.getTypes.size() > 0 && avroUnion.getTypes.get(0).getType == Type.NULL + val avroInnerTypes = if (nullable) { + avroUnion.getTypes.asScala.tail + } else { + avroUnion.getTypes.asScala + } + val fieldConverters = catalystStruct.zip(avroInnerTypes).map { + case (f1, f2) => newConverter(f1.dataType, f2, catalystPath, avroPath) + } + val numFields = catalystStruct.length + (row: InternalRow) => + var i = 0 + var result: Any = null + while (i < numFields) { + if (!row.isNullAt(i)) { + if (result != null) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst record $catalystStruct to " + + s"Avro union $avroUnion. Record has more than one optional values set") + } + result = fieldConverters(i).apply(row, i) + } + i += 1 + } + if (!nullable && result == null) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst record $catalystStruct to " + + s"Avro union $avroUnion. Record has no values set, while should have exactly one") + } + result + } + + private def canMapUnion(catalystStruct: StructType, avroStruct: Schema): Boolean = { + (avroStruct.getTypes.size() > 0 && + avroStruct.getTypes.get(0).getType == Type.NULL && + avroStruct.getTypes.size() - 1 == catalystStruct.length) || avroStruct.getTypes.size() == catalystStruct.length + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // <<< ENDS + //////////////////////////////////////////////////////////////////////////////////////////// + + + /** + * Resolve a possibly nullable Avro Type. + * + * An Avro type is nullable when it is a [[UNION]] of two types: one null type and another + * non-null type. This method will check the nullability of the input Avro type and return the + * non-null type within when it is nullable. Otherwise it will return the input Avro type + * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an + * unsupported nullable type. + * + * It will also log a warning message if the nullability for Avro and catalyst types are + * different. + */ + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + val (avroNullable, resolvedAvroType) = resolveAvroType(avroType) + warnNullabilityDifference(avroNullable, nullable) + resolvedAvroType + } + + /** + * Check the nullability of the input Avro type and resolve it when it is nullable. The first + * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second + * return value is the possibly resolved type. + */ + private def resolveAvroType(avroType: Schema): (Boolean, Schema) = { + if (avroType.getType == Type.UNION) { + val fields = avroType.getTypes.asScala + val actualType = fields.filter(_.getType != Type.NULL) + if (fields.length == 2 && actualType.length == 1) { + (true, actualType.head) + } else { + // This is just a normal union, not used to designate nullability + (false, avroType) + } + } else { + (false, avroType) + } + } + + /** + * log a warning message if the nullability for Avro and catalyst types are different. + */ + private def warnNullabilityDifference(avroNullable: Boolean, catalystNullable: Boolean): Unit = { + if (avroNullable && !catalystNullable) { + logWarning("Writing Avro files with nullable Avro schema and non-nullable catalyst schema.") + } + if (!avroNullable && catalystNullable) { + logWarning("Writing Avro files with non-nullable Avro schema and nullable catalyst " + + "schema will throw runtime exception if there is a record with null value.") + } + } +} + +object AvroSerializer { + + // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroSerializer]] implementation + // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]]. + // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch, + // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as + // w/ Spark >= 3.2.1 + // + // [1] https://github.com/apache/spark/pull/34978 + + def createDateRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => days: Int => + if (days < RebaseDateTime.lastSwitchGregorianDay) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def createTimestampRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + if (micros < RebaseDateTime.lastSwitchGregorianTs) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => + val timeZone = SQLConf.get.sessionLocalTimeZone + RebaseDateTime.rebaseGregorianToJulianMicros(TimeZone.getTimeZone(timeZone), _) + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } + +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala new file mode 100644 index 0000000000000..8aae6b442f8a1 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.avro.file. FileReader +import org.apache.avro.generic.GenericRecord +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import java.util.Locale + +import scala.collection.JavaConverters._ + +/** + * NOTE: This code is borrowed from Spark 3.3.0 + * This code is borrowed, so that we can better control compatibility w/in Spark minor + * branches (3.2.x, 3.1.x, etc) + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +private[sql] object AvroUtils extends Logging { + + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _: NullType => true + + case _ => false + } + + // The trait provides iterator-like interface for reading records from an Avro file, + // deserializing and returning them as internal rows. + trait RowReader { + protected val fileReader: FileReader[GenericRecord] + protected val deserializer: AvroDeserializer + protected val stopPosition: Long + + private[this] var completed = false + private[this] var currentRow: Option[InternalRow] = None + + def hasNextRow: Boolean = { + while (!completed && currentRow.isEmpty) { + val r = fileReader.hasNext && !fileReader.pastSync(stopPosition) + if (!r) { + fileReader.close() + completed = true + currentRow = None + } else { + val record = fileReader.next() + // the row must be deserialized in hasNextRow, because AvroDeserializer#deserialize + // potentially filters rows + currentRow = deserializer.deserialize(record).asInstanceOf[Option[InternalRow]] + } + } + currentRow.isDefined + } + + def nextRow: InternalRow = { + if (currentRow.isEmpty) { + hasNextRow + } + val returnRow = currentRow + currentRow = None // free up hasNextRow to consume more Avro records, if not exhausted + returnRow.getOrElse { + throw new NoSuchElementException("next on empty iterator") + } + } + } + + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Avro field. */ + private[sql] case class AvroMatchedField( + catalystField: StructField, + catalystPosition: Int, + avroField: Schema.Field) + + /** + * Helper class to perform field lookup/matching on Avro schemas. + * + * This will match `avroSchema` against `catalystSchema`, attempting to find a matching field in + * the Avro schema for each field in the Catalyst schema and vice-versa, respecting settings for + * case sensitivity. The match results can be accessed using the getter methods. + * + * @param avroSchema The schema in which to search for fields. Must be of type RECORD. + * @param catalystSchema The Catalyst schema to use for matching. + * @param avroPath The seq of parent field names leading to `avroSchema`. + * @param catalystPath The seq of parent field names leading to `catalystSchema`. + * @param positionalFieldMatch If true, perform field matching in a positional fashion + * (structural comparison between schemas, ignoring names); + * otherwise, perform field matching using field names. + */ + class AvroSchemaHelper( + avroSchema: Schema, + catalystSchema: StructType, + avroPath: Seq[String], + catalystPath: Seq[String], + positionalFieldMatch: Boolean) { + if (avroSchema.getType != Schema.Type.RECORD) { + throw new IncompatibleSchemaException( + s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}") + } + + private[this] val avroFieldArray = avroSchema.getFields.asScala.toArray + private[this] val fieldMap = avroSchema.getFields.asScala + .groupBy(_.name.toLowerCase(Locale.ROOT)) + .mapValues(_.toSeq) // toSeq needed for scala 2.13 + + /** The fields which have matching equivalents in both Avro and Catalyst schemas. */ + val matchedFields: Seq[AvroMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getAvroField(sqlField.name, sqlPos).map(AvroMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Avro field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false, + * consider nullable Catalyst fields to be eligible to be an extra field; otherwise, + * ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) => + if (getAvroField(sqlField.name, sqlPos).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + if (positionalFieldMatch) { + throw new IncompatibleSchemaException("Cannot find field at position " + + s"$sqlPos of ${toFieldStr(avroPath)} from Avro schema (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Avro schema") + } + } + } + + /** + * Validate that there are no Avro fields which don't have a matching Catalyst field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. Only required (non-nullable) + * fields are checked; nullable fields are ignored. + */ + def validateNoExtraRequiredAvroFields(): Unit = { + val extraFields = avroFieldArray.toSet -- matchedFields.map(_.avroField) + extraFields.filterNot(isNullable).foreach { extraField => + if (positionalFieldMatch) { + throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " + + s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " + + s"match in the SQL schema at ${toFieldStr(catalystPath)} (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(avroPath :+ extraField.name())} in Avro schema but there is no " + + "match in the SQL schema") + } + } + } + + /** + * Extract a single field from the contained avro schema which has the desired field name, + * performing the matching with proper case sensitivity according to SQLConf.resolver. + * + * @param name The name of the field to search for. + * @return `Some(match)` if a matching Avro field is found, otherwise `None`. + */ + private[avro] def getFieldByName(name: String): Option[Schema.Field] = { + + // get candidates, ignoring case of field name + val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty) + + // search candidates, taking into account case sensitivity settings + candidates.filter(f => SQLConf.get.resolver(f.name(), name)) match { + case Seq(avroField) => Some(avroField) + case Seq() => None + case matches => throw new IncompatibleSchemaException(s"Searching for '$name' in Avro " + + s"schema at ${toFieldStr(avroPath)} gave ${matches.size} matches. Candidates: " + + matches.map(_.name()).mkString("[", ", ", "]") + ) + } + } + + /** Get the Avro field corresponding to the provided Catalyst field name/position, if any. */ + def getAvroField(fieldName: String, catalystPos: Int): Option[Schema.Field] = { + if (positionalFieldMatch) { + avroFieldArray.lift(catalystPos) + } else { + getFieldByName(fieldName) + } + } + } + + /** + * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable + * string representing the field, like "field 'foo.bar'". If `names` is empty, the string + * "top-level record" is returned. + */ + private[avro] def toFieldStr(names: Seq[String]): String = names match { + case Seq() => "top-level record" + case n => s"field '${n.mkString(".")}'" + } + + /** Return true iff `avroField` is nullable, i.e. `UNION` type and has `NULL` as an option. */ + private[avro] def isNullable(avroField: Schema.Field): Boolean = + avroField.schema().getType == Schema.Type.UNION && + avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL) +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark4_2AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark4_2AvroDeserializer.scala new file mode 100644 index 0000000000000..00e84fa304447 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark4_2AvroDeserializer.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.types.DataType + +class HoodieSpark4_2AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) + extends HoodieAvroDeserializer { + + private val avroDeserializer = new AvroDeserializer(rootAvroType, rootCatalystType, + SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ, LegacyBehaviorPolicy.CORRECTED)) + + def deserialize(data: Any): Option[Any] = avroDeserializer.deserialize(data) +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark4_2AvroSerializer.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark4_2AvroSerializer.scala new file mode 100644 index 0000000000000..173f6c4d14de6 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark4_2AvroSerializer.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.spark.sql.types.DataType + +class HoodieSpark4_2AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) + extends HoodieAvroSerializer { + + val avroSerializer = new AvroSerializer(rootCatalystType, rootAvroType, nullable) + + override def serialize(catalystData: Any): Any = avroSerializer.serialize(catalystData) +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark42PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark42PartitionedFileUtils.scala new file mode 100644 index 0000000000000..1942c2ed38eca --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark42PartitionedFileUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hudi.common.util.ReflectionUtils +import org.apache.hudi.storage.StoragePath + +import org.apache.hadoop.fs.FileStatus +import org.apache.spark.paths.SparkPath +import org.apache.spark.sql.catalyst.InternalRow + +/** + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 4.0. + */ +object HoodieSpark42PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { + override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): StoragePath = { + new StoragePath(partitionedFile.filePath.toUri) + } + + override def getStringPathFromPartitionedFile(partitionedFile: PartitionedFile): String = { + partitionedFile.filePath.toPath.toString + } + + override def createPartitionedFile(partitionValues: InternalRow, + filePath: StoragePath, + start: Long, + length: Long): PartitionedFile = { + PartitionedFile(partitionValues, SparkPath.fromUri(filePath.toUri), start, length, Array.empty) + } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + val files: Seq[FileStatusWithMetadata] = partitionDirs.flatMap(_.files) + try { + files.map(_.fileStatus) + } catch { + case _: NoSuchMethodException | _: NoSuchMethodError | _: IllegalArgumentException => + val methodOpt = ReflectionUtils.getMethod(classOf[FileStatusWithMetadata], "toFileStatus") + if (methodOpt.isPresent) { + val method = methodOpt.get() + files.map(f => method.invoke(f).asInstanceOf[FileStatus]) + } else { + throw new RuntimeException( + "Cannot find toFileStatus method on FileStatusWithMetadata in custom Spark Runtime") + } + } + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses.toArray) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/Spark42NestedSchemaPruning.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/Spark42NestedSchemaPruning.scala new file mode 100644 index 0000000000000..2a26e4a0dd9c0 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/Spark42NestedSchemaPruning.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hudi.HoodieBaseRelation + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.types.StructType + +class Spark42NestedSchemaPruning extends BaseHoodieNestedSchemaPruning { + + // Prune the given output to make it consistent with `requiredSchema`. + protected def getPrunedOutput(output: Seq[AttributeReference], + requiredSchema: StructType): Seq[AttributeReference] = { + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = output.map(att => (att.name, att.exprId)).toMap + DataTypeUtils.toAttributes(requiredSchema) + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + } + + override protected def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op@PhysicalOperation(projects, filters, + // NOTE: This is modified to accommodate for Hudi's custom relations, given that original + // [[NestedSchemaPruning]] rule is tightly coupled w/ [[HadoopFsRelation]] + // TODO generalize to any file-based relation + l@LogicalRelation(relation: HoodieBaseRelation, _, _, _, _)) + if relation.canPruneRelationSchema => + + prunePhysicalColumns(l.output, projects, filters, relation.dataSchema, + prunedDataSchema => { + val prunedRelation = + relation.updatePrunedDataSchema(prunedSchema = prunedDataSchema) + buildPrunedRelation(l, prunedRelation) + }).getOrElse(op) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/orc/Spark42OrcReader.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/orc/Spark42OrcReader.scala new file mode 100644 index 0000000000000..fccf826ee99f2 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/orc/Spark42OrcReader.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.memory.MemoryMode +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.execution.datasources.{FileFormat, PartitionedFile} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +class Spark42OrcReader(enableVectorizedReader: Boolean, + memoryMode: MemoryMode, + dataSchema: StructType, + orcFilterPushDown: Boolean, + isCaseSensitive: Boolean, + capacity: Int) extends SparkOrcReaderBase(enableVectorizedReader, dataSchema, orcFilterPushDown, isCaseSensitive) { + + override def partitionedFileToPath(file: PartitionedFile): Path = { + file.toPath + } + + override def buildReader(): OrcColumnarBatchReader = { + new OrcColumnarBatchReader(capacity, memoryMode) + } + + override def structTypeToAttributes(schema: StructType): Seq[Attribute] = { + toAttributes(schema) + } +} + +object Spark42OrcReader { + /** + * Get ORC file reader + * + * @param vectorized true if vectorized reading is not prohibited due to schema, reading mode, etc + * @param sqlConf the [[SQLConf]] used for the read + * @param options passed as a param to the file format + * @param hadoopConf some configs will be set for the hadoopConf + * @return ORC file reader + */ + def build(vectorized: Boolean, + sqlConf: SQLConf, + options: Map[String, String], + hadoopConf: Configuration, + dataSchema: StructType): Spark42OrcReader = { + //set hadoopconf + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + + val memoryMode = if (sqlConf.offHeapColumnVectorEnabled) { + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } + + val enableVectorizedReader = sqlConf.orcVectorizedReaderEnabled && + options.getOrElse(FileFormat.OPTION_RETURNING_BATCH, + throw new IllegalArgumentException( + "OPTION_RETURNING_BATCH should always be set for OrcFileFormat. " + + "To workaround this issue, set spark.sql.orc.enableVectorizedReader=false.")) + .equals("true") + + new Spark42OrcReader( + enableVectorizedReader = enableVectorizedReader && vectorized, + memoryMode = memoryMode, + isCaseSensitive = sqlConf.caseSensitiveAnalysis, + capacity = sqlConf.orcVectorizedReaderBatchSize, + orcFilterPushDown = sqlConf.orcFilterPushDown, + dataSchema = dataSchema) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42DataSourceUtils.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42DataSourceUtils.scala new file mode 100644 index 0000000000000..7c62b8125cc28 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42DataSourceUtils.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.util.Utils + +object Spark42DataSourceUtils { + + /** + * NOTE: This method was copied from [[Spark32PlusDataSourceUtils]], and is required to maintain runtime + * compatibility against Spark 3.5.0 + */ + // scalastyle:off + def int96RebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + + /** + * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime + * compatibility against Spark 3.2.0 + */ + // scalastyle:off + def datetimeRebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42LegacyHoodieParquetFileFormat.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42LegacyHoodieParquetFileFormat.scala new file mode 100644 index 0000000000000..af5481c9bba12 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42LegacyHoodieParquetFileFormat.scala @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hudi.client.utils.SparkInternalSchemaConverter +import org.apache.hudi.common.fs.FSUtils +import org.apache.hudi.common.table.timeline.TimelineLayout +import org.apache.hudi.common.table.timeline.versioning.TimelineLayoutVersion +import org.apache.hudi.common.util.InternalSchemaCache +import org.apache.hudi.common.util.StringUtils.isNullOrEmpty +import org.apache.hudi.common.util.collection.Pair +import org.apache.hudi.internal.schema.InternalSchema +import org.apache.hudi.internal.schema.action.InternalSchemaMerger +import org.apache.hudi.internal.schema.utils.{InternalSchemaUtils, SerDeHelper} +import org.apache.hudi.storage.hadoop.HoodieHadoopStorage + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.FileSplit +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Cast, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.parquet.Spark42LegacyHoodieParquetFileFormat._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{AtomicType, DataType, StructField, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} + +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` + +/** + * This class is an extension of [[ParquetFileFormat]] overriding Spark-specific behavior + * that's not possible to customize in any other way + * + * NOTE: This is a version of [[AvroDeserializer]] impl from Spark 3.2.1 w/ w/ the following changes applied to it: + *
    + *
  1. Avoiding appending partition values to the rows read from the data file
  2. + *
  3. Schema on-read
  4. + *
+ */ +class Spark42LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValues: Boolean) extends ParquetFileFormat { + + def supportsColumnar(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = getSqlConf(sparkSession) + // Only output columnar if there is WSCG to read it. + val requiredWholeStageCodegenSettings = + conf.wholeStageEnabled && !WholeStageCodegenExec.isTooManyFields(conf, schema) + requiredWholeStageCodegenSettings && + supportBatch(sparkSession, schema) + } + + override def buildReaderWithPartitionValues(sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val sqlConf = getSqlConf(sparkSession) + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set( + ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + requiredSchema.json) + hadoopConf.set( + ParquetWriteSupport.SPARK_ROW_SCHEMA, + requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sqlConf.caseSensitiveAnalysis) + + ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp) + // Using string value of this conf to preserve compatibility across spark versions. + hadoopConf.setBoolean( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + sqlConf.getConfString( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValueString).toBoolean + ) + hadoopConf.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key, sqlConf.parquetInferTimestampNTZEnabled) + hadoopConf.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, sqlConf.legacyParquetNanosAsLong) + val internalSchemaStr = hadoopConf.get(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA) + // For Spark DataSource v1, there's no Physical Plan projection/schema pruning w/in Spark itself, + // therefore it's safe to do schema projection here + if (!isNullOrEmpty(internalSchemaStr)) { + val prunedInternalSchemaStr = + pruneInternalSchema(internalSchemaStr, requiredSchema) + hadoopConf.set(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA, prunedInternalSchemaStr) + } + + val broadcastedHadoopConf = + SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf) + + // TODO: if you move this into the closure it reverts to the default values. + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader: Boolean = supportBatch(sparkSession, resultSchema) + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringPredicate + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val parquetOptions = new ParquetOptions(options, sqlConf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + val timeZoneId = Option(sqlConf.sessionLocalTimeZone) + // Should always be set by FileSourceScanExec creating this. + // Check conf before checking option, to allow working around an issue by changing conf. + val returningBatch = sqlConf.parquetVectorizedReaderEnabled && + supportsColumnar(sparkSession, resultSchema).toString.equals("true") + + + (file: PartitionedFile) => { + assert(!shouldAppendPartitionValues || file.partitionValues.numFields == partitionSchema.size) + + val filePath = file.filePath.toPath + val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) + + val sharedConf = broadcastedHadoopConf.value.value + + // Fetch internal schema + val internalSchemaStr = sharedConf.get(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA) + // Internal schema has to be pruned at this point + val querySchemaOption = SerDeHelper.fromJson(internalSchemaStr) + + var shouldUseInternalSchema = !isNullOrEmpty(internalSchemaStr) && querySchemaOption.isPresent + + val tablePath = sharedConf.get(SparkInternalSchemaConverter.HOODIE_TABLE_PATH) + val fileSchema = if (shouldUseInternalSchema) { + val commitInstantTime = FSUtils.getCommitTime(filePath.getName).toLong; + val validCommits = sharedConf.get(SparkInternalSchemaConverter.HOODIE_VALID_COMMITS_LIST) + val storage = new HoodieHadoopStorage(tablePath, sharedConf) + //TODO: HARDCODED TIMELINE OBJECT + val layout = TimelineLayout.fromVersion(TimelineLayoutVersion.CURR_LAYOUT_VERSION) + InternalSchemaCache.getInternalSchemaByVersionId( + commitInstantTime, tablePath, storage, if (validCommits == null) "" else validCommits, layout) + } else { + null + } + + // When there are vectorized reads, we can avoid + // 1. opening the file twice by transfering the SeekableInputStream + // 2. reading the footer twice by reading all row groups in advance and filter row groups + // according to filters that require push down + val openedFooter = ParquetFooterReader.openFileAndReadFooter(sharedConf, file, enableVectorizedReader) + assert(openedFooter.inputStreamOpt.isPresent == enableVectorizedReader) + + // Before transferring the ownership of inputStream to the vectorizedReader, + // we must take responsibility to close the inputStream if something goes wrong + // to avoid resource leak. + val shouldCloseInputStream = new AtomicBoolean(openedFooter.inputStreamOpt.isPresent) + try { + val footerFileMetaData = openedFooter.footer.getFileMetaData + val datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + val int96RebaseSpec = DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters.map(rebuildFilterFromParquet(_, fileSchema, querySchemaOption.orElse(null))) + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' + // *only* if the file was created by something other than "parquet-mr", so check the actual + // writer here for this file. We have to do this per-file, as each file in the table may + // have different writers. + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + + // Clone new conf + val hadoopAttemptConf = new Configuration(broadcastedHadoopConf.value.value) + val typeChangeInfos: java.util.Map[Integer, Pair[DataType, DataType]] = if (shouldUseInternalSchema) { + val mergedInternalSchema = new InternalSchemaMerger(fileSchema, querySchemaOption.get(), true, true).mergeSchema() + val mergedSchema = SparkInternalSchemaConverter.constructSparkSchemaFromInternalSchema(mergedInternalSchema) + + hadoopAttemptConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, mergedSchema.json) + + SparkInternalSchemaConverter.collectTypeChangedCols(querySchemaOption.get(), mergedInternalSchema) + } else { + val (implicitTypeChangeInfo, sparkRequestSchema) = HoodieParquetFileFormatHelper.buildImplicitSchemaChangeInfo(hadoopAttemptConf, footerFileMetaData, requiredSchema) + if (!implicitTypeChangeInfo.isEmpty) { + shouldUseInternalSchema = true + hadoopAttemptConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, sparkRequestSchema.json) + } + implicitTypeChangeInfo + } + + if (enableVectorizedReader && shouldUseInternalSchema && + !typeChangeInfos.values().forall(_.getLeft.isInstanceOf[AtomicType])) { + throw new IllegalArgumentException( + "Nested types with type changes(implicit or explicit) cannot be read in vectorized mode. " + + "To workaround this issue, set spark.sql.parquet.enableVectorizedReader=false.") + } + + val hadoopAttemptContext = + new TaskAttemptContextImpl(hadoopAttemptConf, attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + pushed.foreach { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, _) + } + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val vectorizedReader = + if (shouldUseInternalSchema) { + new HoodieVectorizedParquetRecordReader( + convertTz.orNull, + datetimeRebaseSpec.mode.toString, + datetimeRebaseSpec.timeZone, + int96RebaseSpec.mode.toString, + int96RebaseSpec.timeZone, + enableOffHeapColumnVector && taskContext.isDefined, + capacity, + typeChangeInfos) + } else { + new VectorizedParquetRecordReader( + convertTz.orNull, + datetimeRebaseSpec.mode.toString, + datetimeRebaseSpec.timeZone, + int96RebaseSpec.mode.toString, + int96RebaseSpec.timeZone, + enableOffHeapColumnVector && taskContext.isDefined, + capacity) + } + + // SPARK-37089: We cannot register a task completion listener to close this iterator here + // because downstream exec nodes have already registered their listeners. Since listeners + // are executed in reverse order of registration, a listener registered here would close the + // iterator while downstream exec nodes are still running. When off-heap column vectors are + // enabled, this can cause a use-after-free bug leading to a segfault. + // + // Instead, we use FileScanRDD's task completion listener to close this iterator. + val iter = new RecordReaderIterator(vectorizedReader) + try { + vectorizedReader.initialize( + split, hadoopAttemptContext, Some(openedFooter.inputFile), + Some(openedFooter.inputStream), Some(openedFooter.footer)) + // The caller don't need to take care of the close of inputStream after calling + // `initialize` because the ownership of inputStream has been transferred to the + // vectorizedReader + shouldCloseInputStream.set(false) + + // NOTE: We're making appending of the partitioned values to the rows read from the + // data file configurable + if (shouldAppendPartitionValues) { + logDebug(s"Appending $partitionSchema ${file.partitionValues}") + vectorizedReader.initBatch(partitionSchema, file.partitionValues) + } else { + vectorizedReader.initBatch(StructType(Nil), InternalRow.empty) + } + + if (returningBatch) { + vectorizedReader.enableReturningBatches() + } + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } + } else { + logDebug(s"Falling back to parquet-mr") + val readSupport = new HoodieParquetReadSupport( + convertTz, + enableVectorizedReader = false, + enableTimestampFieldRepair = true, + datetimeRebaseSpec, + int96RebaseSpec) + + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[InternalRow](readSupport, parquetFilter) + } else { + new ParquetRecordReader[InternalRow](readSupport) + } + val readerWithRowIndexes = ParquetRowIndexUtil.addRowIndexToRecordReaderIfNeeded(reader, requiredSchema) + val iter = new RecordReaderIterator[InternalRow](readerWithRowIndexes) + try { + readerWithRowIndexes.initialize(split, hadoopAttemptContext) + + val fullSchema = DataTypeUtils.toAttributes(requiredSchema) ++ DataTypeUtils.toAttributes(partitionSchema) + val unsafeProjection = if (typeChangeInfos.isEmpty) { + GenerateUnsafeProjection.generate(fullSchema, fullSchema) + } else { + // find type changed. + val newSchema = new StructType(requiredSchema.fields.zipWithIndex.map { case (f, i) => + if (typeChangeInfos.containsKey(i)) { + StructField(f.name, typeChangeInfos.get(i).getRight, f.nullable, f.metadata) + } else f + }) + val newFullSchema = DataTypeUtils.toAttributes(newSchema) ++ DataTypeUtils.toAttributes(partitionSchema) + val castSchema = newFullSchema.zipWithIndex.map { case (attr, i) => + if (typeChangeInfos.containsKey(i)) { + val srcType = typeChangeInfos.get(i).getRight + val dstType = typeChangeInfos.get(i).getLeft + val needTimeZone = Cast.needsTimeZone(srcType, dstType) + Cast(attr, dstType, if (needTimeZone) timeZoneId else None) + } else attr + } + GenerateUnsafeProjection.generate(castSchema, newFullSchema) + } + + // NOTE: We're making appending of the partitioned values to the rows read from the + // data file configurable + if (!shouldAppendPartitionValues || partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + } + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } + } + } finally { + if (shouldCloseInputStream.get) { + openedFooter.inputStreamOpt.ifPresent(Utils.closeQuietly) + } + } + } + } +} + +object Spark42LegacyHoodieParquetFileFormat { + + def pruneInternalSchema(internalSchemaStr: String, requiredSchema: StructType): String = { + val querySchemaOption = SerDeHelper.fromJson(internalSchemaStr) + if (querySchemaOption.isPresent && requiredSchema.nonEmpty) { + val prunedSchema = SparkInternalSchemaConverter.convertAndPruneStructTypeToInternalSchema(requiredSchema, querySchemaOption.get()) + SerDeHelper.toJson(prunedSchema) + } else { + internalSchemaStr + } + } + + private def rebuildFilterFromParquet(oldFilter: Filter, fileSchema: InternalSchema, querySchema: InternalSchema): Filter = { + if (fileSchema == null || querySchema == null) { + oldFilter + } else { + oldFilter match { + case eq: EqualTo => + val newAttribute = InternalSchemaUtils.reBuildFilterName(eq.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else eq.copy(attribute = newAttribute) + case eqs: EqualNullSafe => + val newAttribute = InternalSchemaUtils.reBuildFilterName(eqs.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else eqs.copy(attribute = newAttribute) + case gt: GreaterThan => + val newAttribute = InternalSchemaUtils.reBuildFilterName(gt.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else gt.copy(attribute = newAttribute) + case gtr: GreaterThanOrEqual => + val newAttribute = InternalSchemaUtils.reBuildFilterName(gtr.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else gtr.copy(attribute = newAttribute) + case lt: LessThan => + val newAttribute = InternalSchemaUtils.reBuildFilterName(lt.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else lt.copy(attribute = newAttribute) + case lte: LessThanOrEqual => + val newAttribute = InternalSchemaUtils.reBuildFilterName(lte.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else lte.copy(attribute = newAttribute) + case i: In => + val newAttribute = InternalSchemaUtils.reBuildFilterName(i.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else i.copy(attribute = newAttribute) + case isn: IsNull => + val newAttribute = InternalSchemaUtils.reBuildFilterName(isn.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else isn.copy(attribute = newAttribute) + case isnn: IsNotNull => + val newAttribute = InternalSchemaUtils.reBuildFilterName(isnn.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else isnn.copy(attribute = newAttribute) + case And(left, right) => + And(rebuildFilterFromParquet(left, fileSchema, querySchema), rebuildFilterFromParquet(right, fileSchema, querySchema)) + case Or(left, right) => + Or(rebuildFilterFromParquet(left, fileSchema, querySchema), rebuildFilterFromParquet(right, fileSchema, querySchema)) + case Not(child) => + Not(rebuildFilterFromParquet(child, fileSchema, querySchema)) + case ssw: StringStartsWith => + val newAttribute = InternalSchemaUtils.reBuildFilterName(ssw.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else ssw.copy(attribute = newAttribute) + case ses: StringEndsWith => + val newAttribute = InternalSchemaUtils.reBuildFilterName(ses.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else ses.copy(attribute = newAttribute) + case sc: StringContains => + val newAttribute = InternalSchemaUtils.reBuildFilterName(sc.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else sc.copy(attribute = newAttribute) + case AlwaysTrue => + AlwaysTrue + case AlwaysFalse => + AlwaysFalse + case _ => + AlwaysTrue + } + } + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42ParquetReader.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42ParquetReader.scala new file mode 100644 index 0000000000000..60fbae0566ee8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark42ParquetReader.scala @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hudi.common.util.{Option => HOption} +import org.apache.hudi.internal.schema.InternalSchema +import org.apache.hudi.io.storage.HoodieSparkParquetReader.ENABLE_LOGICAL_TIMESTAMP_REPAIR + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.FileSplit +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import org.apache.parquet.schema.{MessageType, SchemaRepair} +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FileFormat, PartitionedFile, RecordReaderIterator, SparkColumnarFileReader} +import org.apache.spark.sql.execution.datasources.parquet.Spark42ParquetReader.repairFooterSchema +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +import java.io.Closeable +import java.time.ZoneId +import java.util.concurrent.atomic.AtomicBoolean + +class Spark42ParquetReader(enableVectorizedReader: Boolean, + datetimeRebaseModeInRead: String, + int96RebaseModeInRead: String, + enableParquetFilterPushDown: Boolean, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownInFilterThreshold: Int, + pushDownStringPredicate: Boolean, + isCaseSensitive: Boolean, + timestampConversion: Boolean, + enableOffHeapColumnVector: Boolean, + capacity: Int, + returningBatch: Boolean, + enableRecordFilter: Boolean, + enableLogicalTimestampRepair: Boolean, + timeZoneId: Option[String]) extends SparkParquetReaderBase( + enableVectorizedReader = enableVectorizedReader, + enableParquetFilterPushDown = enableParquetFilterPushDown, + pushDownDate = pushDownDate, + pushDownTimestamp = pushDownTimestamp, + pushDownDecimal = pushDownDecimal, + pushDownInFilterThreshold = pushDownInFilterThreshold, + isCaseSensitive = isCaseSensitive, + timestampConversion = timestampConversion, + enableOffHeapColumnVector = enableOffHeapColumnVector, + capacity = capacity, + returningBatch = returningBatch, + enableRecordFilter = enableRecordFilter, + enableLogicalTimestampRepair = enableLogicalTimestampRepair, + timeZoneId = timeZoneId) { + + /** + * Read an individual parquet file + * Code from ParquetFileFormat#buildReaderWithPartitionValues from Spark v3.5.1 adapted here + * + * @param file parquet file to read + * @param requiredSchema desired output schema of the data + * @param partitionSchema schema of the partition columns. Partition values will be appended to the end of every row + * @param internalSchemaOpt option of internal schema for schema.on.read + * @param filters filters for data skipping. Not guaranteed to be used; the spark plan will also apply the filters. + * @param sharedConf the hadoop conf + * @return iterator of rows read from the file output type says [[InternalRow]] but could be [[ColumnarBatch]] + */ + override protected def doRead(file: PartitionedFile, + requiredSchema: StructType, + partitionSchema: StructType, + internalSchemaOpt: HOption[InternalSchema], + filters: scala.Seq[Filter], + sharedConf: Configuration, + tableSchemaOpt: HOption[MessageType]): Iterator[InternalRow] = { + assert(file.partitionValues.numFields == partitionSchema.size) + + val filePath = file.toPath + val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) + + val schemaEvolutionUtils = new ParquetSchemaEvolutionUtils(sharedConf, filePath, requiredSchema, + partitionSchema, internalSchemaOpt) + + // When there are vectorized reads, we can avoid + // 1. opening the file twice by transfering the SeekableInputStream + // 2. reading the footer twice by reading all row groups in advance and filter row groups + // according to filters that require push down + val originalFooter = + ParquetFooterReader.openFileAndReadFooter(sharedConf, file, enableVectorizedReader) + val openedFooter = if (enableLogicalTimestampRepair) { + repairFooterSchema(originalFooter, tableSchemaOpt) + } else { + originalFooter + } + assert(openedFooter.inputStreamOpt.isPresent == enableVectorizedReader) + + // Before transferring the ownership of inputStream to the vectorizedReader, + // we must take responsibility to close the inputStream if something goes wrong + // to avoid resource leak. + val shouldCloseInputStream = new AtomicBoolean(openedFooter.inputStreamOpt.isPresent) + try { + val footerFileMetaData = openedFooter.footer.getFileMetaData + val datetimeRebaseSpec = DataSourceUtils.datetimeRebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + val int96RebaseSpec = DataSourceUtils.int96RebaseSpec( + footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + filters.map(schemaEvolutionUtils.rebuildFilterFromParquet) + // Collects all converted Parquet filter predicates. Notice that not all predicates + // can be converted (`ParquetFilters.createFilter` returns an `Option`). That's why + // a `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 + // timestamps *only* if the file was created by something other than "parquet-mr", + // so check the actual writer here for this file. We have to do this per-file, + // as each file in the table may have different writers. + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy.startsWith("parquet-mr") + + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = + new TaskAttemptContextImpl(schemaEvolutionUtils.getHadoopConfClone(footerFileMetaData, enableVectorizedReader), attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + pushed.foreach { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, _) + } + if (enableVectorizedReader) { + buildVectorizedIterator( + hadoopAttemptContext, split, file.partitionValues, partitionSchema, convertTz, + datetimeRebaseSpec, int96RebaseSpec, enableOffHeapColumnVector, returningBatch, + capacity, openedFooter, shouldCloseInputStream, schemaEvolutionUtils) + } else { + buildRowBasedIterator( + hadoopAttemptContext, split, file.partitionValues, partitionSchema, convertTz, + datetimeRebaseSpec, int96RebaseSpec, requiredSchema, pushed, enableRecordFilter, + enableLogicalTimestampRepair, tableSchemaOpt, schemaEvolutionUtils) + } + } finally { + if (shouldCloseInputStream.get) { + openedFooter.inputStreamOpt.ifPresent(Utils.closeQuietly) + } + } + } + + // scalastyle:off parameter.number + private def buildVectorizedIterator( + hadoopAttemptContext: TaskAttemptContextImpl, + split: FileSplit, + partitionValues: InternalRow, + partitionSchema: StructType, + convertTz: Option[ZoneId], + datetimeRebaseSpec: RebaseDateTime.RebaseSpec, + int96RebaseSpec: RebaseDateTime.RebaseSpec, + enableOffHeapColumnVector: Boolean, + returningBatch: Boolean, + batchSize: Int, + openedFooter: OpenedParquetFooter, + shouldCloseInputStream: AtomicBoolean, + schemaEvolutionUtils: ParquetSchemaEvolutionUtils): Iterator[InternalRow] = { + // scalastyle:on parameter.number + assert(openedFooter.inputStreamOpt.isPresent) + val vectorizedReader = schemaEvolutionUtils.buildVectorizedReader( + convertTz.orNull, + datetimeRebaseSpec.mode.toString, + datetimeRebaseSpec.timeZone, + int96RebaseSpec.mode.toString, + int96RebaseSpec.timeZone, + enableOffHeapColumnVector && TaskContext.get() != null, + batchSize) + // SPARK-37089: We cannot register a task completion listener to close this iterator here + // because downstream exec nodes have already registered their listeners. Since listeners + // are executed in reverse order of registration, a listener registered here would close the + // iterator while downstream exec nodes are still running. When off-heap column vectors are + // enabled, this can cause a use-after-free bug leading to a segfault. + // + // Instead, we use FileScanRDD's task completion listener to close this iterator. + val iter = new RecordReaderIterator(vectorizedReader) + try { + vectorizedReader.initialize( + split, hadoopAttemptContext, Some(openedFooter.inputFile), + Some(openedFooter.inputStream), Some(openedFooter.footer)) + // The caller don't need to take care of the close of inputStream after calling + // `initialize` because the ownership of inputStream has been transferred to the + // vectorizedReader + shouldCloseInputStream.set(false) + vectorizedReader.initBatch(partitionSchema, partitionValues) + if (returningBatch) { + vectorizedReader.enableReturningBatches() + } + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } + } + + // scalastyle:off parameter.number + private def buildRowBasedIterator( + hadoopAttemptContext: TaskAttemptContextImpl, + split: FileSplit, + partitionValues: InternalRow, + partitionSchema: StructType, + convertTz: Option[ZoneId], + datetimeRebaseSpec: RebaseDateTime.RebaseSpec, + int96RebaseSpec: RebaseDateTime.RebaseSpec, + requiredSchema: StructType, + pushed: Option[FilterPredicate], + enableRecordFilter: Boolean, + enableLogicalTimestampRepair: Boolean, + tableSchemaOpt: HOption[MessageType], + schemaEvolutionUtils: ParquetSchemaEvolutionUtils): Iterator[InternalRow] with Closeable = { + // scalastyle:on parameter.number + // ParquetRecordReader returns InternalRow + val readSupport = new HoodieParquetReadSupport( + convertTz, + enableVectorizedReader = false, + enableLogicalTimestampRepair, + datetimeRebaseSpec, + int96RebaseSpec, + tableSchemaOpt) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[InternalRow](readSupport, parquetFilter) + } else { + new ParquetRecordReader[InternalRow](readSupport) + } + val readerWithRowIndexes = ParquetRowIndexUtil.addRowIndexToRecordReaderIfNeeded(reader, + requiredSchema) + val iter = new RecordReaderIterator[InternalRow](readerWithRowIndexes) + try { + readerWithRowIndexes.initialize(split, hadoopAttemptContext) + + val fullSchema = toAttributes(requiredSchema) ++ toAttributes(partitionSchema) + val unsafeProjection = schemaEvolutionUtils.generateUnsafeProjection(fullSchema, timeZoneId) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, partitionValues))) + } + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } + } +} + +object Spark42ParquetReader extends SparkParquetReaderBuilder { + /** + * Get parquet file reader + * + * @param vectorized true if vectorized reading is not prohibited due to schema, reading mode, etc + * @param sqlConf the [[SQLConf]] used for the read + * @param options passed as a param to the file format + * @param hadoopConf some configs will be set for the hadoopConf + * @return parquet file reader + */ + def build(vectorized: Boolean, + sqlConf: SQLConf, + options: Map[String, String], + hadoopConf: Configuration): SparkColumnarFileReader = { + //set hadoopconf + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, sqlConf.sessionLocalTimeZone) + hadoopConf.setBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, sqlConf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, sqlConf.caseSensitiveAnalysis) + hadoopConf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, sqlConf.isParquetBinaryAsString) + hadoopConf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sqlConf.isParquetINT96AsTimestamp) + // Using string value of this conf to preserve compatibility across spark versions. See [HUDI-5868] + hadoopConf.setBoolean( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + sqlConf.getConfString( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValueString).toBoolean + ) + hadoopConf.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key, sqlConf.parquetInferTimestampNTZEnabled) + + val enableLogicalTimestampRepair = hadoopConf.getBoolean(ENABLE_LOGICAL_TIMESTAMP_REPAIR, true) + val returningBatch = sqlConf.parquetVectorizedReaderEnabled && + options.getOrElse(FileFormat.OPTION_RETURNING_BATCH, + throw new IllegalArgumentException( + "OPTION_RETURNING_BATCH should always be set for ParquetFileFormat. " + + "To workaround this issue, set spark.sql.parquet.enableVectorizedReader=false.")) + .equals("true") + + val parquetOptions = new ParquetOptions(options, sqlConf) + new Spark42ParquetReader( + enableVectorizedReader = vectorized, + datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead, + int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead, + enableParquetFilterPushDown = sqlConf.parquetFilterPushDown, + pushDownDate = sqlConf.parquetFilterPushDownDate, + pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp, + pushDownDecimal = sqlConf.parquetFilterPushDownDecimal, + pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold, + pushDownStringPredicate = sqlConf.parquetFilterPushDownStringPredicate, + isCaseSensitive = sqlConf.caseSensitiveAnalysis, + timestampConversion = sqlConf.isParquetINT96TimestampConversion, + enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled, + capacity = sqlConf.parquetVectorizedReaderBatchSize, + returningBatch = returningBatch, + enableRecordFilter = sqlConf.parquetRecordFilterEnabled, + enableLogicalTimestampRepair = enableLogicalTimestampRepair, + timeZoneId = Some(sqlConf.sessionLocalTimeZone)) + } + + // Helper to repair the schema if needed + def repairFooterSchema(original: OpenedParquetFooter, + tableSchemaOpt: HOption[MessageType]): OpenedParquetFooter = { + val originalParquetMetadata = original.footer(); + val repairedSchema = SchemaRepair.repairLogicalTypes(originalParquetMetadata.getFileMetaData.getSchema, tableSchemaOpt) + val oldMeta = originalParquetMetadata.getFileMetaData + new OpenedParquetFooter(new ParquetMetadata( + new FileMetaData( + repairedSchema, + oldMeta.getKeyValueMetaData, + oldMeta.getCreatedBy, + oldMeta.getEncryptionType, + oldMeta.getFileDecryptor + ), + originalParquetMetadata.getBlocks + ), original.inputFile(), original.inputStreamOpt()) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/hudi/Spark42ResolveHudiAlterTableCommand.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/hudi/Spark42ResolveHudiAlterTableCommand.scala new file mode 100644 index 0000000000000..c654a2079f9d8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/hudi/Spark42ResolveHudiAlterTableCommand.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi + +import org.apache.hudi.common.config.HoodieCommonConfig +import org.apache.hudi.internal.schema.action.TableChange.ColumnChangeID + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.ResolvedTable +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.hudi.command.{AlterTableCommand => HudiAlterTableCommand} + +/** + * Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. + * for alter table column commands. + */ +class Spark42ResolveHudiAlterTableCommand(sparkSession: SparkSession) extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (ProvidesHoodieConfig.isSchemaEvolutionEnabled(sparkSession)) { + plan.resolveOperatorsUp { + case set@SetTableProperties(ResolvedHoodieV2TablePlan(t), _) if set.resolved => + HudiAlterTableCommand(t.v1Table, set.changes, ColumnChangeID.PROPERTY_CHANGE) + case unSet@UnsetTableProperties(ResolvedHoodieV2TablePlan(t), _, _) if unSet.resolved => + HudiAlterTableCommand(t.v1Table, unSet.changes, ColumnChangeID.PROPERTY_CHANGE) + case drop@DropColumns(ResolvedHoodieV2TablePlan(t), _, _) if drop.resolved => + HudiAlterTableCommand(t.v1Table, drop.changes, ColumnChangeID.DELETE) + case add@AddColumns(ResolvedHoodieV2TablePlan(t), _) if add.resolved => + HudiAlterTableCommand(t.v1Table, add.changes, ColumnChangeID.ADD) + case renameColumn@RenameColumn(ResolvedHoodieV2TablePlan(t), _, _) if renameColumn.resolved => + HudiAlterTableCommand(t.v1Table, renameColumn.changes, ColumnChangeID.UPDATE) + case alter@AlterColumns(ResolvedHoodieV2TablePlan(t), _) if alter.resolved => + HudiAlterTableCommand(t.v1Table, alter.changes, ColumnChangeID.UPDATE) + case replace@ReplaceColumns(ResolvedHoodieV2TablePlan(t), _) if replace.resolved => + HudiAlterTableCommand(t.v1Table, replace.changes, ColumnChangeID.REPLACE) + } + } else { + plan + } + } + + object ResolvedHoodieV2TablePlan { + def unapply(plan: LogicalPlan): Option[HoodieInternalV2Table] = { + plan match { + case ResolvedTable(_, _, v2Table: HoodieInternalV2Table, _) => Some(v2Table) + case _ => None + } + } + } +} + diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark42Analysis.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark42Analysis.scala new file mode 100644 index 0000000000000..54ab3f66baa8a --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark42Analysis.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.analysis + +import org.apache.hudi.{DefaultSource, EmptyRelation, HoodieBaseRelation} +import org.apache.hudi.SparkAdapterSupport.sparkAdapter + +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.analysis.{ResolveInsertionBase, TableOutputResolver} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PreprocessTableInsertion} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec + +/** + * NOTE: PLEASE READ CAREFULLY + * + * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. + * Such fallback will have considerable performance impact, therefore it's only performed in cases + * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature + * + * Check out HUDI-4178 for more details + */ +case class HoodieSpark42DataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] + with ProvidesHoodieConfig { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // The only place we're avoiding fallback is in [[AlterTableCommand]]s since + // current implementation relies on DSv2 features + case _: AlterTableCommand => plan + + // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose + // target relation as a child (even though there's no good reason for that) + case iis@InsertIntoStatement(rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _, _), _, _, _, _, _, _, _, _) => + iis.copy(table = convertToV1(rv2, v2Table)) + + case _ => + plan.resolveOperatorsDown { + case rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _, _) => convertToV1(rv2, v2Table) + } + } + + private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { + val output = rv2.output + val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) + val relation = new DefaultSource().createRelation(sparkSession.sqlContext, + buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) + + LogicalRelation(relation, output, catalogTable, isStreaming = false, Option.empty) + } +} + +/** + * In Spark 3.5, the following Resolution rules are removed, + * [[ResolveUserSpecifiedColumns]] and [[ResolveDefaultColumns]] + * (see code changes in [[org.apache.spark.sql.catalyst.analysis.Analyzer]] + * from https://github.com/apache/spark/pull/41262). + * The same logic of resolving the user specified columns and default values, + * which are required for a subset of columns as user specified compared to the table + * schema to work properly, are deferred to [[PreprocessTableInsertion]] for v1 INSERT. + * + * Note that [[HoodieAnalysis]] intercepts the [[InsertIntoStatement]] after Spark's built-in + * Resolution rules are applies, the logic of resolving the user specified columns and default + * values may no longer be applied. To make INSERT with a subset of columns specified by user + * to work, this custom resolution rule [[HoodieSpark42ResolveColumnsForInsertInto]] is added + * to achieve the same, before converting [[InsertIntoStatement]] into + * [[InsertIntoHoodieTableCommand]]. + * + * The implementation is copied and adapted from [[PreprocessTableInsertion]] + * https://github.com/apache/spark/blob/d061aadf25fd258d2d3e7332a489c9c24a2b5530/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L373 + * + * Also note that, the project logic in [[ResolveImplementationsEarly]] for INSERT is still + * needed in the case of INSERT with all columns in a different ordering. + */ +case class HoodieSpark42ResolveColumnsForInsertInto() extends ResolveInsertionBase { + // NOTE: This is copied from [[PreprocessTableInsertion]] with additional handling of Hudi relations + override def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + case i@InsertIntoStatement(table, _, _, query, _, _, _, _, _) + if table.resolved && query.resolved + && i.userSpecifiedCols.nonEmpty && i.table.isInstanceOf[LogicalRelation] + && sparkAdapter.isHoodieTable(i.table.asInstanceOf[LogicalRelation].catalogTable.get) => + table match { + case relation: HiveTableRelation => + val metadata = relation.tableMeta + preprocess(i, metadata.identifier.quotedString, metadata.partitionSchema, + Some(metadata)) + case LogicalRelation(h: HadoopFsRelation, _, catalogTable, _, _) => + preprocess(i, catalogTable, h.partitionSchema) + case LogicalRelation(_: InsertableRelation, _, catalogTable, _, _) => + preprocess(i, catalogTable, new StructType()) + // The two conditions below are adapted to Hudi relations + case LogicalRelation(_: EmptyRelation, _, catalogTable, _, _) => + preprocess(i, catalogTable) + case LogicalRelation(_: HoodieBaseRelation, _, catalogTable, _, _) => + preprocess(i, catalogTable) + case _ => i + } + case _ => plan + } + } + + private def preprocess(insert: InsertIntoStatement, + catalogTable: Option[CatalogTable]): InsertIntoStatement = { + preprocess(insert, catalogTable, catalogTable.map(_.partitionSchema).getOrElse(new StructType())) + } + + private def preprocess(insert: InsertIntoStatement, + catalogTable: Option[CatalogTable], + partitionSchema: StructType): InsertIntoStatement = { + val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown") + preprocess(insert, tblName, partitionSchema, catalogTable) + } + + // NOTE: this is copied from [[PreprocessTableInsertion]] with additional logic + // to unset user-specified columns at the end + private def preprocess(insert: InsertIntoStatement, + tblName: String, + partColNames: StructType, + catalogTable: Option[CatalogTable]): InsertIntoStatement = { + + val normalizedPartSpec = normalizePartitionSpec( + insert.partitionSpec, partColNames, tblName, conf.resolver) + + val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet + val expectedColumns = insert.table.output.filterNot(a => staticPartCols.contains(a.name)) + + val partitionsTrackedByCatalog = catalogTable.isDefined && + catalogTable.get.partitionColumnNames.nonEmpty && + catalogTable.get.tracksPartitionsInCatalog + if (partitionsTrackedByCatalog && normalizedPartSpec.nonEmpty) { + // empty partition column value + if (normalizedPartSpec.values.flatten.exists(v => v != null && v.isEmpty)) { + val spec = normalizedPartSpec.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") + throw QueryCompilationErrors.invalidPartitionSpecError( + s"The spec ($spec) contains an empty partition column value") + } + } + + // Create a project if this INSERT has a user-specified column list. + val hasColumnList = insert.userSpecifiedCols.nonEmpty + val query = if (hasColumnList) { + createProjectForByNameQuery(tblName, insert) + } else { + insert.query + } + val newQuery = try { + TableOutputResolver.resolveOutputColumns( + tblName, + expectedColumns, + query, + byName = hasColumnList || insert.byName, + conf, + supportColDefaultValue = true) + } catch { + case e: AnalysisException if staticPartCols.nonEmpty && + (e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS" || + e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS") => + val newException = e.copy( + errorClass = Some("INSERT_PARTITION_COLUMN_ARITY_MISMATCH"), + messageParameters = e.messageParameters ++ Map( + "tableColumns" -> insert.table.output.map(c => toSQLId(c.name)).mkString(", "), + "staticPartCols" -> staticPartCols.toSeq.sorted.map(c => toSQLId(c)).mkString(", ") + )) + newException.setStackTrace(e.getStackTrace) + throw newException + } + if (normalizedPartSpec.nonEmpty) { + if (normalizedPartSpec.size != partColNames.length) { + throw QueryCompilationErrors.requestedPartitionsMismatchTablePartitionsError( + tblName, normalizedPartSpec, partColNames) + } + + // NOTE: Hudi converts [[InsertIntoStatement]] to [[InsertIntoHoodieTableCommand]] + // and the user specified is no longer need after resolution + // (`userSpecifiedCols = Seq()`) + insert.copy(query = newQuery, partitionSpec = normalizedPartSpec, userSpecifiedCols = Seq()) + } else { + // All partition columns are dynamic because the InsertIntoTable command does + // not explicitly specify partitioning columns. + // NOTE: Hudi converts [[InsertIntoStatement]] to [[InsertIntoHoodieTableCommand]] + // and the user specified is no longer need after resolution + // (`userSpecifiedCols = Seq()`) + insert.copy(query = newQuery, partitionSpec = partColNames.map(_.name).map(_ -> None).toMap, + userSpecifiedCols = Seq()) + } + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark4_2ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark4_2ExtendedSqlAstBuilder.scala new file mode 100644 index 0000000000000..cebef33848bef --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark4_2ExtendedSqlAstBuilder.scala @@ -0,0 +1,3567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parser + +import org.apache.hudi.common.schema.{HoodieSchema, HoodieSchemaType} +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser} +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._ + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.parser.{EnhancedLogicalPlan, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.parser.ParserUtils.{checkDuplicateClauses, checkDuplicateKeys, entry, escapedIdentifier, operationNotAllowed, source, string, stringWithoutUnescape, validate, withOrigin} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.BucketSpecHelper +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils.isTesting +import org.apache.spark.util.random.RandomSampler + +import javax.xml.bind.DatatypeConverter + +import java.util.Locale +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +/** + * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan. + * Here we only do the parser for the extended sql syntax. e.g MergeInto. For + * other sql syntax we use the delegate sql parser which is the SparkSqlParser. + */ +class HoodieSpark4_2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) + extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging { + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val relation = UnresolvedRelation(tableId) + val table = mayApplyAliasPlan( + ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel)) + table.optionalMap(ctx.sample)(withSample) + } + + private def withTimeTravel( + ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val v = ctx.version + val version = if (ctx.INTEGER_VALUE != null) { + Some(v.getText) + } else { + Option(v).map(string) + } + + val timestamp = Option(ctx.timestamp).map(expression) + if (timestamp.exists(_.references.nonEmpty)) { + throw new ParseException( + "timestamp expression cannot refer to any columns", ctx.timestamp) + } + if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) { + throw new ParseException( + "timestamp expression cannot contain subqueries", ctx.timestamp) + } + + TimeTravelRelation(plan, timestamp, version) + } + + // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + + override def visitSingleMultipartIdentifier( + ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + visitMultipartIdentifier(ctx.multipartIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + typedVisit[DataType](ctx.dataType) + } + + override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { + val schema = StructType(visitColTypeList(ctx.colTypeList)) + withOrigin(ctx)(schema) + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + + // Apply CTEs + query.optionalMap(ctx.ctes)(withCTE) + } + + override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) { + val dmlStmt = plan(ctx.dmlStatementNoWith) + // Apply CTEs + dmlStmt.optionalMap(ctx.ctes)(withCTE) + } + + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { + val ctes = ctx.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + val rowLevelLimit: Option[Int] = if (nCtx.integerValue() != null) { + if (ctx.RECURSIVE() == null) { + operationNotAllowed("Cannot specify MAX RECURSION LEVEL when the CTE is not marked as " + + "RECURSIVE", ctx) + } + Some(getIntegerValue(nCtx.integerValue())) + } else { + None + } + (namedQuery.alias, namedQuery, rowLevelLimit) + } + // Check for duplicate names. + val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys + if (duplicates.nonEmpty) { + throw new ParseException(s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.", ctx) + } + UnresolvedWith(plan, ctes.toSeq, ctx.RECURSIVE() != null) + } + + /** + * Gets the integer value from an IntegerValueContext after parameter replacement. Asserts that + * parameter markers have been substituted before reaching DataTypeAstBuilder. + * + * @param ctx + * The IntegerValueContext to extract the integer from + * @return + * The integer value + */ + private def getIntegerValue(ctx: IntegerValueContext): Int = { + assert( + !ctx.isInstanceOf[ParameterIntegerValueContext], + "Parameter markers should be substituted before DataTypeAstBuilder processes the " + + s"parse tree. Found unsubstituted parameter: ${ctx.getText}") + ctx.getText.toInt + } + + /** + * Create a logical query plan for a hive-style FROM statement body. + */ + private def withFromStatementBody( + ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // two cases for transforms and selects + if (ctx.transformClause != null) { + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } else { + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } + } + + override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + val selects = ctx.fromStatementBody.asScala.map { body => + withFromStatementBody(body, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses) + } + // If there are multiple SELECT just UNION them together into one query. + if (selects.length == 1) { + selects.head + } else { + Union(selects.toSeq) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)( + (columnAliases, plan) => + UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan) + ) + SubqueryAlias(ctx.name.getText, subQuery) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { body => + withInsertInto(body.insertInto, + withFromStatementBody(body.fromStatementBody, from). + optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses)) + } + + // If there are multiple INSERTS just UNION them together into one query. + if (inserts.length == 1) { + inserts.head + } else { + Union(inserts.toSeq) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + withInsertInto( + ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) + } + + /** + * Parameters used for writing query to a table: + * (UnresolvedRelation, tableColumnList, partitionKeys, ifPartitionNotExists). + */ + type InsertTableParams = (UnresolvedRelation, Seq[String], Map[String, Option[String]], Boolean) + + /** + * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). + */ + type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String]) + + /** + * Add an + * {{{ + * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] + * INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList] + * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat] + * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList] + * }}} + * operation to logical plan + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + ctx match { + case table: InsertIntoTableContext => + val (relation, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table) + InsertIntoStatement( + relation, + partition, + cols, + query, + overwrite = false, + ifPartitionNotExists) + case table: InsertOverwriteTableContext => + val (relation, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) + InsertIntoStatement( + relation, + partition, + cols, + query, + overwrite = true, + ifPartitionNotExists) + case dir: InsertOverwriteDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case hiveDir: InsertOverwriteHiveDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case _ => + throw new ParseException("Invalid InsertIntoContext", ctx) + } + } + + /** + * Add an INSERT INTO TABLE operation to the logical plan. + */ + override def visitInsertIntoTable( + ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + if (ctx.EXISTS != null) { + operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx) + } + + (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, false) + } + + /** + * Add an INSERT OVERWRITE TABLE operation to the logical plan. + */ + override def visitInsertOverwriteTable( + ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { + assert(ctx.OVERWRITE() != null) + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) + if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { + operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + + dynamicPartitionKeys.keys.mkString(", "), ctx) + } + + (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, ctx.EXISTS() != null) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteDir( + ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteHiveDir( + ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + private def getTableAliasWithoutColumnAlias( + ctx: TableAliasContext, op: String): Option[String] = { + if (ctx == null) { + None + } else { + val ident = ctx.strictIdentifier() + if (ctx.identifierList() != null) { + throw new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList()) + } + if (ident != null) Some(ident.getText) else None + } + } + + override def visitDeleteFromTable( + ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { + val table = createUnresolvedRelation(ctx.multipartIdentifier()) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + DeleteFromTable(aliasedTable, predicate.get) + } + + override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { + val table = createUnresolvedRelation(ctx.multipartIdentifier()) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val assignments = withAssignments(ctx.setClause().assignmentList()) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + + UpdateTable(aliasedTable, assignments, predicate) + } + + private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] = + withOrigin(assignCtx) { + assignCtx.assignment().asScala.map { assign => + Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)), + expression(assign.value)) + }.toSeq + } + + override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { + val targetTable = createUnresolvedRelation(ctx.target) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) + + val sourceTableOrQuery = if (ctx.source != null) { + createUnresolvedRelation(ctx.source) + } else if (ctx.sourceQuery != null) { + visitQuery(ctx.sourceQuery) + } else { + throw new ParseException("Empty source for merge: you should specify a source" + + " table/subquery in merge.", ctx.source) + } + val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE") + val aliasedSource = + sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery) + + val mergeCondition = expression(ctx.mergeCondition) + + val matchedActions = ctx.matchedClause().asScala.map { + clause => { + if (clause.matchedAction().DELETE() != null) { + DeleteAction(Option(clause.matchedCond).map(expression)) + } else if (clause.matchedAction().UPDATE() != null) { + val condition = Option(clause.matchedCond).map(expression) + if (clause.matchedAction().ASTERISK() != null) { + UpdateStarAction(condition) + } else { + UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList())) + } + } else { + // It should not be here. + throw new ParseException(s"Unrecognized matched action: ${clause.matchedAction().getText}", + clause.matchedAction()) + } + } + } + val notMatchedActions = ctx.notMatchedClause().asScala.map { + clause => { + if (clause.notMatchedAction().INSERT() != null) { + val condition = Option(clause.notMatchedCond).map(expression) + if (clause.notMatchedAction().ASTERISK() != null) { + InsertStarAction(condition) + } else { + val columns = clause.notMatchedAction().columns.multipartIdentifier() + .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr))) + val values = clause.notMatchedAction().expression().asScala.map(expression) + if (columns.size != values.size) { + throw new ParseException("The number of inserted values cannot match the fields.", + clause.notMatchedAction()) + } + InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq) + } + } else { + // It should not be here. + throw new ParseException(s"Unrecognized not matched action: ${clause.notMatchedAction().getText}", + clause.notMatchedAction()) + } + } + } + if (matchedActions.isEmpty && notMatchedActions.isEmpty) { + throw new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx) + } + // children being empty means that the condition is not set + val matchedActionSize = matchedActions.length + if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) { + throw new ParseException("When there are more than one MATCHED clauses in a MERGE " + + "statement, only the last MATCHED clause can omit the condition.", ctx) + } + val notMatchedActionSize = notMatchedActions.length + if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) { + throw new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " + + "statement, only the last NOT MATCHED clause can omit the condition.", ctx) + } + + MergeIntoTable( + aliasedTarget, + aliasedSource, + mergeCondition, + matchedActions.toSeq, + notMatchedActions.toSeq, + Seq.empty, + false + ) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + val legacyNullAsString = + conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL) + val parts = ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText + val value = Option(pVal.constant).map(v => visitStringConstant(v, legacyNullAsString)) + name -> value + } + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. + if (conf.caseSensitiveAnalysis) { + checkDuplicateKeys(parts.toSeq, ctx) + } else { + checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx) + } + parts.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant( + ctx: ConstantContext, + legacyNullAsString: Boolean): String = withOrigin(ctx) { + expression(ctx) match { + case Literal(null, _) if !legacyNullAsString => null + case l@Literal(null, _) => l.toString + case l: Literal => + // TODO For v2 commands, we will cast the string back to its actual value, + // which is a waste and can be improved in the future. + Cast(l, StringType, Some(conf.sessionLocalTimeZone)).eval().toString + case other => + throw new IllegalArgumentException(s"Only literals are allowed in the " + + s"partition spec, but got ${other.sql}") + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + withRepartitionByExpression(ctx, expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem).toSeq, + global = false, + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + withRepartitionByExpression(ctx, expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windowClause)(withWindowClause) + + // LIMIT + // - LIMIT ALL is the same as omitting the LIMIT clause + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + RepartitionByExpression(expressions, query, None) + } + + override def visitTransformQuerySpecification( + ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitRegularQuerySpecification( + ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitNamedExpressionSeq( + ctx: NamedExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + } + + override def visitExpressionSeq(ctx: ExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.expression.asScala) + .map(typedVisit[Expression]) + } + + /** + * Create a logical plan using a having clause. + */ + private def withHavingClause( + ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = { + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(ctx.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + UnresolvedHaving(predicate, plan) + } + + /** + * Create a logical plan using a where clause. + */ + private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx.booleanExpression), plan) + } + + /** + * Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan. + */ + private def withTransformQuerySpecification( + ctx: ParserRuleContext, + transformClause: TransformClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + if (transformClause.setQuantifier != null) { + throw new ParseException("TRANSFORM does not support DISTINCT/ALL in inputs", transformClause.setQuantifier) + } + // Create the attributes. + val (attributes, schemaLess) = if (transformClause.colTypeList != null) { + // Typed return columns. + (DataTypeUtils.toAttributes(createSchema(transformClause.colTypeList)), false) + } else if (transformClause.identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + val plan = visitCommonSelectQueryClausePlan( + relation, + visitExpressionSeq(transformClause.expressionSeq), + lateralView, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct = false) + + ScriptTransformation( + string(transformClause.script), + attributes, + plan, + withScriptIOSchema( + ctx, + transformClause.inRowFormat, + transformClause.recordWriter, + transformClause.outRowFormat, + transformClause.recordReader, + schemaLess + ) + ) + } + + /** + * Add a regular (SELECT) query specification to a logical plan. The query specification + * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT), + * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withSelectQuerySpecification( + ctx: ParserRuleContext, + selectClause: SelectClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val isDistinct = selectClause.setQuantifier() != null && + selectClause.setQuantifier().DISTINCT() != null + + val plan = visitCommonSelectQueryClausePlan( + relation, + visitNamedExpressionSeq(selectClause.namedExpressionSeq), + lateralView, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct) + + // Hint + selectClause.hints.asScala.foldRight(plan)(withHints) + } + + def visitCommonSelectQueryClausePlan( + relation: LogicalPlan, + expressions: Seq[Expression], + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + isDistinct: Boolean): LogicalPlan = { + // Add lateral views. + val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + + def createProject() = if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + val withProject = if (aggregationClause == null && havingClause != null) { + if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) { + // If the legacy conf is set, treat HAVING without GROUP BY as WHERE. + val predicate = expression(havingClause.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, createProject()) + } else { + // According to SQL standard, HAVING without GROUP BY means global aggregate. + withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter)) + } + } else if (aggregationClause != null) { + val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter) + aggregate.optionalMap(havingClause)(withHavingClause) + } else { + // When hitting this branch, `having` must be null. + createProject() + } + + // Distinct + val withDistinct = if (isDistinct) { + Distinct(withProject) + } else { + withProject + } + + // Window + val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause) + + withWindow + } + + // Script Transform's input/output format. + type ScriptIOFormat = + (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = { + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "TOK_TABLEROWFORMATLINES" -> value + } + + (entries, None, Seq.empty, None) + } + + /** + * Create a [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + ctx: ParserRuleContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + + def format(fmt: RowFormatContext): ScriptIOFormat = fmt match { + case c: RowFormatDelimitedContext => + getRowFormatDelimited(c) + + case c: RowFormatSerdeContext => + throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) + + // SPARK-32106: When there is no definition about format, we return empty result + // to use a built-in default Serde in SparkScriptTransformationExec. + case null => + (Nil, None, Seq.empty, None) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat) + + val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat) + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) + } + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left) { (left, right) => + if (relation.LATERAL != null) { + if (!relation.relationPrimary.isInstanceOf[AliasedQueryContext]) { + throw new ParseException(s"LATERAL can only be used with subquery", relation.relationPrimary) + } + LateralJoin(left, LateralSubquery(right), Inner, None) + } else { + Join(left, right, Inner, None, JoinHint.NONE) + } + } + withJoinRelations(join, relation) + } + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case HoodieSqlBaseParser.UNION if all => + Union(left, right) + case HoodieSqlBaseParser.UNION => + Distinct(Union(left, right)) + case HoodieSqlBaseParser.INTERSECT if all => + Intersect(left, right, isAll = true) + case HoodieSqlBaseParser.INTERSECT => + Intersect(left, right, isAll = false) + case HoodieSqlBaseParser.EXCEPT if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.EXCEPT => + Except(left, right, isAll = false) + case HoodieSqlBaseParser.SETMINUS if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.SETMINUS => + Except(left, right, isAll = false) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindowClause( + ctx: WindowClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowTuples = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + } + baseWindowTuples.groupBy(_._1).foreach { kv => + if (kv._2.size > 1) { + throw new ParseException(s"The definition of window '${kv._1}' is repetitive", ctx) + } + } + val baseWindowMap = baseWindowTuples.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity).toMap, query, forPipeSQL = false) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregationClause( + ctx: AggregationClauseContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + if (ctx.groupingExpressionsWithGroupingAnalytics.isEmpty) { + val groupByExpressions = expressionList(ctx.groupingExpressions) + if (ctx.GROUPING != null) { + // GROUP BY ... GROUPING SETS (...) + // `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping + // expressions that do not exist in GROUPING SETS (...), and the value is always null. + // For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output + // of column `c` is always null. + val groupingSets = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) + Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)), + selectExpressions, query) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (ctx.CUBE != null) { + Seq(Cube(groupByExpressions.map(Seq(_)))) + } else if (ctx.ROLLUP != null) { + Seq(Rollup(groupByExpressions.map(Seq(_)))) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } else { + val groupByExpressions = + ctx.groupingExpressionsWithGroupingAnalytics.asScala + .map(groupByExpr => { + val groupingAnalytics = groupByExpr.groupingAnalytics + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics) + } else { + expression(groupByExpr.expression) + } + }) + Aggregate(groupByExpressions.toSeq, selectExpressions, query) + } + } + + override def visitGroupingAnalytics( + groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = { + val groupingSets = groupingAnalytics.groupingSet.asScala + .map(_.expression.asScala.map(e => expression(e)).toSeq) + if (groupingAnalytics.CUBE != null) { + // CUBE(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw new ParseException(s"Empty set in CUBE grouping sets is not supported.", groupingAnalytics) + } + Cube(groupingSets.toSeq) + } else if (groupingAnalytics.ROLLUP != null) { + // ROLLUP(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw new ParseException(s"Empty set in ROLLUP grouping sets is not supported.", groupingAnalytics) + } + Rollup(groupingSets.toSeq) + } else { + assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null) + val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr => + val groupingAnalytics = expr.groupingAnalytics() + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs + } else { + Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq) + } + } + GroupingSets(groupingSets.toSeq) + } + } + + /** + * Add [[UnresolvedHint]]s to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + var plan = query + ctx.hintStatements.asScala.reverse.foreach { stmt => + plan = UnresolvedHint(stmt.hintName.getText, + stmt.parameters.asScala.map(expression).toSeq, plan) + } + plan + } + + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) + Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query) + } + + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + Generate( + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), + unrequiredChildIndex = Nil, + outer = ctx.OUTER != null, + // scalastyle:off caselocale + Some(ctx.tblName.getText.toLowerCase), + // scalastyle:on caselocale + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.quoted).toSeq, + query) + } + + /** + * Create a single relation referenced in a FROM clause. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} + */ + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } + + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) { + throw new ParseException(s"LATERAL can only be used with subquery", join.right) + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + if (join.LATERAL != null) { + throw new ParseException("LATERAL join with USING join is not supported", ctx) + } + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new ParseException(s"Unimplemented joinCriteria: $c", ctx) + case None if join.NATURAL != null => + if (join.LATERAL != null) { + throw new ParseException("LATERAL join with NATURAL join is not supported", ctx) + } + if (baseJoinType == Cross) { + throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + if (join.LATERAL != null) { + if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { + throw new ParseException(s"Unsupported LATERAL join type ${joinType.toString}", ctx) + } + LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition) + } else { + Join(left, plan(join.right), joinType, condition, JoinHint.NONE) + } + } + } + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + } + + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => + Limit(expression(ctx.expression), query) + + case ctx: SampleByPercentileContext => + val fraction = ctx.percentage.getText.toDouble + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) + + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException(s"TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException(s"$bytesStr is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } + + case ctx: SampleByBucketContext if ctx.ON() != null => + if (ctx.identifier != null) { + throw new ParseException(s"TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) + } else { + throw new ParseException(s"TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) + } + + case ctx: SampleByBucketContext => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.query) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier)) + } + + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + val name = getFunctionIdentifier(func.functionName) + if (name.database.nonEmpty) { + operationNotAllowed(s"table valued function cannot specify database name: $name", ctx) + } + + val tvf = UnresolvedTableValuedFunction(name, func.expression.asScala.map(expression).toSeq) + + val tvfAliases = if (aliases.nonEmpty) UnresolvedTVFAliases(name, tvf, aliases) else tvf + + tvfAliases.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 + } + } + + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) + } else { + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") + } + + val table = UnresolvedInlineTable(aliases, rows.toSeq) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d) + * }}} + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample) + mayApplyAliasPlan(ctx.tableAlias, relation) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample) + if (ctx.tableAlias.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + SubqueryAlias("__auto_generated_subquery_name", relation) + } else { + mayApplyAliasPlan(ctx.tableAlias, relation) + } + } + + /** + * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]]. + */ + private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and + * column aliases for a [[LogicalPlan]]. + */ + private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = { + if (tableAlias.strictIdentifier != null) { + val alias = tableAlias.strictIdentifier.getText + if (tableAlias.identifierList != null) { + val columnNames = visitIdentifierList(tableAlias.identifierList) + SubqueryAlias(alias, UnresolvedSubqueryColumnAliases(columnNames, plan)) + } else { + SubqueryAlias(alias, plan) + } + } else { + plan + } + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.ident.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a multi-part identifier. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + + /** + * Create an expression from the given context. This method just passes the context on to the + * visitor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression).toSeq + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq)) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.name != null) { + Alias(e, ctx.name.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case HoodieSqlBaseParser.AND => And.apply _ + case HoodieSqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverseMap(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query (EXISTS). + */ + override def visitExists(ctx: ExistsContext): Expression = { + Exists(plan(ctx.query)) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case HoodieSqlBaseParser.EQ => + EqualTo(left, right) + case HoodieSqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case HoodieSqlBaseParser.LT => + LessThan(left, right) + case HoodieSqlBaseParser.LTE => + LessThanOrEqual(left, right) + case HoodieSqlBaseParser.GT => + GreaterThan(left, right) + case HoodieSqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} + */ + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } + } + + /** + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE (ANY | SOME | ALL) + * - (NOT) RLIKE + * - IS (NOT) NULL. + * - IS (NOT) (TRUE | FALSE | UNKNOWN) + * - IS (NOT) DISTINCT FROM + */ + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } + + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + + // Create the predicate. + ctx.kind.getType match { + case HoodieSqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case HoodieSqlBaseParser.IN if ctx.query != null => + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) + case HoodieSqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq)) + case HoodieSqlBaseParser.LIKE => + Option(ctx.quantifier).map(_.getType) match { + case Some(HoodieSqlBaseParser.ANY) | Some(HoodieSqlBaseParser.SOME) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAny or NotLikeAny instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAny(e, patterns) + case _ => NotLikeAny(e, patterns) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(Or) + } + case Some(HoodieSqlBaseParser.ALL) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAll or NotLikeAll instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAll(e, patterns) + case _ => NotLikeAll(e, patterns) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(And) + } + case _ => + val escapeChar = Option(ctx.escapeChar).map(string).map { str => + if (str.length != 1) { + throw new ParseException("Invalid escape string. Escape string must contain only one character.", ctx) + } + str.charAt(0) + }.getOrElse('\\') + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) + } + case HoodieSqlBaseParser.RLIKE => + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case HoodieSqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case HoodieSqlBaseParser.NULL => + IsNull(e) + case HoodieSqlBaseParser.TRUE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(true)) + case _ => Not(EqualNullSafe(e, Literal(true))) + } + case HoodieSqlBaseParser.FALSE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(false)) + case _ => Not(EqualNullSafe(e, Literal(false))) + } + case HoodieSqlBaseParser.UNKNOWN => ctx.NOT match { + case null => IsUnknown(e) + case _ => IsNotUnknown(e) + } + case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case HoodieSqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Multiplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case HoodieSqlBaseParser.ASTERISK => + Multiply(left, right) + case HoodieSqlBaseParser.SLASH => + Divide(left, right) + case HoodieSqlBaseParser.PERCENT => + Remainder(left, right) + case HoodieSqlBaseParser.DIV => + IntegralDivide(left, right) + case HoodieSqlBaseParser.PLUS => + Add(left, right) + case HoodieSqlBaseParser.MINUS => + Subtract(left, right) + case HoodieSqlBaseParser.CONCAT_PIPE => + Concat(left :: right :: Nil) + case HoodieSqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case HoodieSqlBaseParser.HAT => + BitwiseXor(left, right) + case HoodieSqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case HoodieSqlBaseParser.PLUS => + UnaryPositive(value) + case HoodieSqlBaseParser.MINUS => + UnaryMinus(value) + case HoodieSqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + override def visitCurrentLike(ctx: CurrentLikeContext): Expression = withOrigin(ctx) { + if (conf.ansiEnabled) { + ctx.name.getType match { + case HoodieSqlBaseParser.CURRENT_DATE => + CurrentDate() + case HoodieSqlBaseParser.CURRENT_TIMESTAMP => + CurrentTimestamp() + case HoodieSqlBaseParser.CURRENT_USER => + CurrentUser() + } + } else { + // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there + // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`. + UnresolvedAttribute.quoted(ctx.name.getText) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + val rawDataType = typedVisit[DataType](ctx.dataType()) + val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) + val cast = ctx.name.getType match { + case HoodieSqlBaseParser.CAST => + Cast(expression(ctx.expression), dataType) + + case HoodieSqlBaseParser.TRY_CAST => + Cast(expression(ctx.expression), dataType, evalMode = EvalMode.TRY) + } + cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) + cast + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct.create(ctx.argument.asScala.map(expression).toSeq) + } + + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a Position expression. + */ + override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) { + new StringLocate(expression(ctx.substr), expression(ctx.str)) + } + + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source)) + UnresolvedFunction("extract", arguments, isDistinct = false) + } + + /** + * Create a Substring/Substr expression. + */ + override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) { + if (ctx.len != null) { + Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len)) + } else { + new Substring(expression(ctx.str), expression(ctx.pos)) + } + } + + /** + * Create a Trim expression. + */ + override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) { + val srcStr = expression(ctx.srcStr) + val trimStr = Option(ctx.trimStr).map(expression) + Option(ctx.trimOption).map(_.getType).getOrElse(HoodieSqlBaseParser.BOTH) match { + case HoodieSqlBaseParser.BOTH => + StringTrim(srcStr, trimStr) + case HoodieSqlBaseParser.LEADING => + StringTrimLeft(srcStr, trimStr) + case HoodieSqlBaseParser.TRAILING => + StringTrimRight(srcStr, trimStr) + case other => + throw new ParseException("Function trim doesn't support with " + + s"type $other. Please use BOTH, LEADING or TRAILING as trim type", ctx) + } + } + + /** + * Create a Overlay expression. + */ + override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) { + val input = expression(ctx.input) + val replace = expression(ctx.replace) + val position = expression(ctx.position) + val lengthOpt = Option(ctx.length).map(expression) + lengthOpt match { + case Some(length) => Overlay(input, replace, position, length) + case None => new Overlay(input, replace, position) + } + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.functionName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13 + val arguments = ctx.argument.asScala.map(expression).toSeq match { + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). + Seq(Literal(1)) + case expressions => + expressions + } + val filter = Option(ctx.where).map(expression(_)) + val ignoreNulls = + Option(ctx.nullsOption).map(_.getType == HoodieSqlBaseParser.IGNORE) + val function = UnresolvedFunction( + getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a function database (optional) and name pair. + */ + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) + } + + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { + texts match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => + throw new ParseException(s"Unsupported function name '${texts.mkString(".")}'", ctx) + } + } + + /** + * Get a function identifier consist by database (optional) and name. + */ + protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = { + if (ctx.qualifiedName != null) { + visitFunctionName(ctx.qualifiedName) + } else { + FunctionIdentifier(ctx.getText, None) + } + } + + protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = { + if (ctx.qualifiedName != null) { + ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq + } else { + Seq(ctx.getText) + } + } + + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.identifier().asScala.map { name => + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) + } + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments.toSeq) + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.name.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case HoodieSqlBaseParser.RANGE => RangeFrame + case HoodieSqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition.toSeq, + order.toSeq, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a frame boundary expressions. + */ + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { + val e = expression(ctx.expression) + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e + } + + ctx.boundType.getType match { + case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case HoodieSqlBaseParser.PRECEDING => + UnaryMinus(value) + case HoodieSqlBaseParser.CURRENT => + CurrentRow + case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case HoodieSqlBaseParser.FOLLOWING => + value + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.value) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Currently only regex in expressions of SELECT statements are supported; in other + * places, e.g., where `(a)?+.+` = 2, regex are not meaningful. + */ + private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) { + var parent = ctx.getParent + var rtn = false + while (parent != null) { + if (parent.isInstanceOf[NamedExpressionContext]) { + rtn = true + } + parent = parent.getParent + } + rtn + } + + /** + * Create a dereference expression. The return type depends on the type of the parent. + * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or + * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression, + * it can be [[UnresolvedExtractValue]]. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case unresolved_attr@UnresolvedAttribute(nameParts) => + ctx.fieldName.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), + conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute(nameParts :+ attr) + } + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex + * quoted in `` + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + ctx.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute.quoted(ctx.getText) + } + + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + val direction = if (ctx.DESC != null) { + Descending + } else { + Ascending + } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty) + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date, Timestamp, Interval and Binary typed literals are supported. + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) + + def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = { + f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse { + throw new ParseException(s"Cannot parse the $valueType value: $value", ctx) + } + } + + def constructTimestampLTZLiteral(value: String): Literal = { + val zoneId = getZoneId(conf.sessionLocalTimeZone) + val specialTs = convertSpecialTimestamp(value, zoneId).map(Literal(_, TimestampType)) + specialTs.getOrElse(toLiteral(stringToTimestamp(_, zoneId), TimestampType)) + } + + try { + valueType match { + case "DATE" => + val zoneId = getZoneId(conf.sessionLocalTimeZone) + val specialDate = convertSpecialDate(value, zoneId).map(Literal(_, DateType)) + specialDate.getOrElse(toLiteral(stringToDate, DateType)) + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case "TIMESTAMP_NTZ" if isTesting => + convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone)) + .map(Literal(_, TimestampNTZType)) + .getOrElse(toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType)) + case "TIMESTAMP_LTZ" if isTesting => + constructTimestampLTZLiteral(value) + case "TIMESTAMP" => + SQLConf.get.timestampType match { + case TimestampNTZType => + convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone)) + .map(Literal(_, TimestampNTZType)) + .getOrElse { + val containsTimeZonePart = + DateTimeUtils.parseTimestampString(UTF8String.fromString(value))._2.isDefined + // If the input string contains time zone part, return a timestamp with local time + // zone literal. + if (containsTimeZonePart) { + constructTimestampLTZLiteral(value) + } else { + toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType) + } + } + + case TimestampType => + constructTimestampLTZLiteral(value) + } + + case "INTERVAL" => + val interval = try { + IntervalUtils.stringToInterval(UTF8String.fromString(value)) + } catch { + case e: IllegalArgumentException => + val ex = new ParseException(s"Cannot parse the INTERVAL value: $value", ctx) + ex.setStackTrace(e.getStackTrace) + throw ex + } + if (!conf.legacyIntervalEnabled) { + val units = value + .split("\\s") + .map(_.toLowerCase(Locale.ROOT).stripSuffix("s")) + .filter(s => s != "interval" && s.matches("[a-z]+")) + constructMultiUnitsIntervalLiteral(ctx, interval, units) + } else { + Literal(interval, CalendarIntervalType) + } + case "X" => + val padding = if (value.length % 2 != 0) "0" else "" + Literal(DatatypeConverter.parseHexBinary(padding + value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } catch { + case e: IllegalArgumentException => + val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType") + throw new ParseException(message, ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue) + case v if v.isValidLong => + Literal(v.longValue) + case v => Literal(v.underlying()) + } + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a decimal literal for a regular decimal number or a scientific decimal number. + */ + override def visitLegacyDecimalLiteral( + ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a double literal for number with an exponent, e.g. 1E-30 + */ + override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = { + numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */ + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** Create a numeric literal expression. */ + private def numericLiteral( + ctx: NumberContext, + rawStrippedQualifier: String, + minValue: BigDecimal, + maxValue: BigDecimal, + typeName: String)(converter: String => Any): Literal = withOrigin(ctx) { + try { + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw new ParseException(s"Numeric literal $rawStrippedQualifier does not " + + s"fit in range [$minValue, $maxValue] for type $typeName", ctx) + } + Literal(converter(rawStrippedQualifier)) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) + } + + /** + * Create a Float Literal expression. + */ + override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat) + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(x => stringWithoutUnescape(x.getSymbol)).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } + } + + /** + * Create an [[UnresolvedRelation]] from a multi-part identifier context. + */ + private def createUnresolvedRelation( + ctx: MultipartIdentifierContext): UnresolvedRelation = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx)) + } + + /** + * Construct an [[Literal]] from [[CalendarInterval]] and + * units represented as a [[Seq]] of [[String]]. + */ + private def constructMultiUnitsIntervalLiteral( + ctx: ParserRuleContext, + calendarInterval: CalendarInterval, + units: Seq[String]): Literal = { + var yearMonthFields = Set.empty[Byte] + var dayTimeFields = Set.empty[Byte] + for (unit <- units) { + if (YearMonthIntervalType.stringToField.contains(unit)) { + yearMonthFields += YearMonthIntervalType.stringToField(unit) + } else if (DayTimeIntervalType.stringToField.contains(unit)) { + dayTimeFields += DayTimeIntervalType.stringToField(unit) + } else if (unit == "week") { + dayTimeFields += DayTimeIntervalType.DAY + } else { + assert(unit == "millisecond" || unit == "microsecond") + dayTimeFields += DayTimeIntervalType.SECOND + } + } + if (yearMonthFields.nonEmpty) { + if (dayTimeFields.nonEmpty) { + val literalStr = source(ctx) + throw new ParseException(s"Cannot mix year-month and day-time fields: $literalStr", ctx) + } + Literal( + calendarInterval.months, + YearMonthIntervalType(yearMonthFields.min, yearMonthFields.max) + ) + } else { + Literal( + IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS), + DayTimeIntervalType(dayTimeFields.min, dayTimeFields.max)) + } + } + + /** + * Create a [[CalendarInterval]] or ANSI interval literal expression. + * Two syntaxes are supported: + * - multiple unit value pairs, for instance: interval 2 months 2 days. + * - from-to unit, for instance: interval '1-2' year to month. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val calendarInterval = parseIntervalLiteral(ctx) + if (ctx.errorCapturingUnitToUnitInterval != null && !conf.legacyIntervalEnabled) { + // Check the `to` unit to distinguish year-month and day-time intervals because + // `CalendarInterval` doesn't have enough info. For instance, new CalendarInterval(0, 0, 0) + // can be derived from INTERVAL '0-0' YEAR TO MONTH as well as from + // INTERVAL '0 00:00:00' DAY TO SECOND. + val fromUnit = + ctx.errorCapturingUnitToUnitInterval.body.from.getText.toLowerCase(Locale.ROOT) + val toUnit = ctx.errorCapturingUnitToUnitInterval.body.to.getText.toLowerCase(Locale.ROOT) + if (toUnit == "month") { + assert(calendarInterval.days == 0 && calendarInterval.microseconds == 0) + val start = YearMonthIntervalType.stringToField(fromUnit) + Literal(calendarInterval.months, YearMonthIntervalType(start, YearMonthIntervalType.MONTH)) + } else { + assert(calendarInterval.months == 0) + val micros = IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS) + val start = DayTimeIntervalType.stringToField(fromUnit) + val end = DayTimeIntervalType.stringToField(toUnit) + Literal(micros, DayTimeIntervalType(start, end)) + } + } else if (ctx.errorCapturingMultiUnitsInterval != null && !conf.legacyIntervalEnabled) { + val units = + ctx.errorCapturingMultiUnitsInterval.body.unit.asScala.map( + _.getText.toLowerCase(Locale.ROOT).stripSuffix("s")).toSeq + constructMultiUnitsIntervalLiteral(ctx, calendarInterval, units) + } else { + Literal(calendarInterval, CalendarIntervalType) + } + } + + /** + * Create a [[CalendarInterval]] object + */ + protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) { + if (ctx.errorCapturingMultiUnitsInterval != null) { + val innerCtx = ctx.errorCapturingMultiUnitsInterval + if (innerCtx.unitToUnitInterval != null) { + throw new ParseException("Can only have a single from-to unit in the interval literal syntax", innerCtx.unitToUnitInterval) + } + visitMultiUnitsInterval(innerCtx.multiUnitsInterval) + } else if (ctx.errorCapturingUnitToUnitInterval != null) { + val innerCtx = ctx.errorCapturingUnitToUnitInterval + if (innerCtx.error1 != null || innerCtx.error2 != null) { + val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2 + throw new ParseException("Can only have a single from-to unit in the interval literal syntax", errorCtx) + } + visitUnitToUnitInterval(innerCtx.body) + } else { + throw new ParseException("at least one time unit should be given for interval literal", ctx) + } + } + + /** + * Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS. + */ + override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val units = ctx.unit.asScala + val values = ctx.intervalValue().asScala + try { + assert(units.length == values.length) + val kvs = units.indices.map { i => + val u = units(i).getText + val v = if (values(i).STRING() != null) { + val value = string(values(i).STRING()) + // SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour, + // INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with + // units and become valid ones, e.g. '1 day 2 hour'. + // Ideally, we only ensure the value parts don't contain any units here. + if (value.exists(Character.isLetter)) { + throw new ParseException("Can only use numbers in the interval value part for" + + s" multiple unit value pairs interval form, but got invalid value: $value", ctx) + } + if (values(i).MINUS() == null) { + value + } else { + value.startsWith("-") match { + case true => value.replaceFirst("-", "") + case false => s"-$value" + } + } + } else { + values(i).getText + } + UTF8String.fromString(" " + v + " " + u) + } + IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*)) + } catch { + case i: IllegalArgumentException => + val e = new ParseException(i.getMessage, ctx) + e.setStackTrace(i.getStackTrace) + throw e + } + } + } + + /** + * Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH. + */ + override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val value = Option(ctx.intervalValue.STRING).map(string).map { interval => + if (ctx.intervalValue().MINUS() == null) { + interval + } else { + interval.startsWith("-") match { + case true => interval.replaceFirst("-", "") + case false => s"-$interval" + } + } + }.getOrElse { + throw new ParseException("The value of from-to unit must be a string", ctx.intervalValue) + } + try { + val from = ctx.from.getText.toLowerCase(Locale.ROOT) + val to = ctx.to.getText.toLowerCase(Locale.ROOT) + (from, to) match { + case ("year", "month") => + IntervalUtils.fromYearMonthString(value) + case ("day", "hour") | ("day", "minute") | ("day", "second") | ("hour", "minute") | + ("hour", "second") | ("minute", "second") => + IntervalUtils.fromDayTimeString(value, + DayTimeIntervalType.stringToField(from), DayTimeIntervalType.stringToField(to)) + case _ => + throw new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx) + } + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float" | "real", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => SQLConf.get.timestampType + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case ("timestamp_ntz", Nil) if isTesting => TimestampNTZType + case ("timestamp_ltz", Nil) if isTesting => TimestampType + case ("string", Nil) => StringType + case ("character" | "char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) + case ("binary", Nil) => BinaryType + case ("blob", Nil) => BlobType() + case ("vector", _ :: _) => + // Delegate validation to HoodieSchema.parseTypeDescriptor which handles dimension + // range checks, element type validation, and canonical normalization. + val vectorSchema = try { + HoodieSchema.parseTypeDescriptor(ctx.getText).asInstanceOf[HoodieSchema.Vector] + } catch { + case e: IllegalArgumentException => + throw new ParseException(s"Invalid VECTOR type: ${e.getMessage}", ctx) + } + val sparkElemType = vectorSchema.getVectorElementType match { + case HoodieSchema.Vector.VectorElementType.FLOAT => FloatType + case HoodieSchema.Vector.VectorElementType.DOUBLE => DoubleType + case HoodieSchema.Vector.VectorElementType.INT8 => ByteType + } + ArrayType(sparkElemType, containsNull = false) + case ("decimal" | "dec" | "numeric", Nil) => DecimalType.USER_DEFAULT + case ("decimal" | "dec" | "numeric", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case ("void", Nil) => NullType + case ("interval", Nil) => CalendarIntervalType + case (dt, params) => + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) + } + } + + override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = { + val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + val start = YearMonthIntervalType.stringToField(startStr) + if (ctx.to != null) { + val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) + val end = YearMonthIntervalType.stringToField(endStr) + if (end <= start) { + throw new ParseException(s"Intervals FROM $startStr TO $endStr are not supported.", ctx) + } + YearMonthIntervalType(start, end) + } else { + YearMonthIntervalType(start) + } + } + + override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = { + val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + val start = DayTimeIntervalType.stringToField(startStr) + if (ctx.to != null) { + val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) + val end = DayTimeIntervalType.stringToField(endStr) + if (end <= start) { + throw new ParseException(s"Intervals FROM $startStr TO $endStr are not supported.", ctx) + } + DayTimeIntervalType(start, end) + } else { + DayTimeIntervalType(start) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case HoodieSqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case HoodieSqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case HoodieSqlBaseParser.STRUCT => + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) + } + } + + /** + * Create top level table schema. + */ + protected def createSchema(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType).toSeq + } + + /** + * Create a top level [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + val builder = new MetadataBuilder + // Add comment to metadata + Option(commentSpec()).map(visitCommentSpec).foreach { + builder.putString("comment", _) + } + + val dataType = typedVisit[DataType](ctx.dataType) + + addMetadataForType(ctx.dataType(), builder) + + StructField( + name = colName.getText, + dataType = dataType, + nullable = NULL == null, + metadata = builder.build()) + } + + private def addMetadataForType(dataType: HoodieSqlBaseParser.DataTypeContext, builder: MetadataBuilder): Unit = { + val typeText = dataType.getText + val upperTypeText = typeText.toUpperCase(Locale.ROOT) + if (upperTypeText == HoodieSchemaType.BLOB.name()) { + builder.putString(HoodieSchema.TYPE_METADATA_FIELD, HoodieSchemaType.BLOB.name()) + } else if (upperTypeText.startsWith("VECTOR(")) { + // Normalize to canonical form (e.g. "VECTOR(128,FLOAT)" -> "VECTOR(128)") + val vectorSchema = HoodieSchema.parseTypeDescriptor(typeText).asInstanceOf[HoodieSchema.Vector] + builder.putString(HoodieSchema.TYPE_METADATA_FIELD, vectorSchema.toTypeDescriptor) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType).toSeq + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val builder = new MetadataBuilder + // Add comment to metadata + Option(commentSpec()).map(visitCommentSpec).foreach { + builder.putString("comment", _) + } + addMetadataForType(ctx.dataType(), builder) + + StructField( + name = identifier.getText, + dataType = typedVisit(dataType()), + nullable = NULL == null, + metadata = builder.build()) + } + + /** + * Create a location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional location string. + */ + protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitLocationSpec) + } + + /** + * Create a comment string. + */ + override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional comment string. + */ + protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitCommentSpec) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList) + .toSeq + .flatMap(_.orderedIdentifier.asScala) + .map { orderedIdCtx => + Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => + if (dir.toLowerCase(Locale.ROOT) != "asc") { + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) + } + } + + orderedIdCtx.ident.getText + }) + } + + /** + * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = visitTablePropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties.toSeq, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A table property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (Seq[String], Boolean, Boolean, Boolean) + + /** + * Type to keep track of table clauses: + * - partition transforms + * - partition columns + * - bucketSpec + * - properties + * - options + * - location + * - comment + * - serde + * + * Note: Partition transforms are based on existing table schema definition. It can be simple + * column names, or functions like `year(date_col)`. Partition columns are column names with data + * types like `i INT`, which should be appended to the existing table schema. + */ + type TableClauses = ( + Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String], + Map[String, String], Option[String], Option[String], Option[SerdeInfo]) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Validate a replace table statement and return the [[TableIdentifier]]. + */ + override def visitReplaceTableHeader( + ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) { + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, false, false, false) + } + + /** + * Parse a qualified name to a multipart name. + */ + override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText).toSeq + } + + /** + * Parse a list of transforms or columns. + */ + override def visitPartitionFieldList( + ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) { + val (transforms, columns) = ctx.fields.asScala.map { + case transform: PartitionTransformContext => + (Some(visitPartitionTransform(transform)), None) + case field: PartitionColumnContext => + (None, Some(visitColType(field.colType))) + }.unzip + + (transforms.flatten.toSeq, columns.flatten.toSeq) + } + + override def visitPartitionTransform( + ctx: PartitionTransformContext): Transform = withOrigin(ctx) { + def getFieldReference( + ctx: ApplyTransformContext, + arg: V2Expression): FieldReference = { + lazy val name: String = ctx.identifier.getText + arg match { + case ref: FieldReference => + ref + case nonRef => + throw new ParseException(s"Expected a column reference for transform $name: $nonRef.describe", ctx) + } + } + + def getSingleFieldReference( + ctx: ApplyTransformContext, + arguments: Seq[V2Expression]): FieldReference = { + lazy val name: String = ctx.identifier.getText + if (arguments.size > 1) { + throw new ParseException(s"Too many arguments for transform $name", ctx) + } else if (arguments.isEmpty) { + throw + + new ParseException(s"Not enough arguments for transform $name", ctx) + } else { + getFieldReference(ctx, arguments.head) + } + } + + ctx.transform match { + case identityCtx: IdentityTransformContext => + IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) + + case applyCtx: ApplyTransformContext => + val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq + + applyCtx.identifier.getText match { + case "bucket" => + val numBuckets: Int = arguments.head match { + case LiteralValue(shortValue, ShortType) => + shortValue.asInstanceOf[Short].toInt + case LiteralValue(intValue, IntegerType) => + intValue.asInstanceOf[Int] + case LiteralValue(longValue, LongType) => + longValue.asInstanceOf[Long].toInt + case lit => + throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx) + } + + val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) + + BucketTransform(LiteralValue(numBuckets, IntegerType), fields) + + case "years" => + YearsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "months" => + MonthsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "days" => + DaysTransform(getSingleFieldReference(applyCtx, arguments)) + + case "hours" => + HoursTransform(getSingleFieldReference(applyCtx, arguments)) + + case name => + ApplyTransform(name, arguments) + } + } + } + + /** + * Parse an argument to a transform. An argument may be a field reference (qualified name) or + * a value literal. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = { + withOrigin(ctx) { + val reference = Option(ctx.qualifiedName) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(typedVisit[Literal]) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new ParseException("Invalid transform argument", ctx)) + } + } + + def cleanTableProperties( + ctx: ParserRuleContext, properties: Map[String, String]): Map[String, String] = { + import TableCatalog._ + val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED) + properties.filter { + case (PROP_PROVIDER, _) if !legacyOn => + throw new ParseException(s"$PROP_PROVIDER is a reserved table property, please use the USING clause to specify it.", ctx) + case (PROP_PROVIDER, _) => false + case (PROP_LOCATION, _) if !legacyOn => + throw new ParseException(s"$PROP_LOCATION is a reserved table property, please use the LOCATION clause to specify it.", ctx) + case (PROP_LOCATION, _) => false + case (PROP_OWNER, _) if !legacyOn => + throw new ParseException(s"$PROP_OWNER is a reserved table property, it will be set to the current user.", ctx) + case (PROP_OWNER, _) => false + case _ => true + } + } + + def cleanTableOptions( + ctx: ParserRuleContext, + options: Map[String, String], + location: Option[String]): (Map[String, String], Option[String]) = { + var path = location + val filtered = cleanTableProperties(ctx, options).filter { + case (k, v) if k.equalsIgnoreCase("path") && path.nonEmpty => + throw new ParseException(s"Duplicated table paths found: '${path.get}' and '$v'. LOCATION" + + s" and the case insensitive key 'path' in OPTIONS are all used to indicate the custom" + + s" table path, you can only specify one of them.", ctx) + case (k, v) if k.equalsIgnoreCase("path") => + path = Some(v) + false + case _ => true + } + (filtered, path) + } + + /** + * Create a [[SerdeInfo]] for creating tables. + * + * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format) + */ + override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) { + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt)))) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + SerdeInfo(storedAs = Some(c.identifier.getText)) + case (null, storageHandler) => + operationNotAllowed("STORED BY", ctx) + case _ => + throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx) + } + } + + /** + * Create a [[SerdeInfo]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} + */ + def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) { + import ctx._ + SerdeInfo( + serde = Some(string(name)), + serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + + // TODO we need proper support for the NULL format. + val entries = + entry("field.delim", ctx.fieldsTerminatedBy) ++ + entry("serialization.format", ctx.fieldsTerminatedBy) ++ + entry("escape.delim", ctx.escapedBy) ++ + // The following typo is inherited from Hive... + entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ + entry("mapkey.delim", ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "line.delim" -> value + } + SerdeInfo(serdeProperties = entries.toMap) + } + + /** + * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT + * and STORED AS. + * + * The following are allowed. Anything else is not: + * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] + * ROW FORMAT DELIMITED ... STORED AS TEXTFILE + * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... + */ + protected def validateRowFormatFileFormat( + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { + if (!(rowFormatCtx == null || createFileFormatCtx == null)) { + (rowFormatCtx, createFileFormatCtx.fileFormat) match { + case (_, ffTable: TableFileFormatContext) => // OK + case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case ("sequencefile" | "textfile" | "rcfile") => // OK + case fmt => + operationNotAllowed( + s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", + parentCtx) + } + case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case "textfile" => // OK + case fmt => operationNotAllowed( + s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + } + case _ => + // should never happen + def str(ctx: ParserRuleContext): String = { + (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") + } + + operationNotAllowed( + s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", + parentCtx) + } + } + } + + protected def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + + override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = { + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + + if (ctx.skewSpec.size > 0) { + operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) + } + + val (partTransforms, partCols) = + Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil)) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + val cleanedProperties = cleanTableProperties(ctx, properties) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val location = visitLocationSpecList(ctx.locationSpec()) + val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location) + val comment = visitCommentSpecList(ctx.commentSpec()) + val serdeInfo = + getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, + serdeInfo) + } + + protected def getSerdeInfo( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + ctx: ParserRuleContext): Option[SerdeInfo] = { + validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx) + val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat) + val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat) + (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r)) + } + + private def partitionExpressions( + partTransforms: Seq[Transform], + partCols: Seq[StructField], + ctx: ParserRuleContext): Seq[Transform] = { + if (partTransforms.nonEmpty) { + if (partCols.nonEmpty) { + val references = partTransforms.map(_.describe()).mkString(", ") + val columns = partCols + .map(field => s"${field.name} ${field.dataType.simpleString}") + .mkString(", ") + operationNotAllowed( + s"""PARTITION BY: Cannot mix partition expressions and partition columns: + |Expressions: $references + |Columns: $columns""".stripMargin, ctx) + + } + partTransforms + } else { + // columns were added to create the schema. convert to column references + partCols.map { column => + IdentityTransform(FieldReference(Seq(column.name))) + } + } + } + + /** + * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * [USING table_provider] + * create_table_clauses + * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [PARTITIONED BY (partition_fields)] + * [OPTIONS table_property_list] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = + visitCreateTableClauses(ctx.createTableClauses()) + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + if (temp) { + val asSelect = if (ctx.query == null) "" else " AS ..." + operationNotAllowed( + s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx) + } + + // partition transforms for BucketSpec was moved inside parser + // https://issues.apache.org/jira/browse/SPARK-37923 + val partitioning = + partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform) + val tableSpec = TableSpec(properties, provider, options, location, comment, + Option.empty, serdeInfo, external) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) + + // CreateTable / CreateTableAsSelect was migrated to v2 in Spark 3.3.0 + // https://issues.apache.org/jira/browse/SPARK-36850 + case Some(query) => + CreateTableAsSelect( + UnresolvedIdentifier(table), + partitioning, query, tableSpec, Map.empty, ifNotExists) + + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + CreateTable( + UnresolvedIdentifier(table), + schema.map(ColumnDefinition.fromV1Column(_, delegate)), partitioning, tableSpec, ignoreIfExists = ifNotExists) + } + } + + /** + * Parse new column info from ADD COLUMN into a QualifiedColType. + */ + override def visitQualifiedColTypeWithPosition( + ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { + val name = typedVisit[Seq[String]](ctx.name) + QualifiedColType( + path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None, + colName = name.last, + dataType = typedVisit[DataType](ctx.dataType), + nullable = ctx.NULL == null, + comment = Option(ctx.commentSpec()).map(visitCommentSpec), + position = Option(ctx.colPosition).map(pos => + UnresolvedFieldPosition(typedVisit[ColumnPosition](pos))), + default = Option(null)) + } + + + /** + * Create an index, returning a [[CreateIndex]] logical plan. + * For example: + * {{{ + * CREATE INDEX index_name ON [TABLE] table_name [USING index_type] (column_index_property_list) + * [OPTIONS indexPropertyList] + * column_index_property_list: column_name [OPTIONS(indexPropertyList)] [ , . . . ] + * indexPropertyList: index_property_name [= index_property_value] [ , . . . ] + * }}} + */ + override def visitCreateIndex(ctx: CreateIndexContext): LogicalPlan = withOrigin(ctx) { + val (indexName, indexType) = if (ctx.identifier.size() == 1) { + (ctx.identifier(0).getText, "") + } else { + (ctx.identifier(0).getText, ctx.identifier(1).getText) + } + + val columns = ctx.columns.multipartIdentifierProperty.asScala + .map(_.multipartIdentifier).map(typedVisit[Seq[String]]).toSeq + val columnsProperties = ctx.columns.multipartIdentifierProperty.asScala + .map(x => (Option(x.options).map(visitPropertyKeyValues).getOrElse(Map.empty))).toSeq + val options = Option(ctx.indexOptions).map(visitPropertyKeyValues).getOrElse(Map.empty) + + CreateIndex( + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier())), + indexName, + indexType, + ctx.EXISTS != null, + columns.map(UnresolvedFieldName).zip(columnsProperties), + options) + } + + /** + * Drop an index, returning a [[DropIndex]] logical plan. + * For example: + * {{{ + * DROP INDEX [IF EXISTS] index_name ON [TABLE] table_name + * }}} + */ + override def visitDropIndex(ctx: DropIndexContext): LogicalPlan = withOrigin(ctx) { + val indexName = ctx.identifier.getText + DropIndex( + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier())), + indexName, + ctx.EXISTS != null) + } + + /** + * Show indexes, returning a [[HoodieShowIndexes]] logical plan. + * For example: + * {{{ + * SHOW INDEXES (FROM | IN) [TABLE] table_name + * }}} + */ + override def visitShowIndexes(ctx: ShowIndexesContext): LogicalPlan = withOrigin(ctx) { + HoodieShowIndexes(UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier()))) + } + + /** + * Refresh index, returning a [[RefreshIndex]] logical plan + * For example: + * {{{ + * REFRESH INDEX index_name ON [TABLE] table_name + * }}} + */ + override def visitRefreshIndex(ctx: RefreshIndexContext): LogicalPlan = withOrigin(ctx) { + RefreshIndex(UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier())), ctx.identifier.getText) + } + + /** + * Convert a property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitPropertyList(ctx: PropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.property.asScala.map { property => + val key = visitPropertyKey(property.key) + val value = visitPropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties.toSeq, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[PropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: PropertyListContext): Map[String, String] = { + val props = visitPropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[PropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: PropertyListContext): Seq[String] = { + val props = visitPropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a property + * identifier. + */ + override def visitPropertyKey(key: PropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitPropertyValue(value: PropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } +} + +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * + * @param child The final query of this CTE. + * @param cteRelations A sequence of pair (alias, the CTE definition) that this CTE defined + * Each CTE can see the base tables and the previously defined CTEs only. + */ +case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def simpleString(maxFields: Int): String = { + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields) + s"CTE $cteAliases" + } + + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) + + def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = this +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark4_2ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark4_2ExtendedSqlParser.scala new file mode 100644 index 0000000000000..b8841a3b4936d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark4_2ExtendedSqlParser.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser + +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser} +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext} + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.VariableSubstitution +import org.apache.spark.sql.types._ + +import java.util.Locale + +import scala.jdk.CollectionConverters._ + +class HoodieSpark4_2ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) + extends HoodieExtendedParserInterface with Logging { + + private lazy val conf = session.sessionState.conf + private lazy val builder = new HoodieSpark4_2ExtendedSqlAstBuilder(conf, delegate) + private val substitutor = new VariableSubstitution + + override def parsePlan(sqlText: String): LogicalPlan = { + val substitutionSql = substitutor.substitute(sqlText) + if (isHoodieCommand(substitutionSql)) { + parse(substitutionSql) { parser => + builder.visit(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => delegate.parsePlan(sqlText) + } + } + } else { + delegate.parsePlan(substitutionSql) + } + } + + override def parseQuery(sqlText: String): LogicalPlan = delegate.parseQuery(sqlText) + + override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) + + protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = { + logDebug(s"Parsing command: $command") + + val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new HoodieSqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + // parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced + parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled + parser.SQL_standard_keyword_behavior = conf.ansiEnabled + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException( + command = Option(command), + start = position, + errorClass = e.getErrorClass, + messageParameters = e.getMessageParameters.asScala.toMap, + queryContext = e.getQueryContext + ) + } + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + private def isHoodieCommand(sqlText: String): Boolean = { + val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ") + normalized.contains("system_time as of") || + normalized.contains("timestamp as of") || + normalized.contains("system_version as of") || + normalized.contains("version as of") || + normalized.contains("create index") || + normalized.contains("drop index") || + normalized.contains("show indexes") || + normalized.contains("refresh index") || + normalized.contains(" blob") || + normalized.contains(" vector") + } + + override def parseRoutineParam(sqlText: String): StructType = throw new UnsupportedOperationException() +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`. + */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + // scalastyle:off + override def LA(i: Int): Int = { + // scalastyle:on + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`. + */ +case object PostProcessor extends HoodieSqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + HoodieSqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} diff --git a/hudi-spark-datasource/hudi-spark4.2.x/src/test/scala/org/apache/hudi/TestHoodieStreamingSinkConstants.scala b/hudi-spark-datasource/hudi-spark4.2.x/src/test/scala/org/apache/hudi/TestHoodieStreamingSinkConstants.scala new file mode 100644 index 0000000000000..0229ec138dfb0 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark4.2.x/src/test/scala/org/apache/hudi/TestHoodieStreamingSinkConstants.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi + +import org.apache.spark.sql.execution.streaming.runtime.StreamExecution +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +/** + * Validates that HoodieStreamingSink.QUERY_ID_KEY matches the actual + * StreamExecution.QUERY_ID_KEY from Spark 4.1, where StreamExecution + * moved to org.apache.spark.sql.execution.streaming.runtime. + */ +class TestHoodieStreamingSinkConstants { + + @Test + def testQueryIdKeyMatchesStreamExecution(): Unit = { + assertEquals(StreamExecution.QUERY_ID_KEY, HoodieStreamingSink.QUERY_ID_KEY, + "HoodieStreamingSink.QUERY_ID_KEY must match StreamExecution.QUERY_ID_KEY") + } +} diff --git a/packaging/bundle-validation/base/build_flink1200hive313spark420scala213.sh b/packaging/bundle-validation/base/build_flink1200hive313spark420scala213.sh new file mode 100755 index 0000000000000..63c8b45b2982d --- /dev/null +++ b/packaging/bundle-validation/base/build_flink1200hive313spark420scala213.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +docker build \ + --build-arg HIVE_VERSION=3.1.3 \ + --build-arg FLINK_VERSION=1.20.1 \ + --build-arg SPARK_VERSION=4.2.0-preview4 \ + --build-arg SPARK_HADOOP_VERSION=3 \ + --build-arg HADOOP_VERSION=3.4.3 \ + --build-arg SCALA_VERSION=2.13 \ + --build-arg DERBY_VERSION=10.16.1.1 \ + -t hudi-ci-bundle-validation-base:flink1200hive313spark420scala213 . +docker image tag hudi-ci-bundle-validation-base:flink1200hive313spark420scala213 apachehudi/hudi-ci-bundle-validation-base:flink1200hive313spark420scala213 diff --git a/packaging/bundle-validation/ci_run.sh b/packaging/bundle-validation/ci_run.sh index eef7cb12ba5d2..cfb998e72ca32 100755 --- a/packaging/bundle-validation/ci_run.sh +++ b/packaging/bundle-validation/ci_run.sh @@ -142,6 +142,16 @@ elif [[ ${SPARK_RUNTIME} == 'spark4.1.1' && ${SCALA_PROFILE} == 'scala-2.13' ]]; CONFLUENT_VERSION=5.5.12 KAFKA_CONNECT_HDFS_VERSION=10.1.13 IMAGE_TAG=flink1200hive313spark411scala213 +elif [[ ${SPARK_RUNTIME} == 'spark4.2.0-preview4' && ${SCALA_PROFILE} == 'scala-2.13' ]]; then + HADOOP_VERSION=3.4.3 + HIVE_VERSION=3.1.3 + DERBY_VERSION=10.14.1.0 + FLINK_VERSION=1.20.1 + SPARK_VERSION=4.2.0-preview4 + SPARK_HADOOP_VERSION=3 + CONFLUENT_VERSION=5.5.12 + KAFKA_CONNECT_HDFS_VERSION=10.1.13 + IMAGE_TAG=flink1200hive313spark420scala213 fi # Copy bundle jars to temp dir for mounting @@ -199,6 +209,11 @@ else HUDI_SPARK_BUNDLE_NAME=hudi-spark4.1-bundle_2.13 HUDI_UTILITIES_BUNDLE_NAME=hudi-utilities-bundle_2.13 HUDI_UTILITIES_SLIM_BUNDLE_NAME=hudi-utilities-slim-bundle_2.13 + elif [[ ${SPARK_PROFILE} == 'spark4.2' && ${SCALA_PROFILE} == 'scala-2.13' ]]; then + HUDI_CLI_BUNDLE_NAME=hudi-cli-bundle_2.13 + HUDI_SPARK_BUNDLE_NAME=hudi-spark4.2-bundle_2.13 + HUDI_UTILITIES_BUNDLE_NAME=hudi-utilities-bundle_2.13 + HUDI_UTILITIES_SLIM_BUNDLE_NAME=hudi-utilities-slim-bundle_2.13 elif [[ ${SPARK_PROFILE} == 'spark3' ]]; then HUDI_CLI_BUNDLE_NAME=hudi-cli-bundle_2.12 HUDI_SPARK_BUNDLE_NAME=hudi-spark3-bundle_2.12 diff --git a/packaging/bundle-validation/run_docker_java17.sh b/packaging/bundle-validation/run_docker_java17.sh index a380319a210ab..4844155ad4e03 100755 --- a/packaging/bundle-validation/run_docker_java17.sh +++ b/packaging/bundle-validation/run_docker_java17.sh @@ -93,6 +93,16 @@ elif [[ ${SPARK_RUNTIME} == 'spark4.1.1' && ${SCALA_PROFILE} == 'scala-2.13' ]]; CONFLUENT_VERSION=5.5.12 KAFKA_CONNECT_HDFS_VERSION=10.1.13 IMAGE_TAG=flink1200hive313spark411scala213 +elif [[ ${SPARK_RUNTIME} == 'spark4.2.0-preview4' && ${SCALA_PROFILE} == 'scala-2.13' ]]; then + HADOOP_VERSION=3.4.3 + HIVE_VERSION=3.1.3 + DERBY_VERSION=10.14.1.0 + FLINK_VERSION=1.20.1 + SPARK_VERSION=4.2.0-preview4 + SPARK_HADOOP_VERSION=3 + CONFLUENT_VERSION=5.5.12 + KAFKA_CONNECT_HDFS_VERSION=10.1.13 + IMAGE_TAG=flink1200hive313spark420scala213 fi # build docker image diff --git a/pom.xml b/pom.xml index 4afc759e0e378..1f11a658938c0 100644 --- a/pom.xml +++ b/pom.xml @@ -115,6 +115,8 @@ 5.5.0 2.17 3.0.1-b12 + org.lz4 + 1.8.0 1.13.1 5.14.1 5.14.1 @@ -177,6 +179,7 @@ 3.5.5 4.0.2 4.1.1 + 4.2.0-preview4 hudi-spark3.5.x hudi-spark3-common 1.11.4 @@ -2870,6 +2873,76 @@ + + spark4.2 + + ${spark42.version} + ${spark4.version} + 4.2 + 2.13.18 + ${scala13.version} + 2.13 + hudi-spark4.2.x + + hudi-spark4-common + ${scalatest.spark4.version} + 3.4.3 + 3.9.2 + 2.8.1 + + lance-spark-4.0_2.13 + true + + 1.17.0 + 2.3.0 + 1.12.1 + 4.13.1 + 2.21.2 + ${fasterxml.spark4.version} + 2.21 + ${fasterxml.spark4.version} + ${fasterxml.spark4.version} + ${fasterxml.spark4.version} + + ${pulsar.spark.scala13.version} + 3.1.11 + 2.25.4 + 2.0.17 + + at.yawk.lz4 + 1.10.4 + true + true + + + hudi-spark-datasource/hudi-spark4.2.x + hudi-spark-datasource/hudi-spark4-common + + + + org.slf4j + slf4j-log4j12 + ${slf4j.version} + test + + + ${hive.groupid} + hive-storage-api + ${hive.storage.version} + + + + + spark4.2 + + + + flink2.1 diff --git a/scripts/release/deploy_staging_jars_java17.sh b/scripts/release/deploy_staging_jars_java17.sh index c034f15d3c97f..b70ca05635ec4 100755 --- a/scripts/release/deploy_staging_jars_java17.sh +++ b/scripts/release/deploy_staging_jars_java17.sh @@ -46,6 +46,11 @@ declare -a ALL_VERSION_OPTS=( # hudi-spark4.1.x_2.13 # hudi-spark4.1-bundle_2.13 "-T 1C -Djava17 -Djava.version=17 -Dscala-2.13 -Dspark4.1 -pl hudi-spark-datasource/hudi-spark4-common,hudi-spark-datasource/hudi-spark4.1.x,packaging/hudi-spark-bundle -am" +# For Spark 4.2, Scala 2.13: +# hudi-spark4-common +# hudi-spark4.2.x_2.13 +# hudi-spark4.2-bundle_2.13 +"-T 1C -Djava17 -Djava.version=17 -Dscala-2.13 -Dspark4.2 -pl hudi-spark-datasource/hudi-spark4-common,hudi-spark-datasource/hudi-spark4.2.x,packaging/hudi-spark-bundle -am" ) printf -v joined "'%s'\n" "${ALL_VERSION_OPTS[@]}" diff --git a/scripts/release/validate_staged_bundles.sh b/scripts/release/validate_staged_bundles.sh index 06763dc27e379..bb31f4241afd0 100755 --- a/scripts/release/validate_staged_bundles.sh +++ b/scripts/release/validate_staged_bundles.sh @@ -37,7 +37,7 @@ declare -a bundles=("hudi-aws-bundle" "hudi-cli-bundle_2.12" "hudi-cli-bundle_2. "hudi-flink2.0-bundle" "hudi-gcp-bundle" "hudi-hadoop-mr-bundle" "hudi-hive-sync-bundle" "hudi-integ-test-bundle" "hudi-kafka-connect-bundle" "hudi-metaserver-server-bundle" "hudi-presto-bundle" "hudi-spark3.3-bundle_2.12" "hudi-spark3.4-bundle_2.12" "hudi-spark3.5-bundle_2.12" -"hudi-spark3.5-bundle_2.13" "hudi-spark4.0-bundle_2.13" "hudi-spark4.1-bundle_2.13" "hudi-timeline-server-bundle" "hudi-trino-bundle" +"hudi-spark3.5-bundle_2.13" "hudi-spark4.0-bundle_2.13" "hudi-spark4.1-bundle_2.13" "hudi-spark4.2-bundle_2.13" "hudi-timeline-server-bundle" "hudi-trino-bundle" "hudi-utilities-bundle_2.12" "hudi-utilities-bundle_2.13" "hudi-utilities-slim-bundle_2.12" "hudi-utilities-slim-bundle_2.13")