Skip to content

Commit

Permalink
sql: support LATERAL joins
Browse files Browse the repository at this point in the history
A LATERAL join allows the right-hand side of the join to access columns
defined on the left-hand side of the join. A simple example from the
PostgreSQL docs is the query

    SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) ss

which is equivalent to:

    SELECT * FROM foo, bar WHERE bar.id = foo.bar_id;

The hope is that LATERAL joins will be useful for expressing "top-k
within a group", as in:

   SELECT * FROM
       (SELECT DISTINCT cat FROM foo) grp,
       JOIN LATERAL (SELECT * FROM foo WHERE foo.cat = grp.cat ORDER BY foo.val LIMIT $k)
  • Loading branch information
benesch committed Jul 23, 2020
1 parent d90a52a commit 038a614
Show file tree
Hide file tree
Showing 9 changed files with 899 additions and 119 deletions.
45 changes: 42 additions & 3 deletions src/sql/src/plan/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ impl RelationExpr {
}
input
}
FlatMap { input, func, exprs } => {
CallTable { func, exprs } => {
// FlatMap expressions may contain correlated subqueries. Unlike Map they are not
// allowed to refer to the results of previous expressions, and we have a simpler
// implementation that appends all relevant columns first, then applies the flatmap
// operator to the result, then strips off any columns introduce by subqueries.

let mut input = input.applied_to(id_gen, get_outer, col_map);
let mut input = get_outer;
let old_arity = input.arity();

let exprs = exprs
Expand Down Expand Up @@ -209,6 +209,45 @@ impl RelationExpr {
}
input
}
Join {
left,
right,
on,
kind,
} if kind.is_lateral() => {
let left = left.applied_to(id_gen, get_outer, col_map);
let mut join = branch(
id_gen,
left,
col_map,
*right,
|id_gen, right, get_left, col_map| {
let join = right.applied_to(id_gen, get_left.clone(), col_map);
if let JoinKind::LeftOuter { .. } = kind {
let default = join
.typ()
.column_types
.into_iter()
.skip(get_left.arity())
.map(|typ| (Datum::Null, typ.nullable(true)))
.collect();
get_left.lookup(id_gen, join, default)
} else {
join
}
},
);
let old_arity = join.arity();
let on = on.applied_to(id_gen, col_map, &mut join);
join = join.filter(vec![on]);
let new_arity = join.arity();
if old_arity != new_arity {
// This means we added some columns to handle
// subqueries, and now we need to get rid of them.
join = join.project((0..old_arity).collect());
}
join
}
Join {
left,
right,
Expand Down Expand Up @@ -257,7 +296,7 @@ impl RelationExpr {
}
join.let_in(id_gen, |id_gen, get_join| {
let mut result = get_join.clone();
if let JoinKind::LeftOuter | JoinKind::FullOuter = kind {
if let JoinKind::LeftOuter { .. } | JoinKind::FullOuter { .. } = kind {
let left_outer = get_left.clone().anti_lookup(
id_gen,
get_join.clone(),
Expand Down
14 changes: 8 additions & 6 deletions src/sql/src/plan/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl RelationExpr {
Some(parent_expr) => match parent_expr {
Project { .. }
| Map { .. }
| FlatMap { .. }
| CallTable { .. }
| Filter { .. }
| Reduce { .. }
| TopK { .. }
Expand Down Expand Up @@ -163,7 +163,7 @@ impl RelationExpr {
| TopK { .. } => (),
Map { scalars, .. } => scalar_exprs.extend(scalars),
Filter { predicates, .. } => scalar_exprs.extend(predicates),
FlatMap { exprs, .. } => scalar_exprs.extend(exprs),
CallTable { exprs, .. } => scalar_exprs.extend(exprs),
Join { on, .. } => scalar_exprs.push(on),
Reduce { aggregates, .. } => {
scalar_exprs.extend(aggregates.iter().map(|a| &*a.expr))
Expand Down Expand Up @@ -241,10 +241,10 @@ impl RelationExpr {
)
.unwrap();
}
FlatMap { func, exprs, .. } => {
CallTable { func, exprs } => {
write!(
pretty,
"FlatMap {}({})",
"CallTable {}({})",
func,
Separated(
", ",
Expand Down Expand Up @@ -445,8 +445,10 @@ impl std::fmt::Display for JoinKind {
f,
"{}",
match self {
JoinKind::Inner => "Inner",
JoinKind::LeftOuter => "LeftOuter",
JoinKind::Inner { lateral: false } => "Inner",
JoinKind::Inner { lateral: true } => "InnerLateral",
JoinKind::LeftOuter { lateral: false } => "LeftOuter",
JoinKind::LeftOuter { lateral: true } => "LeftOuterLateral",
JoinKind::RightOuter => "RightOuter",
JoinKind::FullOuter => "FullOuter",
}
Expand Down
76 changes: 43 additions & 33 deletions src/sql/src/plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! similar to that file, with some differences which are noted below. It gets turned into that
//! representation via a call to decorrelate().
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::mem;

Expand Down Expand Up @@ -54,8 +55,7 @@ pub enum RelationExpr {
input: Box<RelationExpr>,
scalars: Vec<ScalarExpr>,
},
FlatMap {
input: Box<RelationExpr>,
CallTable {
func: TableFunc,
exprs: Vec<ScalarExpr>,
},
Expand Down Expand Up @@ -255,12 +255,21 @@ pub struct ColumnRef {

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JoinKind {
Inner,
LeftOuter,
Inner { lateral: bool },
LeftOuter { lateral: bool },
RightOuter,
FullOuter,
}

impl JoinKind {
pub fn is_lateral(&self) -> bool {
match self {
JoinKind::Inner { lateral } | JoinKind::LeftOuter { lateral } => *lateral,
JoinKind::RightOuter | JoinKind::FullOuter => false,
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AggregateExpr {
pub func: AggregateFunc,
Expand Down Expand Up @@ -293,32 +302,35 @@ impl RelationExpr {
}
typ
}
RelationExpr::FlatMap {
input,
func,
exprs: _,
} => {
let mut typ = input.typ(outers, params);
typ.column_types.extend(func.output_type().column_types);
// FlatMap can add duplicate rows, so input keys are no longer valid
RelationType::new(typ.column_types)
}
RelationExpr::CallTable { func, exprs: _ } => func.output_type(),
RelationExpr::Filter { input, .. } | RelationExpr::TopK { input, .. } => {
input.typ(outers, params)
}
RelationExpr::Join {
left, right, kind, ..
} => {
let left_nullable = *kind == JoinKind::RightOuter || *kind == JoinKind::FullOuter;
let right_nullable = *kind == JoinKind::LeftOuter || *kind == JoinKind::FullOuter;
let left_nullable = matches!(kind, JoinKind::RightOuter | JoinKind::FullOuter);
let right_nullable =
matches!(kind, JoinKind::LeftOuter { .. } | JoinKind::FullOuter);
let lt = left.typ(outers, params).column_types.into_iter().map(|t| {
let nullable = t.nullable || left_nullable;
t.nullable(nullable)
});
let rt = right.typ(outers, params).column_types.into_iter().map(|t| {
let nullable = t.nullable || right_nullable;
t.nullable(nullable)
});
let outers = if kind.is_lateral() {
let mut outers = outers.to_vec();
outers.push(RelationType::new(lt.clone().collect()));
Cow::Owned(outers)
} else {
Cow::Borrowed(outers)
};
let rt = right
.typ(&outers, params)
.column_types
.into_iter()
.map(|t| {
let nullable = t.nullable || right_nullable;
t.nullable(nullable)
});
RelationType::new(lt.chain(rt).collect())
}
RelationExpr::Reduce {
Expand Down Expand Up @@ -364,7 +376,7 @@ impl RelationExpr {
RelationExpr::Get { typ, .. } => typ.column_types.len(),
RelationExpr::Project { outputs, .. } => outputs.len(),
RelationExpr::Map { input, scalars } => input.arity() + scalars.len(),
RelationExpr::FlatMap { input, func, .. } => input.arity() + func.output_arity(),
RelationExpr::CallTable { func, .. } => func.output_arity(),
RelationExpr::Filter { input, .. }
| RelationExpr::TopK { input, .. }
| RelationExpr::Distinct { input }
Expand Down Expand Up @@ -507,16 +519,15 @@ impl RelationExpr {
F: FnMut(&'a Self),
{
match self {
RelationExpr::Constant { .. } | RelationExpr::Get { .. } => (),
RelationExpr::Constant { .. }
| RelationExpr::Get { .. }
| RelationExpr::CallTable { .. } => (),
RelationExpr::Project { input, .. } => {
f(input);
}
RelationExpr::Map { input, .. } => {
f(input);
}
RelationExpr::FlatMap { input, .. } => {
f(input);
}
RelationExpr::Filter { input, .. } => {
f(input);
}
Expand Down Expand Up @@ -559,16 +570,15 @@ impl RelationExpr {
F: FnMut(&'a mut Self),
{
match self {
RelationExpr::Constant { .. } | RelationExpr::Get { .. } => (),
RelationExpr::Constant { .. }
| RelationExpr::Get { .. }
| RelationExpr::CallTable { .. } => (),
RelationExpr::Project { input, .. } => {
f(input);
}
RelationExpr::Map { input, .. } => {
f(input);
}
RelationExpr::FlatMap { input, .. } => {
f(input);
}
RelationExpr::Filter { input, .. } => {
f(input);
}
Expand Down Expand Up @@ -609,11 +619,12 @@ impl RelationExpr {
{
match self {
RelationExpr::Join {
kind: _,
kind,
on,
left,
right,
} => {
let depth = if kind.is_lateral() { depth + 1 } else { depth };
on.visit_columns(depth, f);
left.visit_columns(depth, f);
right.visit_columns(depth, f);
Expand All @@ -624,11 +635,10 @@ impl RelationExpr {
}
input.visit_columns(depth, f);
}
RelationExpr::FlatMap { exprs, input, .. } => {
RelationExpr::CallTable { exprs, .. } => {
for expr in exprs {
expr.visit_columns(depth, f);
}
input.visit_columns(depth, f);
}
RelationExpr::Filter { predicates, input } => {
for predicate in predicates {
Expand Down Expand Up @@ -669,7 +679,7 @@ impl RelationExpr {
scalar.bind_parameters(parameters);
}
}
RelationExpr::FlatMap { exprs, .. } => {
RelationExpr::CallTable { exprs, .. } => {
for expr in exprs {
expr.bind_parameters(parameters);
}
Expand Down
Loading

0 comments on commit 038a614

Please sign in to comment.