Skip to content

Commit

Permalink
fix aggregation count function
Browse files Browse the repository at this point in the history
  • Loading branch information
thanh-nguyen-dang committed Jan 29, 2024
1 parent 8a09251 commit dd96cf7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions tube/etl/indexers/aggregation/new_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
lit,
sort_array,
struct,
sum,
count,
when,
)
from copy import deepcopy

Expand Down Expand Up @@ -121,16 +122,18 @@ def aggregate_with_count_on_edge_tbl(self, node_name, df, edge_df, child):
break

node_id = get_node_id_name(node_name)
child_id = get_node_id_name(child)
if count_reducer is None:
# if there is no reducer, group by parent key and get out empty value
count_df = edge_df.select(node_id).drop_duplicates([node_id])
else:
# if there is reducer, group by parent key and get out the number of children
# only non-leaf nodes goes through this step
child_prop_name = count_reducer.prop.name
count_df = (
edge_df.groupBy(node_id)
.count()
.select(node_id, col("count").alias(count_reducer.prop.name))
.agg(count(when(col(child_id).isNotNull(), 1)).alias(child_prop_name))
.select(node_id, col(child_prop_name))
)
count_reducer.done = True
# combine value lists new counted dataframe to existing one
Expand Down
4 changes: 2 additions & 2 deletions tube/etl/indexers/base/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pyspark.sql.context import SQLContext
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import col, min, sum, count, collect_set, collect_list, first
from pyspark.sql.functions import col, min, sum, count, collect_set, collect_list, first, when

from .lambdas import (
extract_link,
Expand Down Expand Up @@ -278,7 +278,7 @@ def reducer_to_agg_func_expr(
if func_name == "count":
if is_merging:
return sum(col(value)).alias(col_alias)
return count(col(value)).alias(col_alias)
return count(when(col(value).isNotNull(), 1)).alias(col_alias)
if func_name == "sum":
return sum(col(value)).alias(col_alias)
if func_name == "set":
Expand Down

0 comments on commit dd96cf7

Please sign in to comment.