Skip to content

Commit

Permalink
v0.9.0
Browse files Browse the repository at this point in the history
  • Loading branch information
j6k4m8 committed Jun 11, 2024
1 parent f1ca6f7 commit 95b0748
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 16 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# CHANGELOG

### **0.9.0** (June 11 2024)

> Support for aggregate functions like `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG`.
#### Features

- Support for aggregate functions like `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG` (#45, thanks @jackboyla!)
- Logical `OR` support in relationship matches (#44, thanks @jackboyla!)

#### Testing

- Combine tests for digraphs and multidigraphs (#43, thanks @jackboyla!)

### **0.8.0** (May 14 2024)

> Support for MultiDiGraphs.
Expand Down
55 changes: 40 additions & 15 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
start="start",
)

__version__ = "0.8.0"
__version__ = "0.9.0"


_ALPHABET = string.ascii_lowercase + string.digits
Expand Down Expand Up @@ -235,8 +235,8 @@ def _is_edge_attr_match(
motif_edges = _aggregate_edge_labels(motif_edges)
host_edges = _aggregate_edge_labels(host_edges)

motif_types = motif_edges.get('__labels__', set())
host_types = host_edges.get('__labels__', set())
motif_types = motif_edges.get("__labels__", set())
host_types = host_edges.get("__labels__", set())

if motif_types and not motif_types.intersection(host_types):
return False
Expand All @@ -246,7 +246,7 @@ def _is_edge_attr_match(
continue
if host_edges.get(attr) != val:
return False

return True


Expand All @@ -271,6 +271,7 @@ def _aggregate_edge_labels(edges: Dict) -> Dict:
aggregated[edge_id] = attrs
return aggregated


def _get_entity_from_host(
host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None
):
Expand All @@ -288,7 +289,7 @@ def _get_entity_from_host(
edge_data = host.get_edge_data(*entity_name)
if not edge_data:
return None # print(f"Nothing found for {entity_name} {entity_attribute}")

if entity_attribute:
# looking for edge attribute:
if isinstance(host, nx.MultiDiGraph):
Expand Down Expand Up @@ -491,15 +492,16 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[(i, list(v.get('__labels__'))[0])] = v.get(entity_attribute, None)
r_attr[(i, list(v.get("__labels__"))[0])] = v.get(
entity_attribute, None
)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)

ret = ret_with_attr

result[data_path] = list(ret)[offset_limit]


return result

def return_clause(self, clause):
Expand All @@ -519,7 +521,6 @@ def return_clause(self, clause):
item = str(item.value)
self._return_requests.append(item)


def order_clause(self, order_clause):
self._order_by = []
for item in order_clause[0].children:
Expand All @@ -544,7 +545,6 @@ def skip_clause(self, skip):
skip = int(skip[-1])
self._skip = skip


def aggregate(self, func, results, entity, group_keys):
# Collect data based on group keys
grouped_data = {}
Expand All @@ -558,12 +558,24 @@ def _collate_data(data, unique_labels, func):
# for ["COUNT", "SUM", "AVG"], we treat None as 0
if func in ["COUNT", "SUM", "AVG"]:
collated_data = {
label: [(v or 0) for rel in data for k, v in rel.items() if k[1] == label] for label in unique_labels
label: [
(v or 0)
for rel in data
for k, v in rel.items()
if k[1] == label
]
for label in unique_labels
}
# for ["MAX", "MIN"], we treat None as non-existent
elif func in ["MAX", "MIN"]:
collated_data = {
label: [v for rel in data for k, v in rel.items() if (k[1] == label and v is not None)] for label in unique_labels
label: [
v
for rel in data
for k, v in rel.items()
if (k[1] == label and v is not None)
]
for label in unique_labels
}

return collated_data
Expand All @@ -583,7 +595,14 @@ def _collate_data(data, unique_labels, func):
elif func == "AVG":
sum_data = {label: sum(data) for label, data in collated_data.items()}
count_data = {label: len(data) for label, data in collated_data.items()}
avg_data = {label: sum_data[label] / count_data[label] if count_data[label] > 0 else 0 for label in sum_data}
avg_data = {
label: (
sum_data[label] / count_data[label]
if count_data[label] > 0
else 0
)
for label in sum_data
}
aggregate_results[group] = avg_data
elif func == "MAX":
max_data = {label: max(data) for label, data in collated_data.items()}
Expand All @@ -602,7 +621,11 @@ def returns(self, ignore_limit=False):
offset_limit=slice(0, None),
)
if len(self._aggregate_functions) > 0:
group_keys = [key for key in results.keys() if not any(key.endswith(func[1]) for func in self._aggregate_functions)]
group_keys = [
key
for key in results.keys()
if not any(key.endswith(func[1]) for func in self._aggregate_functions)
]

aggregated_results = {}
for func, entity in self._aggregate_functions:
Expand Down Expand Up @@ -865,7 +888,9 @@ def flatten_tokens(edge_tokens):
flat_tokens = []
for token in edge_tokens:
if isinstance(token, Tree):
flat_tokens.extend(flatten_tokens(token.children)) # Recursively flatten the tree
flat_tokens.extend(
flatten_tokens(token.children)
) # Recursively flatten the tree
else:
flat_tokens.append(token)
return flat_tokens
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="grand-cypher",
version="0.8.0",
version="0.9.0",
author="Jordan Matelsky",
author_email="opensource@matelsky.com",
description="Query Grand graphs using Cypher",
Expand Down

0 comments on commit 95b0748

Please sign in to comment.