diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 67aefb392..206f0f442 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -5757,3 +5757,156 @@ impl From for crate::ast::Statement { crate::ast::Statement::AlterPolicy(v) } } + +/// CREATE AGGREGATE statement. +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct CreateAggregate { + /// True if `OR REPLACE` was specified. + pub or_replace: bool, + /// The aggregate name (can be schema-qualified). + pub name: ObjectName, + /// Input argument types. Empty for zero-argument aggregates. + pub args: Vec, + /// The options listed inside the required parentheses after the argument + /// list (e.g. `SFUNC`, `STYPE`, `FINALFUNC`, `PARALLEL`, …). + pub options: Vec, +} + +impl fmt::Display for CreateAggregate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CREATE")?; + if self.or_replace { + write!(f, " OR REPLACE")?; + } + write!(f, " AGGREGATE {}", self.name)?; + write!(f, " ({})", display_comma_separated(&self.args))?; + write!(f, " (")?; + for (i, option) in self.options.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{option}")?; + } + write!(f, ")") + } +} + +impl From for crate::ast::Statement { + fn from(v: CreateAggregate) -> Self { + crate::ast::Statement::CreateAggregate(v) + } +} + +/// A single option in a `CREATE AGGREGATE` options list. +/// +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum CreateAggregateOption { + /// `SFUNC = state_transition_function` + Sfunc(ObjectName), + /// `STYPE = state_data_type` + Stype(DataType), + /// `SSPACE = state_data_size` (in bytes) + Sspace(u64), + /// `FINALFUNC = final_function` + Finalfunc(ObjectName), + /// `FINALFUNC_EXTRA` — pass extra dummy arguments to the final function. + FinalfuncExtra, + /// `FINALFUNC_MODIFY = { READ_ONLY | SHAREABLE | READ_WRITE }` + FinalfuncModify(AggregateModifyKind), + /// `COMBINEFUNC = combine_function` + Combinefunc(ObjectName), + /// `SERIALFUNC = serial_function` + Serialfunc(ObjectName), + /// `DESERIALFUNC = deserial_function` + Deserialfunc(ObjectName), + /// `INITCOND = initial_condition` (a string literal) + Initcond(ValueWithSpan), + /// `MSFUNC = moving_state_transition_function` + Msfunc(ObjectName), + /// `MINVFUNC = moving_inverse_transition_function` + Minvfunc(ObjectName), + /// `MSTYPE = moving_state_data_type` + Mstype(DataType), + /// `MSSPACE = moving_state_data_size` (in bytes) + Msspace(u64), + /// `MFINALFUNC = moving_final_function` + Mfinalfunc(ObjectName), + /// `MFINALFUNC_EXTRA` + MfinalfuncExtra, + /// `MFINALFUNC_MODIFY = { READ_ONLY | SHAREABLE | READ_WRITE }` + MfinalfuncModify(AggregateModifyKind), + /// `MINITCOND = moving_initial_condition` (a string literal) + Minitcond(ValueWithSpan), + /// `SORTOP = sort_operator` + Sortop(ObjectName), + /// `PARALLEL = { SAFE | RESTRICTED | UNSAFE }` + Parallel(FunctionParallel), + /// `HYPOTHETICAL` — marks the aggregate as hypothetical-set. + Hypothetical, +} + +impl fmt::Display for CreateAggregateOption { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Sfunc(name) => write!(f, "SFUNC = {name}"), + Self::Stype(data_type) => write!(f, "STYPE = {data_type}"), + Self::Sspace(size) => write!(f, "SSPACE = {size}"), + Self::Finalfunc(name) => write!(f, "FINALFUNC = {name}"), + Self::FinalfuncExtra => write!(f, "FINALFUNC_EXTRA"), + Self::FinalfuncModify(kind) => write!(f, "FINALFUNC_MODIFY = {kind}"), + Self::Combinefunc(name) => write!(f, "COMBINEFUNC = {name}"), + Self::Serialfunc(name) => write!(f, "SERIALFUNC = {name}"), + Self::Deserialfunc(name) => write!(f, "DESERIALFUNC = {name}"), + Self::Initcond(cond) => write!(f, "INITCOND = {cond}"), + Self::Msfunc(name) => write!(f, "MSFUNC = {name}"), + Self::Minvfunc(name) => write!(f, "MINVFUNC = {name}"), + Self::Mstype(data_type) => write!(f, "MSTYPE = {data_type}"), + Self::Msspace(size) => write!(f, "MSSPACE = {size}"), + Self::Mfinalfunc(name) => write!(f, "MFINALFUNC = {name}"), + Self::MfinalfuncExtra => write!(f, "MFINALFUNC_EXTRA"), + Self::MfinalfuncModify(kind) => write!(f, "MFINALFUNC_MODIFY = {kind}"), + Self::Minitcond(cond) => write!(f, "MINITCOND = {cond}"), + Self::Sortop(name) => write!(f, "SORTOP = {name}"), + Self::Parallel(parallel) => { + let kind = match parallel { + FunctionParallel::Safe => "SAFE", + FunctionParallel::Restricted => "RESTRICTED", + FunctionParallel::Unsafe => "UNSAFE", + }; + write!(f, "PARALLEL = {kind}") + } + Self::Hypothetical => write!(f, "HYPOTHETICAL"), + } + } +} + +/// Modifier kind for `FINALFUNC_MODIFY` / `MFINALFUNC_MODIFY` in `CREATE AGGREGATE`. +/// +/// See +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum AggregateModifyKind { + /// The final function does not modify the transition state. + ReadOnly, + /// The transition state may be shared between aggregate calls. + Shareable, + /// The final function may modify the transition state. + ReadWrite, +} + +impl fmt::Display for AggregateModifyKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::ReadOnly => write!(f, "READ_ONLY"), + Self::Shareable => write!(f, "SHAREABLE"), + Self::ReadWrite => write!(f, "READ_WRITE"), + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 886bea26d..ccd9cab21 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -60,16 +60,17 @@ pub use self::dcl::{ SetConfigValue, Use, }; pub use self::ddl::{ - Alignment, AlterCollation, AlterCollationOperation, AlterColumnOperation, AlterConnectorOwner, - AlterFunction, AlterFunctionAction, AlterFunctionKind, AlterFunctionOperation, - AlterIndexOperation, AlterOperator, AlterOperatorClass, AlterOperatorClassOperation, - AlterOperatorFamily, AlterOperatorFamilyOperation, AlterOperatorOperation, AlterPolicy, - AlterPolicyOperation, AlterSchema, AlterSchemaOperation, AlterTable, AlterTableAlgorithm, - AlterTableLock, AlterTableOperation, AlterTableType, AlterType, AlterTypeAddValue, - AlterTypeAddValuePosition, AlterTypeOperation, AlterTypeRename, AlterTypeRenameValue, - ClusteredBy, ColumnDef, ColumnOption, ColumnOptionDef, ColumnOptions, ColumnPolicy, - ColumnPolicyProperty, ConstraintCharacteristics, CreateCollation, CreateCollationDefinition, - CreateConnector, CreateDomain, CreateExtension, CreateFunction, CreateIndex, CreateOperator, + AggregateModifyKind, Alignment, AlterCollation, AlterCollationOperation, AlterColumnOperation, + AlterConnectorOwner, AlterFunction, AlterFunctionAction, AlterFunctionKind, + AlterFunctionOperation, AlterIndexOperation, AlterOperator, AlterOperatorClass, + AlterOperatorClassOperation, AlterOperatorFamily, AlterOperatorFamilyOperation, + AlterOperatorOperation, AlterPolicy, AlterPolicyOperation, AlterSchema, AlterSchemaOperation, + AlterTable, AlterTableAlgorithm, AlterTableLock, AlterTableOperation, AlterTableType, + AlterType, AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, AlterTypeRename, + AlterTypeRenameValue, ClusteredBy, ColumnDef, ColumnOption, ColumnOptionDef, ColumnOptions, + ColumnPolicy, ColumnPolicyProperty, ConstraintCharacteristics, CreateAggregate, + CreateAggregateOption, CreateCollation, CreateCollationDefinition, CreateConnector, + CreateDomain, CreateExtension, CreateFunction, CreateIndex, CreateOperator, CreateOperatorClass, CreateOperatorFamily, CreatePolicy, CreatePolicyCommand, CreatePolicyType, CreateTable, CreateTrigger, CreateView, Deduplicate, DeferrableInitial, DistStyle, DropBehavior, DropExtension, DropFunction, DropOperator, DropOperatorClass, DropOperatorFamily, @@ -3762,6 +3763,11 @@ pub enum Statement { /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createopclass.html) CreateOperatorClass(CreateOperatorClass), /// ```sql + /// CREATE AGGREGATE + /// ``` + /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createaggregate.html) + CreateAggregate(CreateAggregate), + /// ```sql /// ALTER TABLE /// ``` AlterTable(AlterTable), @@ -5549,6 +5555,7 @@ impl fmt::Display for Statement { create_operator_family.fmt(f) } Statement::CreateOperatorClass(create_operator_class) => create_operator_class.fmt(f), + Statement::CreateAggregate(create_aggregate) => create_aggregate.fmt(f), Statement::AlterTable(alter_table) => write!(f, "{alter_table}"), Statement::AlterIndex { name, operation } => { write!(f, "ALTER INDEX {name} {operation}") diff --git a/src/ast/spans.rs b/src/ast/spans.rs index adc1443fc..34ec5d233 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -518,6 +518,7 @@ impl Spanned for Statement { Statement::Vacuum(..) => Span::empty(), Statement::AlterUser(..) => Span::empty(), Statement::Reset(..) => Span::empty(), + Statement::CreateAggregate(stmt) => stmt.name.span(), } } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 7501919a0..c1224a126 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5172,9 +5172,11 @@ impl<'a> Parser<'a> { self.parse_create_secret(or_replace, temporary, persistent) } else if self.parse_keyword(Keyword::USER) { self.parse_create_user(or_replace).map(Into::into) + } else if self.parse_keyword(Keyword::AGGREGATE) { + self.parse_create_aggregate(or_replace).map(Into::into) } else if or_replace { self.expected_ref( - "[EXTERNAL] TABLE or [MATERIALIZED] VIEW or FUNCTION after CREATE OR REPLACE", + "[EXTERNAL] TABLE or [MATERIALIZED] VIEW or FUNCTION or AGGREGATE after CREATE OR REPLACE", self.peek_token_ref(), ) } else if self.parse_keyword(Keyword::EXTENSION) { @@ -7209,6 +7211,175 @@ impl<'a> Parser<'a> { }) } + /// Parse a [Statement::CreateAggregate] + /// + /// [PostgreSQL Documentation](https://www.postgresql.org/docs/current/sql-createaggregate.html) + pub fn parse_create_aggregate( + &mut self, + or_replace: bool, + ) -> Result { + let name = self.parse_object_name(false)?; + + // Argument type list: `(input_data_type [, ...])` or `(*)` for zero-arg. + self.expect_token(&Token::LParen)?; + let args = if self.consume_token(&Token::Mul) || self.peek_token().token == Token::RParen { + vec![] + } else { + self.parse_comma_separated(|p| p.parse_data_type())? + }; + self.expect_token(&Token::RParen)?; + + // Options block: `( SFUNC = ..., STYPE = ..., ... )`. + self.expect_token(&Token::LParen)?; + let options = self.parse_comma_separated(|parser| { + let key = parser.parse_identifier()?; + parser.parse_create_aggregate_option(&key.value.to_uppercase()) + })?; + self.expect_token(&Token::RParen)?; + + Ok(CreateAggregate { + or_replace, + name, + args, + options, + }) + } + + fn parse_create_aggregate_option( + &mut self, + key: &str, + ) -> Result { + match key { + "SFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Sfunc(self.parse_object_name(false)?)) + } + "STYPE" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Stype(self.parse_data_type()?)) + } + "SSPACE" => { + self.expect_token(&Token::Eq)?; + let size = self.parse_literal_uint()?; + Ok(CreateAggregateOption::Sspace(size)) + } + "FINALFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Finalfunc( + self.parse_object_name(false)?, + )) + } + "FINALFUNC_EXTRA" => Ok(CreateAggregateOption::FinalfuncExtra), + "FINALFUNC_MODIFY" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::FinalfuncModify( + self.parse_aggregate_modify_kind()?, + )) + } + "COMBINEFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Combinefunc( + self.parse_object_name(false)?, + )) + } + "SERIALFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Serialfunc( + self.parse_object_name(false)?, + )) + } + "DESERIALFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Deserialfunc( + self.parse_object_name(false)?, + )) + } + "INITCOND" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Initcond(self.parse_value()?)) + } + "MSFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Msfunc( + self.parse_object_name(false)?, + )) + } + "MINVFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Minvfunc( + self.parse_object_name(false)?, + )) + } + "MSTYPE" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Mstype(self.parse_data_type()?)) + } + "MSSPACE" => { + self.expect_token(&Token::Eq)?; + let size = self.parse_literal_uint()?; + Ok(CreateAggregateOption::Msspace(size)) + } + "MFINALFUNC" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Mfinalfunc( + self.parse_object_name(false)?, + )) + } + "MFINALFUNC_EXTRA" => Ok(CreateAggregateOption::MfinalfuncExtra), + "MFINALFUNC_MODIFY" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::MfinalfuncModify( + self.parse_aggregate_modify_kind()?, + )) + } + "MINITCOND" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Minitcond(self.parse_value()?)) + } + "SORTOP" => { + self.expect_token(&Token::Eq)?; + Ok(CreateAggregateOption::Sortop(self.parse_operator_name()?)) + } + "PARALLEL" => { + self.expect_token(&Token::Eq)?; + let parallel = if self.parse_keyword(Keyword::SAFE) { + FunctionParallel::Safe + } else if self.parse_keyword(Keyword::RESTRICTED) { + FunctionParallel::Restricted + } else if self.parse_keyword(Keyword::UNSAFE) { + FunctionParallel::Unsafe + } else { + return self.expected_ref( + "SAFE, RESTRICTED, or UNSAFE after PARALLEL =", + self.peek_token_ref(), + ); + }; + Ok(CreateAggregateOption::Parallel(parallel)) + } + "HYPOTHETICAL" => Ok(CreateAggregateOption::Hypothetical), + other => Err(ParserError::ParserError(format!( + "Unknown CREATE AGGREGATE option: {other}" + ))), + } + } + + fn parse_aggregate_modify_kind(&mut self) -> Result { + let token = self.next_token(); + match &token.token { + Token::Word(word) => match word.value.to_uppercase().as_str() { + "READ_ONLY" => Ok(AggregateModifyKind::ReadOnly), + "SHAREABLE" => Ok(AggregateModifyKind::Shareable), + "READ_WRITE" => Ok(AggregateModifyKind::ReadWrite), + other => Err(ParserError::ParserError(format!( + "Expected READ_ONLY, SHAREABLE, or READ_WRITE, got: {other}" + ))), + }, + other => Err(ParserError::ParserError(format!( + "Expected READ_ONLY, SHAREABLE, or READ_WRITE, got: {other:?}" + ))), + } + } + /// Parse a [Statement::CreateOperatorFamily] /// /// [PostgreSQL Documentation](https://www.postgresql.org/docs/current/sql-createopfamily.html) diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 86315b1ef..0382f2fc4 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -9221,3 +9221,57 @@ fn parse_lock_table() { } } } + +#[test] +fn parse_create_aggregate_basic() { + let sql = "CREATE AGGREGATE myavg (NUMERIC) (SFUNC = numeric_avg_accum, STYPE = internal, FINALFUNC = numeric_avg, INITCOND = '0')"; + let stmt = pg().verified_stmt(sql); + match stmt { + Statement::CreateAggregate(agg) => { + assert!(!agg.or_replace); + assert_eq!(agg.name.to_string(), "myavg"); + assert_eq!(agg.args.len(), 1); + assert_eq!(agg.args[0].to_string(), "NUMERIC"); + assert_eq!(agg.options.len(), 4); + assert_eq!(agg.options[0].to_string(), "SFUNC = numeric_avg_accum"); + assert_eq!(agg.options[1].to_string(), "STYPE = internal"); + assert_eq!(agg.options[2].to_string(), "FINALFUNC = numeric_avg"); + assert_eq!(agg.options[3].to_string(), "INITCOND = '0'"); + } + _ => panic!("Expected CreateAggregate, got: {stmt:?}"), + } +} + +#[test] +fn parse_create_aggregate_or_replace_with_parallel() { + let sql = "CREATE OR REPLACE AGGREGATE sum2 (INT4, INT4) (SFUNC = int4pl, STYPE = INT4, PARALLEL = SAFE)"; + let stmt = pg().verified_stmt(sql); + match stmt { + Statement::CreateAggregate(agg) => { + assert!(agg.or_replace); + assert_eq!(agg.name.to_string(), "sum2"); + assert_eq!(agg.args.len(), 2); + assert_eq!(agg.options.len(), 3); + assert_eq!(agg.options[2].to_string(), "PARALLEL = SAFE"); + } + _ => panic!("Expected CreateAggregate, got: {stmt:?}"), + } +} + +#[test] +fn parse_create_aggregate_with_moving_aggregate_options() { + let sql = "CREATE AGGREGATE moving_sum (FLOAT8) (SFUNC = float8pl, STYPE = FLOAT8, MSFUNC = float8pl, MINVFUNC = float8mi, MSTYPE = FLOAT8, MFINALFUNC_EXTRA, MFINALFUNC_MODIFY = READ_ONLY)"; + let stmt = pg().verified_stmt(sql); + match stmt { + Statement::CreateAggregate(agg) => { + assert!(!agg.or_replace); + assert_eq!(agg.name.to_string(), "moving_sum"); + assert_eq!(agg.args.len(), 1); + assert_eq!(agg.options.len(), 7); + assert_eq!(agg.options[4].to_string(), "MSTYPE = FLOAT8"); + assert_eq!(agg.options[5].to_string(), "MFINALFUNC_EXTRA"); + assert_eq!(agg.options[6].to_string(), "MFINALFUNC_MODIFY = READ_ONLY"); + } + _ => panic!("Expected CreateAggregate, got: {stmt:?}"), + } +}