Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Feb 11, 2025
1 parent 85bd303 commit 7feab04
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 29 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/proto/ml_common_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class ObjectRef(google.protobuf.message.Message):
ID_FIELD_NUMBER: builtins.int
id: builtins.str
"""(Required) The ID is used to lookup the object on the server side.
Note that this 'id' is not the same as the 'uid' of a ML object.
Note it is different from the 'uid' of a ML object.
"""
def __init__(
self,
Expand Down
34 changes: 17 additions & 17 deletions python/pyspark/sql/connect/proto/ml_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xcb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x1a\xa2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12/\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61taset\x1a;\n\x06\x44\x65lete\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x1a\x8a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12/\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x01R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xa7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12/\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06params\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\x83\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\xb3\x01\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\x06\n\x04typeB\x06\n\x04_uidB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1a;\n\x06\x44\x65lete\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_paramsB\t\n\x07\x63ommand"\x93\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\xc3\x01\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)

_globals = globals()
Expand All @@ -54,21 +54,21 @@
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_MLCOMMAND"]._serialized_start = 137
_globals["_MLCOMMAND"]._serialized_end = 1364
_globals["_MLCOMMAND"]._serialized_end = 1412
_globals["_MLCOMMAND_FIT"]._serialized_start = 480
_globals["_MLCOMMAND_FIT"]._serialized_end = 642
_globals["_MLCOMMAND_DELETE"]._serialized_start = 644
_globals["_MLCOMMAND_DELETE"]._serialized_end = 703
_globals["_MLCOMMAND_WRITE"]._serialized_start = 706
_globals["_MLCOMMAND_WRITE"]._serialized_end = 1100
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1013
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1071
_globals["_MLCOMMAND_READ"]._serialized_start = 1102
_globals["_MLCOMMAND_READ"]._serialized_end = 1183
_globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1186
_globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1353
_globals["_MLCOMMANDRESULT"]._serialized_start = 1367
_globals["_MLCOMMANDRESULT"]._serialized_end = 1754
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1560
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1739
_globals["_MLCOMMAND_FIT"]._serialized_end = 658
_globals["_MLCOMMAND_DELETE"]._serialized_start = 660
_globals["_MLCOMMAND_DELETE"]._serialized_end = 719
_globals["_MLCOMMAND_WRITE"]._serialized_start = 722
_globals["_MLCOMMAND_WRITE"]._serialized_end = 1132
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1034
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1092
_globals["_MLCOMMAND_READ"]._serialized_start = 1134
_globals["_MLCOMMAND_READ"]._serialized_end = 1215
_globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1218
_globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1401
_globals["_MLCOMMANDRESULT"]._serialized_start = 1415
_globals["_MLCOMMANDRESULT"]._serialized_end = 1818
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1608
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1803
# @@protoc_insertion_point(module_scope)
60 changes: 55 additions & 5 deletions python/pyspark/sql/connect/proto/ml_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,32 @@ class MlCommand(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
"dataset", b"dataset", "estimator", b"estimator", "params", b"params"
"_params",
b"_params",
"dataset",
b"dataset",
"estimator",
b"estimator",
"params",
b"params",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"dataset", b"dataset", "estimator", b"estimator", "params", b"params"
"_params",
b"_params",
"dataset",
b"dataset",
"estimator",
b"estimator",
"params",
b"params",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_params", b"_params"]
) -> typing_extensions.Literal["params"] | None: ...

class Delete(google.protobuf.message.Message):
"""Command to delete the cached object which could be a model
Expand Down Expand Up @@ -174,6 +191,8 @@ class MlCommand(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
"_params",
b"_params",
"_should_overwrite",
b"_should_overwrite",
"obj_ref",
Expand All @@ -191,6 +210,8 @@ class MlCommand(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"_params",
b"_params",
"_should_overwrite",
b"_should_overwrite",
"obj_ref",
Expand All @@ -210,6 +231,10 @@ class MlCommand(google.protobuf.message.Message):
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_params", b"_params"]
) -> typing_extensions.Literal["params"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_should_overwrite", b"_should_overwrite"]
) -> typing_extensions.Literal["should_overwrite"] | None: ...
Expand Down Expand Up @@ -270,15 +295,32 @@ class MlCommand(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
"dataset", b"dataset", "evaluator", b"evaluator", "params", b"params"
"_params",
b"_params",
"dataset",
b"dataset",
"evaluator",
b"evaluator",
"params",
b"params",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"dataset", b"dataset", "evaluator", b"evaluator", "params", b"params"
"_params",
b"_params",
"dataset",
b"dataset",
"evaluator",
b"evaluator",
"params",
b"params",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_params", b"_params"]
) -> typing_extensions.Literal["params"] | None: ...

FIT_FIELD_NUMBER: builtins.int
FETCH_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -375,7 +417,7 @@ class MlCommandResult(google.protobuf.message.Message):
"""Operator name"""
uid: builtins.str
"""(Optional) the 'uid' of a ML object
Note it is not the same as the 'id' of a cached object.
Note it is different from the 'id' of a cached object.
"""
@property
def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
Expand All @@ -391,6 +433,8 @@ class MlCommandResult(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
"_params",
b"_params",
"_uid",
b"_uid",
"name",
Expand All @@ -408,6 +452,8 @@ class MlCommandResult(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"_params",
b"_params",
"_uid",
b"_uid",
"name",
Expand All @@ -423,6 +469,10 @@ class MlCommandResult(google.protobuf.message.Message):
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_params", b"_params"]
) -> typing_extensions.Literal["params"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_uid", b"_uid"]
) -> typing_extensions.Literal["uid"] | None: ...
Expand Down
10 changes: 5 additions & 5 deletions sql/connect/common/src/main/protobuf/spark/connect/ml.proto
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ message MlCommand {
// (Required) Estimator information (its type should be OPERATOR_TYPE_ESTIMATOR)
MlOperator estimator = 1;
// (Optional) parameters of the Estimator
MlParams params = 2;
optional MlParams params = 2;
// (Required) the training dataset
Relation dataset = 3;
}
Expand All @@ -64,7 +64,7 @@ message MlCommand {
ObjectRef obj_ref = 2;
}
// (Optional) The parameters of operator which could be estimator/evaluator or a cached model
MlParams params = 3;
optional MlParams params = 3;
// (Required) Save the ML instance to the path
string path = 4;
// (Optional) Overwrites if the output path already exists.
Expand All @@ -86,7 +86,7 @@ message MlCommand {
// (Required) Evaluator information (its type should be OPERATOR_TYPE_EVALUATOR)
MlOperator evaluator = 1;
// (Optional) parameters of the Evaluator
MlParams params = 2;
optional MlParams params = 2;
// (Required) the evaluating dataset
Relation dataset = 3;
}
Expand All @@ -112,9 +112,9 @@ message MlCommandResult {
string name = 2;
}
// (Optional) the 'uid' of a ML object
// Note it is not the same as the 'id' of a cached object.
// Note it is different from the 'id' of a cached object.
optional string uid = 3;
// (Optional) parameters
MlParams params = 4;
optional MlParams params = 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ message MlOperator {
// or summary evaluated by a model
message ObjectRef {
// (Required) The ID is used to lookup the object on the server side.
// Note that this 'id' is not the same as the 'uid' of a ML object.
// Note it is different from the 'uid' of a ML object.
string id = 1;
}

0 comments on commit 7feab04

Please sign in to comment.