Skip to content

Commit

Permalink
feat(connect): support DdlParse (#3580)
Browse files Browse the repository at this point in the history
Co-authored-by: Cory Grinstead <cory.grinstead@gmail.com>
  • Loading branch information
andrewgazelka and universalmind303 authored Dec 17, 2024
1 parent e148248 commit 47f5897
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 6 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ daft-logical-plan = {path = "src/daft-logical-plan"}
daft-micropartition = {path = "src/daft-micropartition"}
daft-scan = {path = "src/daft-scan"}
daft-schema = {path = "src/daft-schema"}
daft-sql = {path = "src/daft-sql"}
daft-table = {path = "src/daft-table"}
derivative = "2.2.0"
derive_builder = "0.20.2"
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ daft-logical-plan = {workspace = true}
daft-micropartition = {workspace = true}
daft-scan = {workspace = true}
daft-schema = {workspace = true}
daft-sql = {workspace = true}
daft-table = {workspace = true}
dashmap = "6.1.0"
eyre = "0.6.12"
Expand Down
24 changes: 23 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,29 @@ impl SparkConnectService for DaftSparkConnectService {

Ok(Response::new(response))
}
_ => unimplemented_err!("Analyze plan operation is not yet implemented"),
Analyze::DdlParse(DdlParse { ddl_string }) => {
let daft_schema = match daft_sql::sql_schema(&ddl_string) {
Ok(daft_schema) => daft_schema,
Err(e) => return invalid_argument_err!("{e}"),
};

let daft_schema = daft_schema.to_struct();

let schema = translation::to_spark_datatype(&daft_schema);

let schema = analyze_plan_response::Schema {
schema: Some(schema),
};

let response = AnalyzePlanResponse {
session_id,
server_side_session_id: String::new(),
result: Some(analyze_plan_response::Result::Schema(schema)),
};

Ok(Response::new(response))
}
other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"),
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/daft-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use derive_more::Display;
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};

use crate::field::Field;
use crate::{field::Field, prelude::DataType};

pub type SchemaRef = Arc<Schema>;

Expand Down Expand Up @@ -48,6 +48,11 @@ impl Schema {
Ok(Self { fields: map })
}

pub fn to_struct(&self) -> DataType {
let fields = self.fields.values().cloned().collect();
DataType::Struct(fields)
}

pub fn exclude<S: AsRef<str>>(&self, names: &[S]) -> DaftResult<Self> {
let mut fields = IndexMap::new();
let names = names.iter().map(|s| s.as_ref()).collect::<HashSet<&str>>();
Expand Down
3 changes: 3 additions & 0 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ pub mod catalog;
pub mod error;
pub mod functions;
mod modules;

mod planner;
pub use planner::*;

#[cfg(feature = "python")]
pub mod python;
mod table_provider;
Expand Down
111 changes: 107 additions & 4 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use daft_functions::{
use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField,
Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions, With,
ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct,
ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr,
Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator,
Value, WildcardAdditionalOptions, With,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -1262,6 +1262,28 @@ impl<'a> SQLPlanner<'a> {
}
}

fn column_to_field(&self, column_def: &ColumnDef) -> SQLPlannerResult<Field> {
let ColumnDef {
name,
data_type,
collation,
options,
} = column_def;

if let Some(collation) = collation {
unsupported_sql_err!("collation operation ({collation:?}) is not supported")
}

if !options.is_empty() {
unsupported_sql_err!("unsupported options: {options:?}")
}

let name = ident_to_str(name);
let data_type = self.sql_dtype_to_dtype(data_type)?;

Ok(Field::new(name, data_type))
}

fn value_to_lit(&self, value: &Value) -> SQLPlannerResult<LiteralValue> {
Ok(match value {
Value::SingleQuotedString(s) => LiteralValue::Utf8(s.clone()),
Expand Down Expand Up @@ -2114,6 +2136,32 @@ fn check_wildcard_options(

Ok(())
}

pub fn sql_schema<S: AsRef<str>>(s: S) -> SQLPlannerResult<SchemaRef> {
let planner = SQLPlanner::default();

let tokens = Tokenizer::new(&GenericDialect, s.as_ref()).tokenize()?;

let mut parser = Parser::new(&GenericDialect)
.with_options(ParserOptions {
trailing_commas: true,
..Default::default()
})
.with_tokens(tokens);

let column_defs = parser.parse_comma_separated(Parser::parse_column_def)?;

let fields: Result<Vec<_>, _> = column_defs
.into_iter()
.map(|c| planner.column_to_field(&c))
.collect();

let fields = fields?;

let schema = Schema::new(fields)?;
Ok(Arc::new(schema))
}

pub fn sql_expr<S: AsRef<str>>(s: S) -> SQLPlannerResult<ExprRef> {
let mut planner = SQLPlanner::default();

Expand All @@ -2138,6 +2186,12 @@ pub fn sql_expr<S: AsRef<str>>(s: S) -> SQLPlannerResult<ExprRef> {
// ----------------
// Helper functions
// ----------------

/// # Examples
/// ```
/// // Quoted identifier "MyCol" -> "MyCol"
/// // Unquoted identifier MyCol -> "MyCol"
/// ```
fn ident_to_str(ident: &Ident) -> String {
if ident.quote_style == Some('"') {
ident.value.to_string()
Expand Down Expand Up @@ -2190,3 +2244,52 @@ fn unresolve_alias(expr: ExprRef, projection: &[ExprRef]) -> SQLPlannerResult<Ex
})
.ok_or_else(|| PlannerError::column_not_found(expr.name(), "projection"))
}

#[cfg(test)]
mod tests {
use daft_core::prelude::*;

use crate::sql_schema;

#[test]
fn test_sql_schema_creates_expected_schema() {
let result =
sql_schema("Year int, First_Name STRING, County STRING, Sex STRING, Count int")
.unwrap();

let expected = Schema::new(vec![
Field::new("Year", DataType::Int32),
Field::new("First_Name", DataType::Utf8),
Field::new("County", DataType::Utf8),
Field::new("Sex", DataType::Utf8),
Field::new("Count", DataType::Int32),
])
.unwrap();

assert_eq!(&*result, &expected);
}

#[test]
fn test_duplicate_column_names_in_schema() {
// This test checks that sql_schema fails or handles duplicates gracefully.
// The planner currently returns errors if schema construction fails, so we expect an Err here.
let result = sql_schema("col1 INT, col1 STRING");

assert_eq!(
result.unwrap_err().to_string(),
"Daft error: DaftError::ValueError Attempting to make a Schema with duplicate field names: col1"
);
}

#[test]
fn test_degenerate_empty_schema() {
assert!(sql_schema("").is_err());
}

#[test]
fn test_single_field_schema() {
let result = sql_schema("col1 INT").unwrap();
let expected = Schema::new(vec![Field::new("col1", DataType::Int32)]).unwrap();
assert_eq!(&*result, &expected);
}
}
18 changes: 18 additions & 0 deletions tests/connect/test_analyze_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

import pytest


@pytest.mark.skip(
reason="Currently an issue in the spark connect code. It always passes the inferred schema instead of the supplied schema."
)
def test_analyze_plan(spark_session):
data = [[1000, 99]]
df1 = spark_session.createDataFrame(data, schema="Value int, Total int")
s = df1.schema

# todo: this is INCORRECT but it is an issue with pyspark client
# right now it is assert str(s) == "StructType([StructField('_1', LongType(), True), StructField('_2', LongType(), True)])"
assert (
str(s) == "StructType([StructField('Value', IntegerType(), True), StructField('Total', IntegerType(), True)])"
)

0 comments on commit 47f5897

Please sign in to comment.