Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 30, 2024
1 parent be0fc6c commit 239a4cc
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 48 deletions.
6 changes: 5 additions & 1 deletion crates/polars-lazy/src/frame/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ impl LazyFrame {
// Should be a python function that returns a generator
scan_fn: Some(scan_fn.into()),
schema: Arc::new(schema),
is_pyarrow: pyarrow,
python_source: if pyarrow {
PythonScanSource::Pyarrow
} else {
PythonScanSource::IOPlugin
},
..Default::default()
},
}
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-mem-engine/src/executors/scan/python_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ impl Executor for PythonScanExec {
},
};

let generator_init = if self.options.is_pyarrow {
let generator_init = if matches!(
self.options.python_source,
PythonScanSource::Pyarrow | PythonScanSource::Cuda
) {
let args = (python_scan_function, with_columns, predicate, n_rows);
callable.call1(args).map_err(to_compute_err)
} else {
Expand All @@ -86,6 +89,7 @@ impl Executor for PythonScanExec {
}?;

// This isn't a generator, but a `DataFrame`.
// This is the pyarrow and the CuDF path.
if generator_init.getattr(intern!(py, "_df")).is_ok() {
let df = python_df_to_rust(py, generator_init)?;
return if let Some(pred) = &self.predicate {
Expand All @@ -96,6 +100,7 @@ impl Executor for PythonScanExec {
};
}

// This is the IO plugin path.
let generator = generator_init
.get_item(0)
.map_err(|_| polars_err!(ComputeError: "expected tuple got {}", generator_init))?;
Expand Down
32 changes: 19 additions & 13 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,34 +160,40 @@ fn create_physical_plan_impl(
#[cfg(feature = "python")]
PythonScan { mut options } => {
let mut predicate_serialized = None;

let predicate = if let PythonPredicate::Polars(e) = &options.predicate {
let phys_expr = || {
let mut state = ExpressionConversionState::new(true, state.expr_depth);
create_physical_expr(
e,
Context::Default,
expr_arena,
Some(&options.schema),
&mut state,
)
};

// Convert to a pyarrow eval string.
if options.is_pyarrow {
if matches!(options.python_source, PythonScanSource::Pyarrow) {
if let Some(eval_str) = polars_plan::plans::python::pyarrow::predicate_to_pa(
e.node(),
expr_arena,
Default::default(),
) {
options.predicate = PythonPredicate::PyArrow(eval_str)
options.predicate = PythonPredicate::PyArrow(eval_str);
// We don't have to use a physical expression as pyarrow deals with the filter.
None
} else {
Some(phys_expr()?)
}

// We don't have to use a physical expression as pyarrow deals with the filter.
None
}
// Convert to physical expression for the case the reader cannot consume the predicate.
else {
let dsl_expr = e.to_expr(expr_arena);
predicate_serialized =
polars_plan::plans::python::predicate::serialize(&dsl_expr)?;

let mut state = ExpressionConversionState::new(true, state.expr_depth);
Some(create_physical_expr(
e,
Context::Default,
expr_arena,
Some(&options.schema),
&mut state,
)?)
Some(phys_expr()?)
}
} else {
None
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/plans/optimizer/fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ impl OptimizationRule for FusedArithmetic {
// We don't want to fuse arithmetic that we send to pyarrow.
#[cfg(feature = "python")]
if let IR::PythonScan { options } = lp_arena.get(lp_node) {
if options.is_pyarrow {
if matches!(
options.python_source,
PythonScanSource::Pyarrow | PythonScanSource::IOPlugin
) {
return Ok(None);
}
};
Expand Down
13 changes: 11 additions & 2 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,24 @@ pub struct PythonOptions {
pub output_schema: Option<SchemaRef>,
// Projected column names.
pub with_columns: Option<Arc<[String]>>,
// Whether this is a pyarrow dataset source or a Polars source.
pub is_pyarrow: bool,
// Which interface is the python function.
pub python_source: PythonScanSource,
/// Optional predicate the reader must apply.
#[cfg_attr(feature = "serde", serde(skip))]
pub predicate: PythonPredicate,
/// A `head` call passed to the reader.
pub n_rows: Option<usize>,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum PythonScanSource {
Pyarrow,
Cuda,
#[default]
IOPlugin,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
pub enum PythonPredicate {
// A pyarrow predicate python expression
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Mutex;

use polars_plan::plans::{to_aexpr, Context, IR};
use polars_plan::prelude::expr_ir::ExprIR;
use polars_plan::prelude::{AExpr, PythonOptions};
use polars_plan::prelude::{AExpr, PythonOptions, PythonScanSource};
use polars_utils::arena::{Arena, Node};
use pyo3::prelude::*;
use visitor::{expr_nodes, nodes};
Expand Down Expand Up @@ -164,7 +164,7 @@ impl NodeTraverser {
schema,
output_schema: None,
with_columns: None,
is_pyarrow: false,
python_source: PythonScanSource::Cuda,
predicate: Default::default(),
n_rows: None,
},
Expand Down
64 changes: 37 additions & 27 deletions py-polars/src/lazyframe/visitor/nodes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use polars_core::prelude::{IdxSize, UniqueKeepStrategy};
use polars_ops::prelude::JoinType;
use polars_plan::plans::IR;
use polars_plan::prelude::{FileCount, FileScan, FileScanOptions, FunctionNode, PythonPredicate};
use polars_plan::prelude::{
FileCount, FileScan, FileScanOptions, FunctionNode, PythonPredicate, PythonScanSource,
};
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
use pyo3::prelude::*;

Expand Down Expand Up @@ -255,33 +257,41 @@ pub struct Sink {

pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult<PyObject> {
let result = match plan {
IR::PythonScan { options } => PythonScan {
options: (
options
.scan_fn
.as_ref()
.map_or_else(|| py.None(), |s| s.0.clone()),
options
.with_columns
.as_ref()
.map_or_else(|| py.None(), |cols| cols.to_object(py)),
options.is_pyarrow,
match &options.predicate {
PythonPredicate::None => py.None(),
PythonPredicate::PyArrow(s) => s.to_object(py),
PythonPredicate::Polars(_) => {
return Err(PyNotImplementedError::new_err(
"polars native predicates not yet supported",
))
IR::PythonScan { options } => {
let python_src = match options.python_source {
PythonScanSource::Pyarrow => "pyarrow",
PythonScanSource::Cuda => "cuda",
PythonScanSource::IOPlugin => "io_plugin",
};

PythonScan {
options: (
options
.scan_fn
.as_ref()
.map_or_else(|| py.None(), |s| s.0.clone()),
options
.with_columns
.as_ref()
.map_or_else(|| py.None(), |cols| cols.to_object(py)),
python_src,
match &options.predicate {
PythonPredicate::None => py.None(),
PythonPredicate::PyArrow(s) => s.to_object(py),
PythonPredicate::Polars(_) => {
return Err(PyNotImplementedError::new_err(
"polars native predicates not yet supported",
))
},
},
},
options
.n_rows
.map_or_else(|| py.None(), |s| s.to_object(py)),
)
.to_object(py),
}
.into_py(py),
options
.n_rows
.map_or_else(|| py.None(), |s| s.to_object(py)),
)
.to_object(py),
}
.into_py(py)
},
IR::Slice { input, offset, len } => Slice {
input: input.0,
offset: *offset,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/test_pyarrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def helper_dataset_test(


@pytest.mark.write_disk()
def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None:
def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None:
file_path = tmp_path / "small.ipc"
df.write_ipc(file_path)

Expand Down

0 comments on commit 239a4cc

Please sign in to comment.