diff --git a/SqlScriptDom/Parser/TSql/TSql160.g b/SqlScriptDom/Parser/TSql/TSql160.g index 4d794ca..78169c8 100644 --- a/SqlScriptDom/Parser/TSql/TSql160.g +++ b/SqlScriptDom/Parser/TSql/TSql160.g @@ -32171,6 +32171,9 @@ builtInFunctionCall returns [FunctionCall vResult = FragmentFactory.CreateFragme | aggregateBuiltInFunctionCall[vResult] ) + { + NormalizeDatePartFirstArgument(vResult); + } ; jsonArrayBuiltInFunctionCall [FunctionCall vParent] diff --git a/SqlScriptDom/Parser/TSql/TSql170.g b/SqlScriptDom/Parser/TSql/TSql170.g index 88ce660..e56bc8a 100644 --- a/SqlScriptDom/Parser/TSql/TSql170.g +++ b/SqlScriptDom/Parser/TSql/TSql170.g @@ -33172,6 +33172,9 @@ builtInFunctionCall returns [FunctionCall vResult = FragmentFactory.CreateFragme | aggregateBuiltInFunctionCall[vResult] ) + { + NormalizeDatePartFirstArgument(vResult); + } ; jsonArrayBuiltInFunctionCall [FunctionCall vParent] diff --git a/SqlScriptDom/Parser/TSql/TSql80ParserBaseInternal.cs b/SqlScriptDom/Parser/TSql/TSql80ParserBaseInternal.cs index 4f47429..a7cdf82 100644 --- a/SqlScriptDom/Parser/TSql/TSql80ParserBaseInternal.cs +++ b/SqlScriptDom/Parser/TSql/TSql80ParserBaseInternal.cs @@ -34,6 +34,17 @@ internal abstract class TSql80ParserBaseInternal : antlr.LLkParser private static readonly antlr.collections.impl.BitSet _ddlStatementBeginnerTokens = new antlr.collections.impl.BitSet(2); + private static readonly HashSet _datePartFirstArgumentBuiltInFunctions = new HashSet(StringComparer.OrdinalIgnoreCase) + { + "DATEADD", + "DATEDIFF", + "DATEDIFF_BIG", + "DATENAME", + "DATEPART", + "DATETRUNC", + "DATE_BUCKET" + }; + const int LookAhead = 2; //private static HashSet _languageString = new HashSet(StringComparer.OrdinalIgnoreCase) @@ -1979,6 +1990,48 @@ protected void PutIdentifiersIntoFunctionCall(FunctionCall functionCall, MultiPa } } + protected void NormalizeDatePartFirstArgument(FunctionCall functionCall) + { + if (functionCall == null || + functionCall.FunctionName == null || + functionCall.Parameters == null || + functionCall.Parameters.Count == 0 || + !_datePartFirstArgumentBuiltInFunctions.Contains(functionCall.FunctionName.Value)) + { + return; + } + + ColumnReferenceExpression firstParameter = functionCall.Parameters[0] as ColumnReferenceExpression; + if (firstParameter == null || + firstParameter.ColumnType != ColumnType.Regular || + firstParameter.MultiPartIdentifier == null || + firstParameter.MultiPartIdentifier.Count != 1 || + firstParameter.MultiPartIdentifier.Identifiers == null || + firstParameter.MultiPartIdentifier.Identifiers.Count != 1) + { + return; + } + + Identifier firstIdentifier = firstParameter.MultiPartIdentifier.Identifiers[0]; + if (firstIdentifier == null) + { + return; + } + + IdentifierLiteral identifierLiteral = FragmentFactory.CreateFragment(); + if (firstIdentifier.QuoteType == QuoteType.NotQuoted) + { + identifierLiteral.SetUnquotedIdentifier(firstIdentifier.Value); + } + else + { + identifierLiteral.SetIdentifier(Identifier.EncodeIdentifier(firstIdentifier.Value, firstIdentifier.QuoteType)); + } + + identifierLiteral.UpdateTokenInfo(firstParameter); + functionCall.Parameters[0] = identifierLiteral; + } + protected void VerifyColumnDataType(ColumnDefinition column) { // If the scalarDataType is not parsed, the ColumnIdentifier has to be a timestamp. @@ -2426,4 +2479,4 @@ protected static TSqlParseErrorException GetUnexpectedTokenErrorException(Identi #endregion } -} \ No newline at end of file +} diff --git a/SqlScriptDom/Parser/TSql/TSqlFabricDW.g b/SqlScriptDom/Parser/TSql/TSqlFabricDW.g index 21fd33d..8515ea2 100644 --- a/SqlScriptDom/Parser/TSql/TSqlFabricDW.g +++ b/SqlScriptDom/Parser/TSql/TSqlFabricDW.g @@ -32337,6 +32337,9 @@ builtInFunctionCall returns [FunctionCall vResult = FragmentFactory.CreateFragme | aggregateBuiltInFunctionCall[vResult] ) + { + NormalizeDatePartFirstArgument(vResult); + } ; jsonArrayBuiltInFunctionCall [FunctionCall vParent] diff --git a/Test/SqlDom/TSqlParserTest.cs b/Test/SqlDom/TSqlParserTest.cs index ef6575e..235932a 100644 --- a/Test/SqlDom/TSqlParserTest.cs +++ b/Test/SqlDom/TSqlParserTest.cs @@ -563,6 +563,64 @@ public void OpenRowsetBulkWithOneFile() )); } + [TestMethod] + [Priority(0)] + [SqlStudioTestCategory(Category.UnitTest)] + [Timeout(GlobalConstants.DefaultTestTimeout)] + public void DateDiffDatePartIsIdentifierLiteralIn160Parser() + { + AssertDateDiffDatePartIsIdentifierLiteral(new TSql160Parser(true)); + } + + [TestMethod] + [Priority(0)] + [SqlStudioTestCategory(Category.UnitTest)] + [Timeout(GlobalConstants.DefaultTestTimeout)] + public void DateDiffDatePartIsIdentifierLiteralIn170Parser() + { + AssertDateDiffDatePartIsIdentifierLiteral(new TSql170Parser(true)); + } + + [TestMethod] + [Priority(0)] + [SqlStudioTestCategory(Category.UnitTest)] + [Timeout(GlobalConstants.DefaultTestTimeout)] + public void DateDiffDatePartIsIdentifierLiteralInFabricDWParser() + { + AssertDateDiffDatePartIsIdentifierLiteral(new TSqlFabricDWParser(true)); + } + + private static void AssertDateDiffDatePartIsIdentifierLiteral(TSqlParser parser) + { + const string input = "SELECT DATEDIFF(mm, ColA, ColB) FROM my_table;"; + IList errors; + TSqlScript script = (TSqlScript)parser.Parse(new StringReader(input), out errors); + + Assert.AreEqual(0, errors.Count, "Unexpected parsing error"); + + SelectStatement selectStatement = script.Batches[0].Statements[0] as SelectStatement; + Assert.IsNotNull(selectStatement); + + QuerySpecification querySpecification = selectStatement.QueryExpression as QuerySpecification; + Assert.IsNotNull(querySpecification); + + SelectScalarExpression selectScalarExpression = querySpecification.SelectElements[0] as SelectScalarExpression; + Assert.IsNotNull(selectScalarExpression); + + FunctionCall functionCall = selectScalarExpression.Expression as FunctionCall; + Assert.IsNotNull(functionCall); + Assert.IsTrue(string.Equals(functionCall.FunctionName.Value, "DATEDIFF", StringComparison.OrdinalIgnoreCase)); + Assert.AreEqual(3, functionCall.Parameters.Count); + + Assert.IsInstanceOfType(functionCall.Parameters[0], typeof(IdentifierLiteral)); + IdentifierLiteral datePartLiteral = functionCall.Parameters[0] as IdentifierLiteral; + Assert.IsNotNull(datePartLiteral); + Assert.IsTrue(string.Equals(datePartLiteral.Value, "mm", StringComparison.OrdinalIgnoreCase)); + + Assert.IsInstanceOfType(functionCall.Parameters[1], typeof(ColumnReferenceExpression)); + Assert.IsInstanceOfType(functionCall.Parameters[2], typeof(ColumnReferenceExpression)); + } + [TestMethod] [Priority(0)] [SqlStudioTestCategory(Category.UnitTest)]