From d11d5b95ef82a208c579daa0073bdc072a682be5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Fri, 1 May 2015 23:50:12 +0800 Subject: [PATCH] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 12 ++++++++++++ python/pyspark/sql/tests.py | 6 ++++++ .../main/scala/org/apache/spark/sql/Column.scala | 14 ++++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 6 ++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 11 +++++++++++ 5 files changed, 49 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5908ebc990a56..a4cbc7396e386 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1289,6 +1289,18 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) + @ignore_unicode_prefix + def between(self, col1, col2): + """ A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + + >>> df[df.col1.between(col2, col3)].collect() + [Row(col1=5, col2=6, col3=8)] + """ + #sc = SparkContext._active_spark_context + jc = self > col1 & self < col2 + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5640bb5ea2346..206e3b7fd08f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -426,6 +426,12 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + def test_between_function(self): + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF() + self.assertEqual([False, True, False], + df.select(df.a.between(df.b, df.c)).collect()) + + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 33f9d0b37d006..8e0eab7918131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,6 +295,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Between col1 and col2. + * + * @group java_expr_ops + */ + def between(col1: String, col2: String): Column = between(Column(col1), Column(col2)) + + /** + * Between col1 and col2. + * + * @group java_expr_ops + */ + def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr)) + /** * True if the current expression is null. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6322faf4d9907..0a81f884e9a16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -208,6 +208,12 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } + test("between") { + checkAnswer( + testData4.filter($"a".between($"b", $"c")), + testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2))) + } + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73d6c..487d07249922f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -57,6 +57,17 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") + case class TestData4(a: Int, b: Int, c: Int) + val testData4 = + TestSQLContext.sparkContext.parallelize( + TestData4(0, 1, 2) :: + TestData4(1, 2, 3) :: + TestData4(2, 1, 0) :: + TestData4(2, 2, 4) :: + TestData4(3, 1, 6) :: + TestData4(3, 2, 0) :: Nil, 2).toDF() + testData4.registerTempTable("TestData4") + case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData =