Skip to content

Commit

Permalink
Modify COUNT() to output agtype (#1311) (#1335)
Browse files Browse the repository at this point in the history
Modified the make_function_expr logic to wrap the PG COUNT()
function with a cast to agtype. This enables COUNT() to be used as
a subquery in CASE.

Also added logic for casting for future PG function
additions.

This modification passes all regression tests. Also added
regression tests for COUNT in CASE statements.
  • Loading branch information
dehowef authored Oct 31, 2023
1 parent 05888e6 commit f87fa6b
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 9 deletions.
109 changes: 109 additions & 0 deletions regress/expected/expr.out
Original file line number Diff line number Diff line change
Expand Up @@ -6388,6 +6388,115 @@ $$ ) AS (case_statement agtype);
{"id": 844424930131970, "label": "connected_to", "end_id": 281474976710660, "start_id": 281474976710659, "properties": {"k": 1, "id": 2}}::edge
(2 rows)

--CASE with count()
--count(*)
SELECT * FROM cypher('case_statement', $$
MATCH (n)
RETURN n, CASE n.j
WHEN 1 THEN count(*)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);
j | case_statement
------------------------------------------------------------------------------------------------+----------------
{"id": 281474976710658, "label": "", "properties": {"i": "a", "j": "b", "id": 2}}::vertex | "not count"
{"id": 281474976710661, "label": "", "properties": {"i": [], "j": [0, 1, 2], "id": 5}}::vertex | "not count"
{"id": 281474976710657, "label": "", "properties": {"i": 1, "id": 1}}::vertex | "not count"
{"id": 281474976710659, "label": "", "properties": {"i": 0, "j": 1, "id": 3}}::vertex | 1
{"id": 281474976710660, "label": "", "properties": {"i": true, "j": false, "id": 4}}::vertex | "not count"
{"id": 281474976710662, "label": "", "properties": {"i": {}, "j": {"i": 1}, "id": 6}}::vertex | "not count"
(6 rows)

--concatenated
SELECT * FROM cypher('case_statement', $$
MATCH (n) MATCH (m)
RETURN n, CASE n.j
WHEN 1 THEN count(*)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);
j | case_statement
------------------------------------------------------------------------------------------------+----------------
{"id": 281474976710658, "label": "", "properties": {"i": "a", "j": "b", "id": 2}}::vertex | "not count"
{"id": 281474976710661, "label": "", "properties": {"i": [], "j": [0, 1, 2], "id": 5}}::vertex | "not count"
{"id": 281474976710657, "label": "", "properties": {"i": 1, "id": 1}}::vertex | "not count"
{"id": 281474976710659, "label": "", "properties": {"i": 0, "j": 1, "id": 3}}::vertex | 6
{"id": 281474976710660, "label": "", "properties": {"i": true, "j": false, "id": 4}}::vertex | "not count"
{"id": 281474976710662, "label": "", "properties": {"i": {}, "j": {"i": 1}, "id": 6}}::vertex | "not count"
(6 rows)

--count(n)
SELECT * FROM cypher('case_statement', $$
MATCH (n)
RETURN n, CASE n.j
WHEN 1 THEN count(n)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);
j | case_statement
------------------------------------------------------------------------------------------------+----------------
{"id": 281474976710658, "label": "", "properties": {"i": "a", "j": "b", "id": 2}}::vertex | "not count"
{"id": 281474976710661, "label": "", "properties": {"i": [], "j": [0, 1, 2], "id": 5}}::vertex | "not count"
{"id": 281474976710657, "label": "", "properties": {"i": 1, "id": 1}}::vertex | "not count"
{"id": 281474976710659, "label": "", "properties": {"i": 0, "j": 1, "id": 3}}::vertex | 1
{"id": 281474976710660, "label": "", "properties": {"i": true, "j": false, "id": 4}}::vertex | "not count"
{"id": 281474976710662, "label": "", "properties": {"i": {}, "j": {"i": 1}, "id": 6}}::vertex | "not count"
(6 rows)

--concatenated
SELECT * FROM cypher('case_statement', $$
MATCH (n) MATCH (m)
RETURN n, CASE n.j
WHEN 1 THEN count(n)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);
j | case_statement
------------------------------------------------------------------------------------------------+----------------
{"id": 281474976710658, "label": "", "properties": {"i": "a", "j": "b", "id": 2}}::vertex | "not count"
{"id": 281474976710661, "label": "", "properties": {"i": [], "j": [0, 1, 2], "id": 5}}::vertex | "not count"
{"id": 281474976710657, "label": "", "properties": {"i": 1, "id": 1}}::vertex | "not count"
{"id": 281474976710659, "label": "", "properties": {"i": 0, "j": 1, "id": 3}}::vertex | 6
{"id": 281474976710660, "label": "", "properties": {"i": true, "j": false, "id": 4}}::vertex | "not count"
{"id": 281474976710662, "label": "", "properties": {"i": {}, "j": {"i": 1}, "id": 6}}::vertex | "not count"
(6 rows)

--count(1)
SELECT * FROM cypher('case_statement', $$
MATCH (n)
RETURN n, CASE n.j
WHEN 1 THEN count(1)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);
j | case_statement
------------------------------------------------------------------------------------------------+----------------
{"id": 281474976710658, "label": "", "properties": {"i": "a", "j": "b", "id": 2}}::vertex | "not count"
{"id": 281474976710661, "label": "", "properties": {"i": [], "j": [0, 1, 2], "id": 5}}::vertex | "not count"
{"id": 281474976710657, "label": "", "properties": {"i": 1, "id": 1}}::vertex | "not count"
{"id": 281474976710659, "label": "", "properties": {"i": 0, "j": 1, "id": 3}}::vertex | 1
{"id": 281474976710660, "label": "", "properties": {"i": true, "j": false, "id": 4}}::vertex | "not count"
{"id": 281474976710662, "label": "", "properties": {"i": {}, "j": {"i": 1}, "id": 6}}::vertex | "not count"
(6 rows)

--concatenated
SELECT * FROM cypher('case_statement', $$
MATCH (n) MATCH (m)
RETURN n, CASE n.j
WHEN 1 THEN count(1)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);
j | case_statement
------------------------------------------------------------------------------------------------+----------------
{"id": 281474976710658, "label": "", "properties": {"i": "a", "j": "b", "id": 2}}::vertex | "not count"
{"id": 281474976710661, "label": "", "properties": {"i": [], "j": [0, 1, 2], "id": 5}}::vertex | "not count"
{"id": 281474976710657, "label": "", "properties": {"i": 1, "id": 1}}::vertex | "not count"
{"id": 281474976710659, "label": "", "properties": {"i": 0, "j": 1, "id": 3}}::vertex | 6
{"id": 281474976710660, "label": "", "properties": {"i": true, "j": false, "id": 4}}::vertex | "not count"
{"id": 281474976710662, "label": "", "properties": {"i": {}, "j": {"i": 1}, "id": 6}}::vertex | "not count"
(6 rows)

-- RETURN * and (u)--(v) optional forms
SELECT create_graph('opt_forms');
NOTICE: graph "opt_forms" has been created
Expand Down
56 changes: 56 additions & 0 deletions regress/sql/expr.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2676,6 +2676,62 @@ SELECT * FROM cypher('case_statement', $$
END
$$ ) AS (case_statement agtype);

--CASE with count()

--count(*)
SELECT * FROM cypher('case_statement', $$
MATCH (n)
RETURN n, CASE n.j
WHEN 1 THEN count(*)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);

--concatenated
SELECT * FROM cypher('case_statement', $$
MATCH (n) MATCH (m)
RETURN n, CASE n.j
WHEN 1 THEN count(*)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);

--count(n)
SELECT * FROM cypher('case_statement', $$
MATCH (n)
RETURN n, CASE n.j
WHEN 1 THEN count(n)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);

--concatenated
SELECT * FROM cypher('case_statement', $$
MATCH (n) MATCH (m)
RETURN n, CASE n.j
WHEN 1 THEN count(n)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);

--count(1)
SELECT * FROM cypher('case_statement', $$
MATCH (n)
RETURN n, CASE n.j
WHEN 1 THEN count(1)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);

--concatenated
SELECT * FROM cypher('case_statement', $$
MATCH (n) MATCH (m)
RETURN n, CASE n.j
WHEN 1 THEN count(1)
ELSE 'not count'
END
$$ ) AS (j agtype, case_statement agtype);


-- RETURN * and (u)--(v) optional forms
SELECT create_graph('opt_forms');
Expand Down
173 changes: 164 additions & 9 deletions src/backend/parser/cypher_gram.y
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ static Node *make_typecast_expr(Node *expr, char *typecast, int location);

// functions
static Node *make_function_expr(List *func_name, List *exprs, int location);
static Node *make_star_function_expr(List *func_name, List *exprs, int location);
static Node *make_distinct_function_expr(List *func_name, List *exprs, int location);
static FuncCall *wrap_pg_funccall_to_agtype(Node* fnode, char *type, int location);

// setops
static Node *make_set_op(SetOperation op, bool all_or_distinct, List *larg,
Expand Down Expand Up @@ -1687,19 +1690,16 @@ expr_func_norm:
* and there are no other aggregates in SQL that accept
* '*' as parameter.
*
* The FuncCall node is also marked agg_star = true,
* The FuncCall node is marked agg_star = true by make_star_function_expr,
* so that later processing can detect what the argument
* really was.
*/
FuncCall *n = (FuncCall *)make_function_expr($1, NIL, @1);
n->agg_star = true;
$$ = (Node *)n;
FuncCall *n = (FuncCall *)make_star_function_expr($1, NIL, @1);
$$ = (Node *)n;
}
| func_name '(' DISTINCT expr_list ')'
{
FuncCall *n = (FuncCall *)make_function_expr($1, $4, @1);
n->agg_order = NIL;
n->agg_distinct = true;
FuncCall *n = (FuncCall *)make_distinct_function_expr($1, $4, @1);
$$ = (Node *)n;
}
;
Expand Down Expand Up @@ -2263,6 +2263,70 @@ static Node *make_function_expr(List *func_name, List *exprs, int location)
if (pg_strcasecmp(name, "count") == 0)
{
funcname = SystemFuncName("count");

/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);

/* build the cast to wrap the function call to return agtype. */
fnode = wrap_pg_funccall_to_agtype((Node *)fnode, "integer", location);

return (Node *)fnode;
}
else
{
/*
* We don't qualify AGE functions here. This is done in the
* transform layer and allows us to know which functions are ours.
*/
funcname = func_name;

/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);
}
}
/* all other functions are passed as is */
else
{
fnode = makeFuncCall(func_name, exprs, location);
}

/* return the node */
return (Node *)fnode;
}

/*
* function to make a function that has received a star-argument
*/
static Node *make_star_function_expr(List *func_name, List *exprs, int location)
{
FuncCall *fnode;

/* AGE function names are unqualified. So, their list size = 1 */
if (list_length(func_name) == 1)
{
List *funcname;
char *name;

/* get the name of the function */
name = ((Value*)linitial(func_name))->val.str;

/*
* Check for openCypher functions that are directly mapped to PG
* functions. We may want to find a better way to do this, as there
* could be many.
*/
if (pg_strcasecmp(name, "count") == 0)
{
funcname = SystemFuncName("count");

/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);
fnode->agg_star = true;

/* build the cast to wrap the function call to return agtype. */
fnode = wrap_pg_funccall_to_agtype((Node *)fnode, "integer", location);

return (Node *)fnode;
}
else
{
Expand All @@ -2271,10 +2335,67 @@ static Node *make_function_expr(List *func_name, List *exprs, int location)
* transform layer and allows us to know which functions are ours.
*/
funcname = func_name;

/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);
}
}
/* all other functions are passed as is */
else
{
fnode = makeFuncCall(func_name, exprs, location);
}

/* return the node */
fnode->agg_star = true;
return (Node *)fnode;
}

/*
* function to make a function that has received a distinct keyword
*/
static Node *make_distinct_function_expr(List *func_name, List *exprs, int location)
{
FuncCall *fnode;

/* AGE function names are unqualified. So, their list size = 1 */
if (list_length(func_name) == 1)
{
List *funcname;
char *name;

/* get the name of the function */
name = ((Value*)linitial(func_name))->val.str;

/*
* Check for openCypher functions that are directly mapped to PG
* functions. We may want to find a better way to do this, as there
* could be many.
*/
if (pg_strcasecmp(name, "count") == 0)
{
funcname = SystemFuncName("count");

/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);
fnode->agg_order = NIL;
fnode->agg_distinct = true;

/* build the cast to wrap the function call to return agtype. */
fnode = wrap_pg_funccall_to_agtype((Node *)fnode, "integer", location);
return (Node *)fnode;
}
else
{
/*
* We don't qualify AGE functions here. This is done in the
* transform layer and allows us to know which functions are ours.
*/
funcname = func_name;

/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);
/* build the function call */
fnode = makeFuncCall(funcname, exprs, location);
}
}
/* all other functions are passed as is */
else
Expand All @@ -2283,9 +2404,43 @@ static Node *make_function_expr(List *func_name, List *exprs, int location)
}

/* return the node */
fnode->agg_order = NIL;
fnode->agg_distinct = true;
return (Node *)fnode;
}

/*
* helper function to wrap pg_function in the appropiate typecast function to
* interface with AGE components
*/
static FuncCall *wrap_pg_funccall_to_agtype(Node * fnode, char *type, int location)
{
List *funcname = list_make1(makeString("ag_catalog"));

if (pg_strcasecmp(type, "float") == 0)
{
funcname = lappend(funcname, makeString("float8_to_agtype"));
}
else if (pg_strcasecmp(type, "int") == 0 ||
pg_strcasecmp(type, "integer") == 0)
{
funcname = lappend(funcname, makeString("int8_to_agtype"));
}
else if (pg_strcasecmp(type, "bool") == 0 ||
pg_strcasecmp(type, "boolean") == 0)
{
funcname = lappend(funcname, makeString("bool_to_agtype"));
}
else
{
ereport(ERROR,
(errmsg_internal("type \'%s\' not supported by AGE functions",
type)));
}

return makeFuncCall(funcname, list_make1(fnode), location);
}

/* function to create a unique name given a prefix */
static char *create_unique_name(char *prefix_name)
{
Expand Down

0 comments on commit f87fa6b

Please sign in to comment.