Skip to content

Commit

Permalink
[SPARK-50915][PYTHON][CONNECT] Add getCondition and deprecate `getE…
Browse files Browse the repository at this point in the history
…rrorClass` in `PySparkException`

### What changes were proposed in this pull request?

This PR proposes to add `getCondition` and deprecate `getErrorClass` in `PySparkException`.

### Why are the changes needed?

To follow new naming convention proposed by SPARK-46810 and also match the behavior with JVM side.

### Does this PR introduce _any_ user-facing change?

Using `getErrorClass` now issues the deprecate warning and encourages to use `getCondition`.

### How was this patch tested?

Updated the existing tests, so the existing CI should pass.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#49594 from itholic/get_condition.

Lead-authored-by: Haejoon Lee <haejoon.lee@databricks.com>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: Haejoon Lee <haejoon.lee@databricks.com>
  • Loading branch information
itholic and HyukjinKwon committed Jan 22, 2025
1 parent 8313320 commit 8611d0f
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 19 deletions.
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Methods
.. autosummary::
:toctree: api/

PySparkException.getCondition
PySparkException.getErrorClass
PySparkException.getMessage
PySparkException.getMessageParameters
Expand Down
37 changes: 28 additions & 9 deletions python/pyspark/errors/exceptions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING, List
Expand Down Expand Up @@ -53,20 +54,38 @@ def __init__(
self._messageParameters = messageParameters
self._contexts = contexts

def getCondition(self) -> Optional[str]:
"""
Returns an error condition.
.. versionadded:: 4.0.0
See Also
--------
:meth:`PySparkException.getMessage`
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getQueryContext`
:meth:`PySparkException.getSqlState`
"""
return self._errorClass

def getErrorClass(self) -> Optional[str]:
"""
Returns an error class as a string.
.. versionadded:: 3.4.0
.. deprecated:: 4.0.0
See Also
--------
:meth:`PySparkException.getMessage`
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getQueryContext`
:meth:`PySparkException.getSqlState`
"""
return self._errorClass
warnings.warn("Deprecated in 4.0.0, use getCondition instead.", FutureWarning)
return self.getCondition()

def getMessageParameters(self) -> Optional[Dict[str, str]]:
"""
Expand All @@ -76,7 +95,7 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]:
See Also
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getCondition`
:meth:`PySparkException.getMessage`
:meth:`PySparkException.getQueryContext`
:meth:`PySparkException.getSqlState`
Expand All @@ -93,7 +112,7 @@ def getSqlState(self) -> Optional[str]:
See Also
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getCondition`
:meth:`PySparkException.getMessage`
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getQueryContext`
Expand All @@ -108,12 +127,12 @@ def getMessage(self) -> str:
See Also
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getCondition`
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getQueryContext`
:meth:`PySparkException.getSqlState`
"""
return f"[{self.getErrorClass()}] {self._message}"
return f"[{self.getCondition()}] {self._message}"

def getQueryContext(self) -> List["QueryContext"]:
"""
Expand All @@ -123,7 +142,7 @@ def getQueryContext(self) -> List["QueryContext"]:
See Also
--------
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getCondition`
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getMessage`
:meth:`PySparkException.getSqlState`
Expand All @@ -143,17 +162,17 @@ def _log_exception(self) -> None:
file=call_site[0],
line=line,
fragment=context.fragment(),
errorClass=self.getErrorClass(),
errorClass=self.getCondition(),
)
else:
logger = PySparkLogger.getLogger("SQLQueryContextLogger")
logger.exception(
self.getMessage(),
errorClass=self.getErrorClass(),
errorClass=self.getCondition(),
)

def __str__(self) -> str:
if self.getErrorClass() is not None:
if self.getCondition() is not None:
return self.getMessage()
else:
return self._message
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, Optional, cast, List, TYPE_CHECKING

Expand Down Expand Up @@ -95,7 +96,7 @@ def __str__(self) -> str:
desc = desc + "\n\nJVM stacktrace:\n%s" % self._stackTrace
return str(desc)

def getErrorClass(self) -> Optional[str]:
def getCondition(self) -> Optional[str]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of

Expand All @@ -105,10 +106,14 @@ def getErrorClass(self) -> Optional[str]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
return self._origin.getErrorClass()
return self._origin.getCondition()
else:
return None

def getErrorClass(self) -> Optional[str]:
warnings.warn("Deprecated in 4.0.0, use getCondition instead.", FutureWarning)
return self.getCondition()

def getMessageParameters(self) -> Optional[Dict[str, str]]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
Expand Down Expand Up @@ -146,7 +151,7 @@ def getMessage(self) -> str:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
errorClass = self._origin.getErrorClass()
errorClass = self._origin.getCondition()
messageParameters = self._origin.getMessageParameters()

error_message = getattr(gw.jvm, "org.apache.spark.SparkThrowableHelper").getMessage(
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _create_py_udf(
== "true"
)
except PySparkRuntimeError as e:
if e.getErrorClass() == "NO_ACTIVE_OR_DEFAULT_SESSION":
if e.getCondition() == "NO_ACTIVE_OR_DEFAULT_SESSION":
pass # Just uses the default if no session found.
else:
raise e
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _create_py_udtf(
== "true"
)
except PySparkRuntimeError as e:
if e.getErrorClass() == "NO_ACTIVE_OR_DEFAULT_SESSION":
if e.getCondition() == "NO_ACTIVE_OR_DEFAULT_SESSION":
pass # Just uses the default if no session found.
else:
raise e
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,7 @@ def schema_from_udf(ddl):
with self.assertRaises(ParseException) as udf_pe:
schema_from_udf(test)
self.assertEqual(
from_ddl_pe.exception.getErrorClass(), udf_pe.exception.getErrorClass()
from_ddl_pe.exception.getCondition(), udf_pe.exception.getCondition()
)

def test_collated_string(self):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,7 +1763,7 @@ def test_get_error_class_state(self):
exception = e

self.assertIsNotNone(exception)
self.assertEqual(exception.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION")
self.assertEqual(exception.getCondition(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION")
self.assertEqual(exception.getSqlState(), "42703")
self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"})
self.assertIn(
Expand All @@ -1785,7 +1785,7 @@ def test_get_error_class_state(self):
try:
self.spark.sql("""SELECT assert_true(FALSE)""")
except AnalysisException as e:
self.assertIsNone(e.getErrorClass())
self.assertIsNone(e.getCondition())
self.assertIsNone(e.getSqlState())
self.assertEqual(e.getMessageParameters(), {})
self.assertEqual(e.getMessage(), "")
Expand All @@ -1797,7 +1797,7 @@ def test_assert_data_frame_equal_not_support_streaming(self):
try:
assertDataFrameEqual(df1, df2)
except PySparkAssertionError as e:
self.assertEqual(e.getErrorClass(), "UNSUPPORTED_OPERATION")
self.assertEqual(e.getCondition(), "UNSUPPORTED_OPERATION")
exception_thrown = True

self.assertTrue(exception_thrown)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def check_error(

# Test error class
expected = errorClass
actual = exception.getErrorClass()
actual = exception.getCondition()
self.assertEqual(
expected, actual, f"Expected error class was '{expected}', got '{actual}'."
)
Expand Down

0 comments on commit 8611d0f

Please sign in to comment.