Skip to content

Commit

Permalink
The flatten command integration tests were extended with additional c…
Browse files Browse the repository at this point in the history
…hecks for logical plans.

Signed-off-by: Lukasz Soszynski <lukasz.soszynski@eliatra.com>
  • Loading branch information
lukasz-soszynski-eliatra committed Oct 29, 2024
1 parent 8f2ade9 commit 03222ce
Showing 1 changed file with 66 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
*/
package org.opensearch.flint.spark.ppl

import java.nio.file.{Files, Path}
import java.nio.file.Files

import org.opensearch.flint.spark.FlattenGenerator
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GeneratorOuter, Literal, Or}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -56,9 +57,24 @@ class FlintSparkPPLFlattenITSuite
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(Row(35, 51.5074, -0.1278), Row(null, null, null))
// // Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1))
assert(results.sorted.sameElements(expectedResults.sorted))
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("flint_ppl_test"))
val filter = Filter(
Or(
EqualTo(UnresolvedAttribute("country"), Literal("England")),
EqualTo(UnresolvedAttribute("country"), Literal("Poland"))),
table)
val projectCoor = Project(Seq(UnresolvedAttribute("coor")), filter)
val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("coor"))
val outerGenerator = GeneratorOuter(flattenGenerator)
val generate = Generate(outerGenerator, seq(), true, None, seq(), projectCoor)
val dropSourceColumn =
DataFrameDropColumns(Seq(UnresolvedAttribute("coor")), generate)
val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test flatten for arrays") {
Expand Down Expand Up @@ -86,6 +102,16 @@ class FlintSparkPPLFlattenITSuite
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0))
assert(results.sorted.sameElements(expectedResults.sorted))
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("flint_ppl_test"))
val projectCoor = Project(Seq(UnresolvedAttribute("bridges")), table)
val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("bridges"))
val outerGenerator = GeneratorOuter(flattenGenerator)
val generate = Generate(outerGenerator, seq(), true, None, seq(), projectCoor)
val dropSourceColumn =
DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generate)
val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumn)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test flatten for structs and arrays") {
Expand Down Expand Up @@ -174,9 +200,25 @@ class FlintSparkPPLFlattenITSuite
96,
47.4979,
19.0402))
// // Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](3))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("flint_ppl_test"))
val flattenGeneratorBridges = new FlattenGenerator(UnresolvedAttribute("bridges"))
val outerGeneratorBridges = GeneratorOuter(flattenGeneratorBridges)
val generateBridges = Generate(outerGeneratorBridges, seq(), true, None, seq(), table)
val dropSourceColumnBridges =
DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generateBridges)
val flattenGeneratorCoor = new FlattenGenerator(UnresolvedAttribute("coor"))
val outerGeneratorCoor = GeneratorOuter(flattenGeneratorCoor)
val generateCoor =
Generate(outerGeneratorCoor, seq(), true, None, seq(), dropSourceColumnBridges)
val dropSourceColumnCoor =
DataFrameDropColumns(Seq(UnresolvedAttribute("coor")), generateCoor)
val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnCoor)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test flatten and stats") {
Expand All @@ -200,5 +242,24 @@ class FlintSparkPPLFlattenITSuite
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("flint_ppl_test"))
val projectCountryBridges =
Project(Seq(UnresolvedAttribute("country"), UnresolvedAttribute("bridges")), table)
val flattenGenerator = new FlattenGenerator(UnresolvedAttribute("bridges"))
val outerGenerator = GeneratorOuter(flattenGenerator)
val generate = Generate(outerGenerator, seq(), true, None, seq(), projectCountryBridges)
val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generate)
val projectCountryLength = Project(
Seq(UnresolvedAttribute("country"), UnresolvedAttribute("length")),
dropSourceColumn)
val average = Alias(
UnresolvedFunction(seq("AVG"), seq(UnresolvedAttribute("length")), false, None, false),
"avg")()
val country = Alias(UnresolvedAttribute("country"), "country")()
val grouping = Alias(UnresolvedAttribute("country"), "country")()
val aggregate = Aggregate(Seq(grouping), Seq(average, country), projectCountryLength)
val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}

0 comments on commit 03222ce

Please sign in to comment.