diff --git a/Cargo.toml b/Cargo.toml index ed977fc..5733845 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ rust-version = "1.57" [dependencies] tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" -pyo3 = { version = "0.15", features = ["extension-module", "abi3", "abi3-py36"] } +pyo3 = { version = "~0.15", features = ["extension-module", "abi3", "abi3-py36"] } datafusion = { version = "^7.0.0", features = ["pyarrow"] } datafusion-expr = { version = "^7.0.0" } datafusion-common = { version = "^7.0.0", features = ["pyarrow"] } diff --git a/datafusion/tests/test_indexing.py b/datafusion/tests/test_indexing.py new file mode 100644 index 0000000..6250e4b --- /dev/null +++ b/datafusion/tests/test_indexing.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow as pa +import pytest + +from datafusion import ExecutionContext + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) + + +def test_indexing(df): + assert df["a"] is not None + assert df["a", "b"] is not None + assert df[("a", "b")] is not None + assert df[["a"]] is not None + + +def test_err(df): + with pytest.raises(Exception) as e_info: + df["c"] + + assert "No field with unqualified name" in e_info.value.args[0] + + with pytest.raises(Exception) as e_info: + df[1] + + assert ( + "DataFrame can only be indexed by string index or indices" + in e_info.value.args[0] + ) diff --git a/src/dataframe.rs b/src/dataframe.rs index 964f042..c73b587 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,18 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use pyo3::prelude::*; - +use crate::utils::wait_for_future; +use crate::{errors::DataFusionError, expression::PyExpr}; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; use datafusion::logical_plan::JoinType; - -use crate::utils::wait_for_future; -use crate::{errors::DataFusionError, expression::PyExpr}; +use pyo3::exceptions::PyTypeError; +use pyo3::mapping::PyMappingProtocol; +use pyo3::prelude::*; +use pyo3::types::PyTuple; +use std::sync::Arc; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -142,3 +142,25 @@ impl PyDataFrame { Ok(pretty::print_batches(&batches)?) } } + +#[pyproto] +impl PyMappingProtocol<'_> for PyDataFrame { + fn __getitem__(&self, key: PyObject) -> PyResult { + Python::with_gil(|py| { + if let Ok(key) = key.extract::<&str>(py) { + self.select_columns(vec![key]) + } else if let Ok(tuple) = key.extract::<&PyTuple>(py) { + let keys = tuple + .iter() + .map(|item| item.extract::<&str>()) + .collect::>>()?; + self.select_columns(keys) + } else if let Ok(keys) = key.extract::>(py) { + self.select_columns(keys) + } else { + let message = "DataFrame can only be indexed by string index or indices"; + Err(PyTypeError::new_err(message)) + } + }) + } +}