Skip to content

Commit

Permalink
[SPARK-51105][ML][PYTHON][CONNECT][TESTS] Add parity test for ml func…
Browse files Browse the repository at this point in the history
…tions

### What changes were proposed in this pull request?
Add parity test for ml functions

### Why are the changes needed?
for test coverage

### Does this PR introduce _any_ user-facing change?
no, test-only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#49824 from zhengruifeng/ml_connect_f_ut.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng authored and Zeyu Chen committed Feb 6, 2025
1 parent 3e218f8 commit f960e82
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 8 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_parity_clustering",
"pyspark.ml.tests.connect.test_parity_evaluation",
"pyspark.ml.tests.connect.test_parity_feature",
"pyspark.ml.tests.connect.test_parity_functions",
"pyspark.ml.tests.connect.test_parity_pipeline",
"pyspark.ml.tests.connect.test_parity_tuning",
"pyspark.ml.tests.connect.test_parity_ovr",
Expand Down
54 changes: 54 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.ml.tests.test_functions import (
ArrayVectorConversionTestsMixin,
PredictBatchUDFTestsMixin,
)
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)


class ArrayVectorConversionParityTests(ArrayVectorConversionTestsMixin, ReusedConnectTestCase):
pass


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class PredictBatchUDFParityTests(PredictBatchUDFTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_functions import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
55 changes: 47 additions & 8 deletions python/pyspark/ml/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,58 @@
import numpy as np

from pyspark.loose_version import LooseVersion
from pyspark.ml.functions import predict_batch_udf
from pyspark.ml.linalg import DenseVector
from pyspark.ml.functions import array_to_vector, vector_to_array, predict_batch_udf
from pyspark.sql.functions import array, struct, col
from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, StructField, FloatType
from pyspark.testing.mlutils import SparkSessionTestCase
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
ReusedSQLTestCase,
)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class PredictBatchUDFTests(SparkSessionTestCase):
class ArrayVectorConversionTestsMixin:
def test_array_vector_conversion(self):
spark = self.spark

query = """
SELECT * FROM VALUES
(1, ARRAY(1.0, 2.0, 3.0)),
(1, ARRAY(-1.0, -2.0, -3.0))
AS tab(a, b)
"""

df = spark.sql(query)

df1 = df.select("*", array_to_vector(df.b).alias("c"))
self.assertEqual(df1.columns, ["a", "b", "c"])
self.assertEqual(df1.count(), 2)
self.assertEqual(
[r.c for r in df1.select("c").collect()],
[DenseVector([1.0, 2.0, 3.0]), DenseVector([-1.0, -2.0, -3.0])],
)

df2 = df1.select("*", vector_to_array(df1.c).alias("d"))
self.assertEqual(df2.columns, ["a", "b", "c", "d"])
self.assertEqual(df2.count(), 2)
self.assertEqual(
[r.d for r in df2.select("d").collect()],
[[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]],
)


class ArrayVectorConversionTests(ArrayVectorConversionTestsMixin, ReusedSQLTestCase):
pass


class PredictBatchUDFTestsMixin:
def setUp(self):
import pandas as pd

super(PredictBatchUDFTests, self).setUp()
super(PredictBatchUDFTestsMixin, self).setUp()
self.data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)

# 4 scalar columns
Expand Down Expand Up @@ -533,6 +564,14 @@ def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
self.assertEqual(value, 9.0)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class PredictBatchUDFTests(PredictBatchUDFTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
from pyspark.ml.tests.test_functions import * # noqa: F401

Expand Down

0 comments on commit f960e82

Please sign in to comment.