diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 49d6499c5..2454928d5 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -4594,13 +4594,35 @@ impl fmt::Display for GrantObjects { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct Assignment { - pub id: Vec, + pub target: AssignmentTarget, pub value: Expr, } impl fmt::Display for Assignment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} = {}", display_separated(&self.id, "."), self.value) + write!(f, "{} = {}", self.target, self.value) + } +} + +/// Left-hand side of an assignment in an UPDATE statement, +/// e.g. `foo` in `foo = 5` (ColumnName assignment) or +/// `(a, b)` in `(a, b) = (1, 2)` (Tuple assignment). +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum AssignmentTarget { + /// A single column + ColumnName(ObjectName), + /// A tuple of columns + Tuple(Vec), +} + +impl fmt::Display for AssignmentTarget { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AssignmentTarget::ColumnName(column) => write!(f, "{}", column), + AssignmentTarget::Tuple(columns) => write!(f, "({})", display_comma_separated(columns)), + } } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index c591b8116..8406eced5 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -9946,10 +9946,22 @@ impl<'a> Parser<'a> { /// Parse a `var = expr` assignment, used in an UPDATE statement pub fn parse_assignment(&mut self) -> Result { - let id = self.parse_identifiers()?; + let target = self.parse_assignment_target()?; self.expect_token(&Token::Eq)?; let value = self.parse_expr()?; - Ok(Assignment { id, value }) + Ok(Assignment { target, value }) + } + + /// Parse the left-hand side of an assignment, used in an UPDATE statement + pub fn parse_assignment_target(&mut self) -> Result { + if self.consume_token(&Token::LParen) { + let columns = self.parse_comma_separated(|p| p.parse_object_name(false))?; + self.expect_token(&Token::RParen)?; + Ok(AssignmentTarget::Tuple(columns)) + } else { + let column = self.parse_object_name(false)?; + Ok(AssignmentTarget::ColumnName(column)) + } } pub fn parse_function_args(&mut self) -> Result { diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 171439d19..fb6e3b88a 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -1590,11 +1590,11 @@ fn parse_merge() { let update_action = MergeAction::Update { assignments: vec![ Assignment { - id: vec![Ident::new("a")], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("a")])), value: Expr::Value(number("1")), }, Assignment { - id: vec![Ident::new("b")], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("b")])), value: Expr::Value(number("2")), }, ], diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index f6518e276..1727494c2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -296,15 +296,15 @@ fn parse_update() { assignments, vec![ Assignment { - id: vec!["a".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["a".into()])), value: Expr::Value(number("1")), }, Assignment { - id: vec!["b".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["b".into()])), value: Expr::Value(number("2")), }, Assignment { - id: vec!["c".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["c".into()])), value: Expr::Value(number("3")), }, ] @@ -363,7 +363,7 @@ fn parse_update_set_from() { joins: vec![], }, assignments: vec![Assignment { - id: vec![Ident::new("name")], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])), value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) }], from: Some(TableWithJoins { @@ -466,7 +466,10 @@ fn parse_update_with_table_alias() { ); assert_eq!( vec![Assignment { - id: vec![Ident::new("u"), Ident::new("username")], + target: AssignmentTarget::ColumnName(ObjectName(vec![ + Ident::new("u"), + Ident::new("username") + ])), value: Expr::Value(Value::SingleQuotedString("new_user".to_string())), }], assignments @@ -7696,14 +7699,20 @@ fn parse_merge() { action: MergeAction::Update { assignments: vec![ Assignment { - id: vec![Ident::new("dest"), Ident::new("F")], + target: AssignmentTarget::ColumnName(ObjectName(vec![ + Ident::new("dest"), + Ident::new("F") + ])), value: Expr::CompoundIdentifier(vec![ Ident::new("stg"), Ident::new("F"), ]), }, Assignment { - id: vec![Ident::new("dest"), Ident::new("G")], + target: AssignmentTarget::ColumnName(ObjectName(vec![ + Ident::new("dest"), + Ident::new("G") + ])), value: Expr::CompoundIdentifier(vec![ Ident::new("stg"), Ident::new("G"), diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index e65fc181b..ff8a49de7 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -1639,23 +1639,33 @@ fn parse_insert_with_on_duplicate_update() { assert_eq!( Some(OnInsert::DuplicateKeyUpdate(vec![ Assignment { - id: vec![Ident::new("description".to_string())], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( + "description".to_string() + )])), value: call("VALUES", [Expr::Identifier(Ident::new("description"))]), }, Assignment { - id: vec![Ident::new("perm_create".to_string())], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( + "perm_create".to_string() + )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_create"))]), }, Assignment { - id: vec![Ident::new("perm_read".to_string())], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( + "perm_read".to_string() + )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_read"))]), }, Assignment { - id: vec![Ident::new("perm_update".to_string())], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( + "perm_update".to_string() + )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_update"))]), }, Assignment { - id: vec![Ident::new("perm_delete".to_string())], + target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( + "perm_delete".to_string() + )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_delete"))]), }, ])), @@ -1835,7 +1845,10 @@ fn parse_update_with_joins() { ); assert_eq!( vec![Assignment { - id: vec![Ident::new("o"), Ident::new("completed")], + target: AssignmentTarget::ColumnName(ObjectName(vec![ + Ident::new("o"), + Ident::new("completed") + ])), value: Expr::Value(Value::Boolean(true)) }], assignments diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 5343fe5e0..fe735b8b2 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1557,7 +1557,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { - id: vec!["dname".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) },], selection: None @@ -1588,14 +1588,14 @@ fn parse_pg_on_conflict() { OnConflictAction::DoUpdate(DoUpdate { assignments: vec![ Assignment { - id: vec!["dname".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::CompoundIdentifier(vec![ "EXCLUDED".into(), "dname".into() ]) }, Assignment { - id: vec!["area".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["area".into()])), value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "area".into()]) }, ], @@ -1645,7 +1645,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { - id: vec!["dname".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::Value(Value::Placeholder("$1".to_string())) },], selection: Some(Expr::BinaryOp { @@ -1682,7 +1682,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { - id: vec!["dname".into()], + target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::Value(Value::Placeholder("$1".to_string())) },], selection: Some(Expr::BinaryOp { diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 16ea9eb8c..1181c480b 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -373,6 +373,40 @@ fn parse_attach_database() { } } +#[test] +fn parse_update_tuple_row_values() { + // See https://github.com/sqlparser-rs/sqlparser-rs/issues/1311 + assert_eq!( + sqlite().verified_stmt("UPDATE x SET (a, b) = (1, 2)"), + Statement::Update { + assignments: vec![Assignment { + target: AssignmentTarget::Tuple(vec![ + ObjectName(vec![Ident::new("a"),]), + ObjectName(vec![Ident::new("b"),]), + ]), + value: Expr::Tuple(vec![ + Expr::Value(Value::Number("1".parse().unwrap(), false)), + Expr::Value(Value::Number("2".parse().unwrap(), false)) + ]) + }], + selection: None, + table: TableWithJoins { + relation: TableFactor::Table { + name: ObjectName(vec![Ident::new("x")]), + alias: None, + args: None, + with_hints: vec![], + version: None, + partitions: vec![] + }, + joins: vec![], + }, + from: None, + returning: None + } + ); +} + #[test] fn parse_where_in_empty_list() { let sql = "SELECT * FROM t1 WHERE a IN ()";