Skip to content

Commit 23eae89

Browse files
committed
update pyo3
1 parent c9c841d commit 23eae89

File tree

3 files changed

+68
-40
lines changed

3 files changed

+68
-40
lines changed

python/src/dataset.rs

+19-12
Original file line numberDiff line numberDiff line change
@@ -1182,14 +1182,15 @@ impl Dataset {
11821182
#[pyo3(signature = (columns, index_type, name = None, replace = None, storage_options = None, fragment_ids = None, kwargs = None))]
11831183
fn create_fragment_index(
11841184
&mut self,
1185-
columns: Vec<&str>,
1185+
columns: Vec<PyBackedStr>,
11861186
index_type: &str,
11871187
name: Option<String>,
11881188
replace: Option<bool>,
11891189
storage_options: Option<HashMap<String, String>>,
11901190
fragment_ids: Option<Vec<u32>>,
11911191
kwargs: Option<&Bound<PyDict>>,
11921192
) -> PyResult<PyLance<Index>> {
1193+
let columns: Vec<&str> = columns.iter().map(|s| &**s).collect();
11931194
let index_type = index_type.to_uppercase();
11941195
let idx_type = self.parse_index_type(&index_type)?;
11951196
log::info!("Creating index: type={}", index_type);
@@ -1258,22 +1259,27 @@ impl Dataset {
12581259
fn unindexed_fragments(&self, name: &str) -> PyResult<PyObject> {
12591260
let result = RT
12601261
.block_on(None, self.ds.unindexed_fragments(name))?
1261-
.map_err(|err| PyIOError::new_err(err.to_string()));
1262+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
12621263

1263-
Python::with_gil(|py| result.map(|vec| export_vec(py, &vec).to_object(py)))
1264+
Python::with_gil(|py| {
1265+
let py_vec = export_vec(py, &result)?;
1266+
PyList::new(py, py_vec).map(|list| list.into())
1267+
})
12641268
}
12651269

12661270
fn indexed_fragments(&self, name: &str) -> PyResult<PyObject> {
12671271
let result = RT
12681272
.block_on(None, self.ds.indexed_fragments(name))?
1269-
.map_err(|err| PyIOError::new_err(err.to_string()));
1273+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
12701274
Python::with_gil(|py| {
1271-
result.map(|vec2| {
1272-
vec2.iter()
1273-
.map(|vec| export_vec(py, vec).to_object(py))
1274-
.collect::<Vec<_>>()
1275-
.to_object(py)
1276-
})
1275+
let result = result
1276+
.iter()
1277+
.map(|vec| {
1278+
let py_vec = export_vec(py, vec)?;
1279+
PyList::new(py, py_vec).map(|list| list.into())
1280+
})
1281+
.collect::<Result<Vec<PyObject>, _>>()?;
1282+
PyList::new(py, result).map(|list| list.into())
12771283
})
12781284
}
12791285

@@ -1645,9 +1651,10 @@ impl Dataset {
16451651
.base_tokenizer(base_tokenizer.extract()?);
16461652
}
16471653
if let Some(language) = kwargs.get_item("language")? {
1648-
let language = language.extract()?;
1654+
let language: PyBackedStr =
1655+
language.downcast::<PyString>()?.clone().try_into()?;
16491656
params.tokenizer_config =
1650-
params.tokenizer_config.language(language).map_err(|e| {
1657+
params.tokenizer_config.language(&language).map_err(|e| {
16511658
PyValueError::new_err(format!(
16521659
"can't set tokenizer language to {}: {:?}",
16531660
language, e

python/src/transaction.rs

+42-25
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use lance::dataset::transaction::{
99
use lance::datatypes::Schema;
1010
use lance_table::format::{DataFile, Fragment, Index};
1111
use pyo3::exceptions::PyValueError;
12-
use pyo3::types::{PyDict, PyNone};
13-
use pyo3::{intern, prelude::*};
12+
use pyo3::types::{PyDict, PyList, PyNone, PySet};
13+
use pyo3::{intern, prelude::*, PyTypeCheck};
1414
use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
1515
use uuid::Uuid;
1616

@@ -49,19 +49,30 @@ impl<'py> IntoPyObject<'py> for PyLance<&DataReplacementGroup> {
4949

5050
impl FromPyObject<'_> for PyLance<Index> {
5151
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
52-
let uuid = ob.get_item("uuid")?.extract()?;
52+
let uuid = ob.get_item("uuid")?.to_string();
5353
let name = ob.get_item("name")?.extract()?;
5454
let fields = ob.get_item("fields")?.extract()?;
5555
let dataset_version = ob.get_item("version")?.extract()?;
5656

5757
let fragment_ids = ob.get_item("fragment_ids")?;
58-
let fragment_ids = fragment_ids
59-
.iter()?
60-
.map(|id| id?.extract::<u32>())
61-
.collect::<PyResult<Vec<u32>>>()?;
58+
let fragment_ids = if PySet::type_check(&fragment_ids) {
59+
let fragment_ids_ref: &Bound<'_, PySet> = fragment_ids.downcast()?;
60+
fragment_ids_ref
61+
.into_iter()
62+
.map(|id| id.extract())
63+
.collect::<PyResult<Vec<u32>>>()?
64+
} else if PyList::type_check(&fragment_ids) {
65+
let fragment_ids_ref: &Bound<'_, PyList> = fragment_ids.downcast()?;
66+
fragment_ids_ref
67+
.into_iter()
68+
.map(|id| id.extract())
69+
.collect::<PyResult<Vec<u32>>>()?
70+
} else {
71+
return Err(PyValueError::new_err("Invalid fragment_ids"));
72+
};
6273
let fragment_bitmap = Some(fragment_ids.into_iter().collect());
6374
Ok(Self(Index {
64-
uuid: Uuid::parse_str(uuid).map_err(|e| PyValueError::new_err(e.to_string()))?,
75+
uuid: Uuid::parse_str(&uuid).map_err(|e| PyValueError::new_err(e.to_string()))?,
6576
name,
6677
fields,
6778
dataset_version,
@@ -73,30 +84,38 @@ impl FromPyObject<'_> for PyLance<Index> {
7384
}
7485
}
7586

76-
impl ToPyObject for PyLance<&Index> {
77-
fn to_object(&self, py: Python<'_>) -> PyObject {
78-
let uuid = self.0.uuid.to_string().to_object(py);
79-
let name = self.0.name.to_object(py);
80-
let fields = export_vec(py, &self.0.fields).to_object(py);
81-
let dataset_version = self.0.dataset_version.to_object(py);
87+
impl<'py> IntoPyObject<'py> for PyLance<&Index> {
88+
type Target = PyDict;
89+
type Output = Bound<'py, Self::Target>;
90+
type Error = PyErr;
91+
92+
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
93+
let uuid = self.0.uuid.to_string().into_pyobject(py)?;
94+
let name = self.0.name.clone().into_pyobject(py)?;
95+
let fields = export_vec(py, &self.0.fields)?;
96+
let dataset_version = self.0.dataset_version.into_pyobject(py)?;
8297
let fragment_ids = match &self.0.fragment_bitmap {
83-
Some(bitmap) => bitmap.into_iter().collect::<Vec<_>>().to_object(py),
84-
None => PyNone::get_bound(py).to_object(py),
98+
Some(bitmap) => bitmap.into_iter().collect::<Vec<_>>().into_pyobject(py)?,
99+
None => PyNone::get(py).to_owned().into_any(),
85100
};
86101

87-
let kwargs = PyDict::new_bound(py);
102+
let kwargs = PyDict::new(py);
88103
kwargs.set_item("uuid", uuid).unwrap();
89104
kwargs.set_item("name", name).unwrap();
90105
kwargs.set_item("fields", fields).unwrap();
91106
kwargs.set_item("version", dataset_version).unwrap();
92107
kwargs.set_item("fragment_ids", fragment_ids).unwrap();
93-
kwargs.into()
108+
Ok(kwargs)
94109
}
95110
}
96111

97-
impl ToPyObject for PyLance<Index> {
98-
fn to_object(&self, py: Python<'_>) -> PyObject {
99-
PyLance(&self.0).to_object(py)
112+
impl<'py> IntoPyObject<'py> for PyLance<Index> {
113+
type Target = PyDict;
114+
type Output = Bound<'py, Self::Target>;
115+
type Error = PyErr;
116+
117+
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
118+
PyLance(&self.0).into_pyobject(py)
100119
}
101120
}
102121

@@ -245,14 +264,12 @@ impl<'py> IntoPyObject<'py> for PyLance<&Operation> {
245264
removed_indices,
246265
new_indices,
247266
} => {
248-
let removed_indices = export_vec(py, removed_indices.as_slice());
249-
let new_indices = export_vec(py, new_indices.as_slice());
267+
let removed_indices = export_vec(py, removed_indices.as_slice())?;
268+
let new_indices = export_vec(py, new_indices.as_slice())?;
250269
let cls = namespace
251270
.getattr("CreateIndex")
252271
.expect("Failed to get CreateIndex class");
253272
cls.call1((removed_indices, new_indices))
254-
.unwrap()
255-
.to_object(py)
256273
}
257274
Operation::DataReplacement { replacements } => {
258275
let replacements = export_vec(py, replacements.as_slice())?;

python/src/utils.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,12 @@ pub fn class_name(ob: &Bound<'_, PyAny>) -> PyResult<String> {
282282
}
283283
}
284284

285-
impl ToPyObject for PyLance<&i32> {
286-
fn to_object(&self, py: Python) -> PyObject {
287-
self.0.to_object(py)
285+
impl<'py> IntoPyObject<'py> for PyLance<&i32> {
286+
type Target = PyAny;
287+
type Output = Bound<'py, Self::Target>;
288+
type Error = PyErr;
289+
290+
fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
291+
self.0.into_bound_py_any(py)
288292
}
289293
}

0 commit comments

Comments
 (0)