Skip to content

Commit

Permalink
feat(bindings/python): support stream_load (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Oct 31, 2023
1 parent 108bb3c commit ca44bf3
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 3 deletions.
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ name = "databend_driver"
doc = false

[dependencies]
csv = "1.2"
databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
pyo3 = { version = "0.19", features = ["abi3-py37"] }
pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] }
Expand Down
33 changes: 33 additions & 0 deletions bindings/python/package/databend_driver/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

class ServerStats:
@property
def total_rows(self) -> int: ...
@property
def total_bytes(self) -> int: ...
@property
def read_rows(self) -> int: ...
@property
def read_bytes(self) -> int: ...
@property
def write_rows(self) -> int: ...
@property
def write_bytes(self) -> int: ...
@property
def running_time_ms(self) -> float: ...

class ConnectionInfo:
@property
def handler(self) -> str: ...
@property
def host(self) -> str: ...
@property
def port(self) -> int: ...
@property
def user(self) -> str: ...
@property
def database(self) -> str | None: ...
@property
def warehouse(self) -> str | None: ...

# flake8: noqa
class Row:
def values(self) -> tuple: ...
Expand All @@ -22,9 +52,12 @@ class RowIterator:

# flake8: noqa
class AsyncDatabendConnection:
async def info(self) -> ConnectionInfo: ...
async def version(self) -> str: ...
async def exec(self, sql: str) -> int: ...
async def query_row(self, sql: str) -> Row: ...
async def query_iter(self, sql: str) -> RowIterator: ...
async def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ...

# flake8: noqa
class AsyncDatabendClient:
Expand Down
108 changes: 108 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ pub struct AsyncDatabendConnection(Box<dyn databend_driver::Connection>);

#[pymethods]
impl AsyncDatabendConnection {
pub fn info<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let info = this.info().await;
Ok(ConnectionInfo(info))
})
}

pub fn version<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let version = this.version().await.unwrap();
Ok(version)
})
}

pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
Expand All @@ -86,6 +102,32 @@ impl AsyncDatabendConnection {
Ok(RowIterator(Arc::new(Mutex::new(streamer))))
})
}

pub fn stream_load<'p>(
&self,
py: Python<'p>,
sql: String,
data: Vec<Vec<String>>,
) -> PyResult<&'p PyAny> {
let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
for row in data {
wtr.write_record(row)
.map_err(|e| PyException::new_err(format!("{}", e)))?;
}
let bytes = wtr
.into_inner()
.map_err(|e| PyException::new_err(format!("{}", e)))?;
let size = bytes.len() as u64;
let reader = Box::new(std::io::Cursor::new(bytes));
let this = self.0.clone();
future_into_py(py, async move {
let ss = this
.stream_load(&sql, reader, size, None, None)
.await
.map_err(|e| PyException::new_err(format!("{}", e)))?;
Ok(ServerStats(ss))
})
}
}

#[pyclass(module = "databend_driver")]
Expand Down Expand Up @@ -204,3 +246,69 @@ impl RowIterator {
Ok(Some(future?.into()))
}
}

#[pyclass(module = "databend_driver")]
pub struct ConnectionInfo(databend_driver::ConnectionInfo);

#[pymethods]
impl ConnectionInfo {
#[getter]
pub fn handler(&self) -> String {
self.0.handler.to_string()
}
#[getter]
pub fn host(&self) -> String {
self.0.host.to_string()
}
#[getter]
pub fn port(&self) -> u16 {
self.0.port
}
#[getter]
pub fn user(&self) -> String {
self.0.user.to_string()
}
#[getter]
pub fn database(&self) -> Option<String> {
self.0.database.clone()
}
#[getter]
pub fn warehouse(&self) -> Option<String> {
self.0.warehouse.clone()
}
}

#[pyclass(module = "databend_driver")]
pub struct ServerStats(databend_driver::ServerStats);

#[pymethods]
impl ServerStats {
#[getter]
pub fn total_rows(&self) -> usize {
self.0.total_rows
}
#[getter]
pub fn total_bytes(&self) -> usize {
self.0.total_bytes
}
#[getter]
pub fn read_rows(&self) -> usize {
self.0.read_rows
}
#[getter]
pub fn read_bytes(&self) -> usize {
self.0.read_bytes
}
#[getter]
pub fn write_rows(&self) -> usize {
self.0.write_rows
}
#[getter]
pub fn write_bytes(&self) -> usize {
self.0.write_bytes
}
#[getter]
pub fn running_time_ms(&self) -> f64 {
self.0.running_time_ms
}
}
22 changes: 20 additions & 2 deletions bindings/python/tests/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,23 @@ async def _(context):
@then("Stream load and Select should be equal")
@async_run_until_complete
async def _(context):
# TODO:
pass
values = [
["-1", "1", "1.0", "1", "1", "2011-03-06", "2011-03-06T06:20:00Z"],
["-2", "2", "2.0", "2", "2", "2012-05-31", "2012-05-31T11:20:00Z"],
["-3", "3", "3.0", "3", "2", "2016-04-04", "2016-04-04T11:30:00Z"],
]
progress = await context.conn.stream_load("INSERT INTO test VALUES", values)
assert progress.write_rows == 3
assert progress.write_bytes == 185

rows = await context.conn.query_iter("SELECT * FROM test")
ret = []
async for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06 06:20:00"),
(-2, 2, 2.0, "2", "2", "2012-05-31", "2012-05-31 11:20:00"),
(-3, 3, 3.0, "3", "2", "2016-04-04", "2016-04-04 11:30:00"),
]
print("==>", ret)
assert ret == expected
2 changes: 1 addition & 1 deletion bindings/tests/features/binding.feature
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Feature: Databend Driver

Scenario: Select Simple
Given A new Databend Driver Client
Then Select string "Hello, World!" should be equal to "Hello, World!"
Then Select string "Hello, Databend!" should be equal to "Hello, Databend!"

Scenario: Select Iter
Given A new Databend Driver Client
Expand Down

0 comments on commit ca44bf3

Please sign in to comment.