From e92f2873b900c4df3fb0b97e2e558eb94d21d94e Mon Sep 17 00:00:00 2001 From: Nathaniel Cook Date: Thu, 5 Sep 2024 08:36:59 -0600 Subject: [PATCH] feat: add catalog/schema subcommands to flight_sql_client. (#6332) * feat: add catalog/schema subcommands to flight_sql_client. With this change basic commands are added to query the catalogs and schemas of a Flight SQL server. * fix: adds tests for flight_sql_client cli Additionally adds a builder pattern for the CommandGetTableTypes similar to CommandGetDbSchemas, while its implementation is trivial it helps to have a pattern to follow when implementing the command. * fix: add default to GetTableTypesBuilder --- arrow-flight/src/bin/flight_sql_client.rs | 80 +++- arrow-flight/src/sql/metadata/mod.rs | 1 + arrow-flight/src/sql/metadata/table_types.rs | 158 ++++++++ arrow-flight/tests/flight_sql_client_cli.rs | 382 +++++++++++++++++++ 4 files changed, 620 insertions(+), 1 deletion(-) create mode 100644 arrow-flight/src/sql/metadata/table_types.rs diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index 296efc1c308e..c334b95a9a96 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -20,7 +20,10 @@ use std::{sync::Arc, time::Duration}; use anyhow::{bail, Context, Result}; use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; -use arrow_flight::{sql::client::FlightSqlServiceClient, FlightInfo}; +use arrow_flight::{ + sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables}, + FlightInfo, +}; use arrow_schema::Schema; use clap::{Parser, Subcommand}; use futures::TryStreamExt; @@ -111,6 +114,51 @@ struct Args { /// Different available commands. #[derive(Debug, Subcommand)] enum Command { + /// Get catalogs. + Catalogs, + /// Get db schemas for a catalog. + DbSchemas { + /// Name of a catalog. + /// + /// Required. + catalog: String, + /// Specifies a filter pattern for schemas to search for. + /// When no schema_filter is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[clap(short, long)] + db_schema_filter: Option, + }, + /// Get tables for a catalog. + Tables { + /// Name of a catalog. + /// + /// Required. + catalog: String, + /// Specifies a filter pattern for schemas to search for. + /// When no schema_filter is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[clap(short, long)] + db_schema_filter: Option, + /// Specifies a filter pattern for tables to search for. + /// When no table_filter is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[clap(short, long)] + table_filter: Option, + /// Specifies a filter of table types which must match. + /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. + #[clap(long)] + table_types: Vec, + }, + /// Get table types. + TableTypes, + /// Execute given statement. StatementQuery { /// SQL query. @@ -150,6 +198,36 @@ async fn main() -> Result<()> { .context("setup client")?; let flight_info = match args.cmd { + Command::Catalogs => client.get_catalogs().await.context("get catalogs")?, + Command::DbSchemas { + catalog, + db_schema_filter, + } => client + .get_db_schemas(CommandGetDbSchemas { + catalog: Some(catalog), + db_schema_filter_pattern: db_schema_filter, + }) + .await + .context("get db schemas")?, + Command::Tables { + catalog, + db_schema_filter, + table_filter, + table_types, + } => client + .get_tables(CommandGetTables { + catalog: Some(catalog), + db_schema_filter_pattern: db_schema_filter, + table_name_filter_pattern: table_filter, + table_types, + // Schema is returned as ipc encoded bytes. + // We do not support returning the schema as there is no trivial mechanism + // to display the information to the user. + include_schema: false, + }) + .await + .context("get tables")?, + Command::TableTypes => client.get_table_types().await.context("get table types")?, Command::StatementQuery { query } => client .execute(query, None) .await diff --git a/arrow-flight/src/sql/metadata/mod.rs b/arrow-flight/src/sql/metadata/mod.rs index 1e9881ffa70e..fd71149a3180 100644 --- a/arrow-flight/src/sql/metadata/mod.rs +++ b/arrow-flight/src/sql/metadata/mod.rs @@ -33,6 +33,7 @@ mod catalogs; mod db_schemas; mod sql_info; +mod table_types; mod tables; mod xdbc_info; diff --git a/arrow-flight/src/sql/metadata/table_types.rs b/arrow-flight/src/sql/metadata/table_types.rs new file mode 100644 index 000000000000..54cfe6fe27a7 --- /dev/null +++ b/arrow-flight/src/sql/metadata/table_types.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetTableTypesBuilder`] for building responses to [`CommandGetTableTypes`] queries. +//! +//! [`CommandGetTableTypes`]: crate::sql::CommandGetTableTypes + +use std::sync::Arc; + +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::take::take; +use once_cell::sync::Lazy; + +use crate::error::*; +use crate::sql::CommandGetTableTypes; + +use super::lexsort_to_indices; + +/// A builder for a [`CommandGetTableTypes`] response. +/// +/// Builds rows like this: +/// +/// * table_type: utf8, +#[derive(Default)] +pub struct GetTableTypesBuilder { + // array builder for table types + table_type: StringBuilder, +} + +impl CommandGetTableTypes { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetTableTypesBuilder { + self.into() + } +} + +impl From for GetTableTypesBuilder { + fn from(_value: CommandGetTableTypes) -> Self { + Self::new() + } +} + +impl GetTableTypesBuilder { + /// Create a new instance of [`GetTableTypesBuilder`] + pub fn new() -> Self { + Self { + table_type: StringBuilder::new(), + } + } + + /// Append a row + pub fn append(&mut self, table_type: impl AsRef) { + self.table_type.append_value(table_type); + } + + /// builds a `RecordBatch` with the correct schema for a `CommandGetTableTypes` response + pub fn build(self) -> Result { + let schema = self.schema(); + let Self { mut table_type } = self; + + // Make the arrays + let table_type = table_type.finish(); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(table_type) as ArrayRef])?; + + // Order filtered results by table_type + let indices = lexsort_to_indices(batch.columns()); + let columns = batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(batch.schema(), columns)?) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetTableTypes`] + pub fn schema(&self) -> SchemaRef { + get_table_types_schema() + } +} + +fn get_table_types_schema() -> SchemaRef { + Arc::clone(&GET_TABLE_TYPES_SCHEMA) +} + +/// The schema for [`CommandGetTableTypes`]. +static GET_TABLE_TYPES_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![Field::new( + "table_type", + DataType::Utf8, + false, + )])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::StringArray; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_table_types_schema(), + vec![Arc::new(StringArray::from(vec![ + "a_table_type", + "b_table_type", + "c_table_type", + "d_table_type", + ])) as ArrayRef], + ) + .unwrap() + } + + #[test] + fn test_table_types_are_sorted() { + let ref_batch = get_ref_batch(); + + let mut builder = GetTableTypesBuilder::new(); + builder.append("b_table_type"); + builder.append("a_table_type"); + builder.append("d_table_type"); + builder.append("c_table_type"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } + + #[test] + fn test_builder_from_query() { + let ref_batch = get_ref_batch(); + let query = CommandGetTableTypes {}; + + let mut builder = query.into_builder(); + builder.append("a_table_type"); + builder.append("b_table_type"); + builder.append("c_table_type"); + builder.append("d_table_type"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } +} diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 168015d07e2d..6e1f6142c8b6 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -23,10 +23,12 @@ use crate::common::fixture::TestFixture; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; use arrow_flight::{ decode::FlightRecordBatchStream, + encode::FlightDataEncoderBuilder, flight_service_server::{FlightService, FlightServiceServer}, sql::{ server::{FlightSqlService, PeekableFlightDataStream}, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, + CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, CommandStatementQuery, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, }, @@ -85,6 +87,205 @@ async fn test_simple() { ); } +#[tokio::test] +async fn test_get_catalogs() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("catalogs") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+\ + \n| catalog_name |\ + \n+--------------+\ + \n| catalog_a |\ + \n| catalog_b |\ + \n+--------------+", + ); +} + +#[tokio::test] +async fn test_get_db_schemas() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("db-schemas") + .arg("catalog_a") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+----------------+\ + \n| catalog_name | db_schema_name |\ + \n+--------------+----------------+\ + \n| catalog_a | schema_1 |\ + \n| catalog_a | schema_2 |\ + \n+--------------+----------------+", + ); +} + +#[tokio::test] +async fn test_get_tables() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("tables") + .arg("catalog_a") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+----------------+------------+------------+\ + \n| catalog_name | db_schema_name | table_name | table_type |\ + \n+--------------+----------------+------------+------------+\ + \n| catalog_a | schema_1 | table_1 | TABLE |\ + \n| catalog_a | schema_2 | table_2 | VIEW |\ + \n+--------------+----------------+------------+------------+", + ); +} +#[tokio::test] +async fn test_get_tables_db_filter() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("tables") + .arg("catalog_a") + .arg("--db-schema-filter") + .arg("schema_2") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+----------------+------------+------------+\ + \n| catalog_name | db_schema_name | table_name | table_type |\ + \n+--------------+----------------+------------+------------+\ + \n| catalog_a | schema_2 | table_2 | VIEW |\ + \n+--------------+----------------+------------+------------+", + ); +} + +#[tokio::test] +async fn test_get_tables_types() { + let test_server = FlightSqlServiceImpl::default(); + let fixture = TestFixture::new(test_server.service()).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("table-types") + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+\ + \n| table_type |\ + \n+--------------+\ + \n| SYSTEM_TABLE |\ + \n| TABLE |\ + \n| VIEW |\ + \n+--------------+", + ); +} + const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle"; @@ -278,6 +479,84 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(resp) } + async fn get_flight_info_catalogs( + &self, + query: CommandGetCatalogs, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + async fn get_flight_info_schemas( + &self, + query: CommandGetDbSchemas, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + + async fn get_flight_info_tables( + &self, + query: CommandGetTables, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } + + async fn get_flight_info_table_types( + &self, + query: CommandGetTableTypes, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.as_any().encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .unwrap() + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(Response::new(flight_info)) + } async fn get_flight_info_statement( &self, query: CommandStatementQuery, @@ -309,6 +588,109 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(resp) } + async fn do_get_catalogs( + &self, + query: CommandGetCatalogs, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for catalog_name in ["catalog_a", "catalog_b"] { + builder.append(catalog_name); + } + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_get_schemas( + &self, + query: CommandGetDbSchemas, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for (catalog_name, schema_name) in [ + ("catalog_a", "schema_1"), + ("catalog_a", "schema_2"), + ("catalog_b", "schema_3"), + ] { + builder.append(catalog_name, schema_name); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_get_tables( + &self, + query: CommandGetTables, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for (catalog_name, schema_name, table_name, table_type, schema) in [ + ( + "catalog_a", + "schema_1", + "table_1", + "TABLE", + Arc::new(Schema::empty()), + ), + ( + "catalog_a", + "schema_2", + "table_2", + "VIEW", + Arc::new(Schema::empty()), + ), + ( + "catalog_b", + "schema_3", + "table_3", + "TABLE", + Arc::new(Schema::empty()), + ), + ] { + builder + .append(catalog_name, schema_name, table_name, table_type, &schema) + .unwrap(); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + + async fn do_get_table_types( + &self, + query: CommandGetTableTypes, + _request: Request, + ) -> Result::DoGetStream>, Status> { + let mut builder = query.into_builder(); + for table_type in ["TABLE", "VIEW", "SYSTEM_TABLE"] { + builder.append(table_type); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery,