diff --git a/python/pyspark/ml/connect/proto.py b/python/pyspark/ml/connect/proto.py index 3a81e74b6aec3..b0e012964fc4a 100644 --- a/python/pyspark/ml/connect/proto.py +++ b/python/pyspark/ml/connect/proto.py @@ -50,7 +50,9 @@ def plan(self, session: "SparkConnectClient") -> pb2.Relation: plan.ml_relation.transform.obj_ref.CopyFrom(pb2.ObjectRef(id=self._name)) else: plan.ml_relation.transform.transformer.CopyFrom( - pb2.MlOperator(name=self._name, uid=self._uid, type=pb2.MlOperator.TRANSFORMER) + pb2.MlOperator( + name=self._name, uid=self._uid, type=pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER + ) ) if self._ml_params is not None: diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py index 584ff3237a0a5..c2367282b7c40 100644 --- a/python/pyspark/ml/connect/readwrite.py +++ b/python/pyspark/ml/connect/readwrite.py @@ -118,13 +118,13 @@ def saveInstance( elif isinstance(instance, (JavaEstimator, JavaTransformer, JavaEvaluator)): operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator] if isinstance(instance, JavaEstimator): - ml_type = pb2.MlOperator.ESTIMATOR + ml_type = pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR operator = cast("JavaEstimator", instance) elif isinstance(instance, JavaEvaluator): - ml_type = pb2.MlOperator.EVALUATOR + ml_type = pb2.MlOperator.OPERATOR_TYPE_EVALUATOR operator = cast("JavaEvaluator", instance) else: - ml_type = pb2.MlOperator.TRANSFORMER + ml_type = pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER operator = cast("JavaTransformer", instance) params = serialize_ml_params(operator, session.client) @@ -249,13 +249,13 @@ def loadInstance( or issubclass(clazz, JavaTransformer) ): if issubclass(clazz, JavaModel): - ml_type = pb2.MlOperator.MODEL + ml_type = pb2.MlOperator.OPERATOR_TYPE_MODEL elif issubclass(clazz, JavaEstimator): - ml_type = pb2.MlOperator.ESTIMATOR + ml_type = pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR elif issubclass(clazz, JavaEvaluator): - ml_type = pb2.MlOperator.EVALUATOR + ml_type = pb2.MlOperator.OPERATOR_TYPE_EVALUATOR else: - ml_type = pb2.MlOperator.TRANSFORMER + ml_type = pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER # to get the java corresponding qualified class name java_qualified_class_name = ( @@ -281,7 +281,7 @@ def _get_class() -> Type[RL]: py_type = _get_class() # It must be JavaWrapper, since we're passing the string to the _java_obj if issubclass(py_type, JavaWrapper): - if ml_type == pb2.MlOperator.MODEL: + if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL: session.client.add_ml_cache(result.obj_ref.id) instance = py_type(result.obj_ref.id) else: @@ -358,7 +358,8 @@ def _get_class() -> Type[RL]: command.ml_command.read.CopyFrom( pb2.MlCommand.Read( operator=pb2.MlOperator( - name=java_qualified_class_name, type=pb2.MlOperator.TRANSFORMER + name=java_qualified_class_name, + type=pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER, ), path=path, ) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 7b8ba57a1f8ae..9eab45239b8f5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -136,7 +136,7 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any: input = dataset._plan.plan(client) assert isinstance(self._java_obj, str) estimator = pb2.MlOperator( - name=self._java_obj, uid=self.uid, type=pb2.MlOperator.ESTIMATOR + name=self._java_obj, uid=self.uid, type=pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR ) command = pb2.Command() command.ml_command.fit.CopyFrom( @@ -361,7 +361,7 @@ def wrapped(self: "JavaEvaluator", dataset: "ConnectDataFrame") -> Any: input = dataset._plan.plan(client) assert isinstance(self._java_obj, str) evaluator = pb2.MlOperator( - name=self._java_obj, uid=self.uid, type=pb2.MlOperator.EVALUATOR + name=self._java_obj, uid=self.uid, type=pb2.MlOperator.OPERATOR_TYPE_EVALUATOR ) command = pb2.Command() command.ml_command.evaluate.CopyFrom( diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.py b/python/pyspark/sql/connect/proto/ml_common_pb2.py index 43d6a512f48f8..b61e1bcb205ce 100644 --- a/python/pyspark/sql/connect/proto/ml_common_pb2.py +++ b/python/pyspark/sql/connect/proto/ml_common_pb2.py @@ -38,7 +38,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/ml_common.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa5\x01\n\x08MlParams\x12;\n\x06params\x18\x01 \x03(\x0b\x32#.spark.connect.MlParams.ParamsEntryR\x06params\x1a\\\n\x0bParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\xc9\x01\n\nMlOperator\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n\x03uid\x18\x02 \x01(\tR\x03uid\x12:\n\x04type\x18\x03 \x01(\x0e\x32&.spark.connect.MlOperator.OperatorTypeR\x04type"Y\n\x0cOperatorType\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\r\n\tESTIMATOR\x10\x01\x12\x0f\n\x0bTRANSFORMER\x10\x02\x12\r\n\tEVALUATOR\x10\x03\x12\t\n\x05MODEL\x10\x04"\x1b\n\tObjectRef\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02idB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/ml_common.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa5\x01\n\x08MlParams\x12;\n\x06params\x18\x01 \x03(\x0b\x32#.spark.connect.MlParams.ParamsEntryR\x06params\x1a\\\n\x0bParamsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x90\x02\n\nMlOperator\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n\x03uid\x18\x02 \x01(\tR\x03uid\x12:\n\x04type\x18\x03 \x01(\x0e\x32&.spark.connect.MlOperator.OperatorTypeR\x04type"\x9f\x01\n\x0cOperatorType\x12\x1d\n\x19OPERATOR_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17OPERATOR_TYPE_ESTIMATOR\x10\x01\x12\x1d\n\x19OPERATOR_TYPE_TRANSFORMER\x10\x02\x12\x1b\n\x17OPERATOR_TYPE_EVALUATOR\x10\x03\x12\x17\n\x13OPERATOR_TYPE_MODEL\x10\x04"\x1b\n\tObjectRef\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02idB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -58,9 +58,9 @@ _globals["_MLPARAMS_PARAMSENTRY"]._serialized_start = 155 _globals["_MLPARAMS_PARAMSENTRY"]._serialized_end = 247 _globals["_MLOPERATOR"]._serialized_start = 250 - _globals["_MLOPERATOR"]._serialized_end = 451 - _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_start = 362 - _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_end = 451 - _globals["_OBJECTREF"]._serialized_start = 453 - _globals["_OBJECTREF"]._serialized_end = 480 + _globals["_MLOPERATOR"]._serialized_end = 522 + _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_start = 363 + _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_end = 522 + _globals["_OBJECTREF"]._serialized_start = 524 + _globals["_OBJECTREF"]._serialized_end = 551 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.pyi b/python/pyspark/sql/connect/proto/ml_common_pb2.pyi index f4688e94c3d55..bc540028eb08b 100644 --- a/python/pyspark/sql/connect/proto/ml_common_pb2.pyi +++ b/python/pyspark/sql/connect/proto/ml_common_pb2.pyi @@ -112,28 +112,36 @@ class MlOperator(google.protobuf.message.Message): builtins.type, ): # noqa: F821 DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - UNSPECIFIED: MlOperator._OperatorType.ValueType # 0 - ESTIMATOR: MlOperator._OperatorType.ValueType # 1 - TRANSFORMER: MlOperator._OperatorType.ValueType # 2 - EVALUATOR: MlOperator._OperatorType.ValueType # 3 - MODEL: MlOperator._OperatorType.ValueType # 4 + OPERATOR_TYPE_UNSPECIFIED: MlOperator._OperatorType.ValueType # 0 + OPERATOR_TYPE_ESTIMATOR: MlOperator._OperatorType.ValueType # 1 + """ML estimator""" + OPERATOR_TYPE_TRANSFORMER: MlOperator._OperatorType.ValueType # 2 + """ML transformer (non-model)""" + OPERATOR_TYPE_EVALUATOR: MlOperator._OperatorType.ValueType # 3 + """ML evaluator""" + OPERATOR_TYPE_MODEL: MlOperator._OperatorType.ValueType # 4 + """ML model""" class OperatorType(_OperatorType, metaclass=_OperatorTypeEnumTypeWrapper): ... - UNSPECIFIED: MlOperator.OperatorType.ValueType # 0 - ESTIMATOR: MlOperator.OperatorType.ValueType # 1 - TRANSFORMER: MlOperator.OperatorType.ValueType # 2 - EVALUATOR: MlOperator.OperatorType.ValueType # 3 - MODEL: MlOperator.OperatorType.ValueType # 4 + OPERATOR_TYPE_UNSPECIFIED: MlOperator.OperatorType.ValueType # 0 + OPERATOR_TYPE_ESTIMATOR: MlOperator.OperatorType.ValueType # 1 + """ML estimator""" + OPERATOR_TYPE_TRANSFORMER: MlOperator.OperatorType.ValueType # 2 + """ML transformer (non-model)""" + OPERATOR_TYPE_EVALUATOR: MlOperator.OperatorType.ValueType # 3 + """ML evaluator""" + OPERATOR_TYPE_MODEL: MlOperator.OperatorType.ValueType # 4 + """ML model""" NAME_FIELD_NUMBER: builtins.int UID_FIELD_NUMBER: builtins.int TYPE_FIELD_NUMBER: builtins.int name: builtins.str - """The qualified name of the ML operator.""" + """(Required) The qualified name of the ML operator.""" uid: builtins.str - """Unique id of the ML operator""" + """(Required) Unique id of the ML operator""" type: global___MlOperator.OperatorType.ValueType - """Represents what the ML operator is""" + """(Required) Represents what the ML operator is""" def __init__( self, *, @@ -156,7 +164,9 @@ class ObjectRef(google.protobuf.message.Message): ID_FIELD_NUMBER: builtins.int id: builtins.str - """The ID is used to lookup the object on the server side.""" + """(Required) The ID is used to lookup the object on the server side. + Note it is different from the 'uid' of a ML object. + """ def __init__( self, *, diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py index 8e8bc34a7a97e..666cb1efdd2b4 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.py +++ b/python/pyspark/sql/connect/proto/ml_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x1a\xa2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12/\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61taset\x1a;\n\x06\x44\x65lete\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x1a\xf0\x02\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12/\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12)\n\x10should_overwrite\x18\x05 \x01(\x08R\x0fshouldOverwrite\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04type\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xa7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12/\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\xf6\x02\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\xa6\x01\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x10\n\x03uid\x18\x03 \x01(\tR\x03uid\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\x06\n\x04typeB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1a;\n\x06\x44\x65lete\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_paramsB\t\n\x07\x63ommand"\x93\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\xc3\x01\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -54,21 +54,21 @@ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001" _globals["_MLCOMMAND"]._serialized_start = 137 - _globals["_MLCOMMAND"]._serialized_end = 1338 + _globals["_MLCOMMAND"]._serialized_end = 1412 _globals["_MLCOMMAND_FIT"]._serialized_start = 480 - _globals["_MLCOMMAND_FIT"]._serialized_end = 642 - _globals["_MLCOMMAND_DELETE"]._serialized_start = 644 - _globals["_MLCOMMAND_DELETE"]._serialized_end = 703 - _globals["_MLCOMMAND_WRITE"]._serialized_start = 706 - _globals["_MLCOMMAND_WRITE"]._serialized_end = 1074 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1008 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1066 - _globals["_MLCOMMAND_READ"]._serialized_start = 1076 - _globals["_MLCOMMAND_READ"]._serialized_end = 1157 - _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1160 - _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1327 - _globals["_MLCOMMANDRESULT"]._serialized_start = 1341 - _globals["_MLCOMMANDRESULT"]._serialized_end = 1715 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1534 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1700 + _globals["_MLCOMMAND_FIT"]._serialized_end = 658 + _globals["_MLCOMMAND_DELETE"]._serialized_start = 660 + _globals["_MLCOMMAND_DELETE"]._serialized_end = 719 + _globals["_MLCOMMAND_WRITE"]._serialized_start = 722 + _globals["_MLCOMMAND_WRITE"]._serialized_end = 1132 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1034 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1092 + _globals["_MLCOMMAND_READ"]._serialized_start = 1134 + _globals["_MLCOMMAND_READ"]._serialized_end = 1215 + _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1218 + _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1401 + _globals["_MLCOMMANDRESULT"]._serialized_start = 1415 + _globals["_MLCOMMANDRESULT"]._serialized_end = 1818 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1608 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1803 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi index e8ae0be8dded8..3a1e9155d71dc 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.pyi +++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi @@ -42,6 +42,7 @@ import pyspark.sql.connect.proto.expressions_pb2 import pyspark.sql.connect.proto.ml_common_pb2 import pyspark.sql.connect.proto.relations_pb2 import sys +import typing if sys.version_info >= (3, 8): import typing as typing_extensions @@ -65,13 +66,13 @@ class MlCommand(google.protobuf.message.Message): DATASET_FIELD_NUMBER: builtins.int @property def estimator(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlOperator: - """Estimator information""" + """(Required) Estimator information (its type should be OPERATOR_TYPE_ESTIMATOR)""" @property def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: - """parameters of the Estimator""" + """(Optional) parameters of the Estimator""" @property def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: - """the training dataset""" + """(Required) the training dataset""" def __init__( self, *, @@ -82,15 +83,32 @@ class MlCommand(google.protobuf.message.Message): def HasField( self, field_name: typing_extensions.Literal[ - "dataset", b"dataset", "estimator", b"estimator", "params", b"params" + "_params", + b"_params", + "dataset", + b"dataset", + "estimator", + b"estimator", + "params", + b"params", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "dataset", b"dataset", "estimator", b"estimator", "params", b"params" + "_params", + b"_params", + "dataset", + b"dataset", + "estimator", + b"estimator", + "params", + b"params", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_params", b"_params"] + ) -> typing_extensions.Literal["params"] | None: ... class Delete(google.protobuf.message.Message): """Command to delete the cached object which could be a model @@ -150,16 +168,16 @@ class MlCommand(google.protobuf.message.Message): """The cached model""" @property def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: - """The parameters of operator which could be estimator/evaluator or a cached model""" + """(Optional) The parameters of operator which could be estimator/evaluator or a cached model""" path: builtins.str - """Save the ML instance to the path""" + """(Required) Save the ML instance to the path""" should_overwrite: builtins.bool - """Overwrites if the output path already exists.""" + """(Optional) Overwrites if the output path already exists.""" @property def options( self, ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: - """The options of the writer""" + """(Optional) The options of the writer""" def __init__( self, *, @@ -167,18 +185,35 @@ class MlCommand(google.protobuf.message.Message): obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ..., params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., path: builtins.str = ..., - should_overwrite: builtins.bool = ..., + should_overwrite: builtins.bool | None = ..., options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "obj_ref", b"obj_ref", "operator", b"operator", "params", b"params", "type", b"type" + "_params", + b"_params", + "_should_overwrite", + b"_should_overwrite", + "obj_ref", + b"obj_ref", + "operator", + b"operator", + "params", + b"params", + "should_overwrite", + b"should_overwrite", + "type", + b"type", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_params", + b"_params", + "_should_overwrite", + b"_should_overwrite", "obj_ref", b"obj_ref", "operator", @@ -195,6 +230,15 @@ class MlCommand(google.protobuf.message.Message): b"type", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_params", b"_params"] + ) -> typing_extensions.Literal["params"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_should_overwrite", b"_should_overwrite"] + ) -> typing_extensions.Literal["should_overwrite"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["type", b"type"] ) -> typing_extensions.Literal["operator", "obj_ref"] | None: ... @@ -208,9 +252,9 @@ class MlCommand(google.protobuf.message.Message): PATH_FIELD_NUMBER: builtins.int @property def operator(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlOperator: - """ML operator information""" + """(Required) ML operator information""" path: builtins.str - """Load the ML instance from the input path""" + """(Required) Load the ML instance from the input path""" def __init__( self, *, @@ -234,13 +278,13 @@ class MlCommand(google.protobuf.message.Message): DATASET_FIELD_NUMBER: builtins.int @property def evaluator(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlOperator: - """Evaluator information""" + """(Required) Evaluator information (its type should be OPERATOR_TYPE_EVALUATOR)""" @property def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: - """parameters of the Evaluator""" + """(Optional) parameters of the Evaluator""" @property def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: - """the evaluating dataset""" + """(Required) the evaluating dataset""" def __init__( self, *, @@ -251,15 +295,32 @@ class MlCommand(google.protobuf.message.Message): def HasField( self, field_name: typing_extensions.Literal[ - "dataset", b"dataset", "evaluator", b"evaluator", "params", b"params" + "_params", + b"_params", + "dataset", + b"dataset", + "evaluator", + b"evaluator", + "params", + b"params", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "dataset", b"dataset", "evaluator", b"evaluator", "params", b"params" + "_params", + b"_params", + "dataset", + b"dataset", + "evaluator", + b"evaluator", + "params", + b"params", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_params", b"_params"] + ) -> typing_extensions.Literal["params"] | None: ... FIT_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int @@ -355,25 +416,46 @@ class MlCommandResult(google.protobuf.message.Message): name: builtins.str """Operator name""" uid: builtins.str + """(Optional) the 'uid' of a ML object + Note it is different from the 'id' of a cached object. + """ @property - def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: ... + def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: + """(Optional) parameters""" def __init__( self, *, obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ..., name: builtins.str = ..., - uid: builtins.str = ..., + uid: builtins.str | None = ..., params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "name", b"name", "obj_ref", b"obj_ref", "params", b"params", "type", b"type" + "_params", + b"_params", + "_uid", + b"_uid", + "name", + b"name", + "obj_ref", + b"obj_ref", + "params", + b"params", + "type", + b"type", + "uid", + b"uid", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_params", + b"_params", + "_uid", + b"_uid", "name", b"name", "obj_ref", @@ -386,6 +468,15 @@ class MlCommandResult(google.protobuf.message.Message): b"uid", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_params", b"_params"] + ) -> typing_extensions.Literal["params"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_uid", b"_uid"] + ) -> typing_extensions.Literal["uid"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["type", b"type"] ) -> typing_extensions.Literal["obj_ref", "name"] | None: ... diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto index 20a5cafebb367..6e469bb9027e1 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto @@ -40,11 +40,11 @@ message MlCommand { // Command for estimator.fit(dataset) message Fit { - // Estimator information + // (Required) Estimator information (its type should be OPERATOR_TYPE_ESTIMATOR) MlOperator estimator = 1; - // parameters of the Estimator - MlParams params = 2; - // the training dataset + // (Optional) parameters of the Estimator + optional MlParams params = 2; + // (Required) the training dataset Relation dataset = 3; } @@ -63,31 +63,31 @@ message MlCommand { // The cached model ObjectRef obj_ref = 2; } - // The parameters of operator which could be estimator/evaluator or a cached model - MlParams params = 3; - // Save the ML instance to the path + // (Optional) The parameters of operator which could be estimator/evaluator or a cached model + optional MlParams params = 3; + // (Required) Save the ML instance to the path string path = 4; - // Overwrites if the output path already exists. - bool should_overwrite = 5; - // The options of the writer + // (Optional) Overwrites if the output path already exists. + optional bool should_overwrite = 5; + // (Optional) The options of the writer map options = 6; } // Command to load ML operator. message Read { - // ML operator information + // (Required) ML operator information MlOperator operator = 1; - // Load the ML instance from the input path + // (Required) Load the ML instance from the input path string path = 2; } // Command for evaluator.evaluate(dataset) message Evaluate { - // Evaluator information + // (Required) Evaluator information (its type should be OPERATOR_TYPE_EVALUATOR) MlOperator evaluator = 1; - // parameters of the Evaluator - MlParams params = 2; - // the evaluating dataset + // (Optional) parameters of the Evaluator + optional MlParams params = 2; + // (Required) the evaluating dataset Relation dataset = 3; } } @@ -111,8 +111,10 @@ message MlCommandResult { // Operator name string name = 2; } - string uid = 3; - MlParams params = 4; + // (Optional) the 'uid' of a ML object + // Note it is different from the 'id' of a cached object. + optional string uid = 3; + // (Optional) parameters + optional MlParams params = 4; } - } diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto index 48b5fa8135cc9..06ca4e5db697c 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto @@ -33,24 +33,32 @@ message MlParams { // MLOperator represents the ML operators like (Estimator, Transformer or Evaluator) message MlOperator { - // The qualified name of the ML operator. + // (Required) The qualified name of the ML operator. string name = 1; - // Unique id of the ML operator + + // (Required) Unique id of the ML operator string uid = 2; - // Represents what the ML operator is + + // (Required) Represents what the ML operator is OperatorType type = 3; + enum OperatorType { - UNSPECIFIED = 0; - ESTIMATOR = 1; - TRANSFORMER = 2; - EVALUATOR = 3; - MODEL = 4; + OPERATOR_TYPE_UNSPECIFIED = 0; + // ML estimator + OPERATOR_TYPE_ESTIMATOR = 1; + // ML transformer (non-model) + OPERATOR_TYPE_TRANSFORMER = 2; + // ML evaluator + OPERATOR_TYPE_EVALUATOR = 3; + // ML model + OPERATOR_TYPE_MODEL = 4; } } // Represents a reference to the cached object which could be a model // or summary evaluated by a model message ObjectRef { - // The ID is used to lookup the object on the server side. + // (Required) The ID is used to lookup the object on the server side. + // Note it is different from the 'uid' of a ML object. string id = 1; } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index d4ef1eee5c249..08080c099200a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -114,7 +114,7 @@ private[connect] object MLHandler extends Logging { case proto.MlCommand.CommandCase.FIT => val fitCmd = mlCommand.getFit val estimatorProto = fitCmd.getEstimator - assert(estimatorProto.getType == proto.MlOperator.OperatorType.ESTIMATOR) + assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder) val estimator = @@ -197,21 +197,21 @@ private[connect] object MLHandler extends Logging { val params = Some(writer.getParams) operatorType match { - case proto.MlOperator.OperatorType.ESTIMATOR => + case proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR => val estimator = MLUtils.getEstimator(sessionHolder, writer.getOperator, params) estimator match { case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite) case other => throw MlUnsupportedException(s"Estimator $other is not writable") } - case proto.MlOperator.OperatorType.EVALUATOR => + case proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR => val evaluator = MLUtils.getEvaluator(sessionHolder, writer.getOperator, params) evaluator match { case writable: MLWritable => MLUtils.write(writable, mlCommand.getWrite) case other => throw MlUnsupportedException(s"Evaluator $other is not writable") } - case proto.MlOperator.OperatorType.TRANSFORMER => + case proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER => val transformer = MLUtils.getTransformer(sessionHolder, writer.getOperator, params) transformer match { @@ -232,7 +232,7 @@ private[connect] object MLHandler extends Logging { val name = operator.getName val path = mlCommand.getRead.getPath - if (operator.getType == proto.MlOperator.OperatorType.MODEL) { + if (operator.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_MODEL) { val model = MLUtils.loadTransformer(sessionHolder, name, path) val id = mlCache.register(model) return proto.MlCommandResult @@ -244,18 +244,21 @@ private[connect] object MLHandler extends Logging { .setUid(model.uid) .setParams(Serializer.serializeParams(model))) .build() - } - val mlOperator = if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR) { - MLUtils.loadEstimator(sessionHolder, name, path).asInstanceOf[Params] - } else if (operator.getType == proto.MlOperator.OperatorType.EVALUATOR) { - MLUtils.loadEvaluator(sessionHolder, name, path).asInstanceOf[Params] - } else if (operator.getType == proto.MlOperator.OperatorType.TRANSFORMER) { - MLUtils.loadTransformer(sessionHolder, name, path).asInstanceOf[Params] - } else { - throw MlUnsupportedException(s"${operator.getType} read not supported") - } + val mlOperator = + if (operator.getType == + proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) { + MLUtils.loadEstimator(sessionHolder, name, path).asInstanceOf[Params] + } else if (operator.getType == + proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR) { + MLUtils.loadEvaluator(sessionHolder, name, path).asInstanceOf[Params] + } else if (operator.getType == + proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER) { + MLUtils.loadTransformer(sessionHolder, name, path).asInstanceOf[Params] + } else { + throw MlUnsupportedException(s"${operator.getType} read not supported") + } proto.MlCommandResult .newBuilder() @@ -270,7 +273,7 @@ private[connect] object MLHandler extends Logging { case proto.MlCommand.CommandCase.EVALUATE => val evalCmd = mlCommand.getEvaluate val evalProto = evalCmd.getEvaluator - assert(evalProto.getType == proto.MlOperator.OperatorType.EVALUATOR) + assert(evalProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR) val dataset = MLUtils.parseRelationProto(evalCmd.getDataset, sessionHolder) val evaluator = @@ -295,7 +298,7 @@ private[connect] object MLHandler extends Logging { val transformProto = relation.getTransform assert( transformProto.getTransformer.getType == - proto.MlOperator.OperatorType.TRANSFORMER) + proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER) val dataset = MLUtils.parseRelationProto(transformProto.getInput, sessionHolder) val transformer = MLUtils.getTransformer(sessionHolder, transformProto) transformer.transform(dataset) @@ -323,5 +326,4 @@ private[connect] object MLHandler extends Logging { case other => throw MlUnsupportedException(s"$other not supported") } } - } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index c999772b7d826..3647fa3d9dae9 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -693,7 +693,7 @@ private[ml] object MLUtils { } def write(instance: MLWritable, writeProto: proto.MlCommand.Write): Unit = { - val writer = if (writeProto.getShouldOverwrite) { + val writer = if (writeProto.hasShouldOverwrite && writeProto.getShouldOverwrite) { instance.write.overwrite() } else { instance.write diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala index 5b2b5e6dd793f..f7788fb3cd1a7 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala @@ -42,7 +42,7 @@ class MLBackendSuite extends MLHelper { .newBuilder() .setName(name) .setUid(name) - .setType(proto.MlOperator.OperatorType.ESTIMATOR) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) } private def getMaxIterBuilder: proto.MlParams.Builder = { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala index 5a447189d8702..5939b673501b5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala @@ -98,7 +98,7 @@ trait MLHelper extends SparkFunSuite with SparkConnectPlanTest { .newBuilder() .setName("org.apache.spark.ml.classification.LogisticRegression") .setUid("LogisticRegression") - .setType(proto.MlOperator.OperatorType.ESTIMATOR) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) def getMaxIter: proto.MlParams.Builder = proto.MlParams @@ -110,7 +110,7 @@ trait MLHelper extends SparkFunSuite with SparkConnectPlanTest { .newBuilder() .setName("org.apache.spark.ml.evaluation.RegressionEvaluator") .setUid("RegressionEvaluator") - .setType(proto.MlOperator.OperatorType.EVALUATOR) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR) def getMetricName: proto.MlParams.Builder = proto.MlParams @@ -149,7 +149,7 @@ trait MLHelper extends SparkFunSuite with SparkConnectPlanTest { .newBuilder() .setUid("vec") .setName("org.apache.spark.ml.feature.VectorAssembler") - .setType(proto.MlOperator.OperatorType.TRANSFORMER) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER) def getVectorAssemblerParams: proto.MlParams.Builder = proto.MlParams @@ -220,7 +220,7 @@ trait MLHelper extends SparkFunSuite with SparkConnectPlanTest { proto.MlOperator .newBuilder() .setName(clsName) - .setType(proto.MlOperator.OperatorType.MODEL)) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_MODEL)) .setPath(path)) .build() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index cc24a2a67439f..0d0fbc4b1b7b3 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -250,7 +250,7 @@ class MLSuite extends MLHelper { .newBuilder() .setName("org.apache.spark.ml.NotExistingML") .setUid("FakedUid") - .setType(proto.MlOperator.OperatorType.ESTIMATOR))) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR))) .build() MLHandler.handleMlCommand(sessionHolder, command) } @@ -280,7 +280,7 @@ class MLSuite extends MLHelper { .setOperator(proto.MlOperator .newBuilder() .setName("org.apache.spark.sql.connect.ml.NotImplementingMLReadble") - .setType(proto.MlOperator.OperatorType.ESTIMATOR)) + .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)) .setPath("/tmp/fake")) .build() MLHandler.handleMlCommand(sessionHolder, readCmd)