Skip to content

Commit

Permalink
Skip long running model train test.
Browse files Browse the repository at this point in the history
  • Loading branch information
xzdandy committed Oct 24, 2023
1 parent 81ab7a3 commit c6970bd
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions test/integration_tests/long/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,21 @@ def tearDownClass(cls):

# clean up
execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS HomeRentals;")
execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS Employee;")
execute_query_fetch_all(
cls.evadb, "DROP FUNCTION IF EXISTS PredictHouseRentLudwig;"
)
execute_query_fetch_all(
cls.evadb, "DROP FUNCTION IF EXISTS PredictHouseRentSklearn;"
)
execute_query_fetch_all(
cls.evadb, "DROP FUNCTION IF EXISTS PredictRentXgboost;"
)
execute_query_fetch_all(
cls.evadb, "DROP FUNCTION IF EXISTS PredictEmployeeXgboost;"
)

@pytest.marker.skip(reason="Model training intergration test takes too long to complete.")
@ludwig_skip_marker
def test_ludwig_automl(self):
create_predict_function = """
Expand All @@ -97,6 +105,8 @@ def test_ludwig_automl(self):
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)


@pytest.marker.skip(reason="Model training intergration test takes too long to complete.")
@sklearn_skip_marker
def test_sklearn_regression(self):
create_predict_function = """
Expand All @@ -114,10 +124,11 @@ def test_sklearn_regression(self):
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)


@xgboost_skip_marker
def test_xgboost_regression(self):
create_predict_function = """
CREATE FUNCTION IF NOT EXISTS PredictRent FROM
CREATE OR REPLACE FUNCTION PredictRentXgboost FROM
( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals )
TYPE XGBoost
PREDICT 'rental_price'
Expand All @@ -128,7 +139,7 @@ def test_xgboost_regression(self):
execute_query_fetch_all(self.evadb, create_predict_function)

predict_query = """
SELECT PredictRent(number_of_rooms, number_of_bathrooms, days_on_market, rental_price) FROM HomeRentals LIMIT 10;
SELECT PredictRentXgboost(number_of_rooms, number_of_bathrooms, days_on_market, rental_price) FROM HomeRentals LIMIT 10;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result.columns), 1)
Expand All @@ -137,7 +148,7 @@ def test_xgboost_regression(self):
@xgboost_skip_marker
def test_xgboost_classification(self):
create_predict_function = """
CREATE FUNCTION IF NOT EXISTS PredictEmployee FROM
CREATE OR REPLACE FUNCTION PredictEmployeeXgboost FROM
( SELECT payment_tier, age, gender, experience_in_current_domain, leave_or_not FROM Employee )
TYPE XGBoost
PREDICT 'leave_or_not'
Expand All @@ -148,7 +159,7 @@ def test_xgboost_classification(self):
execute_query_fetch_all(self.evadb, create_predict_function)

predict_query = """
SELECT PredictEmployee(payment_tier, age, gender, experience_in_current_domain, leave_or_not) FROM Employee LIMIT 10;
SELECT PredictEmployeeXgboost(payment_tier, age, gender, experience_in_current_domain, leave_or_not) FROM Employee LIMIT 10;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result.columns), 1)
Expand Down

0 comments on commit c6970bd

Please sign in to comment.