Skip to content

Commit

Permalink
Bugfix + Unify digraph and multidigraph behaviour (#46)
Browse files Browse the repository at this point in the history
* Handles DiGraphs and MultiDiGraphs in the same way

* Adapts tests to new output and adds more tests for `order by`

* Removes commented code

* Adds support for alises and normalises order by edge output

* Adds alias unit tests and updates order by tests to reflect new output
  • Loading branch information
jackboyla authored Jun 18, 2024
1 parent 95b0748 commit ed6bcf6
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 92 deletions.
139 changes: 110 additions & 29 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@
return_clause : "return"i distinct_return? return_item ("," return_item)*
return_item : entity_id | aggregation_function | entity_id "." attribute_id
return_item : (entity_id | aggregation_function | entity_id "." attribute_id) ( "AS"i alias )?
aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")"
AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN"
attribute_id : CNAME
alias : CNAME
distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
Expand All @@ -97,7 +98,7 @@
order_items : order_item ("," order_item)*
order_item : entity_id order_direction?
order_item : (entity_id | aggregation_function) order_direction?
order_direction : "ASC"i -> asc
| "DESC"i -> desc
Expand Down Expand Up @@ -363,7 +364,7 @@ def inner(


def _data_path_to_entity_name_attribute(data_path):
if not isinstance(data_path, str):
if isinstance(data_path, Token):
data_path = data_path.value
if "." in data_path:
entity_name, entity_attribute = data_path.split(".")
Expand All @@ -376,7 +377,9 @@ def _data_path_to_entity_name_attribute(data_path):

class _GrandCypherTransformer(Transformer):
def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = target_graph
self._target_graph = nx.MultiDiGraph(target_graph)
self._entity2alias = dict()
self._alias2entity = dict()
self._paths = []
self._where_condition: CONDITION = None
self._motif = nx.MultiDiGraph()
Expand All @@ -385,6 +388,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._return_requests = []
self._return_edges = {}
self._aggregate_functions = []
self._aggregation_attributes = set()
self._distinct = False
self._order_by = None
self._order_by_attributes = set()
Expand Down Expand Up @@ -491,12 +495,15 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
ret_with_attr = []
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[(i, list(v.get("__labels__"))[0])] = v.get(
entity_attribute, None
)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)
if isinstance(r, dict):
r = [r]
for el in r:
for i, v in el.items():
r_attr[(i, list(v.get("__labels__", [i]))[0])] = v.get(
entity_attribute, None
)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)

ret = ret_with_attr

Expand All @@ -508,31 +515,73 @@ def return_clause(self, clause):
# collect all entity identifiers to be returned
for item in clause:
if item:
alias = self._extract_alias(item)
item = item.children[0] if isinstance(item, Tree) else item
if isinstance(item, Tree) and item.data == "aggregation_function":
func = str(item.children[0].value) # AGGREGATE_FUNC
entity = str(item.children[1].value)
if len(item.children) > 2:
entity += "." + str(item.children[2].children[0].value)
func, entity = self._parse_aggregation_token(item)
if alias:
self._entity2alias[self._format_aggregation_key(func, entity)] = alias
self._aggregation_attributes.add(entity)
self._aggregate_functions.append((func, entity))
self._return_requests.append(entity)
else:
if not isinstance(item, str):
item = str(item.value)

if alias:
self._entity2alias[item] = alias
self._return_requests.append(item)

self._alias2entity.update({v: k for k, v in self._entity2alias.items()})

def _extract_alias(self, item: Tree):
'''
Extract the alias from the return item (if it exists)
'''

if len(item.children) == 1:
return None
item_keys = [it.data if isinstance(it, Tree) else None for it in item.children]
if any(k == 'alias' for k in item_keys):
# get the index of the alias
alias_index = item_keys.index('alias')
return str(item.children[alias_index].children[0].value)

return None

def _parse_aggregation_token(self, item: Tree):
'''
Parse the aggregation function token and return the function and entity
input: Tree('aggregation_function', [Token('AGGREGATE_FUNC', 'SUM'), Token('CNAME', 'r'), Tree('attribute_id', [Token('CNAME', 'value')])])
output: ('SUM', 'r.value')
'''
func = str(item.children[0].value) # AGGREGATE_FUNC
entity = str(item.children[1].value)
if len(item.children) > 2:
entity += "." + str(item.children[2].children[0].value)

return func, entity

def _format_aggregation_key(self, func, entity):
return f"{func}({entity})"

def order_clause(self, order_clause):
self._order_by = []
for item in order_clause[0].children:
field = str(item.children[0]) # assuming the field name is the first child
if isinstance(item.children[0], Tree) and item.children[0].data == "aggregation_function":
func, entity = self._parse_aggregation_token(item.children[0])
field = self._format_aggregation_key(func, entity)
self._order_by_attributes.add(entity)
else:
field = str(item.children[0]) # assuming the field name is the first child
self._order_by_attributes.add(field)

# Default to 'ASC' if not specified
if len(item.children) > 1 and str(item.children[1].data).lower() != "desc":
direction = "ASC"
else:
direction = "DESC"

self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...]
self._order_by_attributes.add(field)

def distinct_return(self, distinct):
self._distinct = True
Expand Down Expand Up @@ -616,8 +665,11 @@ def _collate_data(data, unique_labels, func):

def returns(self, ignore_limit=False):

data_paths = self._return_requests + list(self._order_by_attributes) + list(self._aggregation_attributes)
# aliases should already be requested in their original form, so we will remove them for lookup
data_paths = [d for d in data_paths if d not in self._alias2entity]
results = self._lookup(
self._return_requests + list(self._order_by_attributes),
data_paths,
offset_limit=slice(0, None),
)
if len(self._aggregate_functions) > 0:
Expand All @@ -630,46 +682,75 @@ def returns(self, ignore_limit=False):
aggregated_results = {}
for func, entity in self._aggregate_functions:
aggregated_data = self.aggregate(func, results, entity, group_keys)
func_key = f"{func}({entity})"
func_key = self._format_aggregation_key(func, entity)
aggregated_results[func_key] = aggregated_data
self._return_requests.append(func_key)
results.update(aggregated_results)

# update the results with the given alias(es)
results = {self._entity2alias.get(k, k): v for k, v in results.items()}

if self._order_by:
results = self._apply_order_by(results)
if self._distinct:
results = self._apply_distinct(results)
results = self._apply_pagination(results, ignore_limit)

# Exclude order-by-only attributes from the final results
# Only include keys that were asked for in `RETURN` in the final results
results = {
key: values
for key, values in results.items()
if key in self._return_requests
if self._alias2entity.get(key, key) in self._return_requests
}

return results

def _apply_order_by(self, results):
if self._order_by:
sort_lists = [
(results[field], direction)
(results[field], field, direction)
for field, direction in self._order_by
if field in results
]

if sort_lists:
# Generate a list of indices sorted by the specified fields
indices = range(
len(next(iter(results.values())))
) # Safe because all lists are assumed to be of the same length
for sort_list, direction in reversed(
for (sort_list, field, direction) in reversed(
sort_lists
): # reverse to ensure the first sort key is primary
indices = sorted(
indices,
key=lambda i: sort_list[i],
reverse=(direction == "DESC"),
)

if all(isinstance(item, dict) for item in sort_list):
# (for edge attributes) If all items in sort_list are dictionaries
# example: ([{(0, 'paid'): 9, (1, 'paid'): 40}, {(0, 'paid'): 14}], 'DESC')

# sort within each edge first
sorted_sublists = []
for sublist in sort_list:
sorted_sublist = sorted(
sublist.items(),
key=lambda x: x[1] or 0, # 0 if `None`
reverse=(direction == "DESC"),
)
sorted_sublists.append({k: v for k, v in sorted_sublist})
sort_list = sorted_sublists

# then sort the indices based on the sorted sublists
indices = sorted(
indices,
key=lambda i: list(sort_list[i].values())[0] or 0, # 0 if `None`
reverse=(direction == "DESC"),
)
# update results with sorted edge attributes list
results[field] = sort_list
else:
# (for node attributes) single values
indices = sorted(
indices,
key=lambda i: sort_list[i],
reverse=(direction == "DESC"),
)

# Reorder all lists in results using sorted indices
for key in results:
Expand Down
Loading

0 comments on commit ed6bcf6

Please sign in to comment.