diff --git a/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala index dd14d2b3ced..7fd4a804a52 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala @@ -3077,4 +3077,172 @@ class DeltaNameColumnMappingSuite extends DeltaSuite insertedDF.filter(col("id") >= 6).union(otherDF)) } } + + test("replaceWhere SQL - partition column - dynamic filter") { + withTempDir { dir => + // create partitioned table + spark.range(100).withColumn("part", 'id % 10) + .write + .format("delta") + .partitionBy("part") + .save(dir.toString) + + // ans will be used to replace the entire contents of the table + val ans = spark.range(10) + .withColumn("part", lit(0)) + + ans.createOrReplaceTempView("replace") + sql(s"INSERT INTO delta.`${dir.toString}` REPLACE WHERE part >=0 SELECT * FROM replace") + checkAnswer(spark.read.format("delta").load(dir.toString), ans) + } + } + + test("replaceWhere SQL - partition column - static filter") { + withTable("tbl") { + // create partitioned table + spark.range(100).withColumn("part", lit(0)) + .write + .format("delta") + .partitionBy("part") + .saveAsTable("tbl") + + val partEq1DF = spark.range(10, 20) + .withColumn("part", lit(1)) + partEq1DF.write.format("delta").mode("append").saveAsTable("tbl") + + + val replacer = spark.range(10) + .withColumn("part", lit(0)) + + replacer.createOrReplaceTempView("replace") + sql(s"INSERT INTO tbl REPLACE WHERE part=0 SELECT * FROM replace") + checkAnswer(spark.read.format("delta").table("tbl"), replacer.union(partEq1DF)) + } + } + + test("replaceWhere SQL - data column - dynamic") { + withTable("tbl") { + // write table + spark.range(100).withColumn("col", lit(1)) + .write + .format("delta") + .saveAsTable("tbl") + + val colGt2DF = spark.range(100, 200) + .withColumn("col", lit(3)) + + colGt2DF.write + .format("delta") + .mode("append") + .saveAsTable("tbl") + + val replacer = spark.range(10) + .withColumn("col", lit(1)) + + replacer.createOrReplaceTempView("replace") + sql(s"INSERT INTO tbl REPLACE WHERE col < 2 SELECT * FROM replace") + checkAnswer( + spark.read.format("delta").table("tbl"), + replacer.union(colGt2DF) + ) + } + } + + test("replaceWhere SQL - data column - static") { + withTempDir { dir => + // write table + spark.range(100).withColumn("col", lit(2)) + .write + .format("delta") + .save(dir.toString) + + val colEq2DF = spark.range(100, 200) + .withColumn("col", lit(1)) + + colEq2DF.write + .format("delta") + .mode("append") + .save(dir.toString) + + val replacer = spark.range(10) + .withColumn("col", lit(2)) + + replacer.createOrReplaceTempView("replace") + sql(s"INSERT INTO delta.`${dir.toString}` REPLACE WHERE col = 2 SELECT * FROM replace") + checkAnswer( + spark.read.format("delta").load(dir.toString), + replacer.union(colEq2DF) + ) + } + } + + test("replaceWhere SQL - multiple predicates - static") { + withTempDir { dir => + // write table + spark.range(100).withColumn("col", lit(2)) + .write + .format("delta") + .save(dir.toString) + + spark.range(100, 200).withColumn("col", lit(5)) + .write + .format("delta") + .mode("append") + .save(dir.toString) + + val colEq2DF = spark.range(100, 200) + .withColumn("col", lit(1)) + + colEq2DF.write + .format("delta") + .mode("append") + .save(dir.toString) + + val replacer = spark.range(10) + .withColumn("col", lit(2)) + + replacer.createOrReplaceTempView("replace") + sql(s"INSERT INTO delta.`${dir.toString}` REPLACE WHERE col = 2 OR col = 5 " + + s"SELECT * FROM replace") + checkAnswer( + spark.read.format("delta").load(dir.toString), + replacer.union(colEq2DF) + ) + } + } + + test("replaceWhere with less than predicate") { + withTempDir { dir => + val insertedDF = spark.range(10).toDF() + + insertedDF.write.format("delta").save(dir.toString) + + val otherDF = spark.range(start = 0, end = 4).toDF() + otherDF.write.format("delta").mode("overwrite") + .option(DeltaOptions.REPLACE_WHERE_OPTION, "id < 6") + .save(dir.toString) + checkAnswer(spark.read.load(dir.toString), + insertedDF.filter(col("id") >= 6).union(otherDF)) + } + } + + test("replaceWhere SQL with less than predicate") { + withTempDir { dir => + val insertedDF = spark.range(10).toDF() + + insertedDF.write.format("delta").save(dir.toString) + + val otherDF = spark.range(start = 0, end = 4).toDF() + otherDF.createOrReplaceTempView("replace") + + sql( + s""" + |INSERT INTO delta.`${dir.getAbsolutePath}` + |REPLACE WHERE id < 6 + |SELECT * FROM replace + |""".stripMargin) + checkAnswer(spark.read.load(dir.toString), + insertedDF.filter(col("id") >= 6).union(otherDF)) + } + } }