diff --git a/velox/py/plan_builder/PyPlanBuilder.cpp b/velox/py/plan_builder/PyPlanBuilder.cpp index fdaf8f79cb84..39535e4c9471 100644 --- a/velox/py/plan_builder/PyPlanBuilder.cpp +++ b/velox/py/plan_builder/PyPlanBuilder.cpp @@ -97,19 +97,24 @@ std::optional PyPlanBuilder::planNode() const { } PyPlanBuilder& PyPlanBuilder::tableWrite( - const PyType& outputSchema, const PyFile& outputFile, - const std::string& connectorId) { + const std::string& connectorId, + const std::optional& outputSchema) { exec::test::PlanBuilder::TableWriterBuilder builder(planBuilder_); // Try to convert the output type. - auto outputRowSchema = asRowType(outputSchema.type()); - if (outputRowSchema == nullptr) { - throw std::runtime_error("Output schema must be a ROW()."); + RowTypePtr outputRowSchema; + + if (outputSchema != std::nullopt) { + outputRowSchema = asRowType(outputSchema->type()); + + if (outputRowSchema == nullptr) { + throw std::runtime_error("Output schema must be a ROW()."); + } + builder.outputType(outputRowSchema); } - builder.outputType(outputRowSchema) - .outputFileName(outputFile.filePath()) + builder.outputFileName(outputFile.filePath()) .fileFormat(outputFile.fileFormat()) .connectorId(connectorId) .endTableWriter(); diff --git a/velox/py/plan_builder/PyPlanBuilder.h b/velox/py/plan_builder/PyPlanBuilder.h index 31955ad7869a..e02c7a943b25 100644 --- a/velox/py/plan_builder/PyPlanBuilder.h +++ b/velox/py/plan_builder/PyPlanBuilder.h @@ -150,15 +150,16 @@ class PyPlanBuilder { /// Adds a table writer node to write to an output file(s). /// - /// @param outputSchema The schema to be used when writing the file (columns - /// and types). /// @param outputFile The output file to be written. /// @param connectorId The id of the connector to use during the write /// process. + /// @param outputSchema An optional schema to be used when writing the file + /// (columns and types). By default use the schema produced by the upstream + /// operator. PyPlanBuilder& tableWrite( - const PyType& outputSchema, const PyFile& outputFile, - const std::string& connectorId); + const std::string& connectorId, + const std::optional& outputSchema); // Add the provided vectors straight into the operator tree. PyPlanBuilder& values(const std::vector& values); diff --git a/velox/py/plan_builder/plan_builder.cpp b/velox/py/plan_builder/plan_builder.cpp index 1296cce79604..6d557253362d 100644 --- a/velox/py/plan_builder/plan_builder.cpp +++ b/velox/py/plan_builder/plan_builder.cpp @@ -108,17 +108,18 @@ PYBIND11_MODULE(plan_builder, m) { .def( "table_write", &velox::py::PyPlanBuilder::tableWrite, - py::arg("output_schema"), py::arg("output_file"), py::arg("connector_id") = "hive", + py::arg("output_schema") = std::nullopt, py::doc(R"( Adds a table write node to the plan. Args: - output_schema: A RowType containing the schema to be written to - the file. output_file: Name of the file to be written. connector_id: ID of the connector to use for this scan. + output_schema: An optional RowType containing the schema to be + written to the file. By default write the schema + produced by the operator upstream. )")) .def( "values", diff --git a/velox/py/tests/test_runner.py b/velox/py/tests/test_runner.py index 0b83e43412f5..30b7e3dcc625 100644 --- a/velox/py/tests/test_runner.py +++ b/velox/py/tests/test_runner.py @@ -100,7 +100,6 @@ def test_write_read_file(self): plan_builder = PlanBuilder() plan_builder.values([input_batch]).table_write( - output_schema=ROW(["c0"], [BIGINT()]), output_file=DWRF(output_file), connector_id="hive", )