diff --git a/README.md b/README.md
index c19f7def..e363f194 100755
--- a/README.md
+++ b/README.md
@@ -20,22 +20,20 @@ An explainable inference software supporting annotated, real valued, graph based
[🗎 Documentation](https://pyreason.readthedocs.io/en/latest/)
-Check out the [PyReason Hello World](https://pyreason.readthedocs.io/en/latest/tutorials/Basic%20tutorial.html) program if you're new, or want get get a feel for the software.
-
## Table of Contents
1. [Introduction](#1-introduction)
2. [Documentation](#2-documentation)
3. [Install](#3-install)
-5. [Bibtex](#4-bibtex)
-6. [License](#5-license)
-7. [Contact](#6-contact)
+4. [Bibtex](#4-bibtex)
+5. [License](#5-license)
+6. [Contact](#6-contact)
## 1. Introduction
PyReason is a graphical inference tool that uses a set of logical rules and facts (initial conditions) to reason over graph structures. To get more details, refer to the paper/video/hello-world-example mentioned above.
-
+
## 2. Documentation
All API documentation and code examples can be found on [ReadTheDocs](https://pyreason.readthedocs.io/en/latest/)
diff --git a/docs/advanced_graph.ipynb b/docs/advanced_graph.ipynb
index 3908af7f..63412c50 100644
--- a/docs/advanced_graph.ipynb
+++ b/docs/advanced_graph.ipynb
@@ -2,38 +2,65 @@
"cells": [
{
"cell_type": "code",
- "outputs": [],
- "source": [
- "from pprint import pprint\n",
- "import networkx as nx\n",
- "import pyreason as pr"
- ],
+ "execution_count": 1,
+ "id": "262e6f13c9d84198",
"metadata": {
- "collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-06T16:55:01.204690Z",
"start_time": "2024-03-06T16:54:59.111005Z"
- }
+ },
+ "collapsed": false
},
- "id": "262e6f13c9d84198",
- "execution_count": 1
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Imported PyReason for the first time. Initializing caches for faster runtimes ... this will take a minute\n",
+ "PyReason initialized!\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pprint import pprint\n",
+ "import networkx as nx\n",
+ "import pyreason as pr"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Customers\n"
- ],
+ "id": "1a5145b4d65f6368",
"metadata": {
"collapsed": false
},
- "id": "1a5145b4d65f6368"
+ "source": [
+ "## Customers\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 2,
+ "id": "665d462215b1aace",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T16:55:01.211610Z",
+ "start_time": "2024-03-06T16:55:01.205564Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"data": {
- "text/plain": "{0: ('John', 'M', 'New York', 'NY'),\n 1: ('Mary', 'F', 'Los Angeles', 'CA'),\n 2: ('Justin', 'M', 'Chicago', 'IL'),\n 3: ('Alice', 'F', 'Houston', 'TX'),\n 4: ('Bob', 'M', 'Phoenix', 'AZ'),\n 5: ('Eva', 'F', 'San Diego', 'CA'),\n 6: ('Mike', 'M', 'Dallas', 'TX')}"
+ "text/plain": [
+ "{0: ('John', 'M', 'New York', 'NY'),\n",
+ " 1: ('Mary', 'F', 'Los Angeles', 'CA'),\n",
+ " 2: ('Justin', 'M', 'Chicago', 'IL'),\n",
+ " 3: ('Alice', 'F', 'Houston', 'TX'),\n",
+ " 4: ('Bob', 'M', 'Phoenix', 'AZ'),\n",
+ " 5: ('Eva', 'F', 'San Diego', 'CA'),\n",
+ " 6: ('Mike', 'M', 'Dallas', 'TX')}"
+ ]
},
"execution_count": 2,
"metadata": {},
@@ -55,33 +82,39 @@
"customer_dict = {i: customer for i, customer in enumerate(customer_details)}\n",
"\n",
"customer_dict"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T16:55:01.211610Z",
- "start_time": "2024-03-06T16:55:01.205564Z"
- }
- },
- "id": "665d462215b1aace",
- "execution_count": 2
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Pets"
- ],
+ "id": "6c045353bd1abf62",
"metadata": {
"collapsed": false
},
- "id": "6c045353bd1abf62"
+ "source": [
+ "## Pets"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 3,
+ "id": "d76169fae6692dd9",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T16:55:01.218314Z",
+ "start_time": "2024-03-06T16:55:01.212164Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"data": {
- "text/plain": "{0: ('Dog', 'Mammal'),\n 1: ('Cat', 'Mammal'),\n 2: ('Rabbit', 'Mammal'),\n 3: ('Parrot', 'Bird'),\n 4: ('Fish', 'Fish')}"
+ "text/plain": [
+ "{0: ('Dog', 'Mammal'),\n",
+ " 1: ('Cat', 'Mammal'),\n",
+ " 2: ('Rabbit', 'Mammal'),\n",
+ " 3: ('Parrot', 'Bird'),\n",
+ " 4: ('Fish', 'Fish')}"
+ ]
},
"execution_count": 3,
"metadata": {},
@@ -99,33 +132,46 @@
"]\n",
"pet_dict = {i: pet for i, pet in enumerate(pet_details)}\n",
"pet_dict"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T16:55:01.218314Z",
- "start_time": "2024-03-06T16:55:01.212164Z"
- }
- },
- "id": "d76169fae6692dd9",
- "execution_count": 3
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Cars"
- ],
+ "id": "ef9acbe66746de4b",
"metadata": {
"collapsed": false
},
- "id": "ef9acbe66746de4b"
+ "source": [
+ "## Cars"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 4,
+ "id": "247a351dd20e3dff",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T16:55:01.226020Z",
+ "start_time": "2024-03-06T16:55:01.218880Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"data": {
- "text/plain": "{0: ('Toyota Camry', 'Red'),\n 1: ('Honda Civic', 'Blue'),\n 2: ('Ford Focus', 'Red'),\n 3: ('BMW 3 Series', 'Black'),\n 4: ('Tesla Model S', 'Red'),\n 5: ('Chevrolet Bolt EV', 'White'),\n 6: ('Ford Mustang', 'Yellow'),\n 7: ('Audi A4', 'Silver'),\n 8: ('Mercedes-Benz C-Class', 'Grey'),\n 9: ('Subaru Outback', 'Green'),\n 10: ('Volkswagen Golf', 'Blue'),\n 11: ('Porsche 911', 'Black')}"
+ "text/plain": [
+ "{0: ('Toyota Camry', 'Red'),\n",
+ " 1: ('Honda Civic', 'Blue'),\n",
+ " 2: ('Ford Focus', 'Red'),\n",
+ " 3: ('BMW 3 Series', 'Black'),\n",
+ " 4: ('Tesla Model S', 'Red'),\n",
+ " 5: ('Chevrolet Bolt EV', 'White'),\n",
+ " 6: ('Ford Mustang', 'Yellow'),\n",
+ " 7: ('Audi A4', 'Silver'),\n",
+ " 8: ('Mercedes-Benz C-Class', 'Grey'),\n",
+ " 9: ('Subaru Outback', 'Green'),\n",
+ " 10: ('Volkswagen Golf', 'Blue'),\n",
+ " 11: ('Porsche 911', 'Black')}"
+ ]
},
"execution_count": 4,
"metadata": {},
@@ -152,29 +198,29 @@
"\n",
"car_dict = {i: car for i, car in enumerate(car_details)}\n",
"car_dict"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T16:55:01.226020Z",
- "start_time": "2024-03-06T16:55:01.218880Z"
- }
- },
- "id": "247a351dd20e3dff",
- "execution_count": 4
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Travels"
- ],
+ "id": "74f8acc79272b7f1",
"metadata": {
"collapsed": false
},
- "id": "74f8acc79272b7f1"
+ "source": [
+ "## Travels"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 5,
+ "id": "fed68df1e81535c5",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T16:55:01.232158Z",
+ "start_time": "2024-03-06T16:55:01.227192Z"
+ },
+ "collapsed": false
+ },
"outputs": [],
"source": [
"travels = [\n",
@@ -183,29 +229,29 @@
" ('Eva', 'San Diego', 'CA', 'Dallas', 'TX', 1),\n",
" ('Mike', 'Dallas', 'TX', 'Chicago', 'IL', 3)\n",
"]"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T16:55:01.232158Z",
- "start_time": "2024-03-06T16:55:01.227192Z"
- }
- },
- "id": "fed68df1e81535c5",
- "execution_count": 5
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Friendships, Car ownerships, and Pet ownerships"
- ],
+ "id": "2f1c2b312698c7f5",
"metadata": {
"collapsed": false
},
- "id": "2f1c2b312698c7f5"
+ "source": [
+ "## Friendships, Car ownerships, and Pet ownerships"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 6,
+ "id": "3ba07b7f85501dfd",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T16:55:01.243663Z",
+ "start_time": "2024-03-06T16:55:01.232941Z"
+ },
+ "collapsed": false
+ },
"outputs": [],
"source": [
"friendships = [('customer_2', 'customer_1'), ('customer_0', 'customer_1'), ('customer_0', 'customer_2'),\n",
@@ -220,29 +266,29 @@
"pet_ownerships = [('customer_1', 'Pet_1'), ('customer_2', 'Pet_1'), ('customer_2', 'Pet_0'), ('customer_0', 'Pet_0'),\n",
" ('customer_3', 'Pet_2'), ('customer_4', 'Pet_2'), ('customer_5', 'Pet_3'), ('customer_6', 'Pet_4'),\n",
" ('customer_0', 'Pet_4')]"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T16:55:01.243663Z",
- "start_time": "2024-03-06T16:55:01.232941Z"
- }
- },
- "id": "3ba07b7f85501dfd",
- "execution_count": 6
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Creating the Nodes"
- ],
+ "id": "9a42a34f10bc90fe",
"metadata": {
"collapsed": false
},
- "id": "9a42a34f10bc90fe"
+ "source": [
+ "## Creating the Nodes"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 8,
+ "id": "1c05c386c4820dbb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T16:55:01.257039Z",
+ "start_time": "2024-03-06T16:55:01.246002Z"
+ },
+ "collapsed": false
+ },
"outputs": [],
"source": [
"g = nx.DiGraph()\n",
@@ -279,29 +325,29 @@
" }\n",
" name = \"Car_\" + str(car_id)\n",
" g.add_node(name, **attributes)\n"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T16:55:01.257039Z",
- "start_time": "2024-03-06T16:55:01.246002Z"
- }
- },
- "id": "1c05c386c4820dbb",
- "execution_count": 7
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "## Creating the Edges"
- ],
+ "id": "9617d6e1f43ad7ea",
"metadata": {
"collapsed": false
},
- "id": "9617d6e1f43ad7ea"
+ "source": [
+ "## Creating the Edges"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 9,
+ "id": "29f4b3b461d64dfc",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T17:18:54.458084Z",
+ "start_time": "2024-03-06T17:18:54.454189Z"
+ },
+ "collapsed": false
+ },
"outputs": [],
"source": [
"for f1, f2 in friendships:\n",
@@ -310,19 +356,19 @@
" g.add_edge(owner, car, owns_car=1, car_color_id=int(car.split('_')[1]))\n",
"for owner, pet in pet_ownerships:\n",
" g.add_edge(owner, pet, owns_pet=1)\n"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T17:18:54.458084Z",
- "start_time": "2024-03-06T17:18:54.454189Z"
- }
- },
- "id": "29f4b3b461d64dfc",
- "execution_count": 37
+ ]
},
{
"cell_type": "code",
+ "execution_count": 14,
+ "id": "b9b96d7d6734019e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T17:34:59.142836Z",
+ "start_time": "2024-03-06T17:34:59.115848Z"
+ },
+ "collapsed": false
+ },
"outputs": [],
"source": [
"pr.load_graph(g)\n",
@@ -341,26 +387,34 @@
" pr.Rule(\"same_color_car(x, y) <- owns_car(x, c1) , owns_car(y, c2), car_color_id(x,c1) == car_color_id(y,c2) , c_id(x) != c_id(y)\",\n",
" \"same_car_color_rule\"))\n",
"\n",
- "pr.add_fact(pr.Fact(name='popular-fact', component='customer_0', attribute='popular', bound=[1, 1],start_time=0,end_time=20))"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T17:34:59.142836Z",
- "start_time": "2024-03-06T17:34:59.115848Z"
- }
- },
- "id": "b9b96d7d6734019e",
- "execution_count": 57
+ "#pr.add_fact(pr.Fact(name='popular-fact', component='customer_0', attribute='popular', bound=[1, 1],start_time=0,end_time=20))\n",
+ "pr.add_fact(pr.Fact('popular-fact', 'popular(customer_0)', 0, 20))"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 18,
+ "id": "53496c463f89efa4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T17:35:01.734472Z",
+ "start_time": "2024-03-06T17:34:59.471245Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Timestep: 0\n",
+ "Optimizing rules by moving node clauses ahead of edge clauses\n",
+ "Timestep: 0\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
"Timestep: 1\n",
"Timestep: 2\n",
"Timestep: 3\n",
@@ -378,978 +432,29 @@
"source": [
"interpretation = pr.reason(timesteps=10)\n",
"# pr.save_rule_trace(interpretation)"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T17:35:01.734472Z",
- "start_time": "2024-03-06T17:34:59.471245Z"
- }
- },
- "id": "53496c463f89efa4",
- "execution_count": 58
+ ]
},
{
"cell_type": "code",
+ "execution_count": 17,
+ "id": "a50b05a88124bddd",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T17:35:04.407063Z",
+ "start_time": "2024-03-06T17:35:04.344424Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{0: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {},\n",
- " 'customer_4': {},\n",
- " 'customer_5': {},\n",
- " 'customer_6': {},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 1: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 2: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 3: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 4: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 5: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 6: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 7: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 8: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 9: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}},\n",
- " 10: {'Car_0': {},\n",
- " 'Car_1': {},\n",
- " 'Car_10': {},\n",
- " 'Car_11': {},\n",
- " 'Car_2': {},\n",
- " 'Car_3': {},\n",
- " 'Car_4': {},\n",
- " 'Car_5': {},\n",
- " 'Car_6': {},\n",
- " 'Car_7': {},\n",
- " 'Car_8': {},\n",
- " 'Car_9': {},\n",
- " 'Pet_0': {},\n",
- " 'Pet_1': {},\n",
- " 'Pet_2': {},\n",
- " 'Pet_3': {},\n",
- " 'Pet_4': {},\n",
- " 'customer_0': {'popular': (1.0, 1.0)},\n",
- " 'customer_1': {},\n",
- " 'customer_2': {},\n",
- " 'customer_3': {'popular': (1.0, 1.0), 'cool_pet': (1.0, 1.0)},\n",
- " 'customer_4': {'cool_car': (1.0, 1.0),\n",
- " 'cool_pet': (1.0, 1.0),\n",
- " 'popular': (1.0, 1.0),\n",
- " 'trendy': (1.0, 1.0)},\n",
- " 'customer_5': {'popular': (1.0, 1.0)},\n",
- " 'customer_6': {'popular': (1.0, 1.0), 'cool_car': (1.0, 1.0)},\n",
- " ('customer_0', 'Car_2'): {},\n",
- " ('customer_0', 'Car_7'): {},\n",
- " ('customer_0', 'Pet_0'): {},\n",
- " ('customer_0', 'Pet_4'): {},\n",
- " ('customer_0', 'customer_1'): {},\n",
- " ('customer_0', 'customer_2'): {},\n",
- " ('customer_1', 'Car_0'): {},\n",
- " ('customer_1', 'Car_8'): {},\n",
- " ('customer_1', 'Pet_1'): {},\n",
- " ('customer_2', 'Car_1'): {},\n",
- " ('customer_2', 'Car_11'): {},\n",
- " ('customer_2', 'Car_3'): {},\n",
- " ('customer_2', 'Pet_0'): {},\n",
- " ('customer_2', 'Pet_1'): {},\n",
- " ('customer_2', 'customer_1'): {},\n",
- " ('customer_2', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_3'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_2', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'Car_0'): {},\n",
- " ('customer_3', 'Car_10'): {},\n",
- " ('customer_3', 'Car_3'): {},\n",
- " ('customer_3', 'Pet_2'): {},\n",
- " ('customer_3', 'customer_1'): {},\n",
- " ('customer_3', 'customer_2'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_3', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'Car_4'): {},\n",
- " ('customer_4', 'Car_9'): {},\n",
- " ('customer_4', 'Pet_2'): {},\n",
- " ('customer_4', 'customer_0'): {},\n",
- " ('customer_4', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_4', 'customer_6'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_5', 'Car_2'): {},\n",
- " ('customer_5', 'Car_5'): {},\n",
- " ('customer_5', 'Pet_3'): {},\n",
- " ('customer_5', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_4'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_5', 'customer_6'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'Car_4'): {},\n",
- " ('customer_6', 'Car_6'): {},\n",
- " ('customer_6', 'Pet_4'): {},\n",
- " ('customer_6', 'customer_0'): {},\n",
- " ('customer_6', 'customer_2'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_3'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_4'): {'car_friend': (1.0, 1.0),\n",
- " 'same_color_car': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_5'): {'car_friend': (1.0, 1.0)},\n",
- " ('customer_6', 'customer_6'): {'car_friend': (1.0, 1.0)}}}\n"
+ "ename": "AttributeError",
+ "evalue": "'Interpretation' object has no attribute 'get_interpretation_dict'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m/Users/ashbysteward-nolan/Documents/PyReason/pyreason/docs/advanced_graph.ipynb Cell 18\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m interpretations_dict \u001b[39m=\u001b[39m interpretation\u001b[39m.\u001b[39;49mget_interpretation_dict()\n\u001b[1;32m 2\u001b[0m pprint(interpretations_dict)\n\u001b[1;32m 4\u001b[0m df1 \u001b[39m=\u001b[39m pr\u001b[39m.\u001b[39mfilter_and_sort_nodes(interpretation, [\u001b[39m'\u001b[39m\u001b[39mtrendy\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mcool_car\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mcool_pet\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39mpopular\u001b[39m\u001b[39m'\u001b[39m])\n",
+ "\u001b[0;31mAttributeError\u001b[0m: 'Interpretation' object has no attribute 'get_interpretation_dict'"
]
}
],
@@ -1359,23 +464,85 @@
"\n",
"df1 = pr.filter_and_sort_nodes(interpretation, ['trendy', 'cool_car', 'cool_pet', 'popular'])\n",
"df2 = pr.filter_and_sort_edges(interpretation, ['car_friend', 'same_color_car'])\n"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T17:35:04.407063Z",
- "start_time": "2024-03-06T17:35:04.344424Z"
- }
- },
- "id": "a50b05a88124bddd",
- "execution_count": 59
+ ]
},
{
"cell_type": "code",
+ "execution_count": 60,
+ "id": "c694d8d383419288",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T17:35:42.163566Z",
+ "start_time": "2024-03-06T17:35:42.124941Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"data": {
- "text/plain": "[ component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [1.0, 1.0] [0, 1],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 2 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0]\n 3 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n component trendy cool_car cool_pet popular\n 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0]]"
+ "text/plain": [
+ "[ component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [1.0, 1.0] [0, 1],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0]\n",
+ " 3 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0],\n",
+ " component trendy cool_car cool_pet popular\n",
+ " 0 customer_0 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 1 customer_3 [0, 1] [0, 1] [1.0, 1.0] [1.0, 1.0]\n",
+ " 2 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 customer_5 [0, 1] [0, 1] [0, 1] [1.0, 1.0]\n",
+ " 4 customer_6 [0, 1] [1.0, 1.0] [0, 1] [1.0, 1.0]]"
+ ]
},
"execution_count": 60,
"metadata": {},
@@ -1384,23 +551,310 @@
],
"source": [
"df1"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T17:35:42.163566Z",
- "start_time": "2024-03-06T17:35:42.124941Z"
- }
- },
- "id": "c694d8d383419288",
- "execution_count": 60
+ ]
},
{
"cell_type": "code",
+ "execution_count": 56,
+ "id": "8a5ec736f5b6e354",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-03-06T17:22:44.809373Z",
+ "start_time": "2024-03-06T17:22:44.738073Z"
+ },
+ "collapsed": false
+ },
"outputs": [
{
"data": {
- "text/plain": "[ component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n component car_friend same_color_car\n 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n 24 (customer_6, customer_5) [1.0, 1.0] [0, 1]]"
+ "text/plain": [
+ "[ component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1],\n",
+ " component car_friend same_color_car\n",
+ " 0 (customer_2, customer_3) [1.0, 1.0] [1.0, 1.0]\n",
+ " 1 (customer_2, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 2 (customer_3, customer_2) [1.0, 1.0] [1.0, 1.0]\n",
+ " 3 (customer_3, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 4 (customer_4, customer_6) [1.0, 1.0] [1.0, 1.0]\n",
+ " 5 (customer_4, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 6 (customer_6, customer_4) [1.0, 1.0] [1.0, 1.0]\n",
+ " 7 (customer_6, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 8 (customer_2, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 9 (customer_2, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 10 (customer_2, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 11 (customer_3, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 12 (customer_3, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 13 (customer_3, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 14 (customer_4, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 15 (customer_4, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 16 (customer_4, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 17 (customer_5, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 18 (customer_5, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 19 (customer_5, customer_4) [1.0, 1.0] [0, 1]\n",
+ " 20 (customer_5, customer_5) [1.0, 1.0] [0, 1]\n",
+ " 21 (customer_5, customer_6) [1.0, 1.0] [0, 1]\n",
+ " 22 (customer_6, customer_2) [1.0, 1.0] [0, 1]\n",
+ " 23 (customer_6, customer_3) [1.0, 1.0] [0, 1]\n",
+ " 24 (customer_6, customer_5) [1.0, 1.0] [0, 1]]"
+ ]
},
"execution_count": 56,
"metadata": {},
@@ -1409,30 +863,21 @@
],
"source": [
"df2"
- ],
- "metadata": {
- "collapsed": false,
- "ExecuteTime": {
- "end_time": "2024-03-06T17:22:44.809373Z",
- "start_time": "2024-03-06T17:22:44.738073Z"
- }
- },
- "id": "8a5ec736f5b6e354",
- "execution_count": 56
+ ]
},
{
"cell_type": "code",
- "outputs": [],
- "source": [],
+ "execution_count": 13,
+ "id": "76fc9b15eb441ebe",
"metadata": {
- "collapsed": false,
"ExecuteTime": {
"end_time": "2024-03-06T16:55:09.117129Z",
"start_time": "2024-03-06T16:55:09.115313Z"
- }
+ },
+ "collapsed": false
},
- "id": "76fc9b15eb441ebe",
- "execution_count": 13
+ "outputs": [],
+ "source": []
}
],
"metadata": {
@@ -1451,7 +896,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
- "version": "2.7.6"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/docs/group-chat-example.md b/docs/group-chat-example.md
index c2aaa0ce..2e29e578 100755
--- a/docs/group-chat-example.md
+++ b/docs/group-chat-example.md
@@ -3,7 +3,7 @@
Here is an example that utilizes custom thresholds.
The following graph represents a network of People and a Text Message in their group chat.
-
+
In this case, we want to know when a text message has been viewed by all members of the group chat.
@@ -14,7 +14,7 @@ First, lets create the group chat.
import networkx as nx
# Create an empty graph
-G = nx.Graph()
+G = nx.DiGraph()
# Add nodes
nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
@@ -35,7 +35,7 @@ G.add_edges_from(edges)
Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows:
```text
-ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)
+ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)
```
The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from
@@ -79,10 +79,10 @@ We add the facts in PyReason as below:
```python
import pyreason as pr
-pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True))
-pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True))
-pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True))
-pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True))
+pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
+pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
+pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
+pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
```
This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds
diff --git a/docs/hello-world.md b/docs/hello-world.md
index f2bf1311..cc4a75cf 100755
--- a/docs/hello-world.md
+++ b/docs/hello-world.md
@@ -88,7 +88,7 @@ We add a fact in PyReason like so:
```python
import pyreason as pr
-pr.add_fact(pr.Fact(name='popular-fact', component='Mary', attribute='popular', bound=[1, 1], start_time=0, end_time=2))
+pr.add_fact(pr.Fact(fact_text='popular(Mary) : true', name='popular_fact', start_time=0, end_time=2))
```
This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds
diff --git a/docs/rules.txt b/docs/rules.txt
index 061a0ece..cb6cbaa5 100644
--- a/docs/rules.txt
+++ b/docs/rules.txt
@@ -1 +1 @@
-popular(x) <-1 popular(y), Friends(x,y), owns_pet(y,z), owns_pet(x,z)
\ No newline at end of file
+popular(x) <-1 popular(y), Friends(x,y), owns_pet(y,z), owns_pet(x,z)
diff --git a/docs/source/_static/pyreason_logo.jpg b/docs/source/_static/pyreason_logo.jpg
new file mode 100755
index 00000000..233618a4
Binary files /dev/null and b/docs/source/_static/pyreason_logo.jpg differ
diff --git a/docs/source/tutorials/Rule_image.png b/docs/source/_static/rule_image.png
similarity index 100%
rename from docs/source/tutorials/Rule_image.png
rename to docs/source/_static/rule_image.png
diff --git a/docs/source/about.rst b/docs/source/about.rst
new file mode 100644
index 00000000..37c21658
--- /dev/null
+++ b/docs/source/about.rst
@@ -0,0 +1,36 @@
+About PyReason
+==============
+
+**PyReason** is a modern Python-based software framework designed for open-world temporal logic reasoning using generalized annotated logic. It addresses the growing needs of neuro-symbolic reasoning frameworks that incorporate differentiable logics and temporal extensions, allowing inference over finite periods with open-world capabilities. PyReason is particularly suited for reasoning over graphical structures such as knowledge graphs, social networks, and biological networks, offering fully explainable inference processes.
+
+Key Capabilities
+--------------
+
+1. **Graph-Based Reasoning**: PyReason supports direct reasoning over knowledge graphs, a popular representation of symbolic data. Unlike black-box frameworks, PyReason provides full explainability of the reasoning process.
+
+2. **Annotated Logic**: It extends classical logic with annotations, supporting various types of logic including fuzzy logic, real-valued intervals, and temporal logic. PyReason's framework goes beyond traditional logic systems like Prolog, allowing for arbitrary functions over reals, enhancing its capability to handle constructs in neuro-symbolic reasoning.
+
+3. **Temporal Reasoning**: PyReason includes temporal extensions to handle reasoning over sequences of time points. This feature enables the creation of rules that incorporate temporal dependencies, such as "if condition A, then condition B after a certain number of time steps."
+
+4. **Open World Reasoning**: Unlike closed-world assumptions where anything not explicitly stated is false, PyReason considers unknowns as a valid state, making it more flexible and suitable for real-world applications where information may be incomplete.
+
+5. **Handling Logical Inconsistencies**: PyReason can detect and resolve inconsistencies in the reasoning process. When inconsistencies are found, it can reset affected interpretations to a state of complete uncertainty, ensuring that the reasoning process remains robust.
+
+6. **Scalability and Performance**: PyReason is optimized for scalability, supporting exact deductive inference with memory-efficient implementations. It leverages sparsity in graphical structures and employs predicate-constant type checking to reduce computational complexity.
+
+7. **Explainability**: All inference results produced by PyReason are fully explainable, as the software maintains a trace of the inference steps that led to each conclusion. This feature is critical for applications where transparency of the reasoning process is necessary.
+
+8. **Integration and Extensibility**: PyReason is implemented in Python and supports integration with other tools and frameworks, making it easy to extend and adapt for specific needs. It can work with popular graph formats like GraphML and is compatible with tools like NetworkX and Neo4j.
+
+Use Cases
+--------------
+
+- **Knowledge Graph Reasoning**: PyReason can be used to perform logical inferences over knowledge graphs, aiding in tasks like knowledge completion, entity classification, and relationship extraction.
+
+- **Temporal Logic Applications**: Its temporal reasoning capabilities are useful in domains requiring time-based analysis, such as monitoring system states over time, or reasoning about events and their sequences.
+
+- **Social and Biological Network Analysis**: PyReason's support for annotated logic and reasoning over complex network structures makes it suitable for applications in social network analysis, supply chain management, and biological systems modeling.
+
+PyReason is open-source and available at: `Github - PyReason `_
+
+For more detailed information on PyReason’s logical framework, implementation details, and experimental results, refer to the full documentation or visit the project's GitHub repository.
diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst
new file mode 100644
index 00000000..3c1b4a08
--- /dev/null
+++ b/docs/source/api_reference/index.rst
@@ -0,0 +1,8 @@
+API Documentation
+=================
+
+
+.. automodule:: pyreason
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 75626cdb..dc1279d2 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -5,7 +5,13 @@
import os
import sys
-sys.path.insert(0, os.path.abspath('../..'))
+#sys.path.insert(0, os.path.abspath('../..'))
+#sys.path.insert(0, os.path.abspath('pyreason/pyreason.py'))
+# Calculate the absolute path to the pyreason directory
+project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'pyreason'))
+# Add the pyreason directory to sys.path
+sys.path.insert(0, project_root)
+
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
@@ -20,27 +26,34 @@
extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.autosummary', 'sphinx.ext.doctest',
'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.mathjax', 'sphinx.ext.ifconfig',
- 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', 'autoapi.extension']
+ 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', 'autoapi.extension',] # Just this line
autosummary_generate = True
-autoapi_dirs = ['../../pyreason']
-autoapi_template_dir = '_templates/autoapi'
+#autoapi_template_dir = '_templates/autoapi'
+# Ensure autoapi_dirs points to the folder containing pyreason.py
+#autoapi_dirs = [project_root]
+autoapi_dirs = [os.path.join(project_root)] # Only include the pyreason directory
-autoapi_options = [
- "members",
- "undoc-members",
- "show-inheritance",
- "show-module-summary",
- "imported-members",
-]
+#autoapi_dirs = [os.path.join(project_root)] # Include only 'pyreason.pyreason'
+#autoapi_dirs = ['../pyreason/pyreason']
+
+autoapi_root = 'pyreason'
+autoapi_ignore = ['*/scripts/*', '*/examples/*', '*/pyreason.pyreason/*']
+# Ignore modules in the 'scripts' folder
+# autoapi_ignore_modules = ['pyreason.scripts']
+autoapi_options = [
+ "members", # Include all class members (functions)
+ "undoc-members", # Include undocumented members
+ "show-inheritance", # Show inheritance tree for methods/functions
+ # "private-members", # Include private members (e.g., _method)
+]
+
templates_path = ['_templates']
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'pyreason.examples.rst',
- 'pyreason.scripts.numba_wrapper.numba_types.rst',
- 'pyreason.scripts.numba_wrapper.rst', 'pyreason.scripts.program.rst',
- 'pyreason.scripts.interpretation.rst']
+
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
diff --git a/docs/source/examples_rst/advanced_example.rst b/docs/source/examples_rst/advanced_example.rst
new file mode 100644
index 00000000..b7762bad
--- /dev/null
+++ b/docs/source/examples_rst/advanced_example.rst
@@ -0,0 +1,167 @@
+Advanced Example
+============================
+
+
+.. code:: python
+
+ from pprint import pprint
+ import networkx as nx
+ import pyreason as pr
+
+ # Customer Data
+ customers = ['John', 'Mary', 'Justin', 'Alice', 'Bob', 'Eva', 'Mike']
+ customer_details = [
+ ('John', 'M', 'New York', 'NY'),
+ ('Mary', 'F', 'Los Angeles', 'CA'),
+ ('Justin', 'M', 'Chicago', 'IL'),
+ ('Alice', 'F', 'Houston', 'TX'),
+ ('Bob', 'M', 'Phoenix', 'AZ'),
+ ('Eva', 'F', 'San Diego', 'CA'),
+ ('Mike', 'M', 'Dallas', 'TX')
+ ]
+
+ # Creating a dictionary of customers with their details
+ customer_dict = {i: customer for i, customer in enumerate(customer_details)}
+
+ # Pet Data
+ pet_details = [
+ ('Dog', 'Mammal'),
+ ('Cat', 'Mammal'),
+ ('Rabbit', 'Mammal'),
+ ('Parrot', 'Bird'),
+ ('Fish', 'Fish')
+ ]
+
+ # Creating a dictionary of pets with their details
+ pet_dict = {i: pet for i, pet in enumerate(pet_details)}
+
+ # Car Data
+ car_details = [
+ ('Toyota Camry', 'Red'),
+ ('Honda Civic', 'Blue'),
+ ('Ford Focus', 'Red'),
+ ('BMW 3 Series', 'Black'),
+ ('Tesla Model S', 'Red'),
+ ('Chevrolet Bolt EV', 'White'),
+ ('Ford Mustang', 'Yellow'),
+ ('Audi A4', 'Silver'),
+ ('Mercedes-Benz C-Class', 'Grey'),
+ ('Subaru Outback', 'Green'),
+ ('Volkswagen Golf', 'Blue'),
+ ('Porsche 911', 'Black')
+ ]
+
+ # Creating a dictionary of cars with their details
+ car_dict = {i: car for i, car in enumerate(car_details)}
+
+ # Travel Data (customer movements between cities)
+ travels = [
+ ('John', 'Los Angeles', 'CA', 'New York', 'NY', 2),
+ ('Alice', 'Houston', 'TX', 'Phoenix', 'AZ', 5),
+ ('Eva', 'San Diego', 'CA', 'Dallas', 'TX', 1),
+ ('Mike', 'Dallas', 'TX', 'Chicago', 'IL', 3)
+ ]
+
+ # Friendships (who is friends with whom)
+ friendships = [('customer_2', 'customer_1'), ('customer_0', 'customer_1'), ('customer_0', 'customer_2'),
+ ('customer_3', 'customer_4'), ('customer_4', 'customer_0'), ('customer_5', 'customer_3'),
+ ('customer_6', 'customer_0'), ('customer_5', 'customer_6'), ('customer_4', 'customer_5'),
+ ('customer_3', 'customer_1')]
+
+ # Car Ownerships (who owns which car)
+ car_ownerships = [('customer_1', 'Car_0'), ('customer_2', 'Car_1'), ('customer_0', 'Car_2'), ('customer_3', 'Car_3'),
+ ('customer_4', 'Car_4'), ('customer_3', 'Car_0'), ('customer_2', 'Car_3'), ('customer_5', 'Car_5'),
+ ('customer_6', 'Car_6'), ('customer_0', 'Car_7'), ('customer_1', 'Car_8'), ('customer_4', 'Car_9'),
+ ('customer_3', 'Car_10'), ('customer_2', 'Car_11'), ('customer_5', 'Car_2'), ('customer_6', 'Car_4')]
+
+ # Pet Ownerships (who owns which pet)
+ pet_ownerships = [('customer_1', 'Pet_1'), ('customer_2', 'Pet_1'), ('customer_2', 'Pet_0'), ('customer_0', 'Pet_0'),
+ ('customer_3', 'Pet_2'), ('customer_4', 'Pet_2'), ('customer_5', 'Pet_3'), ('customer_6', 'Pet_4'),
+ ('customer_0', 'Pet_4')]
+
+ # Create a directed graph
+ g = nx.DiGraph()
+
+ # Add nodes for customers
+ for customer_id, details in customer_dict.items():
+ attributes = {
+ 'c_id': customer_id,
+ 'name': details[0],
+ 'gender': details[1],
+ 'city': details[2],
+ 'state': details[3],
+ }
+ name = "customer_" + str(customer_id)
+ g.add_node(name, **attributes)
+
+ # Add nodes for pets
+ for pet_id, details in pet_dict.items():
+ dynamic_attribute = f"Pet_{pet_id}"
+ attributes = {
+ 'pet_id': pet_id,
+ 'species': details[0],
+ 'class': details[1],
+ dynamic_attribute: 1
+ }
+ name = "Pet_" + str(pet_id)
+ g.add_node(name, **attributes)
+
+ # Add nodes for cars
+ for car_id, details in car_dict.items():
+ dynamic_attribute = f"Car_{car_id}"
+ attributes = {
+ 'car_id': car_id,
+ 'model': details[0],
+ 'color': details[1],
+ dynamic_attribute: 1
+ }
+ name = "Car_" + str(car_id)
+ g.add_node(name, **attributes)
+
+ # Add edges for relationships
+ for f1, f2 in friendships:
+ g.add_edge(f1, f2, Friends=1)
+ for owner, car in car_ownerships:
+ g.add_edge(owner, car, owns_car=1, car_color_id=int(car.split('_')[1]))
+ for owner, pet in pet_ownerships:
+ g.add_edge(owner, pet, owns_pet=1)
+
+ # Load the graph into PyReason
+ pr.load_graph(g)
+
+ # Set PyReason settings
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+
+ # Define logical rules
+ pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y)', 'popular_pet_rule'))
+ pr.add_rule(pr.Rule('cool_car(x) <-1 owns_car(x,y),Car_4(y)', 'cool_car_rule'))
+ pr.add_rule(pr.Rule('cool_pet(x)<-1 owns_pet(x,y),Pet_2(y)', 'cool_pet_rule'))
+ pr.add_rule(pr.Rule('trendy(x) <- cool_car(x) , cool_pet(x)', 'trendy_rule'))
+
+ pr.add_rule(
+ pr.Rule("car_friend(x,y) <- owns_car(x,z), owns_car(y,z)", "car_friend_rule"))
+ pr.add_rule(
+ pr.Rule("same_color_car(x, y) <- owns_car(x, c1) , owns_car(y, c2)","same_car_color_rule"))
+
+
+ # Add a fact about 'customer_0' being popular
+ pr.add_fact(pr.Fact('popular-fact', 'popular(customer_0)', 0, 5))
+
+ # Perform reasoning over 10 timesteps
+ interpretation = pr.reason(timesteps=5)
+
+ # Get the interpretation and display it
+ interpretations_dict = interpretation.get_dict()
+ pprint(interpretations_dict)
+
+ # Filter and sort nodes based on specific attributes
+ df1 = pr.filter_and_sort_nodes(interpretation, ['trendy', 'cool_car', 'cool_pet', 'popular'])
+
+ # Filter and sort edges based on specific relationships
+ df2 = pr.filter_and_sort_edges(interpretation, ['car_friend', 'same_color_car'])
+
+ # Display filtered node and edge data
+ print(df1)
+ print(df2)
+
diff --git a/docs/source/examples_rst/advanced_output_example.rst b/docs/source/examples_rst/advanced_output_example.rst
new file mode 100644
index 00000000..9c4c416c
--- /dev/null
+++ b/docs/source/examples_rst/advanced_output_example.rst
@@ -0,0 +1,698 @@
+Advanced Example Full Output
+============================
+
+.. code:: text
+
+ Interpretations:
+ {0: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {},
+ 'customer_4': {},
+ 'customer_5': {},
+ 'customer_6': {},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 1: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 2: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 3: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 4: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 5: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 6: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 7: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 8: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 9: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 10: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}}}
+
\ No newline at end of file
diff --git a/docs/source/examples_rst/annF_average_example.rst b/docs/source/examples_rst/annF_average_example.rst
new file mode 100644
index 00000000..c27541e0
--- /dev/null
+++ b/docs/source/examples_rst/annF_average_example.rst
@@ -0,0 +1,61 @@
+
+Average Annotation Function Example
+=====================================
+
+.. code:: python
+
+ # Test if annotation functions work
+ import pyreason as pr
+ import numba
+ import numpy as np
+ import networkx as nx
+
+
+
+
+ @numba.njit
+ def avg_ann_fn(annotations, weights):
+ # annotations contains the bounds of the atoms that were used to ground the rule. It is a nested list that contains a list for each clause
+ # You can access for example the first grounded atom's bound by doing: annotations[0][0].lower or annotations[0][0].upper
+
+ # We want the normalised sum of the bounds of the grounded atoms
+ sum_upper_bounds = 0
+ sum_lower_bounds = 0
+ num_atoms = 0
+ for clause in annotations:
+ for atom in clause:
+ sum_lower_bounds += atom.lower
+ sum_upper_bounds += atom.upper
+ num_atoms += 1
+
+ a = sum_lower_bounds / num_atoms
+ b = sum_upper_bounds / num_atoms
+ return a, b
+
+
+
+ #Annotation function that returns average of both upper and lower bounds
+ def average_annotation_function():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
+ pr.settings.allow_ground_rules = True
+
+ pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
+ pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))
+ pr.add_annotation_function(avg_ann_fn)
+ pr.add_rule(pr.Rule('average_function(A, B):avg_ann_fn <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
+
+ interpretation = pr.reason(timesteps=1)
+
+ dataframes = pr.filter_and_sort_edges(interpretation, ['average_function'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ assert interpretation.query('average_function(A, B) : [0.105, 1]'), 'Average function should be [0.105, 1]'
+
+ average_annotation_function()
+
diff --git a/docs/source/examples_rst/annF_linear_combination_example.rst b/docs/source/examples_rst/annF_linear_combination_example.rst
new file mode 100644
index 00000000..18a897c5
--- /dev/null
+++ b/docs/source/examples_rst/annF_linear_combination_example.rst
@@ -0,0 +1,88 @@
+Linear Combination Annotation Function Example
+===================================================
+
+.. code:: python
+
+ # Test if annotation functions work
+ import pyreason as pr
+ import numba
+ import numpy as np
+ import networkx as nx
+
+
+
+ @numba.njit
+ def map_to_unit_interval(value, lower, upper):
+ """
+ Map a value from the interval [lower, upper] to the interval [0, 1].
+ The formula is f(t) = c + ((d - c) / (b - a)) * (t - a),
+ where a = lower, b = upper, c = 0, and d = 1.
+ """
+ if upper == lower:
+ return 0 # Avoid division by zero if upper == lower
+ return (value - lower) / (upper - lower)
+
+
+ @numba.njit
+ def lin_comb_ann_fn(annotations, weights):
+ sum_lower_comb = 0
+ sum_upper_comb = 0
+ num_atoms = 0
+ constant = 0.2
+
+ # Iterate over the clauses in the rule
+ for clause in annotations:
+ for atom in clause:
+ # Map the atom's lower and upper bounds to the interval [0, 1]
+ mapped_lower = map_to_unit_interval(atom.lower, 0, 1)
+ mapped_upper = map_to_unit_interval(atom.upper, 0, 1)
+
+ # Apply the weights to the lower and upper bounds, and accumulate
+ sum_lower_comb += constant * mapped_lower
+ sum_upper_comb += constant * mapped_upper
+ num_atoms += 1
+
+ # Return the weighted linear combination of the lower and upper bounds
+ return sum_lower_comb, sum_upper_comb
+
+
+
+ # Function to run the test
+ def linear_combination_annotation_function():
+
+ # Reset PyReason before starting the test
+ pr.reset()
+ pr.reset_rules()
+
+ pr.settings.allow_ground_rules = True
+
+
+ # Add facts (P(A) and P(B) with bounds)
+ pr.add_fact(pr.Fact('P(A) : [.3, 1]'))
+ pr.add_fact(pr.Fact('P(B) : [.2, 1]'))
+
+
+ # Register the custom annotation function with PyReason
+ pr.add_annotation_function(lin_comb_ann_fn)
+
+ # Define a rule that uses this linear combination function
+ pr.add_rule(pr.Rule('linear_combination_function(A, B):lin_comb_ann_fn <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
+
+ # Perform reasoning for 1 timestep
+ interpretation = pr.reason(timesteps=1)
+
+ # Filter the results for the computed 'linear_combination_function' edges
+ dataframes = pr.filter_and_sort_edges(interpretation, ['linear_combination_function'])
+
+ # Print the resulting dataframes for each timestep
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ # Assert that the linear combination function gives the expected result (adjusted for weights)
+ # Example assertion based on weights and bounds; adjust the expected result based on the weights
+ assert interpretation.query('linear_combination_function(A, B) : [0.1, 0.4]'), 'Linear combination function should be [0.105, 1]'
+
+ # Run the test function
+ linear_combination_annotation_function()
\ No newline at end of file
diff --git a/docs/source/examples_rst/basic_example.rst b/docs/source/examples_rst/basic_example.rst
new file mode 100644
index 00000000..8f416040
--- /dev/null
+++ b/docs/source/examples_rst/basic_example.rst
@@ -0,0 +1,75 @@
+Basic Example
+============================
+
+
+.. code:: python
+
+ # Test if the simple hello world program works
+ import pyreason as pr
+ import faulthandler
+ import networkx as nx
+
+
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+ pr.reset_settings()
+
+
+ # ================================ CREATE GRAPH====================================
+ # Create a Directed graph
+ g = nx.DiGraph()
+
+ # Add the nodes
+ g.add_nodes_from(['John', 'Mary', 'Justin'])
+ g.add_nodes_from(['Dog', 'Cat'])
+
+ # Add the edges and their attributes. When an attribute = x which is <= 1, the annotation
+ # associated with it will be [x,1]. NOTE: These attributes are immutable
+ # Friend edges
+ g.add_edge('Justin', 'Mary', Friends=1)
+ g.add_edge('John', 'Mary', Friends=1)
+ g.add_edge('John', 'Justin', Friends=1)
+
+ # Pet edges
+ g.add_edge('Mary', 'Cat', owns=1)
+ g.add_edge('Justin', 'Cat', owns=1)
+ g.add_edge('Justin', 'Dog', owns=1)
+ g.add_edge('John', 'Dog', owns=1)
+
+
+ # Modify pyreason settings to make verbose
+ pr.settings.verbose = True # Print info to screen
+ # pr.settings.optimize_rules = False # Disable rule optimization for debugging
+
+ # Load all the files into pyreason
+ pr.load_graph(g)
+ pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
+ pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
+
+ # Run the program for two timesteps to see the diffusion take place
+ faulthandler.enable()
+ interpretation = pr.reason(timesteps=2)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person'
+ assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people'
+ assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people'
+
+ # Mary should be popular in all three timesteps
+ assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps'
+ assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps'
+ assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps'
+
+ # Justin should be popular in timesteps 1, 2
+ assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps'
+ assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps'
+
+ # John should be popular in timestep 3
+ assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps'
diff --git a/docs/source/examples_rst/custom_threshold_example.rst b/docs/source/examples_rst/custom_threshold_example.rst
new file mode 100644
index 00000000..0deba0df
--- /dev/null
+++ b/docs/source/examples_rst/custom_threshold_example.rst
@@ -0,0 +1,82 @@
+Custom Threshold Example
+============================
+
+
+.. code:: python
+
+ # Test if the simple program works with thresholds defined
+ import pyreason as pr
+ from pyreason import Threshold
+ import networkx as nx
+
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
+
+ # Create an empty graph
+ G = nx.DiGraph()
+
+ # Add nodes
+ nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
+ G.add_nodes_from(nodes)
+
+ # Add edges with attribute 'HaveAccess'
+ G.add_edge("Zach", "TextMessage", HaveAccess=1)
+ G.add_edge("Justin", "TextMessage", HaveAccess=1)
+ G.add_edge("Michelle", "TextMessage", HaveAccess=1)
+ G.add_edge("Amy", "TextMessage", HaveAccess=1)
+
+
+
+ # Modify pyreason settings to make verbose
+ pr.reset_settings()
+ pr.settings.verbose = True # Print info to screen
+
+ #load the graph
+ pr.load_graph(G)
+
+ # add custom thresholds
+ user_defined_thresholds = [
+ Threshold("greater_equal", ("number", "total"), 1),
+ Threshold("greater_equal", ("percent", "total"), 100),
+
+ ]
+
+ pr.add_rule(
+ pr.Rule(
+ "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)",
+ "viewed_by_all_rule",
+ custom_thresholds=user_defined_thresholds,
+ )
+ )
+
+ pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
+ pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
+ pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
+ pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
+
+ # Run the program for three timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=3)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
+ for t, df in enumerate(dataframes):
+ print(f"TIMESTEP - {t}")
+ print(df)
+ print()
+
+ assert (
+ len(dataframes[0]) == 0
+ ), "At t=0 the TextMessage should not have been ViewedByAll"
+ assert (
+ len(dataframes[2]) == 1
+ ), "At t=2 the TextMessage should have been ViewedByAll"
+
+ # TextMessage should be ViewedByAll in t=2
+ assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[
+ 0
+ ].ViewedByAll == [
+ 1,
+ 1,
+ ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps"
diff --git a/docs/source/examples_rst/index.rst b/docs/source/examples_rst/index.rst
new file mode 100644
index 00000000..c5927584
--- /dev/null
+++ b/docs/source/examples_rst/index.rst
@@ -0,0 +1,16 @@
+Examples
+==========
+
+In this section we outline a series of tutorials that will help you get started with the basics of using the `pyreason` library.
+
+Examples
+--------
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Examples:
+ :glob:
+
+ ./*
+
+
\ No newline at end of file
diff --git a/docs/source/examples_rst/infer_edges_example.rst b/docs/source/examples_rst/infer_edges_example.rst
new file mode 100644
index 00000000..b295a589
--- /dev/null
+++ b/docs/source/examples_rst/infer_edges_example.rst
@@ -0,0 +1,86 @@
+Infer Edges Example
+============================
+
+
+.. code:: python
+
+ import pyreason as pr
+
+
+ import networkx as nx
+ import matplotlib.pyplot as plt
+
+ # Create a directed graph
+ G = nx.DiGraph()
+
+ # Add nodes with attributes
+ nodes = [
+ ("Amsterdam_Airport_Schiphol", {"Amsterdam_Airport_Schiphol": 1}),
+ ("Riga_International_Airport", {"Riga_International_Airport": 1}),
+ ("Chișinău_International_Airport", {"Chișinău_International_Airport": 1}),
+ ("Yali", {"Yali": 1}),
+ ("Düsseldorf_Airport", {"Düsseldorf_Airport": 1}),
+ ("Pobedilovo_Airport", {"Pobedilovo_Airport": 1}),
+ ("Dubrovnik_Airport", {"Dubrovnik_Airport": 1}),
+ ("HévÃz-Balaton_Airport", {"HévÃz-Balaton_Airport": 1}),
+ ("Athens_International_Airport", {"Athens_International_Airport": 1}),
+ ("Vnukovo_International_Airport", {"Vnukovo_International_Airport": 1})
+ ]
+
+ G.add_nodes_from(nodes)
+
+ # Add edges with 'isConnectedTo' attribute
+ edges = [
+ ("Pobedilovo_Airport", "Vnukovo_International_Airport", {"isConnectedTo": 1}),
+ ("Vnukovo_International_Airport", "HévÃz-Balaton_Airport", {"isConnectedTo": 1}),
+ ("Düsseldorf_Airport", "Dubrovnik_Airport", {"isConnectedTo": 1}),
+ ("Dubrovnik_Airport", "Athens_International_Airport", {"isConnectedTo": 1}),
+ ("Riga_International_Airport", "Amsterdam_Airport_Schiphol", {"isConnectedTo": 1}),
+ ("Riga_International_Airport", "Düsseldorf_Airport", {"isConnectedTo": 1}),
+ ("Chișinău_International_Airport", "Riga_International_Airport", {"isConnectedTo": 1}),
+ ("Amsterdam_Airport_Schiphol", "Yali", {"isConnectedTo": 1})
+ ]
+
+ G.add_edges_from(edges)
+
+
+
+ # Print a drawing of the directed graph
+ # nx.draw(G, with_labels=True, node_color='lightblue', font_weight='bold', node_size=3000)
+ # plt.show()
+
+
+
+
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.memory_profile = False
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+ # Load all the files into pyreason
+ pr.load_graph(G)
+ pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+ # pr.save_rule_trace(interpretation)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+ assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations'
+ assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
+ assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
+
+ nx.draw(G, with_labels=True, node_color='lightblue', font_weight='bold', node_size=3000)
+ plt.show()
\ No newline at end of file
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 4cd32c38..3ff47b17 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -3,15 +3,33 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-Welcome to Pyreason's documentation!
+Welcome to PyReason Docs!
====================================
+.. image:: _static/pyreason_logo.jpg
+ :alt: PyReason Logo
+ :align: center
+
+Introduction
+------------
+Welcome to the documentation for **PyReason**, a powerful, optimized Python tool for Reasoning over Graphs. PyReason supports a variety of Logics such as Propositional, First Order, Annotated. This documentation will guide you through the installation, usage and API.
+
.. toctree::
- :caption: Tutorials
- :maxdepth: 2
- :glob:
+ :maxdepth: 1
+ :caption: Contents:
+
+ about
+ installation
+ key_concepts
+ user_guide/index
+ tutorials/index
+ license
+
+
- ./tutorials/*
+Getting Help
+------------
+If you encounter any issues or have questions, feel free to check our Github, or contact one of the authors (`dyuman.aditya@asu.edu`, `kmukher2@asu.edu`).
Indices and tables
==================
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
new file mode 100644
index 00000000..20060dc0
--- /dev/null
+++ b/docs/source/installation.rst
@@ -0,0 +1,23 @@
+Installation
+==========
+
+PyReason is currently compatible with Python 3.9 and 3.10. To install PyReason, you can use pip:
+
+.. code:: bash
+
+ pip install pyreason
+
+
+Make sure you're using the correct version of Python. You can create a conda environment with the correct version of Python using the following command:
+
+.. code:: bash
+
+ conda create -n pyreason-env python=3.10
+
+PyReason uses a JIT compiler called `Numba `_ to speed up the reasoning process. This means that
+the first time PyReason is imported it will have to compile certain functions which will result in faster runtimes later on.
+You will see a message like this when you import PyReason for the first time:
+
+.. code:: text
+
+ Imported PyReason for the first time. Initializing caches for faster runtimes ... this will take a minute
\ No newline at end of file
diff --git a/docs/source/key_concepts.rst b/docs/source/key_concepts.rst
new file mode 100644
index 00000000..8cbde34f
--- /dev/null
+++ b/docs/source/key_concepts.rst
@@ -0,0 +1,79 @@
+Understanding Key Concepts
+==========================
+
+.. _rule:
+Rule
+~~~~
+
+A rule is a statement that establishes a relationship between
+premises and a conclusion, allowing for the derivation of the
+conclusion if the premises are true. Rules are foundational to
+logical systems, facilitating the inference process.
+
+
+.. image:: _static/rule_image.png
+ :align: center
+
+Every rule has a head and a body. The head determines what will
+change in the graph if the body is true.
+
+.. _fact:
+Fact
+~~~~
+
+A fact is a statement that is used to store information in the graph. It is a basic unit
+of knowledge that is used to derive new information. It can be thought of as an initial condition before reasoning.
+Facts are used to initialize the graph and are the starting point for reasoning.
+
+Annotated atom
+~~~~~~~~~~~~~~~~~~~~~~~~~
+An annotated atom or function in logic, refers to an atomic formula (or a simple predicate) that is augmented with additional
+information, such as a certainty factor, a probability, or other annotations that provide context or constraints.
+
+In PyReason, an annotated atom is represented as a predicate with a bound, which is a list of two values that represent the lower and upper bounds of the predicate.
+For example, a predicate ``pred(x,y) : [0.2, 1]`` means that the predicate ``pred(x,y)`` is true with a certainty between 0.2 and 1.
+
+Interpretation
+~~~~~~~~~~~~~~
+An interpretation is a mapping from the set of atoms to the set of truth values. It is a way of assigning truth values to the atoms in the graph.
+
+Fixed point operator
+~~~~~~~~~~~~~~~~~~~~
+
+In simple terms, a fixed point operator is a function that says if you have a set of atoms,
+return that set plus any atoms that can be derived by a single application of a rule in the program.
+
+
+.. _inconsistent_predicate:
+Inconsistencies
+~~~~~~~~~~~~~~~
+A logic program is consistent if there exists an interpretation that satisfies the logic program, i.e., makes all the rules true.
+If no such interpretation exists, the logic program is inconsistent.
+
+For example if we have the following two rules:
+
+.. code-block:: text
+
+ rule-1: grass_wet <- rained,
+ rule-2: ~grass_wet <- rained,
+
+This creates an inconsistency because the first rule states that the grass is wet if it rained, while the second rule states that the grass is not wet if it rained.
+In PyReason, inconsistencies are detected and resolved to ensure the reasoning process remains robust. In such a case,
+the affected interpretations are reset to a state of complete uncertainty i.e ``grass_wet : [0,1]``.
+
+Inconsistent predicate list
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+An inconsistent predicate list is a list of predicates that are inconsistent with each other.
+
+For example, consider the following example of two predicates that are inconsistent with each other:
+
+.. code-block:: text
+
+ sick and healthy
+
+In this case, the predicates "sick" and "healthy" are inconsistent with each other because they cannot both be true at the same time.
+We can model this in PyReason such that when one predicate is has a certain bound ``[l, u]``, the other predicate is given
+a bound ``[1-u, 1-l]`` automatically. See :ref:`here ` for more information.
+
+In this case, if "sick" is true with a bound ``[1, 1]``, then "healthy" is automatically set to ``[0, 0]``.
diff --git a/docs/source/license.rst b/docs/source/license.rst
new file mode 100644
index 00000000..c9569d7d
--- /dev/null
+++ b/docs/source/license.rst
@@ -0,0 +1,12 @@
+License
+==========
+
+This repository is licensed under `BSD-2-Clause `_.
+
+Trademark Permission
+--------------------
+.. |logo| image:: _static/pyreason_logo.jpg
+ :width: 50
+
+PyReasonâ„¢ and PyReason Design Logo |logo| â„¢ are trademarks of the Arizona Board of Regents/Arizona State University. Users of the software are permitted to use PyReasonâ„¢ in association with the software for any purpose, provided such use is related to the software (e.g., Powered by PyReasonâ„¢). Additionally, educational institutions are permitted to use the PyReason Design Logo |logo| â„¢ for non-commercial purposes.
+
diff --git a/docs/source/tutorials/Advanced tutorial.rst b/docs/source/tutorials/Advanced tutorial.rst
deleted file mode 100644
index e063edb6..00000000
--- a/docs/source/tutorials/Advanced tutorial.rst
+++ /dev/null
@@ -1,90 +0,0 @@
-Running Pyreason with an advanced graph
-==========================================
-
-In this tutorial, we will look at how to run PyReason with a more
-complex graph.
-
-Graph
-------------
-
-We use a larger graph for this example. In this example , we have customers , cars , pets and their relationships.
-We first have customer details followed by car details , pet details , travel details .
-
-.. literalinclude:: advanced_graph.py
- :language: python
- :lines: 16-24, 28-34 , 39-52, 58-63
-
-We now have the relationships between the customers , cars , pets and travel details.
-
-.. literalinclude:: advanced_graph.py
- :language: python
- :lines: 67-78
-
-Based on the relationships we now connect the nodes, edges and the form the graph.
-
-.. literalinclude:: advanced_graph.py
- :language: python
- :lines: 84-124
-
-We now have the graph ready. We can now add the rules for our use case. Take a look at it at
-
-.. figure:: advanced_graph.png
- :alt: image
-
- advanced graph image
-
-Rules
------
-
-The below are the rules we want to add:
-
-1. A customer is popular if he is friends with a popular customer.
-2. A customer has a cool car if he owns a car and the car is of type Car_4.
-3. A customer has a cool pet if he owns a pet and the pet is of type Pet_2.
-4. A customer is trendy if he has a cool car and a cool pet.
-
-.. code-block:: python
-
- pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y)', 'popular_pet_rule'))
- pr.add_rule(pr.Rule('cool_car(x) <-1 owns_car(x,y),Car_4(y)', 'cool_car_rule'))
- pr.add_rule(pr.Rule('cool_pet(x)<-1 owns_pet(x,y),Pet_2(y)', 'cool_pet_rule'))
- pr.add_rule(pr.Rule('trendy(x) <- cool_car(x) , cool_pet(x)', 'trendy_rule'))
-The above rules are based on nodes. Now let us add some more rules based on the edges.
-
-1. Two customers are car_friends if they own the same car.
-2. Two customers are friends if they own the same color car.
-
-.. code-block:: python
-
- pr.add_rule(pr.Rule("car_friend(x,y) <- owns_car(x,z), owns_car(y,z) , c_id(x) != c_id(y) ", "car_friend_rule"))
- pr.add_rule(pr.Rule("same_color_car(x, y) <- owns_car(x, c1) , owns_car(y, c2), car_color_id(x,c1) == car_color_id(y,c2) , c_id(x) != c_id(y)","same_car_color_rule"))
-
-Facts
--------
-
-We now add the facts to the graph.
-There is only one fact we are going to use.
-1. customer_0 is popular from time 0 to 5.
-
-.. code-block:: python
-
- pr.add_fact(pr.Fact(name='popular-fact', component='customer_0', attribute='popular', bound=[1, 1], start_time=0, end_time=5))
-
-
-Running Pyreason
-----------------
-
-We now run the PyReason with the graph and the rules.
-
-.. code-block:: python
-
- interpretation = pr.reason(timesteps=6)
- # pr.save_rule_trace(interpretation)
-
- interpretations_dict = interpretation.get_interpretation_dict()
-
- df1 = pr.filter_and_sort_nodes(interpretation, ['trendy', 'cool_car', 'cool_pet', 'popular'])
- df2 = pr.filter_and_sort_edges(interpretation, ['car_friend', 'same_color_car'])
-
-.. note::
- The complete code for this example is on github at advanced_graph.py
\ No newline at end of file
diff --git a/docs/source/tutorials/advanced_tutorial.rst b/docs/source/tutorials/advanced_tutorial.rst
new file mode 100644
index 00000000..b3e619f5
--- /dev/null
+++ b/docs/source/tutorials/advanced_tutorial.rst
@@ -0,0 +1,164 @@
+Advanced Tutorial
+==========================================
+
+In this tutorial, we will look at how to run PyReason with a more
+complex graph.
+
+.. note::
+ Find the full, excecutable code `here `_
+
+Graph
+------------
+
+We use a larger graph for this example. In this example , we have ``customers`` , ``cars`` , ``pets`` and their relationships.
+We first have ``customer_details`` followed by ``car_details`` , ``pet_details`` , ``travel_details`` .
+
+.. literalinclude:: advanced_graph.py
+ :language: python
+ :lines: 16-24, 28-34 , 39-52, 58-63
+
+We now have the relationships between the customers , cars , pets and travel details.
+
+.. literalinclude:: advanced_graph.py
+ :language: python
+ :lines: 67-78
+
+Based on the relationships we now connect the nodes, edges and the form the graph.
+
+.. literalinclude:: advanced_graph.py
+ :language: python
+ :lines: 84-124
+
+We now have the graph ready. We can now add the rules for our use case. Take a look at it at
+
+.. figure:: advanced_graph.png
+ :alt: image
+
+ advanced graph image
+
+Rules
+-----
+
+The below are the rules we want to add:
+
+1. A customer is popular if he is friends with a popular customer.
+2. A customer has a cool car if he owns a car and the car is of type ``Car_4``.
+3. A customer has a cool pet if he owns a pet and the pet is of type ``Pet_2``.
+4. A customer is trendy if he has a cool car and a cool pet.
+
+.. code-block:: python
+
+ pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y)', 'popular_pet_rule'))
+ pr.add_rule(pr.Rule('cool_car(x) <-1 owns_car(x,y),Car_4(y)', 'cool_car_rule'))
+ pr.add_rule(pr.Rule('cool_pet(x)<-1 owns_pet(x,y),Pet_2(y)', 'cool_pet_rule'))
+ pr.add_rule(pr.Rule('trendy(x) <- cool_car(x) , cool_pet(x)', 'trendy_rule'))
+The above rules are based on nodes. Now let us add some more rules based on the edges.
+
+1. Two customers are ``car_friends`` if they own the same car.
+2. Two customers are ``friends`` if they own the same color car.
+
+.. code-block:: python
+
+ pr.add_rule(pr.Rule("car_friend(x,y) <- owns_car(x,z), owns_car(y,z)", "car_friend_rule"))
+ pr.add_rule(pr.Rule("same_color_car(x, y) <- owns_car(x, c1) , owns_car(y, c2)","same_car_color_rule"))
+
+Facts
+-------
+
+We now add the facts to the graph.
+There is only one fact we are going to use.
+1. ``customer_0`` is popular from time ``0`` to ``5``.
+
+.. code-block:: python
+
+ pr.add_fact(pr.Fact(name='popular-fact', fact_text='popular(customer_0)', 0, 5))
+
+Running Pyreason
+----------------
+
+We now run the PyReason interpretation with the graph and the rules.
+
+.. code-block:: python
+
+ interpretation = pr.reason(timesteps=6)
+ # pr.save_rule_trace(interpretation)
+
+ interpretations_dict = interpretation.get_dict()
+
+ df1 = pr.filter_and_sort_nodes(interpretation, ['trendy', 'cool_car', 'cool_pet', 'popular'])
+ df2 = pr.filter_and_sort_edges(interpretation, ['car_friend', 'same_color_car'])
+
+Expected Output
+---------------
+Below is the expected output at timestep ``0``
+
+.. note::
+ Find the full expected output `here `_
+
+.. code:: text
+
+ shortend output
+
+ Interpretations:
+ {0: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {},
+ 'customer_4': {},
+ 'customer_5': {},
+ 'customer_6': {},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+
\ No newline at end of file
diff --git a/docs/source/tutorials/annotation_function.rst b/docs/source/tutorials/annotation_function.rst
new file mode 100644
index 00000000..f2ebe7a3
--- /dev/null
+++ b/docs/source/tutorials/annotation_function.rst
@@ -0,0 +1,261 @@
+PyReason Annotation Functions
+=============================
+
+In this tutorial, we will look at use annotation functions in PyReason.
+Read more about annotation functions `here `_.
+
+.. note::
+ Find the full, excecutable code for both annotation functions `here `_
+
+
+Average Annotation Function Example
+-----------------------------------
+This example takes the average of the lower and higher bounds of the nodes in the graph.
+
+Graph
+^^^^^^^
+
+This example will use a graph created with 2 facts, and only 2 nodes. The annotation functions can be run on a graph of any size. See :ref:`PyReason Graphs ` for more information on how to create graphs in PyReason.
+
+
+Facts
+^^^^^^^
+To initialize this graph, we will add 2 nodes ``P(A)`` and ``P(B)``, using ``add_fact``:
+
+.. code:: python
+
+ import pyreason as pr
+
+ pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
+ pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))
+
+
+
+
+Average Annotation Function
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Next, we will then add the annotation function to find the average of all the upper and lower bounds of the graph.
+
+Here is the Average Annotation Function:
+
+.. code:: python
+
+ @numba.njit
+ def avg_ann_fn(annotations, weights):
+ # annotations contains the bounds of the atoms that were used to ground the rule. It is a nested list that contains a list for each clause
+ # You can access for example the first grounded atom's bound by doing: annotations[0][0].lower or annotations[0][0].upper
+
+ # We want the normalised sum of the bounds of the grounded atoms
+ sum_upper_bounds = 0
+ sum_lower_bounds = 0
+ num_atoms = 0
+ for clause in annotations:
+ for atom in clause:
+ sum_lower_bounds += atom.lower
+ sum_upper_bounds += atom.upper
+ num_atoms += 1
+
+ a = sum_lower_bounds / num_atoms
+ b = sum_upper_bounds / num_atoms
+ return a, b
+
+This takes the annotations, or a list of the bounds of the grounded atoms and the weights of the grounded atoms, and returns the average of the upper and lower bounds repectivley.
+
+Next, we add this function in PyReason:
+
+.. code:: python
+
+ pr.add_annotation_function(avg_ann_fn)
+
+
+Rules
+^^^^^^^
+After we have created the graph, and added the annotation function, we add the annotation function to a Rule.
+
+Create Rules of this general format when using an annotation function:
+
+.. code:: text
+
+ 'average_function(A, B):avg_ann_fn <- P(A):[0, 1], P(B):[0, 1]'
+
+The annotation function will be called when all clauses in the rule have been satisfied and the head of the rule is to be annotated.
+
+.. code:: python
+
+ pr.add_rule(pr.Rule('average_function(A, B):avg_ann_fn <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
+
+
+Running PyReason
+^^^^^^^^^^^^^^^^^^^^^
+Begin the PyReason reasoning process with the added annotation function with:
+
+.. code:: python
+
+ interpretation = pr.reason(timesteps=1)
+
+
+Expected Output
+^^^^^^^^^^^^^^^^^^^^^
+The expected output of this function is
+
+.. code:: text
+
+ Timestep: 0
+ Converged at time: 0
+ Fixed Point iterations: 2
+ TIMESTEP - 0
+ component average_function
+ 0 (A, B) [0.10500000000000001, 1.0]
+
+In this output:
+ - The lower bound of the ``avg_ann_fn(A, B)`` is computed as ``0.105``, based on the weighted combination of the lower bounds of ``P(A)`` (0.01) and ``P(B)`` (0.2), averaged together.
+ - The upper bound of the ``linear_combination_function(A, B)`` is computed as ``0.4``, based on the weighted combination of the upper bounds of ``P(A)`` (1.0) and ``P(B)`` (1.0), averaged together.
+
+
+
+Linear Combination Annotation Function
+----------------------------------------
+
+Now, we will define and use a new annotation function to compute a weighted linear combination of the bounds of grounded atoms in a rule.
+
+
+The `map_interval` Function
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+We will first define a helper function that maps a value from the interval `[lower, upper]` to the interval `[0, 1]`. This will be used in the main annotation function to normalize the bounds:
+
+.. code:: python
+
+ @numba.njit
+ def map_interval(t, a, b, c, d):
+ """
+ Maps a value `t` from the interval [a, b] to the interval [c, d] using the formula:
+
+ f(t) = c + ((d - c) / (b - a)) * (t - a)
+
+ Parameters:
+ - t: The value to be mapped.
+ - a: The lower bound of the original interval.
+ - b: The upper bound of the original interval.
+ - c: The lower bound of the target interval.
+ - d: The upper bound of the target interval.
+
+ Returns:
+ - The value `t` mapped to the new interval [c, d].
+ """
+ # Apply the formula to map the value t
+ mapped_value = c + ((d - c) / (b - a)) * (t - a)
+
+ return mapped_value
+
+
+Graph
+^^^^^^^^^^^^^^^^^^^^^
+
+This example will use a graph created with 2 facts, and only 2 nodes. The annotation functions can be run on a graph of any size. See :ref:`PyReason Graphs ` for more information on how to create graphs in PyReason.
+
+
+Facts
+^^^^^^^^^^^^^^
+To initialize this graph, we will add 3 nodes ``A``, ``B``, and ``C``, using ``add_fact``:
+
+.. code:: python
+
+ import pyreason as pr
+
+ pr.add_fact(pr.Fact('A : [.1, 1]'))
+ pr.add_fact(pr.Fact('B : [.2, 1]'))
+ pr.add_fact(pr.Fact('C : [.4, 1]'))
+
+
+
+Linear Combination Function
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Next, we define the annotation function that computes a weighted linear combination of the mapped lower and upper bounds of the grounded atoms. The weights are applied to normalize the values.
+For simplicity sake, we define the constant at 0.2 within the function, this is alterable for any constant, or for the weights in the graph.
+
+.. code:: python
+
+ @numba.njit
+ def lin_comb_ann_fn(annotations, weights):
+ sum_lower_comb = 0
+ sum_upper_comb = 0
+ num_atoms = 0
+ constant = 0.2
+ # Iterate over the clauses in the rule
+ for clause in annotations:
+ for atom in clause:
+
+ # Apply the constant weight to the lower and upper bounds, and accumulate
+ sum_lower_comb += constant * atom.lower
+ sum_upper_comb += constant * atom.upper
+ num_atoms += 1
+
+
+ #if the lower and upper are equal, return [0,1]
+ if sum_lower_comb == sum_upper_comb:
+ return 0,1
+
+ if sum_lower_comb> sum_upper_comb:
+ sum_lower_comb,sum_upper_comb= sum_upper_comb, sum_lower_comb
+
+ if sum_upper_comb>1:
+ sum_lower_comb = map_interval(sum_lower_comb, sum_lower_comb, sum_upper_comb, 0,1)
+
+ sum_upper_comb = map_interval(sum_upper_comb, sum_lower_comb, sum_lower_comb,0,1)
+
+ # Return the weighted linear combination of the lower and upper bounds
+ return sum_lower_comb, sum_upper_comb
+
+
+
+We now add the new annotation function within the PyReason framework:
+
+.. code:: python
+
+ # Register the custom annotation function with PyReason
+ pr.add_annotation_function(lin_comb_ann_fn)
+
+
+Rules
+^^^^^^^
+After we have created the graph, and added the annotation function, we add the annotation function to a Rule.
+
+Create Rules of this general format when using an annotation function:
+
+.. code:: text
+
+ linear_combination_function(A, B):lin_comb_ann_fn <- A:[0, 1], B:[0, 1], C:[0, 1]
+
+
+.. code:: python
+
+ pr.add_rule(pr.Rule('linear_combination_function(A, B):lin_comb_ann_fn <- A:[0, 1], B:[0, 1], C:[0, 1]', infer_edges=True))
+
+The annotation function will be called when all clauses in the rule have been satisfied and the head of the Rule is to be annotated.
+
+Running PyReason
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Begin the PyReason reasoning process with the added annotation function with:
+
+.. code:: python
+
+ interpretation = pr.reason(timesteps=1)
+
+
+Expected Output
+^^^^^^^^^^^^^^^^^^^^^
+Below is the expected output from running the ``linear_combination_annotation_function``:
+
+.. code:: text
+
+ Timestep: 0
+
+ Converged at time: 0
+ Fixed Point iterations: 2
+ TIMESTEP - 0
+ component linear_combination_function
+ 0 (A, B) [0.24000000000000005, 0.6000000000000001]
+
+In this output:
+ - The lower bound of the ``linear_combination_function(A, B, C)`` is computed as ``0.24000000000000005``, based on the weighted combination of the lower bounds of ``A`` (0.1), ``B`` (0.2), and ``C`` (0.4) multiplied by the constant(0.2) then added together.
+ - The upper bound of the ``linear_combination_function(A, B, C)`` is computed as ``0.6000000000000001``, based on the weighted combination of the upper bounds of ``A`` (1), ``B`` (1), and ``C`` (1) multiplied by the constant(0.2) then added together.
diff --git a/docs/source/tutorials/Basic tutorial.rst b/docs/source/tutorials/basic_tutorial.rst
similarity index 80%
rename from docs/source/tutorials/Basic tutorial.rst
rename to docs/source/tutorials/basic_tutorial.rst
index 22d29a76..677f94b8 100644
--- a/docs/source/tutorials/Basic tutorial.rst
+++ b/docs/source/tutorials/basic_tutorial.rst
@@ -1,10 +1,13 @@
-PyReason Hello World!
+PyReason Basic Tutorial
========================
Welcome to PyReason! In this document we outline a simple program that
demonstrates some of the capabilities of the software. If this is your
first time looking at the software, you’re in the right place.
+.. note::
+ Find the full, excecutable code `here `_
+
The following graph represents a network of people and the pets that
they own.
@@ -101,12 +104,12 @@ GraphML format.
-We then load the graph from the file using the following code:
+We then load the graph from the NetworkX graph using the following code:
.. code:: python
import pyreason as pr
- pr.load_graphml('path_to_file')
+ pr.load_graph(g)
.. figure:: basic_graph.png
:alt: image
@@ -127,17 +130,17 @@ who has the same pet as they do, then they are popular.
popular(x) : [1,1] <-1 popular(y) : [1,1] , Friends(x,y) : [1,1] , owns(y,z) : [1,1] , owns(x,z) : [1,1]
-Since PyReason by default assumes bounds in a rule to be `[1,1]`, we can omit them here and write:
+Since PyReason by default assumes bounds in a rule to be ``[1,1]``, we can omit them here and write:
.. code:: text
popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)
-The rule is read as follows: - The `head` of the rule is
-`popular(x)` and the body is
-`popular(y), Friends(x,y), owns(y,z), owns(x,z)`. The head and body
+The rule is read as follows: - The ``head`` of the rule is
+``popular(x)`` and the body is
+``popular(y), Friends(x,y), owns(y,z), owns(x,z)``. The head and body
are separated by an arrow and the time after which the head will become
-true `<-1` in our case this happens after `1` timestep.
+true ``<-1`` in our case this happens after ``1`` timestep.
To add this rule to PyReason, we can do the following:
@@ -154,7 +157,7 @@ To add the rule directly, we must specify the rule and a name for it.
The name helps understand which rules fired during reasoning later on.
Adding the rule from a file is also possible. The file should be in
-`.txt` format and should contain the rule in the format shown above.
+``.txt`` format and should contain the rule in the format shown above.
.. code:: text
@@ -172,22 +175,20 @@ Facts
Facts are initial conditions that we want to set in the graph.
-In the graph we have created, suppose we want to set `Mary` to be
-`popular` initially.
+In the graph we have created, suppose we want to set ``Mary`` to be
+``popular`` initially.
.. code:: python
import pyreason as pr
- pr.add_fact(pr.Fact(name='popular-fact', component='Mary', attribute='popular', bound=[1, 1], start_time=0, end_time=2))
+ pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
-The fact indicates that `Mary` is `popular` at time `0` and will
-remain so until time `2`.
+The fact indicates that ``Mary`` is ``popular`` at time ``0`` and will
+remain so until time ``2``.
Running PyReason
----------------
-The complete code for the basic tutorial is in the file
-`basic_tutorial.py`.
The main line that runs the reasoning in that file is:
@@ -195,9 +196,9 @@ The main line that runs the reasoning in that file is:
interpretation = pr.reason(timesteps=2)
-This line runs the reasoning for `2` timesteps and returns the
+This line runs the reasoning for ``2`` timesteps and returns the
interpretation of the graph at each timestep. We can also skip the
-`timesteps` argument and let PyReason run until the convergence is
+``timesteps`` argument and let PyReason run until the convergence is
reached.
Expected Output
@@ -207,11 +208,11 @@ Before checking the output , we can check manually what the expected
output should be. Since we have a small graph, we can reason through it
manually.
-1. At timestep 0, we have `Mary` to be `popular`.
-2. At timestep 1, `Justin` becomes `popular` because he has a
- popular friend (`Mary`) and has the same pet as `Mary` (cat).
-3. At timestep 2, `John` becomes `popular` because he has a popular
- friend (`Justin`) and has the same pet as `Justin` (dog).
+1. At timestep 0, we have ``Mary`` to be ``popular``.
+2. At timestep 1, ``Justin`` becomes ``popular`` because he has a
+ popular friend (``Mary``) and has the same pet as ``Mary`` (cat).
+3. At timestep 2, ``John`` becomes ``popular`` because he has a popular
+ friend (``Justin``) and has the same pet as ``Justin`` (dog).
4. At timestep 3, no new nodes become popular and the reasoning stops.
The output of the reasoning is as follows:
diff --git a/docs/source/tutorials/Creating Rules.rst b/docs/source/tutorials/creating_rules.rst
similarity index 100%
rename from docs/source/tutorials/Creating Rules.rst
rename to docs/source/tutorials/creating_rules.rst
diff --git a/docs/source/tutorials/custom_thresholds.rst b/docs/source/tutorials/custom_thresholds.rst
new file mode 100644
index 00000000..5e2398fd
--- /dev/null
+++ b/docs/source/tutorials/custom_thresholds.rst
@@ -0,0 +1,176 @@
+.. _custom_thresholds_tutorial:
+PyReason Custom Threshold
+=================================
+
+In this tutorial, we will look at how to run PyReason with Custom Thresholds.
+Custom Thresholds are parameters in the :ref:`Rule Class `.
+
+.. note::
+ Find the full, excecutable code `here `_
+
+The following graph represents a network of People and a Text Message in their group chat.
+
+.. image:: ../../../media/group_chat_graph.png
+ :align: center
+
+
+Graph
+------------
+
+First, we create the graph using Networkx. This graph has nodes ``Zach``, ``Justin``, ``Michelle``, ``Amy``, and ``TextMessages``.
+
+.. code:: python
+
+
+ import networkx as nx
+
+ # Create an empty graph
+ G = nx.Graph()
+
+ # Add nodes
+ nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
+ G.add_nodes_from(nodes)
+
+ # Add edges with attribute 'HaveAccess'
+ edges = [
+ ("Zach", "TextMessage", {"HaveAccess": 1}),
+ ("Justin", "TextMessage", {"HaveAccess": 1}),
+ ("Michelle", "TextMessage", {"HaveAccess": 1}),
+ ("Amy", "TextMessage", {"HaveAccess": 1})
+ ]
+ G.add_edges_from(edges)
+
+Then intialze and load the graph into PyReason with:
+
+.. code:: python
+
+ import pyreason as pr
+ pr.load_graph(G)
+
+
+Rules
+-----
+
+Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows:
+
+.. code-block:: text
+
+ ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)
+
+The ``head`` of the rule is ``ViewedByAll(x)`` and the body is ``HaveAccess(x,y), Viewed(y)``. The head and body are separated by an arrow which means the rule will start evaluating from
+timestep ``0``.
+
+Next, add in the custom thresholds. In this graph, the custom_thresholds ensure that in order for the rules to be fired, specific criteria must be met.
+
+
+.. code:: python
+
+ import pyreason as pr
+ from pyreason import Threshold
+
+ .. code:: python
+
+ # add custom thresholds
+ user_defined_thresholds = [
+ Threshold("greater_equal", ("number", "total"), 1),
+ Threshold("greater_equal", ("percent", "total"), 100),
+ ]
+
+
+The ``user_defined_thresholds`` are a list of custom thresholds of the format: (quantifier, quantifier_type, thresh) where:
+ - quantifier can be greater_equal, greater, less_equal, less, equal
+ - quantifier_type is a tuple where the first element can be either number or percent and the second element can be either total or available
+ - thresh represents the numerical threshold value to compare against
+
+The custom thresholds are created corresponding to the two clauses ``(HaveAccess(x,y)`` and ``Viewed(y))`` as below:
+ - ('greater_equal', ('number', 'total'), 1) (there needs to be at least one person who has access to ``TextMessage`` for the first clause to be satisfied)
+ - ('greater_equal', ('percent', 'total'), 100) (100% of people who have access to ``TextMessage`` need to view the message for second clause to be satisfied)
+
+
+
+Next, add the Rule, with the ``user_defined_thresholds`` are passed as parameters to the new Rule. ``viewed_by_all_rule`` is the name of the rule. This helps to understand which rule/s are fired during reasoning later on.
+
+
+.. code:: python
+
+ pr.add_rule(
+ pr.Rule(
+ "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)",
+ "viewed_by_all_rule",
+ custom_thresholds=user_defined_thresholds,
+ )
+ )
+
+
+Facts
+-----
+
+The facts determine the initial conditions of elements in the graph. They can be specified from the graph attributes but in that
+case they will be immutable later on. Adding PyReason facts gives us more flexibility.
+
+In our case we want one person to view the ``TextMessage`` in a particular interval of timestep.
+For example, we create facts stating:
+
+ - ``Zach`` and ``Justin`` view the ``TextMessage`` from at timestep ``0``
+ - ``Michelle`` views the ``TextMessage`` at timestep ``1``
+ - ``Amy`` views the ``TextMessage`` at timestep ``2``
+ - ``3`` is the last timestep the rule is active for all.
+
+
+.. code:: python
+
+ pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
+ pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
+ pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
+ pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
+
+This allows us to specify components that have an intial condition.
+
+Running PyReason
+----------------
+
+To run the reasoning in the file:
+
+.. code:: python
+
+ # Run the program for three timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=3)
+
+This specifies how many timesteps to run for.
+
+Expected output
+---------------
+After running the python file, the expected output is:
+
+.. code:: text
+
+ Timestep: 0
+ Timestep: 1
+ Timestep: 2
+ Timestep: 3
+
+ Converged at time: 3
+ Fixed Point iterations: 6
+ TIMESTEP - 0
+ Empty DataFrame
+ Columns: [component, ViewedByAll]
+ Index: []
+
+ TIMESTEP - 1
+ Empty DataFrame
+ Columns: [component, ViewedByAll]
+ Index: []
+
+ TIMESTEP - 2
+ component ViewedByAll
+ 0 TextMessage [1.0, 1.0]
+
+ TIMESTEP - 3
+ component ViewedByAll
+ 0 TextMessage [1.0, 1.0]
+
+
+1. For timestep 0, we set ``Zach -> Viewed: [1,1]`` and ``Justin -> Viewed: [1,1]`` in the facts
+2. For timestep 1, ``Michelle`` views the TextMessage as stated in facts ``Michelle -> Viewed: [1,1]``.
+3. For timestep 2, since ``Amy`` has just viewed the ``TextMessage``, therefore ``Amy -> Viewed: [1,1]``. As per the rule,
+ since all the people have viewed the ``TextMessage``, the message is marked as ``ViewedByAll``.
diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst
new file mode 100644
index 00000000..e0c2f755
--- /dev/null
+++ b/docs/source/tutorials/index.rst
@@ -0,0 +1,19 @@
+Tutorials
+==========
+
+In this section we outline a series of tutorials that will help you get started with the basics of using the `pyreason` library.
+
+Contents
+--------
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+ :glob:
+
+ ./basic_tutorial.rst
+ ./advanced_tutorial.rst
+ ./custom_thresholds.rst
+ ./infer_edges.rst
+ ./annotation_function.rst
+
\ No newline at end of file
diff --git a/docs/source/tutorials/infer_edges.rst b/docs/source/tutorials/infer_edges.rst
new file mode 100644
index 00000000..7aa1e86e
--- /dev/null
+++ b/docs/source/tutorials/infer_edges.rst
@@ -0,0 +1,151 @@
+PyReason Infer Edges
+============================
+
+In this tutorial, we will look at how to run PyReason with infer edges.
+infer edges is a parameter in the :ref:`Rule Class `.
+
+
+.. note::
+ Find the full, excecutable code `here `_
+
+The following graph represents a network of airports and connections.
+
+.. image:: ../../../media/infer_edges11.png
+ :align: center
+
+Graph
+------------
+
+First, we create the graph in Networkx. This graph has airports and flight connections.
+
+.. code:: python
+
+ import networkx as nx
+
+ # Create a directed graph
+ G = nx.DiGraph()
+
+ # Add nodes with attributes
+ nodes = [
+ ("Amsterdam_Airport_Schiphol", {"Amsterdam_Airport_Schiphol": 1}),
+ ("Riga_International_Airport", {"Riga_International_Airport": 1}),
+ ("Chișinău_International_Airport", {"Chișinău_International_Airport": 1}),
+ ("Yali", {"Yali": 1}),
+ ("Düsseldorf_Airport", {"Düsseldorf_Airport": 1}),
+ ("Pobedilovo_Airport", {"Pobedilovo_Airport": 1}),
+ ("Dubrovnik_Airport", {"Dubrovnik_Airport": 1}),
+ ("HévÃz-Balaton_Airport", {"HévÃz-Balaton_Airport": 1}),
+ ("Athens_International_Airport", {"Athens_International_Airport": 1}),
+ ("Vnukovo_International_Airport", {"Vnukovo_International_Airport": 1})
+ ]
+
+ G.add_nodes_from(nodes)
+
+ # Add edges with 'isConnectedTo' attribute
+ edges = [
+ ("Pobedilovo_Airport", "Vnukovo_International_Airport", {"isConnectedTo": 1}),
+ ("Vnukovo_International_Airport", "HévÃz-Balaton_Airport", {"isConnectedTo": 1}),
+ ("Düsseldorf_Airport", "Dubrovnik_Airport", {"isConnectedTo": 1}),
+ ("Dubrovnik_Airport", "Athens_International_Airport", {"isConnectedTo": 1}),
+ ("Riga_International_Airport", "Amsterdam_Airport_Schiphol", {"isConnectedTo": 1}),
+ ("Riga_International_Airport", "Düsseldorf_Airport", {"isConnectedTo": 1}),
+ ("Chișinău_International_Airport", "Riga_International_Airport", {"isConnectedTo": 1}),
+ ("Amsterdam_Airport_Schiphol", "Yali", {"isConnectedTo": 1})
+ ]
+
+ G.add_edges_from(edges)
+
+We can also load the graph from a GraphML `file `_
+
+We then initialize and load the graph using the following code:
+
+.. code:: python
+
+ import pyreason as pr
+
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.memory_profile = False
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+ # Load all the files into pyreason
+ pr.load_graph(G)
+
+Rules
+------------
+
+Next, add the Rule and set ``infer_edges`` to ``True``
+
+.. code:: python
+
+ pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
+
+This will should connect exactly one new relationship from A to Y. The Rule states that if there is a connection from Y to B, then infer an edge from A to Y.
+
+ - B is ``Amsterdam Airport Schiphol``
+ - Y is ``Riga_International_Airport``
+ - A is ``Vnukovo_International_Airport``
+
+Therefore the output of the graph after running 1 timestep should be a new connection [1,1] between ``Vnukovo_International_Airport`` (A) to ``Riga_International_Airport`` (Y).
+
+Facts
+-----
+This example does not havea any facts. All initial conditions are set when the graph is created
+
+
+Running PyReason
+----------------
+
+Run the program for ``1`` timesteps.
+
+.. code:: python
+
+ # Run the program for one timesteps to see the diffusion take place
+
+ interpretation = pr.reason(timesteps=1)
+
+
+Expected output
+---------------
+After running the python file, the expected output is:
+
+The expected output after running will list at timestep 0 the inital connections and timestep 1 the added connection due to the ``infer_edges`` parameter.
+
+.. code:: text
+
+ Timestep: 0
+ Timestep: 1
+
+ Converged at time: 1
+ Fixed Point iterations: 2
+ TIMESTEP - 0
+ component isConnectedTo
+ 0 (Amsterdam_Airport_Schiphol, Yali) [1.0, 1.0]
+ 1 (Riga_International_Airport, Amsterdam_Airport... [1.0, 1.0]
+ 2 (Riga_International_Airport, Düsseldorf_Airport) [1.0, 1.0]
+ 3 (Chișinău_International_Airport, Riga_Internat... [1.0, 1.0]
+ 4 (Düsseldorf_Airport, Dubrovnik_Airport) [1.0, 1.0]
+ 5 (Pobedilovo_Airport, Vnukovo_International_Air... [1.0, 1.0]
+ 6 (Dubrovnik_Airport, Athens_International_Airport) [1.0, 1.0]
+ 7 (Vnukovo_International_Airport, HévÃz-Balaton_... [1.0, 1.0]
+
+ TIMESTEP - 1
+ component isConnectedTo
+ 0 (Vnukovo_International_Airport, Riga_Internati... [1.0, 1.0]
+
+
+
+The graph after running shows a new connection from ``Vnukovo_International_Airport`` to ``Riga_International_Airport``, because during the reasoning process an edges between them was infered.
+
+.. image:: ../../../media/infer_edges2.png
+ :align: center
+
+
+
diff --git a/docs/source/tutorials/Installation.rst b/docs/source/tutorials/installation.rst
similarity index 99%
rename from docs/source/tutorials/Installation.rst
rename to docs/source/tutorials/installation.rst
index 400420d6..3b41e3dd 100644
--- a/docs/source/tutorials/Installation.rst
+++ b/docs/source/tutorials/installation.rst
@@ -21,6 +21,7 @@ Step-by-Step Guide
1. Install pyenv
- Ensure your system has the necessary dependencies installed. The installation steps vary by operating system:
+
- **Linux/Unix/macOS**
.. code-block:: bash
diff --git a/docs/source/tutorials/rule_image.png b/docs/source/tutorials/rule_image.png
new file mode 100644
index 00000000..2c5891ec
Binary files /dev/null and b/docs/source/tutorials/rule_image.png differ
diff --git a/docs/source/tutorials/Understanding Logic.rst b/docs/source/tutorials/understanding_logic.rst
similarity index 98%
rename from docs/source/tutorials/Understanding Logic.rst
rename to docs/source/tutorials/understanding_logic.rst
index 28cb625e..cf8df84f 100644
--- a/docs/source/tutorials/Understanding Logic.rst
+++ b/docs/source/tutorials/understanding_logic.rst
@@ -50,4 +50,4 @@ Inconsistent predicate list
The first rule states that the grass is wet if it rained, while the second rule states that the grass is not wet if it rained.
The fact f1 states that it rained, which is consistent with the first rule, but inconsistent with the second rule.
-.. |rule_image| image:: Rule_image.png
+.. |rule_image| image:: rule_image.png
diff --git a/docs/source/user_guide/1_pyreason_graphs.rst b/docs/source/user_guide/1_pyreason_graphs.rst
new file mode 100644
index 00000000..6a72cad5
--- /dev/null
+++ b/docs/source/user_guide/1_pyreason_graphs.rst
@@ -0,0 +1,151 @@
+Graphs
+===============
+PyReason reasons over knowledge graphs. Graphs serve as a knowledge base with initial conditions given to nodes and edges.
+These initial conditions are used along with :ref:`PyReason rules ` that we'll see later on to infer new relations or attributes.
+
+
+How to Load a Graph in RyReason
+-------------------------------
+In PyReason there are two ways to load graphs:
+
+
+1. Using a NetworkX `DiGraph `_ object
+2. Using a `GraphML `_ file which is an encoding of a directed graph
+
+
+NetworkX allows you to manually add nodes and edges, whereas GraphML reads in a directed graph from a file.
+
+
+NetworkX Example
+~~~~~~~~~~~~~~~~
+Using NetworkX, you can create a `directed graph `_ object. Users can add and remove nodes and edges from the graph.
+
+Read more about NetworkX `here `_.
+
+Given a network of people and their pets, we can create a graph using NetworkX.
+
+#. Mary is friends with Justin
+#. Mary is friends with John
+#. Justin is friends with John
+
+And
+
+#. Mary owns a cat
+#. Justin owns a cat and a dog
+#. John owns a dog
+
+.. code-block:: python
+
+ import networkx as nx
+
+ # Create a NetowrkX Directed graph object
+ g = nx.DiGraph()
+
+ # Add the people as nodes
+ g.add_nodes_from(['John', 'Mary', 'Justin'])
+ g.add_nodes_from(['Dog', 'Cat'])
+
+ # Add the edges and their attributes. When an attribute = x which is <= 1, the annotation
+ # associated with it will be [x,1]. NOTE: These attributes are immutable unless specified otherwise in pyreason settings
+ # Friend edges
+ g.add_edge('Justin', 'Mary', Friends=1)
+ g.add_edge('John', 'Mary', Friends=1)
+ g.add_edge('John', 'Justin', Friends=1)
+
+ # Pet edges
+ g.add_edge('Mary', 'Cat', owns=1)
+ g.add_edge('Justin', 'Cat', owns=1)
+ g.add_edge('Justin', 'Dog', owns=1)
+ g.add_edge('John', 'Dog', owns=1)
+
+After the graph has been created, in the same file, the DiGraph object can be loaded with:
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.load_graph(graph)
+
+
+
+GraphML Example
+~~~~~~~~~~~~~~~~
+Using `GraphML `_, you can read a graph in from a file. Below is the file format for the graph that we made above:
+
+.. code-block:: xml
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+
+
+Then load the graph using the following:
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.load_graphml('path_to_file')
+
+
+Initial Conditions
+------------------
+PyReason uses graph attributes (assigned to nodes or edges) as initial conditions, and converts them to *static facts*. *Static facts* do not change over time.
+Once the graph is loaded, all attributes will remain the same until the end of the section of PyReason using the graph.
+
+
+Graph Attributes to PyReason Bounds
+~~~~~~~~~~~~~~~~~~~~
+Since PyReason uses bounds to that are associated to attributes, it is important to understand how PyReason changes NetworkX attributes to bounds.
+In NetworkX graphs, each node/edge can hold key/value attribute pairs in an associated attribute dictionary. These attributes get transformed into "bounds".
+Bounds are between 0 (false) and 1 (true). The attribute value of the key/value pair in Networkx, is translated into the lower bound in PyReason.
+
+For example in the graph above, the attribute "Friends" is set to 1. This is translated into the lower bound of the interval ``[1,1]``.
+
+.. note::
+ Creating False bounds ``[0,0]`` is a little tricky since the value of a NetworkX attribute cannot be a list, and PyReason only modifies the
+ lower bound keeping the upper bound as 1. To do this, we can set the attribute as a string as seen below.
+
+.. code-block:: python
+
+ import networkx as nx
+ g = nx.DiGraph()
+ g.add_node("some_node", attribute1=1, attribute2="0,0")
+
+
+When the graph is loaded:
+
+.. code-block:: text
+
+ "some_node" is given the attribute1: [1,1], and attribute2 :[0,0].
+
+If the attribute is set equal to a single value, the assumed upper bound is 1. If a specific pair of bounds is required (e.g., for coordinates or ranges), the value should be provided as a string in a specific format.
diff --git a/docs/source/user_guide/2_pyreason_facts.rst b/docs/source/user_guide/2_pyreason_facts.rst
new file mode 100644
index 00000000..4ec30442
--- /dev/null
+++ b/docs/source/user_guide/2_pyreason_facts.rst
@@ -0,0 +1,33 @@
+Facts
+-----
+This section outlines Fact creation and implementation. See :ref:`here ` for more information on Facts in logic.
+
+Fact Parameters
+~~~~~~~~~~~~~~~
+To create a new **Fact** object in PyReason, use the `Fact` class with the following parameters:
+
+1. ``fact_text`` **(str):** The fact in text format, where bounds can be specified or not. The bounds are optional. If not specified, the bounds are assumed to be [1,1]. The fact can also be negated using the '~' symbol.
+
+ Examples of valid fact_text are:
+
+.. code-block:: text
+
+ 1. 'pred(x,y) : [0.2, 1]'
+ 2. 'pred(x,y)'
+ 3. '~pred(x,y)'
+
+2. ``name`` **(str):** The name of the fact. This will appear in the trace so that you know when it was applied
+3. ``start_time`` **(int):** The timestep at which this fact becomes active (default is 0)
+4. ``end_time`` **(int):** The last timestep this fact is active (default is 0)
+5. ``static`` **(bool):** If the fact should be active for the entire program. In which case ``start_time`` and ``end_time`` will be ignored. (default is False)
+
+
+Fact Example
+~~~~~~~~~~~~
+
+To add a fact in PyReason, use the command:
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.add_fact(pr.Fact(fact_text='pred(x,y) : [0.2, 1]', name='fact1', start_time=0, end_time=2))
diff --git a/docs/source/user_guide/3_pyreason_rules.rst b/docs/source/user_guide/3_pyreason_rules.rst
new file mode 100644
index 00000000..51e9be53
--- /dev/null
+++ b/docs/source/user_guide/3_pyreason_rules.rst
@@ -0,0 +1,211 @@
+.. _pyreason_rules:
+
+Rules
+==============
+This section outlines Rule creation and implementation. See :ref:`here ` for more information on Rules in logic.
+
+Creating a New Rule Object
+--------------------------
+
+In PyReason, rules are used to create or modify predicate bounds associated with nodes or edges in the graph if the conditions in the rule body are met.
+
+
+Rule Parameters
+~~~~~~~~~~~~~~~
+
+To create a new **Rule** object in PyReason, use the ``Rule`` class with the following parameters:
+
+#. ``rule_text`` **(str)**:
+ The rule in textual format. It should define a head and body using the syntax
+
+ ``head <- body``, where the body can include predicates and optional bounds. See more on PyReason rule format :ref:`here `.
+
+#. ``name`` **(str, optional)**:
+ A name for the rule, which will appear in the explainable rule trace.
+
+#. ``infer_edges`` **(bool, optional)**:
+ Indicates whether new edges should be inferred between the head variables when the rule is applied:
+
+ * If set to **True**, the rule will connect unconnected nodes when the body is satisfied.
+ * Else, set to **False**, the rule will **only** apply for nodes that are already connected, i.e edges already present in the graph (Default).
+
+#. ``set_static`` **(bool, optional)**:
+ Indicates whether the atom in the head should be set as static after the rule is applied. This means the bounds of that atom will no longer change for the duration of the program.
+
+#. ``custom_thresholds`` **(None, list, or dict, optional)**:
+ A list or dictionary of custom thresholds for the rule.
+ If not specified, default thresholds for ANY will be used. It can either be:
+
+ - A list of thresholds corresponding to each clause.
+ - A dictionary of thresholds mapping clause indices to specific thresholds.
+
+#. ``weights`` **(None, numpy.ndarray, optional)**:
+ A numpy array of weights for the rule passed to an annotation function. The weights can be used to calculate the annotation for the head of the rule. If not specified, the weights will default to 1 for each clause.
+
+
+.. _rule_formatting:
+Important Notes on Rule Formating:
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+1. The head of the rule is always on the left hand side of the rule.
+2. The body of the rule is always on the right hand side of the rule.
+3. You can include timestep in the rule by using the ``<-timestep`` body, if omitted, the rule will be applied with ``timestep=0``.
+4. You can include multiple clauses in the rule by using the ``<-timestep clause1, clause2, clause3``. If bounds are not specified, they default to ``[1,1]``.
+5. A tilde ``~`` can be used to negate a clause in the body of the rule, or the head itself.
+
+
+Rule Structure
+--------------
+Example rule in PyReason with correct formatting:
+
+.. code-block:: text
+
+ head(x) : [1,1] <-1 clause1(y) : [1,1] , clause2(x,y) : [1,1] , clause3(y,z) : [1,1] , clause4(x,z) : [1,1]
+
+which is equivalent to:
+
+.. code-block:: text
+
+ head(x) <-1 clause1(y), clause2(x,y), clause3(y,z), clause4(x,z)
+
+The rule is read as follows:
+
+**Head**:
+
+.. code-block:: text
+
+ head(x) : [1,1]
+
+**Body**:
+
+.. code-block:: text
+
+ clause1(x,y) : [1,1], clause2(y,z) : [1,1], clause3(x,z) : [1,1]
+
+
+The **head** and **body** are separated by an arrow (``<-``), and the rule is applied to the head after ``1`` timestep if the body conditions are met.
+
+
+Adding A Rule to PyReason
+-------------------------
+Add the rule directly
+~~~~~~~~~~~~~~~~~~~~~~
+
+To add the rule directly, we must specify the rule and (optionally) a name for it.
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.add_rule(pr.Rule('head(x) <-1 body1(y), body2(x,y), body3(y,z), body4(x,z)', 'rule_name'))
+
+The name helps understand which rules fired during reasoning later on.
+
+Add the rule from a .txt file
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+To add the rule from a text file, ensure the file is in .txt format, and contains the rule in the format shown above. This
+allows for multiple rules to be added at once, with each rule on a new line. Comments can be added to the file using the ``#`` symbol, and will be ignored by PyReason.
+
+ .. code-block:: text
+
+ head1(x) <-1 body(y), body2(x,y), body3(y,z), body4(x,z)
+ head2(x) <-1 body(y), body2(x,y), body3(y,z), body4(x,z)
+ # This is a comment and will be ignored
+
+Now we can load the rules from the file using the following code:
+
+ .. code-block:: python
+
+ import pyreason as pr
+ pr.add_rules_from_file('rules.txt')
+
+Annotation Functions
+--------------------
+
+What are annotation functions?
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Annotation Functions are specific user defined Python functions that are called when all clauses in a rule have been
+satisfied to annotate (give bounds to) the head of the rule. Annotation functions have access to the bounds of grounded
+atoms for each clause in the rule and users can use these bounds to make an annotation for the target of the rule.
+
+The Structure of an annotation function
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Only specifically structured annotation functions are allowed. The function has to be
+
+#. decorated with ``@numba.njit``
+#. has to take in 2 parameters whether you use them or not
+#. has to return 2 numbers
+
+**Example User Defined Annotation Function:**
+
+
+
+.. code-block:: python
+
+ import numba
+ import numpy as np
+
+ @numba.njit
+ def avg_ann_fn(annotations, weights):
+ # annotations contains the bounds of the atoms that were used to ground the rule. It is a nested list that contains a list for each clause
+ # You can access for example the first grounded atom's bound by doing: annotations[0][0].lower or annotations[0][0].upper
+
+ # We want the normalised sum of the bounds of the grounded atoms
+ sum_upper_bounds = 0
+ sum_lower_bounds = 0
+ num_atoms = 0
+ for clause in annotations:
+ for atom in clause:
+ sum_lower_bounds += atom.lower
+ sum_upper_bounds += atom.upper
+ num_atoms += 1
+
+ a = sum_lower_bounds / num_atoms
+ b = sum_upper_bounds / num_atoms
+ return a, b
+
+
+This annotation function calculates the average of the bounds of all grounded atoms in the rule. The function is decorated
+with ``@numba.njit`` to ensure that it is compiled to machine code for faster execution. The function takes in two parameters,
+``annotations`` and ``weights``, which are the bounds of the grounded atoms and the weights associated with each clause of the rule set by the user when the rule is added.
+The function returns two numbers, which are the lower and upper bounds of the annotation for the head of the rule.
+
+Adding an Annotation Function to a PyReason Rule
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Use the following to add an annotation function into pyreason so that it can be used by rules
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.add_annotation_function(avg_ann_fn)
+
+Then you can create rules of the following format:
+
+.. code-block:: text
+
+ head(x) : avg_ann_fn <- clause1(y), clause2(x,y), clause3(y,z), clause4(x,z)
+
+The annotation function will be called when all clauses in the rule have been satisfied and the head of the rule is to be annotated.
+The ``annotations`` parameter in the annotation function will contain the bounds of the grounded atoms for each of the 4 clauses in the rule.
+
+
+Custom Thresholds
+-----------------
+
+Custom thresholds allow you to specify specific thresholds for the clauses in the body of the rule. By default, with no
+custom thresholds specified, the rule will use the default thresholds for ANY. Custom thresholds can be specified as:
+
+1. A list of thresholds corresponding to each clause. Where the size of the list should be equal to the number of clauses in the rule.
+2. A dictionary of thresholds mapping clause indices to specific thresholds. The first clause has an index of 0.
+
+The Threshold Class
+~~~~~~~~~~~~~~~~~~~
+PyReason's ``Threshold`` class is used to define custom thresholds for a rule. The class has the following parameters:
+
+#. ``quantifier`` **(str)**: "greater_equal", "greater", "less_equal", "less", "equal"
+#. ``quantifier_type`` **(tuple)**: A tuple of two elements indicating the type of quantifier, where the first is either ``"number"`` or ``"percent"``
+and the second is either ``"total"`` or ``"available"``. ``"total"`` refers to all groundings of the clause, while ``"available"`` refers to the groundings that have the predicate of the clause.
+#. ``thresh`` **(int)**: The value of the threshold
+
+An example usage can be found :ref:`here `.
\ No newline at end of file
diff --git a/docs/source/user_guide/4_pyreason_settings.rst b/docs/source/user_guide/4_pyreason_settings.rst
new file mode 100644
index 00000000..bd7d3305
--- /dev/null
+++ b/docs/source/user_guide/4_pyreason_settings.rst
@@ -0,0 +1,94 @@
+
+Settings
+=================
+In this section, we detail the settings that can be used to configure PyReason. These settings can be used to control the behavior of the reasoning process.
+
+Settings can be accessed using the following code:
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.settings.setting_name = value
+
+Where ``setting_name`` is the name of the setting you want to change, and ``value`` is the value you want to set it to.
+Below is a table of all the settings that can be changed in PyReason using the code above.
+
+.. note::
+ All settings need to be modified **before** the reasoning process begins, otherwise they will not take effect.
+
+To reset all settings to their default values, use the following code:
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.reset_settings()
+
+
+.. list-table::
+
+ * - **Setting**
+ - **Default**
+ - **Description**
+ * - ``verbose``
+ - True
+ - | Whether to print extra information
+ | to screen during the reasoning process.
+ * - ``output_to_file``
+ - False
+ - | Whether to output print statements
+ | into a file.
+ * - ``output_file_name``
+ - 'pyreason_output'
+ - | The name the file output will be saved as
+ | (only if ``output_to_file = True``).
+ * - ``graph_attribute_parsing``
+ - True
+ - | Whether graph will be
+ | parsed for attributes.
+ * - ``reverse_digraph``
+ - False
+ - | Whether the directed edges in the graph
+ | will be reversed before reasoning.
+ * - ``atom_trace``
+ - False
+ - | Whether to keep track of all ground atoms
+ | which make the clauses true. **NOTE:** For large graphs
+ | this can use up a lot of memory and slow down the runtime.
+ * - ``save_graph_attributes_to_trace``
+ - False
+ - | Whether to save graph attribute facts to the
+ | rule trace. This might make the trace files large because
+ | there are generally many attributes in graphs.
+ * - ``persistent``
+ - False
+ - | Whether the bounds in the interpretation are reset
+ | to uncertain ``[0,1]`` at each timestep or keep
+ | their value from the previous timestep.
+ * - ``inconsistency_check``
+ - True
+ - | Whether to check for inconsistencies in the interpretation,
+ | and resolve them if found. Inconsistencies are resolved by
+ | resetting the bounds to ``[0,1]`` and making the atom static.
+ * - ``static_graph_facts``
+ - True
+ - | Whether to make graph facts static. In other words, the
+ | attributes in the graph remain constant throughout
+ | the reasoning process.
+ * - ``parallel_computing``
+ - False
+ - | Whether to use multiple CPU cores for inference.
+ | This can greatly speed up runtime if running on a
+ | cluster for large graphs.
+ * - ``update_mode``
+ - 'intersection'
+ - | The mode for updating interpretations. Options are ``'intersection'``
+ | or ``'override'``. When using ``'intersection'``, the resulting bound
+ | is the intersection of the new bound and the old bound. When using
+ | ``'override'``, the resulting bound is the new bound.
+
+
+Notes on Parallelism
+~~~~~~~~~~~~~~~~~~~~
+PyReason is parallelized over rules, so for large rulesets it is recommended that this setting is used. However, for small rulesets,
+the overhead might be more than the speedup and it is worth checking the performance on your specific use case.
+When possible we recommend using the same number of cores (or a multiple) as the number of rules in the program.
\ No newline at end of file
diff --git a/docs/source/user_guide/5_inconsistent_predicate_list.rst b/docs/source/user_guide/5_inconsistent_predicate_list.rst
new file mode 100644
index 00000000..2be00a50
--- /dev/null
+++ b/docs/source/user_guide/5_inconsistent_predicate_list.rst
@@ -0,0 +1,21 @@
+.. _inconsistent_predicate_list:
+
+Inconsistent Predicate List
+===========================
+
+In this section we detail how we can use inconsistent predicate lists to identify inconsistencies in the graph during reasoning.
+For more information on Inconsistencies and the Inconsistent Predicates list, see :ref:`here `.
+
+For this example, assume we have two inconsistent predicates, "sick" and "healthy". To be able to model this in PyReason
+such that when one predicate has a certain bound ``[l, u]``, the other predicate is given a bound ``[1-u, 1-l]`` automatically,
+we add the predicates to the **inconsistent predicate list**.
+
+This can be done by using the following code:
+
+.. code-block:: python
+
+ import pyreason as pr
+ pr.add_inconsistent_predicate('sick', 'healthy')
+
+This allows PyReason to automatically update the bounds of the predicates in the inconsistent predicate list to the
+negation of a predicate that is updated.
\ No newline at end of file
diff --git a/docs/source/user_guide/6_pyreason_expected_output.rst b/docs/source/user_guide/6_pyreason_expected_output.rst
new file mode 100644
index 00000000..e1aeb17a
--- /dev/null
+++ b/docs/source/user_guide/6_pyreason_expected_output.rst
@@ -0,0 +1,376 @@
+.. _pyreason_expected_output:
+
+PyReason Expected Output
+===========================
+
+This section outline four functions that help display and explain the PyReason output and reasoning process.
+
+Filter and Sort Nodes
+-----------------------
+This function filters and sorts the node changes in the interpretation and returns as a list of Pandas dataframes that contain the filtered and sorted interpretations that are easy to access.
+
+Basic Tutorial Example
+^^^^^^^^^^^^^^^^^^^^^^^^
+To see ``filter_and_sort_nodes`` in action we will look at the example usage in PyReasons Basic Tutorial.
+
+.. note::
+ Find the full, explained tutorial here `here `_
+
+
+The tutorial take in a basic graph of people and their pets, then adds a Rule and a Fact.
+
+
+.. code:: python
+
+ pr.load_graph(g)
+ pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
+ pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
+
+ # Run the program for two timesteps to see the diffusion take place
+ faulthandler.enable()
+ interpretation = pr.reason(timesteps=2)
+
+We add the ``filter_and_sort_nodes`` after the interpretation is run, before PyReason prints the output:
+
+.. code:: python
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+This will print the nodes at each timestep in the reasoning process.
+
+Expected Output
+^^^^^^^^^^^^^^^^^
+Using ``filter_and_sort_nodes``, the expected output is :
+
+.. code:: python
+
+ TIMESTEP - 0
+ component popular
+ 0 Mary [1.0,1.0]
+
+
+ TIMESTEP - 1
+ component popular
+ 0 Mary [1.0,1.0]
+ 1 Justin [1.0,1.0]
+
+
+ TIMESTEP - 2
+ component popular
+ 0 Mary [1.0,1.0]
+ 1 Justin [1.0,1.0]
+ 2 John [1.0,1.0]
+
+
+
+Filter and Sort Edges
+----------------------
+This function filters and sorts the edge changes in the interpretation and returns a list of Pandas dataframes that contain the filtered and sorted interpretations, making them easy to access.
+
+Infer Edges Example
+^^^^^^^^^^^^^^^^^^^^^^^^
+To see ``filter_and_sort_edges`` in action, we will look at the example usage in PyReason's Infer Edges Tutorial.
+
+.. note::
+ Find the full, explained tutorial here `here `_.
+
+The tutorial takes in a basic graph of airports and connections, then infers an edges between two unconnected airports.
+
+.. code:: python
+
+ pr.load_graph(G)
+ pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+
+
+We add the ``filter_and_sort_edges`` function after the interpretation is run, before PyReason prints the output:
+
+.. code:: python
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+This will print the edges at each timestep in the reasoning process.
+
+Expected Output
+^^^^^^^^^^^^^^^^
+Using ``filter_and_sort_edges``, the expected output is:
+
+.. code:: text
+
+ Timestep: 0
+ Timestep: 1
+
+ Converged at time: 1
+ Fixed Point iterations: 2
+ TIMESTEP - 0
+ component isConnectedTo
+ 0 (Amsterdam_Airport_Schiphol, Yali) [1.0, 1.0]
+ 1 (Riga_International_Airport, Amsterdam_Airport... [1.0, 1.0]
+ 2 (Riga_International_Airport, Düsseldorf_Airport) [1.0, 1.0]
+ 3 (Chișinău_International_Airport, Riga_Internat... [1.0, 1.0]
+ 4 (Düsseldorf_Airport, Dubrovnik_Airport) [1.0, 1.0]
+ 5 (Pobedilovo_Airport, Vnukovo_International_Air... [1.0, 1.0]
+ 6 (Dubrovnik_Airport, Athens_International_Airport) [1.0, 1.0]
+ 7 (Vnukovo_International_Airport, HévÃz-Balaton_... [1.0, 1.0]
+
+ TIMESTEP - 1
+ component isConnectedTo
+ 0 (Vnukovo_International_Airport, Riga_Internati... [1.0, 1.0]
+
+
+
+Get Rule Trace
+---------------
+This function returns the trace of the program as 2 pandas dataframes (one for nodes, one for edges).
+This includes every change that has occurred to the interpretation. If ``atom_trace`` was set to true
+this gives us full explainability of why interpretations changed
+
+Advanced Tutorial Example
+^^^^^^^^^^^^^^^^^^^^^^^
+
+To see ``get_rule_trace`` in action we will look at the example usage in PyReasons Advanced Tutorial.
+
+.. note::
+ Find the full, explained tutorial here `here `_
+
+
+The tutorial takes in a graph of we have customers, cars, pets and their relationships. We first have customer_details followed by car_details , pet_details , travel_details.
+
+We will only add the ``get_rule_trace`` function after the interpretation:
+
+.. code:: python
+
+ interpretation = pr.reason(timesteps=5)
+ nodes_trace, edges_trace = pr.get_rule_trace(interpretation)
+
+
+Expected Output
+^^^^^^^^^^^^^^^^
+Using ``get_rule_trace``, the expected output of ``nodes_trace`` and ``edges_trace`` is:
+
+Click `here `_ for the full table.
+
+**Nodes Trace:**
+
+.. code:: text
+
+ Time Fixed-Point-Operation Node ... Occurred Due To Clause-1 Clause-2
+ 0 0 0 popular-fac ... popular(customer_0) None None
+ 1 1 2 popular-fac ... popular(customer_0) None None
+ 2 1 2 customer_4 ... cool_car_rule [(customer_4, Car_4)] [Car_4]
+ 3 1 2 customer_6 ... cool_car_rule [(customer_6, Car_4)] [Car_4]
+ 4 1 2 customer_3 ... cool_pet_rule [(customer_3, Pet_2)] [Pet_2]
+ 5 1 2 customer_4 ... cool_pet_rule [(customer_4, Pet_2)] [Pet_2]
+ 6 1 3 customer_4 ... trendy_rule [customer_4] [customer_4]
+ 7 2 4 popular-fac ... popular(customer_0) None None
+ 8 2 4 customer_4 ... cool_car_rule [(customer_4, Car_4)] [Car_4]
+ 9 2 4 customer_6 ... cool_car_rule [(customer_6, Car_4)] [Car_4]
+ 10 2 4 customer_3 ... cool_pet_rule [(customer_3, Pet_2)] [Pet_2]
+ 11 2 4 customer_4 ... cool_pet_rule [(customer_4, Pet_2)] [Pet_2]
+ 12 2 5 customer_4 ... trendy_rule [customer_4] [customer_4]
+ 13 3 6 popular-fac ... popular(customer_0) None None
+ 14 3 6 customer_4 ... cool_car_rule [(customer_4, Car_4)] [Car_4]
+ 15 3 6 customer_6 ... cool_car_rule [(customer_6, Car_4)] [Car_4]
+ 16 3 6 customer_3 ... cool_pet_rule [(customer_3, Pet_2)] [Pet_2]
+ 17 3 6 customer_4 ... cool_pet_rule [(customer_4, Pet_2)] [Pet_2]
+ 18 3 7 customer_4 ... trendy_rule [customer_4] [customer_4]
+ 19 4 8 popular-fac ... popular(customer_0) None None
+ 20 4 8 customer_4 ... cool_car_rule [(customer_4, Car_4)] [Car_4]
+ 21 4 8 customer_6 ... cool_car_rule [(customer_6, Car_4)] [Car_4]
+ 22 4 8 customer_3 ... cool_pet_rule [(customer_3, Pet_2)] [Pet_2]
+ 23 4 8 customer_4 ... cool_pet_rule [(customer_4, Pet_2)] [Pet_2]
+ 24 4 9 customer_4 ... trendy_rule [customer_4] [customer_4]
+ 25 5 10 popular-fac ... popular(customer_0) None None
+ 26 5 10 customer_4 ... cool_car_rule [(customer_4, Car_4)] [Car_4]
+ 27 5 10 customer_6 ... cool_car_rule [(customer_6, Car_4)] [Car_4]
+ 28 5 10 customer_3 ... cool_pet_rule [(customer_3, Pet_2)] [Pet_2]
+ 29 5 10 customer_4 ... cool_pet_rule [(customer_4, Pet_2)] [Pet_2]
+ 30 5 11 customer_4 ... trendy_rule [customer_4] [customer_4]
+
+**Edges Trace**
+
+Click `here `_ for the full table.
+
+
+.. code:: text
+
+ Time ... Clause-2
+ 0 0 ... [(customer_1, Car_0), (customer_1, Car_8)]
+ 1 0 ... [(customer_1, Car_0), (customer_1, Car_8)]
+ 2 0 ... [(customer_2, Car_1), (customer_2, Car_3), (cu...
+ 3 0 ... [(customer_1, Car_0), (customer_1, Car_8)]
+ 4 0 ... [(customer_1, Car_0), (customer_1, Car_8)]
+ .. ... ... ...
+ 61 5 ... [(customer_0, Car_2), (customer_0, Car_7)]
+ 62 5 ... [(customer_5, Car_5), (customer_5, Car_2)]
+ 63 5 ... [(customer_3, Car_3), (customer_3, Car_0), (cu...
+ 64 5 ... [(customer_6, Car_6), (customer_6, Car_4)]
+ 65 5 ... [(customer_0, Car_2), (customer_0, Car_7)]
+
+
+
+
+Save Rule Trace
+---------------
+This function saves the trace of the program as two pandas dataframes (one for nodes, one for edges).
+This includes every change that has occurred to the interpretation. If ``atom_trace`` was set to true,
+this provides full explainability of why interpretations changed.
+
+Infer Edges Tutorial Example
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To see ``save_rule_trace`` in action, we will look at an example usage in PyReason's Infer Edges Tutorial.
+
+.. note::
+ Find the full, explained tutorial here `here `_.
+
+This tutorial takes a graph with airports and their connections.
+
+We will only add the ``save_rule_trace`` function after the interpretation:
+
+.. code:: python
+
+ interpretation = pr.reason(timesteps=1)
+ pr.save_rule_trace(interpretation, folder='./rule_trace_output')
+
+Expected Output
+^^^^^^^^^^^^^^^^
+Using ``save_rule_trace``, the expected output is:
+
+**Saved Nodes Trace:**
+
+The nodes trace will be saved as a CSV file in the specified folder. It will contain the time, the fixed-point operation, the node, and the clause information that led to the change in each timestep. Here's an example snippet of how the data will look when saved:
+
+Click `here `_ for the full table.
+
+
+
+.. code:: text
+
+ Time,Fixed-Point-Operation,Node,Label,Old Bound,New Bound,Occurred Due To
+ 0,0,Amsterdam_Airport_Schiphol,Amsterdam_Airport_Schiphol,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Riga_International_Airport,Riga_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Chișinău_International_Airport,Chișinău_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Yali,Yali,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Düsseldorf_Airport,Düsseldorf_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Pobedilovo_Airport,Pobedilovo_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Dubrovnik_Airport,Dubrovnik_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,HévÃz-Balaton_Airport,HévÃz-Balaton_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Athens_International_Airport,Athens_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+ 0,0,Vnukovo_International_Airport,Vnukovo_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+
+
+
+**Saved Edges Trace:**
+
+The edges trace will be saved as another CSV file. It will contain the time, the edge relationship changes, and the clauses that were involved. Here’s a snippet of how the edge trace will look when saved:
+
+Click `here `_ for the full table.
+
+
+.. code:: text
+
+ Time,Fixed-Point-Operation,Edge,Label,Old Bound,New Bound,Occurred Due To,Clause-1,Clause-2,Clause-3
+ 0,0,"('Amsterdam_Airport_Schiphol', 'Yali')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Riga_International_Airport', 'Amsterdam_Airport_Schiphol')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Riga_International_Airport', 'Düsseldorf_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Chișinău_International_Airport', 'Riga_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Düsseldorf_Airport', 'Dubrovnik_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Pobedilovo_Airport', 'Vnukovo_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Dubrovnik_Airport', 'Athens_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 0,0,"('Vnukovo_International_Airport', 'HévÃz-Balaton_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+ 1,1,"('Vnukovo_International_Airport', 'Riga_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",connected_rule_1,"[('Riga_International_Airport', 'Amsterdam_Airport_Schiphol')]",['Amsterdam_Airport_Schiphol'],['Vnukovo_International_Airport']
+
+Reading PyReasons Explainable Trace
+------------------------------------
+When using the functions ``save_rule_trace`` and ``get_rule_trace``, PyReason will output an explainable trace of the reasoning process.
+
+In the trace, the columens represent the following:
+ - ``time``: the current timestep
+ - ``Fixed-Point Operation``:
+ - ``Edge``: The edge or node that has changed if applicable
+ - ``Label``: The predicate or head of the rule
+ - ``Old Bound`` and ``New Bound``: Bound before and after reasoning step
+ - ``Occured Due to``: what the the change in the step was due to, either ``fact`` or ``rule``
+ - ``Clause-x``: What grounded the clause in the rule
+
+Get Dictionary
+--------------------------
+The function ``interpretation.get_dict()`` can be called externally to retrieve a dictionary of the interpretation values. The dictionary is triply nested from ``time`` -> ``graph component`` -> ``predicate`` -> ``bound``.
+
+Basic Tutorial Example
+^^^^^^^^^^^^^^^^
+To see ``interpretation.get_dict()`` in action we will look at the example usage in PyReasons Basic Tutorial.
+
+.. note::
+ Find the full, explained tutorial here `here `_
+
+Call ``.get_dict()`` function on the interpretation, and print using ``pprint``.
+
+.. code:: python
+
+ import pyreason as pr
+ from pprint import pprint
+
+ interpretation = pr.reason(timesteps=2)
+ interpretations_dict = interpretation.get_dict()
+ pprint(interpretations_dict)
+
+Expected Output
+^^^^^^^^^^^^^^^^
+Using ``.get_dict()``, the expected output is:
+
+
+.. code:: text
+
+ {0: {'Cat': {},
+ 'Dog': {},
+ 'John': {},
+ 'Justin': {},
+ 'Mary': {'popular': (1.0, 1.0)},
+ ('John', 'Dog'): {},
+ ('John', 'Justin'): {},
+ ('John', 'Mary'): {},
+ ('Justin', 'Cat'): {},
+ ('Justin', 'Dog'): {},
+ ('Justin', 'Mary'): {},
+ ('Mary', 'Cat'): {}},
+ 1: {'Cat': {},
+ 'Dog': {},
+ 'John': {},
+ 'Justin': {'popular': (1.0, 1.0)},
+ 'Mary': {'popular': (1.0, 1.0)},
+ ('John', 'Dog'): {},
+ ('John', 'Justin'): {},
+ ('John', 'Mary'): {},
+ ('Justin', 'Cat'): {},
+ ('Justin', 'Dog'): {},
+ ('Justin', 'Mary'): {},
+ ('Mary', 'Cat'): {}},
+ 2: {'Cat': {},
+ 'Dog': {},
+ 'John': {'popular': (1.0, 1.0)},
+ 'Justin': {'popular': (1.0, 1.0)},
+ 'Mary': {'popular': (1.0, 1.0)},
+ ('John', 'Dog'): {},
+ ('John', 'Justin'): {},
+ ('John', 'Mary'): {},
+ ('Justin', 'Cat'): {},
+ ('Justin', 'Dog'): {},
+ ('Justin', 'Mary'): {},
+ ('Mary', 'Cat'): {}}}
+
+
+``interpretation.get_dict()`` first goes through each time step, then the componenets of the graph, and finally the predicates and bounds.
diff --git a/docs/source/user_guide/7_jupyter_notebook_usage.rst b/docs/source/user_guide/7_jupyter_notebook_usage.rst
new file mode 100644
index 00000000..4eb2bea0
--- /dev/null
+++ b/docs/source/user_guide/7_jupyter_notebook_usage.rst
@@ -0,0 +1,11 @@
+Jupyter Notebook Usage
+===========================
+
+.. warning::
+Using PyReason in a Jupyter Notebook can be a little tricky. And it is recommended to run PyReason in a normal python file.
+However, if you want to use PyReason in a Jupyter Notebook, make sure you understand the points below.
+
+
+1. When using functions like ``add_rule`` or ``add_fact`` in a Jupyter Notebook, make sure to run the cell only once. Running the cell multiple times will add the same rule/fact multiple times. It is recommended to store all the rules and facts in an array and then add them all at once in one cell towards the end
+2. Functions like ``load_graph`` and ``load_graphml`` which are run multiple times can also have the same issue. Make sure to run them only once.
+
diff --git a/docs/source/user_guide/8_advanced_usage.rst b/docs/source/user_guide/8_advanced_usage.rst
new file mode 100644
index 00000000..9813ab83
--- /dev/null
+++ b/docs/source/user_guide/8_advanced_usage.rst
@@ -0,0 +1,21 @@
+Advanced Usage of PyReason
+===========================
+
+PyReason is a powerful tool that can be used to reason over complex systems. This section outlines some advanced usage of PyReason.
+
+Reasoning Convergence
+---------------------
+PyReason uses a fixed point iteration algorithm to reason over the graph. This means that the reasoning process will continue
+until the graph reaches a fixed point, i.e., no new facts can be inferred. The fixed point iteration algorithm is guaranteed to converge for acyclic graphs.
+However, for cyclic graphs, the algorithm may not converge, and the user may need to set certain values to ensure convergence.
+The reasoner contains a few settings that can be used to control the convergence of the reasoning process, and can be set when calling
+``pr.reason(...)``
+
+1. ``convergence_threshold`` **(int, optional)**: The convergence threshold is the maximum number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1
+2. ``convergence_bound_threshold`` **(float, optional)**: The convergence bound threshold is the maximum difference between the bounds of the interpretations at each timestep or fixed point operation until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1
+
+Reasoning Multiple Times
+-------------------------
+PyReason allows you to reason over the graph multiple times. This can be useful when you want to reason over the graph iteratively
+and add facts that were not available before. To reason over the graph multiple times, you can set ``again=True`` in ``pr.reason(again=True)``.
+To specify additional facts, use the ``node_facts`` or ``edge_facts`` parameters in ``pr.reason(...)``. These parameters allow you to add additional facts to the graph before reasoning again.
\ No newline at end of file
diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst
new file mode 100644
index 00000000..584d3f01
--- /dev/null
+++ b/docs/source/user_guide/index.rst
@@ -0,0 +1,15 @@
+User Guide
+==========
+
+In this section we demonstrate the functionality of the `pyreason` library and how to use it.
+
+
+.. toctree::
+ :caption: Contents:
+ :maxdepth: 2
+ :glob:
+
+ ./*
+
+
+
diff --git a/examples/advanced_graph_ex.py b/examples/advanced_graph_ex.py
new file mode 100644
index 00000000..b25aa055
--- /dev/null
+++ b/examples/advanced_graph_ex.py
@@ -0,0 +1,162 @@
+from pprint import pprint
+import networkx as nx
+import pyreason as pr
+
+# Customer Data
+customers = ['John', 'Mary', 'Justin', 'Alice', 'Bob', 'Eva', 'Mike']
+customer_details = [
+ ('John', 'M', 'New York', 'NY'),
+ ('Mary', 'F', 'Los Angeles', 'CA'),
+ ('Justin', 'M', 'Chicago', 'IL'),
+ ('Alice', 'F', 'Houston', 'TX'),
+ ('Bob', 'M', 'Phoenix', 'AZ'),
+ ('Eva', 'F', 'San Diego', 'CA'),
+ ('Mike', 'M', 'Dallas', 'TX')
+]
+
+# Creating a dictionary of customers with their details
+customer_dict = {i: customer for i, customer in enumerate(customer_details)}
+
+# Pet Data
+pet_details = [
+ ('Dog', 'Mammal'),
+ ('Cat', 'Mammal'),
+ ('Rabbit', 'Mammal'),
+ ('Parrot', 'Bird'),
+ ('Fish', 'Fish')
+]
+
+# Creating a dictionary of pets with their details
+pet_dict = {i: pet for i, pet in enumerate(pet_details)}
+
+# Car Data
+car_details = [
+ ('Toyota Camry', 'Red'),
+ ('Honda Civic', 'Blue'),
+ ('Ford Focus', 'Red'),
+ ('BMW 3 Series', 'Black'),
+ ('Tesla Model S', 'Red'),
+ ('Chevrolet Bolt EV', 'White'),
+ ('Ford Mustang', 'Yellow'),
+ ('Audi A4', 'Silver'),
+ ('Mercedes-Benz C-Class', 'Grey'),
+ ('Subaru Outback', 'Green'),
+ ('Volkswagen Golf', 'Blue'),
+ ('Porsche 911', 'Black')
+]
+
+# Creating a dictionary of cars with their details
+car_dict = {i: car for i, car in enumerate(car_details)}
+
+# Travel Data (customer movements between cities)
+travels = [
+ ('John', 'Los Angeles', 'CA', 'New York', 'NY', 2),
+ ('Alice', 'Houston', 'TX', 'Phoenix', 'AZ', 5),
+ ('Eva', 'San Diego', 'CA', 'Dallas', 'TX', 1),
+ ('Mike', 'Dallas', 'TX', 'Chicago', 'IL', 3)
+]
+
+# Friendships (who is friends with whom)
+friendships = [('customer_2', 'customer_1'), ('customer_0', 'customer_1'), ('customer_0', 'customer_2'),
+ ('customer_3', 'customer_4'), ('customer_4', 'customer_0'), ('customer_5', 'customer_3'),
+ ('customer_6', 'customer_0'), ('customer_5', 'customer_6'), ('customer_4', 'customer_5'),
+ ('customer_3', 'customer_1')]
+
+# Car Ownerships (who owns which car)
+car_ownerships = [('customer_1', 'Car_0'), ('customer_2', 'Car_1'), ('customer_0', 'Car_2'), ('customer_3', 'Car_3'),
+ ('customer_4', 'Car_4'), ('customer_3', 'Car_0'), ('customer_2', 'Car_3'), ('customer_5', 'Car_5'),
+ ('customer_6', 'Car_6'), ('customer_0', 'Car_7'), ('customer_1', 'Car_8'), ('customer_4', 'Car_9'),
+ ('customer_3', 'Car_10'), ('customer_2', 'Car_11'), ('customer_5', 'Car_2'), ('customer_6', 'Car_4')]
+
+# Pet Ownerships (who owns which pet)
+pet_ownerships = [('customer_1', 'Pet_1'), ('customer_2', 'Pet_1'), ('customer_2', 'Pet_0'), ('customer_0', 'Pet_0'),
+ ('customer_3', 'Pet_2'), ('customer_4', 'Pet_2'), ('customer_5', 'Pet_3'), ('customer_6', 'Pet_4'),
+ ('customer_0', 'Pet_4')]
+
+# Create a directed graph
+g = nx.DiGraph()
+
+# Add nodes for customers
+for customer_id, details in customer_dict.items():
+ attributes = {
+ 'c_id': customer_id,
+ 'name': details[0],
+ 'gender': details[1],
+ 'city': details[2],
+ 'state': details[3],
+ }
+ name = "customer_" + str(customer_id)
+ g.add_node(name, **attributes)
+
+# Add nodes for pets
+for pet_id, details in pet_dict.items():
+ dynamic_attribute = f"Pet_{pet_id}"
+ attributes = {
+ 'pet_id': pet_id,
+ 'species': details[0],
+ 'class': details[1],
+ dynamic_attribute: 1
+ }
+ name = "Pet_" + str(pet_id)
+ g.add_node(name, **attributes)
+
+# Add nodes for cars
+for car_id, details in car_dict.items():
+ dynamic_attribute = f"Car_{car_id}"
+ attributes = {
+ 'car_id': car_id,
+ 'model': details[0],
+ 'color': details[1],
+ dynamic_attribute: 1
+ }
+ name = "Car_" + str(car_id)
+ g.add_node(name, **attributes)
+
+# Add edges for relationships
+for f1, f2 in friendships:
+ g.add_edge(f1, f2, Friends=1)
+for owner, car in car_ownerships:
+ g.add_edge(owner, car, owns_car=1, car_color_id=int(car.split('_')[1]))
+for owner, pet in pet_ownerships:
+ g.add_edge(owner, pet, owns_pet=1)
+
+# Load the graph into PyReason
+pr.load_graph(g)
+
+# Set PyReason settings
+pr.settings.verbose = True
+pr.settings.atom_trace = True
+
+# Define logical rules
+pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y)', 'popular_pet_rule'))
+pr.add_rule(pr.Rule('cool_car(x) <-1 owns_car(x,y),Car_4(y)', 'cool_car_rule'))
+pr.add_rule(pr.Rule('cool_pet(x)<-1 owns_pet(x,y),Pet_2(y)', 'cool_pet_rule'))
+pr.add_rule(pr.Rule('trendy(x) <- cool_car(x) , cool_pet(x)', 'trendy_rule'))
+
+pr.add_rule(
+ pr.Rule("car_friend(x,y) <- owns_car(x,z), owns_car(y,z)", "car_friend_rule"))
+pr.add_rule(
+ pr.Rule("same_color_car(x, y) <- owns_car(x, c1) , owns_car(y, c2)","same_car_color_rule"))
+
+
+# Add a fact about 'customer_0' being popular
+pr.add_fact(pr.Fact('popular-fact', 'popular(customer_0)', 0, 5))
+
+# Perform reasoning over 10 timesteps
+interpretation = pr.reason(timesteps=5)
+
+# Get the interpretation and display it
+interpretations_dict = interpretation.get_dict()
+pprint(interpretations_dict)
+
+# Filter and sort nodes based on specific attributes
+df1 = pr.filter_and_sort_nodes(interpretation, ['trendy', 'cool_car', 'cool_pet', 'popular'])
+
+# Filter and sort edges based on specific relationships
+df2 = pr.filter_and_sort_edges(interpretation, ['car_friend', 'same_color_car'])
+
+#pr.save_rule_trace(interpretation)
+
+# Display filtered node and edge data
+print(df1)
+print(df2)
diff --git a/examples/advanced_output.txt b/examples/advanced_output.txt
new file mode 100644
index 00000000..b24449d2
--- /dev/null
+++ b/examples/advanced_output.txt
@@ -0,0 +1,839 @@
+Interpretations:
+{0: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {},
+ 'customer_4': {},
+ 'customer_5': {},
+ 'customer_6': {},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 1: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 2: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 3: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 4: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 5: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {'popular-fac': (1.0, 1.0)},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 6: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 7: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 8: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 9: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}},
+ 10: {'Car_0': {},
+ 'Car_1': {},
+ 'Car_10': {},
+ 'Car_11': {},
+ 'Car_2': {},
+ 'Car_3': {},
+ 'Car_4': {},
+ 'Car_5': {},
+ 'Car_6': {},
+ 'Car_7': {},
+ 'Car_8': {},
+ 'Car_9': {},
+ 'Pet_0': {},
+ 'Pet_1': {},
+ 'Pet_2': {},
+ 'Pet_3': {},
+ 'Pet_4': {},
+ 'customer_0': {},
+ 'customer_1': {},
+ 'customer_2': {},
+ 'customer_3': {'cool_pet': (1.0, 1.0)},
+ 'customer_4': {'cool_car': (1.0, 1.0),
+ 'cool_pet': (1.0, 1.0),
+ 'trendy': (1.0, 1.0)},
+ 'customer_5': {},
+ 'customer_6': {'cool_car': (1.0, 1.0)},
+ 'popular-fac': {},
+ ('customer_0', 'Car_2'): {},
+ ('customer_0', 'Car_7'): {},
+ ('customer_0', 'Pet_0'): {},
+ ('customer_0', 'Pet_4'): {},
+ ('customer_0', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_0', 'customer_2'): {'same_color_car': (1.0, 1.0)},
+ ('customer_1', 'Car_0'): {},
+ ('customer_1', 'Car_8'): {},
+ ('customer_1', 'Pet_1'): {},
+ ('customer_2', 'Car_1'): {},
+ ('customer_2', 'Car_11'): {},
+ ('customer_2', 'Car_3'): {},
+ ('customer_2', 'Pet_0'): {},
+ ('customer_2', 'Pet_1'): {},
+ ('customer_2', 'customer_1'): {'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'Car_0'): {},
+ ('customer_3', 'Car_10'): {},
+ ('customer_3', 'Car_3'): {},
+ ('customer_3', 'Pet_2'): {},
+ ('customer_3', 'customer_1'): {'car_friend': (1.0, 1.0),
+ 'same_color_car': (1.0, 1.0)},
+ ('customer_3', 'customer_4'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'Car_4'): {},
+ ('customer_4', 'Car_9'): {},
+ ('customer_4', 'Pet_2'): {},
+ ('customer_4', 'customer_0'): {'same_color_car': (1.0, 1.0)},
+ ('customer_4', 'customer_5'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'Car_2'): {},
+ ('customer_5', 'Car_5'): {},
+ ('customer_5', 'Pet_3'): {},
+ ('customer_5', 'customer_3'): {'same_color_car': (1.0, 1.0)},
+ ('customer_5', 'customer_6'): {'same_color_car': (1.0, 1.0)},
+ ('customer_6', 'Car_4'): {},
+ ('customer_6', 'Car_6'): {},
+ ('customer_6', 'Pet_4'): {},
+ ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}}}
+
+Filtered Nodes:
+[Empty DataFrame
+Columns: [component, trendy, cool_car, cool_pet, popular]
+Index: [], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1], component trendy cool_car cool_pet popular
+0 customer_4 [1.0, 1.0] [1.0, 1.0] [1.0, 1.0] [0, 1]
+1 customer_6 [0, 1] [1.0, 1.0] [0, 1] [0, 1]
+2 customer_3 [0, 1] [0, 1] [1.0, 1.0] [0, 1]]
+Filtered Edges:
+[ component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0], component car_friend same_color_car
+0 (customer_3, customer_1) [1.0, 1.0] [1.0, 1.0]
+1 (customer_0, customer_1) [0, 1] [1.0, 1.0]
+2 (customer_0, customer_2) [0, 1] [1.0, 1.0]
+3 (customer_2, customer_1) [0, 1] [1.0, 1.0]
+4 (customer_3, customer_4) [0, 1] [1.0, 1.0]
+5 (customer_4, customer_0) [0, 1] [1.0, 1.0]
+6 (customer_4, customer_5) [0, 1] [1.0, 1.0]
+7 (customer_5, customer_3) [0, 1] [1.0, 1.0]
+8 (customer_5, customer_6) [0, 1] [1.0, 1.0]
+9 (customer_6, customer_0) [0, 1] [1.0, 1.0]]
\ No newline at end of file
diff --git a/examples/annotation_function_ex.py b/examples/annotation_function_ex.py
new file mode 100644
index 00000000..7250d707
--- /dev/null
+++ b/examples/annotation_function_ex.py
@@ -0,0 +1,158 @@
+# Test if annotation functions work
+import pyreason as pr
+import numba
+import numpy as np
+import networkx as nx
+
+
+
+
+@numba.njit
+def avg_ann_fn(annotations, weights):
+ # annotations contains the bounds of the atoms that were used to ground the rule. It is a nested list that contains a list for each clause
+ # You can access for example the first grounded atom's bound by doing: annotations[0][0].lower or annotations[0][0].upper
+
+ # We want the normalised sum of the bounds of the grounded atoms
+ sum_upper_bounds = 0
+ sum_lower_bounds = 0
+ num_atoms = 0
+ for clause in annotations:
+ for atom in clause:
+ sum_lower_bounds += atom.lower
+ sum_upper_bounds += atom.upper
+ num_atoms += 1
+
+ a = sum_lower_bounds / num_atoms
+ b = sum_upper_bounds / num_atoms
+ return a, b
+
+
+
+#Annotation function that returns average of both upper and lower bounds
+def average_annotation_function():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
+ pr.settings.allow_ground_rules = True
+
+ pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
+ pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))
+
+ pr.add_annotation_function(avg_ann_fn)
+ pr.add_rule(pr.Rule('average_function(A, B):avg_ann_fn <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
+
+ interpretation = pr.reason(timesteps=1)
+
+ dataframes = pr.filter_and_sort_edges(interpretation, ['average_function'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ assert interpretation.query('average_function(A, B) : [0.105, 1]'), 'Average function should be [0.105, 1]'
+
+#average_annotation_function()
+
+
+@numba.njit
+def map_interval(t, a, b, c, d):
+ """
+ Maps a value `t` from the interval [a, b] to the interval [c, d] using the formula:
+
+ f(t) = c + ((d - c) / (b - a)) * (t - a)
+
+ Parameters:
+ - t: The value to be mapped.
+ - a: The lower bound of the original interval.
+ - b: The upper bound of the original interval.
+ - c: The lower bound of the target interval.
+ - d: The upper bound of the target interval.
+
+ Returns:
+ - The value `t` mapped to the new interval [c, d].
+ """
+ # Apply the formula to map the value t
+ mapped_value = c + ((d - c) / (b - a)) * (t - a)
+
+ return mapped_value
+
+
+
+
+@numba.njit
+def lin_comb_ann_fn(annotations, weights):
+ sum_lower_comb = 0
+ sum_upper_comb = 0
+ num_atoms = 0
+ constant = 0.2
+ # Iterate over the clauses in the rule
+ for clause in annotations:
+ for atom in clause:
+
+ #weight = weights[clause][atom]
+ # Apply the weights to the lower and upper bounds, and accumulate
+ sum_lower_comb += constant * atom.lower
+ sum_upper_comb += constant * atom.upper
+ num_atoms += 1
+
+
+ #if the lower and upper are equal, return [0,1]
+ if sum_lower_comb == sum_upper_comb:
+ return 0,1
+
+ if sum_lower_comb> sum_upper_comb:
+ sum_lower_comb,sum_upper_comb= sum_upper_comb, sum_lower_comb
+
+ if sum_upper_comb>1:
+ #mapped_lower = map_interval(sum_lower_comb, atom.lower, atom.upper, 0,1)
+ sum_lower_comb = map_interval(sum_lower_comb, sum_lower_comb, sum_upper_comb, 0,1)
+
+ sum_upper_comb = map_interval(sum_upper_comb, sum_lower_comb, sum_upper_comb,0,1)
+
+
+ # Return the weighted linear combination of the lower and upper bounds
+ return sum_lower_comb, sum_upper_comb
+
+
+
+# Function to run the test
+def linear_combination_annotation_function():
+
+ # Reset PyReason before starting the test
+ pr.reset()
+ pr.reset_rules()
+
+ pr.settings.allow_ground_rules = True
+
+
+ # Add facts (P(A) and P(B) with bounds)
+ pr.add_fact(pr.Fact('A : [.1, 1]'))
+ pr.add_fact(pr.Fact('B : [.2, 1]'))
+ pr.add_fact(pr.Fact('C : [.4, 1]'))
+
+
+ # Register the custom annotation function with PyReason
+ pr.add_annotation_function(lin_comb_ann_fn)
+
+ # Define a rule that uses this linear combination function, FIX THIS to be for lin comb?
+ pr.add_rule(pr.Rule('linear_combination_function(A, B):lin_comb_ann_fn <- A:[0, 1], B:[0, 1], C:[0, 1]', infer_edges=True))
+
+ # Perform reasoning for 1 timestep
+ interpretation = pr.reason(timesteps=1)
+
+ # Filter the results for the computed 'linear_combination_function' edges
+ dataframes = pr.filter_and_sort_edges(interpretation, ['linear_combination_function'])
+
+ # Print the resulting dataframes for each timestep
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ # Assert that the linear combination function gives the expected result (adjusted for weights)
+ # Example assertion based on weights and bounds; adjust the expected result based on the weights
+ assert interpretation.query('linear_combination_function(A, B, C) : [0.24000000000000005, 0.6000000000000001]'), 'Linear combination function should be [0.24005, 0.60001]'
+
+# Run the test function
+linear_combination_annotation_function()
diff --git a/examples/basic_tutorial_ex.py b/examples/basic_tutorial_ex.py
new file mode 100644
index 00000000..f103646e
--- /dev/null
+++ b/examples/basic_tutorial_ex.py
@@ -0,0 +1,79 @@
+# Test if the simple hello world program works
+import pyreason as pr
+import faulthandler
+import networkx as nx
+from typing import Tuple
+from pprint import pprint
+
+
+
+# Reset PyReason
+pr.reset()
+pr.reset_rules()
+pr.reset_settings()
+
+
+# ================================ CREATE GRAPH====================================
+# Create a Directed graph
+g = nx.DiGraph()
+
+# Add the nodes
+g.add_nodes_from(['John', 'Mary', 'Justin'])
+g.add_nodes_from(['Dog', 'Cat'])
+
+# Add the edges and their attributes. When an attribute = x which is <= 1, the annotation
+# associated with it will be [x,1]. NOTE: These attributes are immutable
+# Friend edges
+g.add_edge('Justin', 'Mary', Friends=1)
+g.add_edge('John', 'Mary', Friends=1)
+g.add_edge('John', 'Justin', Friends=1)
+
+# Pet edges
+g.add_edge('Mary', 'Cat', owns=1)
+g.add_edge('Justin', 'Cat', owns=1)
+g.add_edge('Justin', 'Dog', owns=1)
+g.add_edge('John', 'Dog', owns=1)
+
+
+# Modify pyreason settings to make verbose
+pr.settings.verbose = True # Print info to screen
+# pr.settings.optimize_rules = False # Disable rule optimization for debugging
+
+# Load all the files into pyreason
+pr.load_graph(g)
+pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
+pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
+
+# Run the program for two timesteps to see the diffusion take place
+faulthandler.enable()
+interpretation = pr.reason(timesteps=2)
+pr.save_rule_trace(interpretation)
+
+interpretations_dict = interpretation.get_dict()
+print("stra")
+pprint(interpretations_dict)
+print("end")
+#Display the changes in the interpretation for each timestep
+dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])
+for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+
+
+assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person'
+assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people'
+assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people'
+
+# Mary should be popular in all three timesteps
+assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps'
+assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps'
+assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps'
+
+# Justin should be popular in timesteps 1, 2
+assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps'
+assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps'
+
+# John should be popular in timestep 3
+assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps'
diff --git a/examples/csv outputs/advanced_rule_trace_edges_20241119-012153.csv b/examples/csv outputs/advanced_rule_trace_edges_20241119-012153.csv
new file mode 100644
index 00000000..1f937539
--- /dev/null
+++ b/examples/csv outputs/advanced_rule_trace_edges_20241119-012153.csv
@@ -0,0 +1,67 @@
+Time,Fixed-Point-Operation,Edge,Label,Old Bound,New Bound,Occurred Due To,Clause-1,Clause-2
+0,1,"('customer_3', 'customer_1')",car_friend,"[0.0,1.0]","[1.0,1.0]",car_friend_rule,"[('customer_3', 'Car_0')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+0,1,"('customer_0', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+0,1,"('customer_0', 'customer_2')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]"
+0,1,"('customer_2', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+0,1,"('customer_3', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+0,1,"('customer_3', 'customer_4')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]"
+0,1,"('customer_4', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+0,1,"('customer_4', 'customer_5')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]"
+0,1,"('customer_5', 'customer_3')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]"
+0,1,"('customer_5', 'customer_6')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]"
+0,1,"('customer_6', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+1,3,"('customer_3', 'customer_1')",car_friend,"[0.0,1.0]","[1.0,1.0]",car_friend_rule,"[('customer_3', 'Car_0')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+1,3,"('customer_0', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+1,3,"('customer_0', 'customer_2')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]"
+1,3,"('customer_2', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+1,3,"('customer_3', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+1,3,"('customer_3', 'customer_4')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]"
+1,3,"('customer_4', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+1,3,"('customer_4', 'customer_5')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]"
+1,3,"('customer_5', 'customer_3')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]"
+1,3,"('customer_5', 'customer_6')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]"
+1,3,"('customer_6', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+2,5,"('customer_3', 'customer_1')",car_friend,"[0.0,1.0]","[1.0,1.0]",car_friend_rule,"[('customer_3', 'Car_0')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+2,5,"('customer_0', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+2,5,"('customer_0', 'customer_2')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]"
+2,5,"('customer_2', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+2,5,"('customer_3', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+2,5,"('customer_3', 'customer_4')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]"
+2,5,"('customer_4', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+2,5,"('customer_4', 'customer_5')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]"
+2,5,"('customer_5', 'customer_3')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]"
+2,5,"('customer_5', 'customer_6')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]"
+2,5,"('customer_6', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+3,7,"('customer_3', 'customer_1')",car_friend,"[0.0,1.0]","[1.0,1.0]",car_friend_rule,"[('customer_3', 'Car_0')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+3,7,"('customer_0', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+3,7,"('customer_0', 'customer_2')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]"
+3,7,"('customer_2', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+3,7,"('customer_3', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+3,7,"('customer_3', 'customer_4')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]"
+3,7,"('customer_4', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+3,7,"('customer_4', 'customer_5')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]"
+3,7,"('customer_5', 'customer_3')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]"
+3,7,"('customer_5', 'customer_6')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]"
+3,7,"('customer_6', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+4,9,"('customer_3', 'customer_1')",car_friend,"[0.0,1.0]","[1.0,1.0]",car_friend_rule,"[('customer_3', 'Car_0')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+4,9,"('customer_0', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+4,9,"('customer_0', 'customer_2')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]"
+4,9,"('customer_2', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+4,9,"('customer_3', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+4,9,"('customer_3', 'customer_4')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]"
+4,9,"('customer_4', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+4,9,"('customer_4', 'customer_5')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]"
+4,9,"('customer_5', 'customer_3')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]"
+4,9,"('customer_5', 'customer_6')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]"
+4,9,"('customer_6', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+5,11,"('customer_3', 'customer_1')",car_friend,"[0.0,1.0]","[1.0,1.0]",car_friend_rule,"[('customer_3', 'Car_0')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+5,11,"('customer_0', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+5,11,"('customer_0', 'customer_2')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]","[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]"
+5,11,"('customer_2', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_2', 'Car_1'), ('customer_2', 'Car_3'), ('customer_2', 'Car_11')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+5,11,"('customer_3', 'customer_1')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_1', 'Car_0'), ('customer_1', 'Car_8')]"
+5,11,"('customer_3', 'customer_4')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]","[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]"
+5,11,"('customer_4', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
+5,11,"('customer_4', 'customer_5')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_4', 'Car_4'), ('customer_4', 'Car_9')]","[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]"
+5,11,"('customer_5', 'customer_3')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_3', 'Car_3'), ('customer_3', 'Car_0'), ('customer_3', 'Car_10')]"
+5,11,"('customer_5', 'customer_6')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_5', 'Car_5'), ('customer_5', 'Car_2')]","[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]"
+5,11,"('customer_6', 'customer_0')",same_color_car,"[0.0,1.0]","[1.0,1.0]",same_car_color_rule,"[('customer_6', 'Car_6'), ('customer_6', 'Car_4')]","[('customer_0', 'Car_2'), ('customer_0', 'Car_7')]"
diff --git a/examples/csv outputs/advanced_rule_trace_nodes_20241119-012153.csv b/examples/csv outputs/advanced_rule_trace_nodes_20241119-012153.csv
new file mode 100644
index 00000000..f648b7e7
--- /dev/null
+++ b/examples/csv outputs/advanced_rule_trace_nodes_20241119-012153.csv
@@ -0,0 +1,32 @@
+Time,Fixed-Point-Operation,Node,Label,Old Bound,New Bound,Occurred Due To,Clause-1,Clause-2
+0,0,popular-fac,popular-fac,"[0.0,1.0]","[1.0,1.0]",popular(customer_0),,
+1,2,popular-fac,popular-fac,"[0.0,1.0]","[1.0,1.0]",popular(customer_0),,
+1,2,customer_4,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_4', 'Car_4')]",['Car_4']
+1,2,customer_6,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_6', 'Car_4')]",['Car_4']
+1,2,customer_3,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_3', 'Pet_2')]",['Pet_2']
+1,2,customer_4,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_4', 'Pet_2')]",['Pet_2']
+1,3,customer_4,trendy,"[0.0,1.0]","[1.0,1.0]",trendy_rule,['customer_4'],['customer_4']
+2,4,popular-fac,popular-fac,"[0.0,1.0]","[1.0,1.0]",popular(customer_0),,
+2,4,customer_4,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_4', 'Car_4')]",['Car_4']
+2,4,customer_6,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_6', 'Car_4')]",['Car_4']
+2,4,customer_3,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_3', 'Pet_2')]",['Pet_2']
+2,4,customer_4,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_4', 'Pet_2')]",['Pet_2']
+2,5,customer_4,trendy,"[0.0,1.0]","[1.0,1.0]",trendy_rule,['customer_4'],['customer_4']
+3,6,popular-fac,popular-fac,"[0.0,1.0]","[1.0,1.0]",popular(customer_0),,
+3,6,customer_4,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_4', 'Car_4')]",['Car_4']
+3,6,customer_6,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_6', 'Car_4')]",['Car_4']
+3,6,customer_3,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_3', 'Pet_2')]",['Pet_2']
+3,6,customer_4,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_4', 'Pet_2')]",['Pet_2']
+3,7,customer_4,trendy,"[0.0,1.0]","[1.0,1.0]",trendy_rule,['customer_4'],['customer_4']
+4,8,popular-fac,popular-fac,"[0.0,1.0]","[1.0,1.0]",popular(customer_0),,
+4,8,customer_4,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_4', 'Car_4')]",['Car_4']
+4,8,customer_6,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_6', 'Car_4')]",['Car_4']
+4,8,customer_3,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_3', 'Pet_2')]",['Pet_2']
+4,8,customer_4,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_4', 'Pet_2')]",['Pet_2']
+4,9,customer_4,trendy,"[0.0,1.0]","[1.0,1.0]",trendy_rule,['customer_4'],['customer_4']
+5,10,popular-fac,popular-fac,"[0.0,1.0]","[1.0,1.0]",popular(customer_0),,
+5,10,customer_4,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_4', 'Car_4')]",['Car_4']
+5,10,customer_6,cool_car,"[0.0,1.0]","[1.0,1.0]",cool_car_rule,"[('customer_6', 'Car_4')]",['Car_4']
+5,10,customer_3,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_3', 'Pet_2')]",['Pet_2']
+5,10,customer_4,cool_pet,"[0.0,1.0]","[1.0,1.0]",cool_pet_rule,"[('customer_4', 'Pet_2')]",['Pet_2']
+5,11,customer_4,trendy,"[0.0,1.0]","[1.0,1.0]",trendy_rule,['customer_4'],['customer_4']
diff --git a/examples/csv outputs/basic_rule_trace_nodes_20241119-012005.csv b/examples/csv outputs/basic_rule_trace_nodes_20241119-012005.csv
new file mode 100644
index 00000000..02ece211
--- /dev/null
+++ b/examples/csv outputs/basic_rule_trace_nodes_20241119-012005.csv
@@ -0,0 +1,7 @@
+Time,Fixed-Point-Operation,Node,Label,Old Bound,New Bound,Occurred Due To
+0,0,Mary,popular,-,"[1.0,1.0]",-
+1,1,Mary,popular,-,"[1.0,1.0]",-
+1,1,Justin,popular,-,"[1.0,1.0]",-
+2,2,Mary,popular,-,"[1.0,1.0]",-
+2,2,John,popular,-,"[1.0,1.0]",-
+2,2,Justin,popular,-,"[1.0,1.0]",-
diff --git a/examples/csv outputs/basic_rule_trace_nodes_20241125-114246.csv b/examples/csv outputs/basic_rule_trace_nodes_20241125-114246.csv
new file mode 100644
index 00000000..02ece211
--- /dev/null
+++ b/examples/csv outputs/basic_rule_trace_nodes_20241125-114246.csv
@@ -0,0 +1,7 @@
+Time,Fixed-Point-Operation,Node,Label,Old Bound,New Bound,Occurred Due To
+0,0,Mary,popular,-,"[1.0,1.0]",-
+1,1,Mary,popular,-,"[1.0,1.0]",-
+1,1,Justin,popular,-,"[1.0,1.0]",-
+2,2,Mary,popular,-,"[1.0,1.0]",-
+2,2,John,popular,-,"[1.0,1.0]",-
+2,2,Justin,popular,-,"[1.0,1.0]",-
diff --git a/examples/csv outputs/infer_edges_rule_trace_edges_20241119-140955.csv b/examples/csv outputs/infer_edges_rule_trace_edges_20241119-140955.csv
new file mode 100644
index 00000000..d540071b
--- /dev/null
+++ b/examples/csv outputs/infer_edges_rule_trace_edges_20241119-140955.csv
@@ -0,0 +1,10 @@
+Time,Fixed-Point-Operation,Edge,Label,Old Bound,New Bound,Occurred Due To,Clause-1,Clause-2,Clause-3
+0,0,"('Amsterdam_Airport_Schiphol', 'Yali')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Riga_International_Airport', 'Amsterdam_Airport_Schiphol')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Riga_International_Airport', 'Düsseldorf_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Chișinău_International_Airport', 'Riga_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Düsseldorf_Airport', 'Dubrovnik_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Pobedilovo_Airport', 'Vnukovo_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Dubrovnik_Airport', 'Athens_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+0,0,"('Vnukovo_International_Airport', 'HévÃz-Balaton_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact,,,
+1,1,"('Vnukovo_International_Airport', 'Riga_International_Airport')",isConnectedTo,"[0.0,1.0]","[1.0,1.0]",connected_rule_1,"[('Riga_International_Airport', 'Amsterdam_Airport_Schiphol')]",['Amsterdam_Airport_Schiphol'],['Vnukovo_International_Airport']
diff --git a/examples/csv outputs/infer_edges_rule_trace_nodes_20241119-140955.csv b/examples/csv outputs/infer_edges_rule_trace_nodes_20241119-140955.csv
new file mode 100644
index 00000000..17adc715
--- /dev/null
+++ b/examples/csv outputs/infer_edges_rule_trace_nodes_20241119-140955.csv
@@ -0,0 +1,11 @@
+Time,Fixed-Point-Operation,Node,Label,Old Bound,New Bound,Occurred Due To
+0,0,Amsterdam_Airport_Schiphol,Amsterdam_Airport_Schiphol,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Riga_International_Airport,Riga_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Chișinău_International_Airport,Chișinău_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Yali,Yali,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Düsseldorf_Airport,Düsseldorf_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Pobedilovo_Airport,Pobedilovo_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Dubrovnik_Airport,Dubrovnik_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,HévÃz-Balaton_Airport,HévÃz-Balaton_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Athens_International_Airport,Athens_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
+0,0,Vnukovo_International_Airport,Vnukovo_International_Airport,"[0.0,1.0]","[1.0,1.0]",graph-attribute-fact
diff --git a/examples/custom_threshold_ex.py b/examples/custom_threshold_ex.py
new file mode 100644
index 00000000..d52d067b
--- /dev/null
+++ b/examples/custom_threshold_ex.py
@@ -0,0 +1,76 @@
+# Test if the simple program works with thresholds defined
+import pyreason as pr
+from pyreason import Threshold
+import networkx as nx
+
+# Reset PyReason
+pr.reset()
+pr.reset_rules()
+
+
+# Create an empty graph
+G = nx.DiGraph()
+
+# Add nodes
+nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
+G.add_nodes_from(nodes)
+
+# Add edges with attribute 'HaveAccess'
+G.add_edge("Zach", "TextMessage", HaveAccess=1)
+G.add_edge("Justin", "TextMessage", HaveAccess=1)
+G.add_edge("Michelle", "TextMessage", HaveAccess=1)
+G.add_edge("Amy", "TextMessage", HaveAccess=1)
+
+
+
+# Modify pyreason settings to make verbose
+pr.reset_settings()
+pr.settings.verbose = True # Print info to screen
+
+#load the graph
+pr.load_graph(G)
+
+# add custom thresholds
+user_defined_thresholds = [
+ Threshold("greater_equal", ("number", "total"), 1),
+ Threshold("greater_equal", ("percent", "total"), 100),
+
+]
+
+pr.add_rule(
+ pr.Rule(
+ "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)",
+ "viewed_by_all_rule",
+ custom_thresholds=user_defined_thresholds,
+ )
+)
+
+pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
+pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
+pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
+pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
+
+# Run the program for three timesteps to see the diffusion take place
+interpretation = pr.reason(timesteps=3)
+
+# Display the changes in the interpretation for each timestep
+dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
+for t, df in enumerate(dataframes):
+ print(f"TIMESTEP - {t}")
+ print(df)
+ print()
+
+assert (
+ len(dataframes[0]) == 0
+), "At t=0 the TextMessage should not have been ViewedByAll"
+assert (
+ len(dataframes[2]) == 1
+), "At t=2 the TextMessage should have been ViewedByAll"
+
+# TextMessage should be ViewedByAll in t=2
+assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[
+ 0
+].ViewedByAll == [
+ 1,
+ 1,
+], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps"
diff --git a/examples/infer_edges_ex.py b/examples/infer_edges_ex.py
new file mode 100644
index 00000000..b094f122
--- /dev/null
+++ b/examples/infer_edges_ex.py
@@ -0,0 +1,68 @@
+import pyreason as pr
+import networkx as nx
+import matplotlib.pyplot as plt
+
+# Create a directed graph
+G = nx.DiGraph()
+
+# Add nodes with attributes
+nodes = [
+ ("Amsterdam_Airport_Schiphol", {"Amsterdam_Airport_Schiphol": 1}),
+ ("Riga_International_Airport", {"Riga_International_Airport": 1}),
+ ("Chișinău_International_Airport", {"Chișinău_International_Airport": 1}),
+ ("Yali", {"Yali": 1}),
+ ("Düsseldorf_Airport", {"Düsseldorf_Airport": 1}),
+ ("Pobedilovo_Airport", {"Pobedilovo_Airport": 1}),
+ ("Dubrovnik_Airport", {"Dubrovnik_Airport": 1}),
+ ("HévÃz-Balaton_Airport", {"HévÃz-Balaton_Airport": 1}),
+ ("Athens_International_Airport", {"Athens_International_Airport": 1}),
+ ("Vnukovo_International_Airport", {"Vnukovo_International_Airport": 1})
+]
+
+G.add_nodes_from(nodes)
+
+# Add edges with 'isConnectedTo' attribute
+edges = [
+ ("Pobedilovo_Airport", "Vnukovo_International_Airport", {"isConnectedTo": 1}),
+ ("Vnukovo_International_Airport", "HévÃz-Balaton_Airport", {"isConnectedTo": 1}),
+ ("Düsseldorf_Airport", "Dubrovnik_Airport", {"isConnectedTo": 1}),
+ ("Dubrovnik_Airport", "Athens_International_Airport", {"isConnectedTo": 1}),
+ ("Riga_International_Airport", "Amsterdam_Airport_Schiphol", {"isConnectedTo": 1}),
+ ("Riga_International_Airport", "Düsseldorf_Airport", {"isConnectedTo": 1}),
+ ("Chișinău_International_Airport", "Riga_International_Airport", {"isConnectedTo": 1}),
+ ("Amsterdam_Airport_Schiphol", "Yali", {"isConnectedTo": 1}),
+]
+
+G.add_edges_from(edges)
+
+
+
+pr.reset()
+pr.reset_rules()
+# Modify pyreason settings to make verbose and to save the rule trace to a file
+pr.settings.verbose = True
+pr.settings.atom_trace = True
+pr.settings.memory_profile = False
+pr.settings.canonical = True
+pr.settings.inconsistency_check = False
+pr.settings.static_graph_facts = False
+pr.settings.output_to_file = False
+pr.settings.store_interpretation_changes = True
+pr.settings.save_graph_attributes_to_trace = True
+# Load all the files into pyreason
+pr.load_graph(G)
+pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
+
+# Run the program for two timesteps to see the diffusion take place
+interpretation = pr.reason(timesteps=1)
+#pr.save_rule_trace(interpretation)
+
+# Display the changes in the interpretation for each timestep
+dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations'
+assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
+assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
diff --git a/examples/text.py b/examples/text.py
new file mode 100644
index 00000000..7f881ceb
--- /dev/null
+++ b/examples/text.py
@@ -0,0 +1,168 @@
+from pprint import pprint
+import networkx as nx
+import pyreason as pr
+
+# Customer Data
+customers = ['John', 'Mary', 'Justin', 'Alice', 'Bob', 'Eva', 'Mike']
+customer_details = [
+ ('John', 'M', 'New York', 'NY'),
+ ('Mary', 'F', 'Los Angeles', 'CA'),
+ ('Justin', 'M', 'Chicago', 'IL'),
+ ('Alice', 'F', 'Houston', 'TX'),
+ ('Bob', 'M', 'Phoenix', 'AZ'),
+ ('Eva', 'F', 'San Diego', 'CA'),
+ ('Mike', 'M', 'Dallas', 'TX')
+]
+
+# Creating a dictionary of customers with their details
+customer_dict = {i: customer for i, customer in enumerate(customer_details)}
+
+# Pet Data
+pet_details = [
+ ('Dog', 'Mammal'),
+ ('Cat', 'Mammal'),
+ ('Rabbit', 'Mammal'),
+ ('Parrot', 'Bird'),
+ ('Fish', 'Fish')
+]
+
+# Creating a dictionary of pets with their details
+pet_dict = {i: pet for i, pet in enumerate(pet_details)}
+
+# Car Data
+car_details = [
+ ('Toyota Camry', 'Red'),
+ ('Honda Civic', 'Blue'),
+ ('Ford Focus', 'Red'),
+ ('BMW 3 Series', 'Black'),
+ ('Tesla Model S', 'Red'),
+ ('Chevrolet Bolt EV', 'White'),
+ ('Ford Mustang', 'Yellow'),
+ ('Audi A4', 'Silver'),
+ ('Mercedes-Benz C-Class', 'Grey'),
+ ('Subaru Outback', 'Green'),
+ ('Volkswagen Golf', 'Blue'),
+ ('Porsche 911', 'Black')
+]
+
+# Creating a dictionary of cars with their details
+car_dict = {i: car for i, car in enumerate(car_details)}
+
+# Travel Data (customer movements between cities)
+travels = [
+ ('John', 'Los Angeles', 'CA', 'New York', 'NY', 2),
+ ('Alice', 'Houston', 'TX', 'Phoenix', 'AZ', 5),
+ ('Eva', 'San Diego', 'CA', 'Dallas', 'TX', 1),
+ ('Mike', 'Dallas', 'TX', 'Chicago', 'IL', 3)
+]
+
+# Friendships (who is friends with whom)
+friendships = [('customer_2', 'customer_1'), ('customer_0', 'customer_1'), ('customer_0', 'customer_2'),
+ ('customer_3', 'customer_4'), ('customer_4', 'customer_0'), ('customer_5', 'customer_3'),
+ ('customer_6', 'customer_0'), ('customer_5', 'customer_6'), ('customer_4', 'customer_5'),
+ ('customer_3', 'customer_1')]
+
+# Car Ownerships (who owns which car)
+car_ownerships = [('customer_1', 'Car_0'), ('customer_2', 'Car_1'), ('customer_0', 'Car_2'), ('customer_3', 'Car_3'),
+ ('customer_4', 'Car_4'), ('customer_3', 'Car_0'), ('customer_2', 'Car_3'), ('customer_5', 'Car_5'),
+ ('customer_6', 'Car_6'), ('customer_0', 'Car_7'), ('customer_1', 'Car_8'), ('customer_4', 'Car_9'),
+ ('customer_3', 'Car_10'), ('customer_2', 'Car_11'), ('customer_5', 'Car_2'), ('customer_6', 'Car_4')]
+
+# Pet Ownerships (who owns which pet)
+pet_ownerships = [('customer_1', 'Pet_1'), ('customer_2', 'Pet_1'), ('customer_2', 'Pet_0'), ('customer_0', 'Pet_0'),
+ ('customer_3', 'Pet_2'), ('customer_4', 'Pet_2'), ('customer_5', 'Pet_3'), ('customer_6', 'Pet_4'),
+ ('customer_0', 'Pet_4')]
+
+# Create a directed graph
+g = nx.DiGraph()
+
+# Add nodes for customers
+for customer_id, details in customer_dict.items():
+ attributes = {
+ 'c_id': customer_id,
+ 'name': details[0],
+ 'gender': details[1],
+ 'city': details[2],
+ 'state': details[3],
+ }
+ name = "customer_" + str(customer_id)
+ g.add_node(name, **attributes)
+
+# Add nodes for pets
+for pet_id, details in pet_dict.items():
+ dynamic_attribute = f"Pet_{pet_id}"
+ attributes = {
+ 'pet_id': pet_id,
+ 'species': details[0],
+ 'class': details[1],
+ dynamic_attribute: 1
+ }
+ name = "Pet_" + str(pet_id)
+ g.add_node(name, **attributes)
+
+# Add nodes for cars
+for car_id, details in car_dict.items():
+ dynamic_attribute = f"Car_{car_id}"
+ attributes = {
+ 'car_id': car_id,
+ 'model': details[0],
+ 'color': details[1],
+ dynamic_attribute: 1
+ }
+ name = "Car_" + str(car_id)
+ g.add_node(name, **attributes)
+
+# Add edges for relationships
+for f1, f2 in friendships:
+ g.add_edge(f1, f2, Friends=1)
+for owner, car in car_ownerships:
+ g.add_edge(owner, car, owns_car=1, car_color_id=int(car.split('_')[1]))
+for owner, pet in pet_ownerships:
+ g.add_edge(owner, pet, owns_pet=1)
+
+# Load the graph into PyReason
+pr.load_graph(g)
+
+# Set PyReason settings
+pr.settings.verbose = True
+pr.settings.atom_trace = True
+
+# Define logical rules
+pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y)', 'popular_pet_rule'))
+pr.add_rule(pr.Rule('cool_car(x) <-1 owns_car(x,y),Car_4(y)', 'cool_car_rule'))
+pr.add_rule(pr.Rule('cool_pet(x)<-1 owns_pet(x,y),Pet_2(y)', 'cool_pet_rule'))
+pr.add_rule(pr.Rule('trendy(x) <- cool_car(x) , cool_pet(x)', 'trendy_rule'))
+
+pr.add_rule(
+ pr.Rule("car_friend(x,y) <- owns_car(x,z), owns_car(y,z)", "car_friend_rule"))
+pr.add_rule(
+ pr.Rule("same_color_car(x, y) <- owns_car(x, c1) , owns_car(y, c2)","same_car_color_rule"))
+
+
+# Add a fact about 'customer_0' being popular
+pr.add_fact(pr.Fact('popular-fact', 'popular(customer_0)', 0, 5))
+
+# Perform reasoning over 10 timesteps
+interpretation = pr.reason(timesteps=10)
+
+# Get the interpretation and display it
+interpretations_dict = interpretation.get_dict()
+
+# Open a file to write the output
+with open('output.txt', 'w') as f:
+ # Write the interpretation dict to the file
+ f.write("Interpretations:\n")
+ pprint(interpretations_dict, stream=f) # Using pprint to format the output nicely
+
+ # Filter and sort nodes based on specific attributes
+ df1 = pr.filter_and_sort_nodes(interpretation, ['trendy', 'cool_car', 'cool_pet', 'popular'])
+
+ # Filter and sort edges based on specific relationships
+ df2 = pr.filter_and_sort_edges(interpretation, ['car_friend', 'same_color_car'])
+
+ # Write filtered node and edge data to the file
+ f.write("\nFiltered Nodes:\n")
+ f.write(str(df1)) # Convert the DataFrame or list to string for writing
+
+ f.write("\nFiltered Edges:\n")
+ f.write(str(df2)) # Convert the DataFrame or list to string for writing
diff --git a/media/group_chat_graph.png b/media/group_chat_graph.png
index dc9afac6..9da3f58f 100644
Binary files a/media/group_chat_graph.png and b/media/group_chat_graph.png differ
diff --git a/media/infer_edges1.png b/media/infer_edges1.png
new file mode 100644
index 00000000..2c66341d
Binary files /dev/null and b/media/infer_edges1.png differ
diff --git a/media/infer_edges11.png b/media/infer_edges11.png
new file mode 100644
index 00000000..1af9d318
Binary files /dev/null and b/media/infer_edges11.png differ
diff --git a/media/infer_edges2.png b/media/infer_edges2.png
new file mode 100644
index 00000000..5c874c7c
Binary files /dev/null and b/media/infer_edges2.png differ
diff --git a/pyreason/.cache_status.yaml b/pyreason/.cache_status.yaml
index 32458f5d..71173842 100644
--- a/pyreason/.cache_status.yaml
+++ b/pyreason/.cache_status.yaml
@@ -1 +1 @@
-initialized: false
+initialized: true
diff --git a/pyreason/__init__.py b/pyreason/__init__.py
index 85f8319a..15fec585 100755
--- a/pyreason/__init__.py
+++ b/pyreason/__init__.py
@@ -8,23 +8,33 @@
from pyreason.pyreason import *
import yaml
+from importlib.metadata import version
+from pkg_resources import get_distribution, DistributionNotFound
+
+try:
+ __version__ = get_distribution(__name__).version
+except DistributionNotFound:
+ # package is not installed
+ pass
with open(cache_status_path) as file:
cache_status = yaml.safe_load(file)
if not cache_status['initialized']:
- print('Imported PyReason for the first time. Initializing ... this will take a minute')
+ print('Imported PyReason for the first time. Initializing caches for faster runtimes ... this will take a minute')
graph_path = os.path.join(package_path, 'examples', 'hello-world', 'friends_graph.graphml')
settings.verbose = False
load_graphml(graph_path)
add_rule(Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
- add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
+ add_fact(Fact('popular(Mary)', 'popular_fact', 0, 2))
reason(timesteps=2)
reset()
reset_rules()
+ print('PyReason initialized!')
+ print()
# Update cache status
cache_status['initialized'] = True
diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py
index 48ae672f..0948368f 100755
--- a/pyreason/pyreason.py
+++ b/pyreason/pyreason.py
@@ -5,6 +5,7 @@
import sys
import pandas as pd
import memory_profiler as mp
+import warnings
from typing import List, Type, Callable, Tuple
from pyreason.scripts.utils.output import Output
@@ -13,19 +14,42 @@
from pyreason.scripts.utils.graphml_parser import GraphmlParser
import pyreason.scripts.utils.yaml_parser as yaml_parser
import pyreason.scripts.utils.rule_parser as rule_parser
+import pyreason.scripts.utils.filter_ruleset as ruleset_filter
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.rules.rule import Rule
from pyreason.scripts.threshold.threshold import Threshold
+from pyreason.scripts.query.query import Query
import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+from pyreason.scripts.utils.reorder_clauses import reorder_clauses
# USER VARIABLES
class _Settings:
def __init__(self):
+ self.__verbose = None
+ self.__output_to_file = None
+ self.__output_file_name = None
+ self.__graph_attribute_parsing = None
+ self.__abort_on_inconsistency = None
+ self.__memory_profile = None
+ self.__reverse_digraph = None
+ self.__atom_trace = None
+ self.__save_graph_attributes_to_trace = None
+ self.__canonical = None
+ self.__persistent = None
+ self.__inconsistency_check = None
+ self.__static_graph_facts = None
+ self.__store_interpretation_changes = None
+ self.__parallel_computing = None
+ self.__update_mode = None
+ self.__allow_ground_rules = None
+ self.reset()
+
+ def reset(self):
self.__verbose = True
self.__output_to_file = False
self.__output_file_name = 'pyreason_output'
@@ -36,11 +60,13 @@ def __init__(self):
self.__atom_trace = False
self.__save_graph_attributes_to_trace = False
self.__canonical = False
+ self.__persistent = False
self.__inconsistency_check = True
self.__static_graph_facts = True
self.__store_interpretation_changes = True
self.__parallel_computing = False
self.__update_mode = 'intersection'
+ self.__allow_ground_rules = False
@property
def verbose(self) -> bool:
@@ -119,12 +145,21 @@ def save_graph_attributes_to_trace(self) -> bool:
@property
def canonical(self) -> bool:
- """Returns whether the interpretation is canonical or non-canonical. Default is False
+ """DEPRECATED, use persistent instead
+ Returns whether the interpretation is canonical or non-canonical. Default is False
:return: bool
"""
- return self.__canonical
-
+ return self.__persistent
+
+ @property
+ def persistent(self) -> bool:
+ """Returns whether the interpretation is persistent (Does not reset bounds at each timestep). Default is False
+
+ :return: bool
+ """
+ return self.__persistent
+
@property
def inconsistency_check(self) -> bool:
"""Returns whether to check for inconsistencies in the interpretation or not. Default is True
@@ -167,6 +202,14 @@ def update_mode(self) -> str:
"""
return self.__update_mode
+ @property
+ def allow_ground_rules(self) -> bool:
+ """Returns whether rules can have ground atoms or not. Default is False
+
+ :return: bool
+ """
+ return self.__allow_ground_rules
+
@verbose.setter
def verbose(self, value: bool) -> None:
"""Set verbose mode. Default is True
@@ -289,8 +332,20 @@ def canonical(self, value: bool) -> None:
if not isinstance(value, bool):
raise TypeError('value has to be a bool')
else:
- self.__canonical = value
-
+ self.__persistent = value
+
+ @persistent.setter
+ def persistent(self, value: bool) -> None:
+ """Whether the interpretation should be canonical where bounds are reset at each timestep or not
+
+ :param value: Whether to reset all bounds at each timestep (non-persistent) or (persistent)
+ :raises TypeError: If not bool raise error
+ """
+ if not isinstance(value, bool):
+ raise TypeError('value has to be a bool')
+ else:
+ self.__persistent = value
+
@inconsistency_check.setter
def inconsistency_check(self, value: bool) -> None:
"""Whether to check for inconsistencies in the interpretation or not
@@ -354,15 +409,26 @@ def update_mode(self, value: str) -> None:
else:
self.__update_mode = value
+ @allow_ground_rules.setter
+ def allow_ground_rules(self, value: bool) -> None:
+ """Allow ground atoms to be used in rules when possible. Default is False
+
+ :param value: Whether to allow ground atoms or not
+ :raises TypeError: If not bool raise error
+ """
+ if not isinstance(value, bool):
+ raise TypeError('value has to be a bool')
+ else:
+ self.__allow_ground_rules = value
+
# VARIABLES
__graph = None
__rules = None
+__clause_maps = None
__node_facts = None
__edge_facts = None
__ipl = None
-__node_labels = None
-__edge_labels = None
__specific_node_labels = None
__specific_edge_labels = None
@@ -384,11 +450,17 @@ def reset():
"""Resets certain variables to None to be able to do pr.reason() multiple times in a program
without memory blowing up
"""
- global __node_facts, __edge_facts, __node_labels, __edge_labels
+ global __node_facts, __edge_facts
__node_facts = None
__edge_facts = None
- __node_labels = None
- __edge_labels = None
+
+
+def get_rules():
+ """
+ Returns the rules
+ """
+ global __rules
+ return __rules
def reset_rules():
@@ -399,6 +471,22 @@ def reset_rules():
__rules = None
+def reset_graph():
+ """
+ Resets graph to none
+ """
+ global __graph
+ __graph = None
+
+
+def reset_settings():
+ """
+ Resets settings to default
+ """
+ global settings
+ settings.reset()
+
+
# FUNCTIONS
def load_graphml(path: str) -> None:
"""Loads graph from GraphMl file path into program
@@ -451,6 +539,18 @@ def load_inconsistent_predicate_list(path: str) -> None:
__ipl = yaml_parser.parse_ipl(path)
+def add_inconsistent_predicate(pred1: str, pred2: str) -> None:
+ """Add an inconsistent predicate pair to the IPL
+
+ :param pred1: First predicate in the inconsistent pair
+ :param pred2: Second predicate in the inconsistent pair
+ """
+ global __ipl
+ if __ipl is None:
+ __ipl = numba.typed.List.empty_list(numba.types.Tuple((label.label_type, label.label_type)))
+ __ipl.append((label.Label(pred1), label.Label(pred2)))
+
+
def add_rule(pr_rule: Rule) -> None:
"""Add a rule to pyreason from text format. This format is not as modular as the YAML format.
"""
@@ -459,6 +559,11 @@ def add_rule(pr_rule: Rule) -> None:
# Add to collection of rules
if __rules is None:
__rules = numba.typed.List.empty_list(rule.rule_type)
+
+ # Generate name for rule if not set
+ if pr_rule.rule.get_rule_name() is None:
+ pr_rule.rule.set_rule_name(f'rule_{len(__rules)}')
+
__rules.append(pr_rule.rule)
@@ -487,16 +592,20 @@ def add_fact(pyreason_fact: Fact) -> None:
"""
global __node_facts, __edge_facts
+ if __node_facts is None:
+ __node_facts = numba.typed.List.empty_list(fact_node.fact_type)
+ if __edge_facts is None:
+ __edge_facts = numba.typed.List.empty_list(fact_edge.fact_type)
+
if pyreason_fact.type == 'node':
- f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.label, pyreason_fact.interval, pyreason_fact.t_lower, pyreason_fact.t_upper, pyreason_fact.static)
- if __node_facts is None:
- __node_facts = numba.typed.List.empty_list(fact_node.fact_type)
+ if pyreason_fact.name is None:
+ pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}'
+ f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static)
__node_facts.append(f)
-
else:
- f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.label, pyreason_fact.interval, pyreason_fact.t_lower, pyreason_fact.t_upper, pyreason_fact.static)
- if __edge_facts is None:
- __edge_facts = numba.typed.List.empty_list(fact_edge.fact_type)
+ if pyreason_fact.name is None:
+ pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}'
+ f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static)
__edge_facts.append(f)
@@ -515,12 +624,13 @@ def add_annotation_function(function: Callable) -> None:
__annotation_functions.append(function)
-def reason(timesteps: int=-1, convergence_threshold: int=-1, convergence_bound_threshold: float=-1, again: bool=False, node_facts: List[Type[fact_node.Fact]]=None, edge_facts: List[Type[fact_edge.Fact]]=None):
+def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, node_facts: List[Type[fact_node.Fact]] = None, edge_facts: List[Type[fact_edge.Fact]] = None):
"""Function to start the main reasoning process. Graph and rules must already be loaded.
:param timesteps: Max number of timesteps to run. -1 specifies run till convergence. If reasoning again, this is the number of timesteps to reason for extra (no zero timestep), defaults to -1
- :param convergence_threshold: Maximim number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1
+ :param convergence_threshold: Maximum number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1
:param convergence_bound_threshold: Maximum change in any interpretation (bounds) between timesteps or fixed point operations until considered convergent, defaults to -1
+ :param queries: A list of PyReason query objects that can be used to filter the ruleset based on the query. Default is None
:param again: Whether to reason again on an existing interpretation, defaults to False
:param node_facts: New node facts to use during the next reasoning process. Other facts from file will be discarded, defaults to None
:param edge_facts: New edge facts to use during the next reasoning process. Other facts from file will be discarded, defaults to None
@@ -537,10 +647,10 @@ def reason(timesteps: int=-1, convergence_threshold: int=-1, convergence_bound_t
if not again or __program is None:
if settings.memory_profile:
start_mem = mp.memory_usage(max_usage=True)
- mem_usage, interp = mp.memory_usage((_reason, [timesteps, convergence_threshold, convergence_bound_threshold]), max_usage=True, retval=True)
+ mem_usage, interp = mp.memory_usage((_reason, [timesteps, convergence_threshold, convergence_bound_threshold, queries]), max_usage=True, retval=True)
print(f"\nProgram used {mem_usage-start_mem} MB of memory")
else:
- interp = _reason(timesteps, convergence_threshold, convergence_bound_threshold)
+ interp = _reason(timesteps, convergence_threshold, convergence_bound_threshold, queries)
else:
if settings.memory_profile:
start_mem = mp.memory_usage(max_usage=True)
@@ -552,9 +662,9 @@ def reason(timesteps: int=-1, convergence_threshold: int=-1, convergence_bound_t
return interp
-def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
+def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queries):
# Globals
- global __graph, __rules, __node_facts, __edge_facts, __ipl, __node_labels, __edge_labels, __specific_node_labels, __specific_edge_labels, __graphml_parser
+ global __graph, __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
global settings, __timestamp, __program
# Assert variables are of correct type
@@ -564,16 +674,12 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
# Check variables that HAVE to be set. Exceptions
if __graph is None:
- raise Exception('Graph not loaded. Use `load_graph` to load the graphml file')
+ load_graph(nx.DiGraph())
+ if settings.verbose:
+ warnings.warn('Graph not loaded. Use `load_graph` to load the graphml file. Using empty graph')
if __rules is None:
raise Exception('There are no rules, use `add_rule` or `add_rules_from_file`')
- # Check variables that are highly recommended. Warnings
- if __node_labels is None and __edge_labels is None:
- __node_labels = numba.typed.List.empty_list(label.label_type)
- __edge_labels = numba.typed.List.empty_list(label.label_type)
- __specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.string))
- __specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.Tuple((numba.types.string, numba.types.string))))
if __node_facts is None:
__node_facts = numba.typed.List.empty_list(fact_node.fact_type)
@@ -583,7 +689,9 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
if __ipl is None:
__ipl = numba.typed.List.empty_list(numba.types.Tuple((label.label_type, label.label_type)))
- # If graph attribute parsing, add results to existing specific labels and facts
+ # Add results of graph parse to existing specific labels and facts
+ __specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.string))
+ __specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.Tuple((numba.types.string, numba.types.string))))
for label_name, nodes in __specific_graph_node_labels.items():
if label_name in __specific_node_labels:
__specific_node_labels[label_name].extend(nodes)
@@ -610,10 +718,25 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
# Convert list of annotation functions into tuple to be numba compatible
annotation_functions = tuple(__annotation_functions)
+ # Filter rules based on queries
+ if settings.verbose:
+ print('Filtering rules based on queries')
+ if queries is not None:
+ __rules = ruleset_filter.filter_ruleset(queries, __rules)
+
+ # Optimize rules by moving clauses around, only if there are more edges than nodes in the graph
+ __clause_maps = {r.get_rule_name(): {i: i for i in range(len(r.get_clauses()))} for r in __rules}
+ if len(__graph.edges) > len(__graph.nodes):
+ if settings.verbose:
+ print('Optimizing rules by moving node clauses ahead of edge clauses')
+ __rules_copy = __rules.copy()
+ __rules = numba.typed.List.empty_list(rule.rule_type)
+ for i, r in enumerate(__rules_copy):
+ r, __clause_maps[r.get_rule_name()] = reorder_clauses(r)
+ __rules.append(r)
+
# Setup logical program
- __program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.canonical, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode)
- __program.available_labels_node = __node_labels
- __program.available_labels_edge = __edge_labels
+ __program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.persistent, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules)
__program.specific_node_labels = __specific_node_labels
__program.specific_edge_labels = __specific_edge_labels
@@ -625,7 +748,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, node_facts, edge_facts):
# Globals
- global __graph, __rules, __node_facts, __edge_facts, __ipl, __node_labels, __edge_labels, __specific_node_labels, __specific_edge_labels, __graphml_parser
+ global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
global settings, __timestamp, __program
assert __program is not None, 'To run `reason_again` you need to have reasoned once before'
@@ -651,11 +774,11 @@ def save_rule_trace(interpretation, folder: str='./'):
:param interpretation: the output of `pyreason.reason()`, the final interpretation
:param folder: the folder in which to save the result, defaults to './'
"""
- global __timestamp, settings
+ global __timestamp, __clause_maps, settings
assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace'
- output = Output(__timestamp)
+ output = Output(__timestamp, __clause_maps)
output.save_rule_trace(interpretation, folder)
@@ -667,11 +790,11 @@ def get_rule_trace(interpretation) -> Tuple[pd.DataFrame, pd.DataFrame]:
:param interpretation: the output of `pyreason.reason()`, the final interpretation
:returns two pandas dataframes (nodes, edges) representing the changes that occurred during reasoning
"""
- global __timestamp, settings
+ global __timestamp, __clause_maps, settings
assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace'
- output = Output(__timestamp)
+ output = Output(__timestamp, __clause_maps)
return output.get_rule_trace(interpretation)
diff --git a/pyreason/scripts/args.py b/pyreason/scripts/args.py
deleted file mode 100755
index d974d5d1..00000000
--- a/pyreason/scripts/args.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import argparse
-
-def argparser():
- parser = argparse.ArgumentParser()
- # General
- parser.add_argument("--graph_path", type=str, required=True, help='[REQUIRED] The path pointing to the graphml file')
- # YAML
- parser.add_argument("--labels_yaml_path", type=str, required=True, help='[REQUIRED] The path pointing to the labels YAML file')
- parser.add_argument("--rules_yaml_path", type=str, required=True, help='[REQUIRED] The path pointing to the rules YAML file')
- parser.add_argument("--facts_yaml_path", type=str, required=True, help='[REQUIRED] The path pointing to the facts YAML file')
- parser.add_argument("--ipl_yaml_path", type=str, required=True, help='[REQUIRED] The path pointing to the IPL YAML file')
- # TMAX
- parser.add_argument("--timesteps", type=int, default=-1, help='The max number of timesteps to run the diffusion')
- # Profile
- parser.add_argument("--no-profile", dest='profile', action='store_false', help='Do not profile the code using cProfile. Profiling is off by Default')
- parser.add_argument("--profile", dest='profile', action='store_true', help='Profile the code using cProfile. Profiling is off by Default')
- parser.set_defaults(profile=False)
- parser.add_argument("--profile_output", type=str, default='profile_output', help='If profile is switched on, specify the file name of the profile output')
- # Output form - on screen or in file
- parser.add_argument("--no-output_to_file", dest='output_to_file', action='store_false', help='Print all output from the program onto the console screen. This is on by default')
- parser.add_argument("--output_to_file", dest='output_to_file', action='store_true', help='Print all output from the program into a file. This is off by default')
- parser.add_argument("--output_file_name", type=str, default='pyreason_output', help='If output_to_file option has been specified, name of the file to print the output')
- parser.set_defaults(output_to_file=False)
- # Graph attribute parsing
- parser.add_argument("--no-graph_attribute_parsing", dest='graph_attribute_parsing', action='store_false', help='Option to not make non fluent facts based on the attributes of the graph.')
- parser.add_argument("--graph_attribute_parsing", dest='graph_attribute_parsing', action='store_true', help='Option to make non fluent facts based on the attributes of the graph. On by default')
- parser.set_defaults(graph_attribute_parsing=True)
- # Check for inconsistencies
- parser.add_argument("--no-inconsistency_check", dest='inconsistency_check', action='store_false', help='Option to not check for any inconsistencies in the interpretation.')
- parser.add_argument("--inconsistency_check", dest='inconsistency_check', action='store_true', help='Option to check for inconsistencies in the interpretation. On by default')
- parser.set_defaults(inconsistency_check=True)
-
- # Interpretation inconsistency check (not done)
- parser.add_argument("--abort_on_inconsistency", dest='abort_on_inconsistency', action='store_true', help='Stop the program if there are inconsistencies, do not fix them automatically')
- parser.set_defaults(abort_on_inconsistency=False)
- # Memory profiling
- parser.add_argument("--no-memory_profile", dest='memory_profile', action='store_false', help='Option to disable memory profiling. Memory profiling is on by default')
- parser.add_argument("--memory_profile", dest='memory_profile', action='store_true',help='Option to enable memory profiling. Memory profiling is on by default')
- parser.set_defaults(memory_profile=True)
- # Reverse Digraph
- parser.add_argument("--reverse_digraph", dest='reverse_digraph', action='store_true', help='Option to reverse the edges of a directed graph')
- parser.set_defaults(reverse_digraph=False)
- # Rule trace with ground atoms (not done)
- parser.add_argument("--atom_trace", dest='atom_trace', action='store_true', help='Option to track the ground atoms which lead to a rule firing. This could be very memory heavy. Default is off')
- parser.set_defaults(atom_trace=False)
- parser.add_argument("--save_graph_attributes_to_trace", dest='save_graph_attributes_to_trace', action='store_true', help='Option to save graph attributes to trace. Graphs are big and turning this on can be very memory heavy. Graph attributes are represented as facts. Default is off')
- parser.set_defaults(save_graph_attributes_to_trace=False)
- # Convergence options
- parser.add_argument("--convergence_threshold", type=int, default=-1, help='Number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => Perfect convergence. This option is default')
- parser.add_argument("--convergence_bound_threshold", type=float, default=-1, help='Max change in any interpretation between timesteps or fixed point operations until considered convergent. Program will end at convergence. --convergence_threshold is default')
- # Canonical vs non-canonical
- parser.add_argument("--non-canonical", dest='canonical', action='store_false', help='Option to reset bounds of interpretation at each timestep. Default is non-canonical')
- parser.add_argument("--canonical", dest='canonical', action='store_true',help='Option to NOT reset bounds of interpretation at each timestep. Default is non-canonical')
- parser.set_defaults(canonical=False)
- # Whether to make graphml facts static or not
- parser.add_argument("--non-static_graph_facts", dest='static_graph_facts', action='store_false', help='Option to keep facts from graphml non-static, for t=0 only. Default is static')
- parser.add_argument("--static_graph_facts", dest='static_graph_facts', action='store_true',help='Option to to keep facts from graphml static, for entire program. Default is static')
- parser.set_defaults(static_graph_facts=True)
-
- # Pickling options
-
- # Filtering options
- parser.add_argument("--filter_sort_by", help='Sort output by lower or upper bound', default='lower')
- parser.add_argument('--filter_labels', nargs='+', type=str, default=[], help='Filter the output by this list of labels')
- parser.add_argument("--filter_ascending", dest='descending', action='store_false', help='Sort by ascending order instead of descending')
- parser.add_argument("--filter_descending", dest='descending', action='store_true', help='Sort by descending order instead of descending')
- parser.set_defaults(descending=True)
-
-
-
-
-
- return parser.parse_args()
\ No newline at end of file
diff --git a/pyreason/scripts/diffuse.py b/pyreason/scripts/diffuse.py
deleted file mode 100755
index 1169aae5..00000000
--- a/pyreason/scripts/diffuse.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import io
-import time
-import cProfile
-import pstats
-import sys
-import memory_profiler as mp
-
-import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
-from pyreason.scripts.program.program import Program
-import pyreason.scripts.utils.yaml_parser as yaml_parser
-from pyreason.scripts.utils.graphml_parser import GraphmlParser
-from pyreason.scripts.utils.filter import Filter
-from pyreason.scripts.utils.output import Output
-from pyreason.scripts.args import argparser
-
-
-
-def main(args):
- timestamp = time.strftime('%Y%m%d-%H%M%S')
- if args.output_to_file:
- sys.stdout = open(f"./output/{args.output_file_name}_{timestamp}.txt", "w")
-
- # Initialize parsers
- graphml_parser = GraphmlParser()
-
- start = time.time()
- graph = graphml_parser.parse_graph(args.graph_path, args.reverse_digraph)
- end = time.time()
- print('Time to read graph:', end-start)
-
- if args.graph_attribute_parsing:
- start = time.time()
- non_fluent_facts_node, non_fluent_facts_edge, specific_node_labels, specific_edge_labels = graphml_parser.parse_graph_attributes(args.static_graph_facts)
- end = time.time()
- print('Time to read graph attributes:', end-start)
- else:
- non_fluent_facts_node = []
- non_fluent_facts_edge = []
-
- tmax = args.timesteps
-
- # Initialize labels
- node_labels, edge_labels, snl, sel = yaml_parser.parse_labels(args.labels_yaml_path)
- if args.graph_attribute_parsing:
- specific_node_labels.update(snl)
- specific_edge_labels.update(sel)
- for label_name, nodes in specific_node_labels.items():
- if label_name in snl:
- snl[label_name].extend(nodes)
- else:
- snl[label_name] = nodes
-
- for label_name, edges in specific_edge_labels.items():
- if label_name in sel:
- sel[label_name].extend(edges)
- else:
- sel[label_name] = edges
- else:
- specific_node_labels = snl
- specific_edge_labels = sel
-
- # Rules come here
- rules = yaml_parser.parse_rules(args.rules_yaml_path)
-
- # Facts come here. Add non fluent facts that came from the graph
- facts_node, facts_edge = yaml_parser.parse_facts(args.facts_yaml_path, args.reverse_digraph)
- facts_node += non_fluent_facts_node
- facts_edge += non_fluent_facts_edge
-
- # Inconsistent predicate list
- ipl = yaml_parser.parse_ipl(args.ipl_yaml_path)
-
- # Program comes here
- program = Program(graph, facts_node, facts_edge, rules, ipl, args.reverse_digraph, args.atom_trace, args.save_graph_attributes_to_trace, args.canonical, args.inconsistency_check)
- program.available_labels_node = node_labels
- program.available_labels_edge = edge_labels
- program.specific_node_labels = specific_node_labels
- program.specific_edge_labels = specific_edge_labels
-
- # Reasoning process
- print('Graph loaded successfully, rules, labels, facts and ipl parsed successfully')
- print('Starting diffusion')
- start = time.time()
- interpretation = program.reason(tmax, args.convergence_threshold, args.convergence_bound_threshold)
- end = time.time()
- print('Time to complete diffusion:', end-start)
- print('Finished diffusion')
-
- # Save the rule trace to a file
- output = Output(timestamp)
- output.save_rule_trace(interpretation, folder='./output')
-
- # This is how you filter the dataframe to show only nodes that have success in a certain interval
- print('Filtering data...')
- filterer = Filter(interpretation.time)
- filtered_df = filterer.filter_and_sort_nodes(interpretation, labels=args.filter_labels, bound=interval.closed(0, 1), sort_by=args.filter_sort_by, descending=args.descending)
-
- # You can index into filtered_df to get a particular timestep
- # This is for each timestep
- for t in range(interpretation.time+1):
- print(f'\n TIMESTEP - {t}')
- print(filtered_df[t])
- print()
-
-
-
-
-if __name__ == "__main__":
- args = argparser()
-
-
- if args.profile:
- profiler = cProfile.Profile()
- profiler.enable()
- main(args)
- profiler.disable()
- s = io.StringIO()
- stats = pstats.Stats(profiler, stream=s).sort_stats('tottime')
- stats.print_stats()
- with open('./profiling/' + args.profile_output + '.txt', 'w+') as f:
- f.write(s.getvalue())
-
- else:
- if args.memory_profile:
- start_mem = mp.memory_usage(max_usage=True)
- mem_usage = mp.memory_usage((main, [args]), max_usage=True)
- print(f"\nProgram used {mem_usage-start_mem} MB of memory")
- else:
- main(args)
\ No newline at end of file
diff --git a/pyreason/scripts/facts/fact.py b/pyreason/scripts/facts/fact.py
index 1004cec3..44d823d0 100644
--- a/pyreason/scripts/facts/fact.py
+++ b/pyreason/scripts/facts/fact.py
@@ -1,23 +1,15 @@
-import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+import pyreason.scripts.utils.fact_parser as fact_parser
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
-from typing import Tuple
-from typing import List
-from typing import Union
-
class Fact:
- def __init__(self, name: str, component: Union[str, Tuple[str, str]], attribute: str, bound: Union[interval.Interval, List[float]], start_time: int, end_time: int, static: bool = False):
+ def __init__(self, fact_text: str, name: str = None, start_time: int = 0, end_time: int = 0, static: bool = False):
"""Define a PyReason fact that can be loaded into the program using `pr.add_fact()`
+ :param fact_text: The fact in text format. Example: `'pred(x,y) : [0.2, 1]'` or `'pred(x,y) : True'`
+ :type fact_text: str
:param name: The name of the fact. This will appear in the trace so that you know when it was applied
:type name: str
- :param component: The node or edge that whose attribute you want to change
- :type component: str | Tuple[str, str]
- :param attribute: The attribute you would like to change for the specified node/edge
- :type attribute: str
- :param bound: The bound to which you'd like to set the attribute corresponding to the specified node/edge
- :type bound: interval.Interval | List[float]
:param start_time: The timestep at which this fact becomes active
:type start_time: int
:param end_time: The last timestep this fact is active
@@ -25,23 +17,12 @@ def __init__(self, name: str, component: Union[str, Tuple[str, str]], attribute:
:param static: If the fact should be active for the entire program. In which case `start_time` and `end_time` will be ignored
:type static: bool
"""
+ pred, component, bound, fact_type = fact_parser.parse_fact(fact_text)
self.name = name
- self.t_upper = end_time
- self.t_lower = start_time
- self.component = component
- self.label = attribute
- self.interval = bound
+ self.start_time = start_time
+ self.end_time = end_time
self.static = static
-
- # Check if it is a node fact or edge fact
- if isinstance(self.component, str):
- self.type = 'node'
- else:
- self.type = 'edge'
-
- # Set label to correct type
- self.label = label.Label(attribute)
-
- # Set bound to correct type
- if isinstance(bound, list):
- self.interval = interval.closed(*bound)
+ self.pred = label.Label(pred)
+ self.component = component
+ self.bound = bound
+ self.type = fact_type
diff --git a/pyreason/scripts/facts/fact_edge.py b/pyreason/scripts/facts/fact_edge.py
index 935d40f4..bbeb3e64 100755
--- a/pyreason/scripts/facts/fact_edge.py
+++ b/pyreason/scripts/facts/fact_edge.py
@@ -12,6 +12,9 @@ def __init__(self, name, component, label, interval, t_lower, t_upper, static=Fa
def get_name(self):
return self._name
+ def set_name(self, name):
+ self._name = name
+
def get_component(self):
return self._component
diff --git a/pyreason/scripts/facts/fact_node.py b/pyreason/scripts/facts/fact_node.py
index 92e97c88..69e379eb 100755
--- a/pyreason/scripts/facts/fact_node.py
+++ b/pyreason/scripts/facts/fact_node.py
@@ -12,6 +12,9 @@ def __init__(self, name, component, label, interval, t_lower, t_upper, static=Fa
def get_name(self):
return self._name
+ def set_name(self, name):
+ self._name = name
+
def get_component(self):
return self._component
diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py
index 779ce688..57777b12 100755
--- a/pyreason/scripts/interpretation/interpretation.py
+++ b/pyreason/scripts/interpretation/interpretation.py
@@ -1,3 +1,5 @@
+from typing import Union, Tuple
+
import pyreason.scripts.numba_wrapper.numba_types.world_type as world
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
@@ -15,6 +17,12 @@
list_of_nodes = numba.types.ListType(node_type)
list_of_edges = numba.types.ListType(edge_type)
+# Type for storing clause data
+clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string)))
+
+# Type for storing refine clause data
+refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8))
+
# Type for facts to be applied
facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
@@ -37,36 +45,44 @@
numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
))
+rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean))
+rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean))
+rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))
+edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
+
class Interpretation:
- available_labels_node = []
- available_labels_edge = []
specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type))
specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type))
- def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode):
+ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules):
self.graph = graph
self.ipl = ipl
self.annotation_functions = annotation_functions
self.reverse_graph = reverse_graph
self.atom_trace = atom_trace
self.save_graph_attributes_to_rule_trace = save_graph_attributes_to_rule_trace
- self.canonical = canonical
+ self.persistent = persistent
self.inconsistency_check = inconsistency_check
self.store_interpretation_changes = store_interpretation_changes
self.update_mode = update_mode
+ self.allow_ground_rules = allow_ground_rules
+
+ # Counter for number of ground atoms for each timestep, start with zero for the zeroth timestep
+ self.num_ga = numba.typed.List.empty_list(numba.types.int64)
+ self.num_ga.append(0)
# For reasoning and reasoning again (contains previous time and previous fp operation cnt)
self.time = 0
self.prev_reasoning_data = numba.typed.List([0, 0])
# Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true
- self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)))
- self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)))
+ self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
+ self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string)
self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string)
- self.rules_to_be_applied_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)))
- self.rules_to_be_applied_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)))
+ self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type)
+ self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type)
self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type)
self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)))
@@ -84,18 +100,8 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace,
self.nodes.extend(numba.typed.List(self.graph.nodes()))
self.edges.extend(numba.typed.List(self.graph.edges()))
- # Make sure they are correct type
- if len(self.available_labels_node)==0:
- self.available_labels_node = numba.typed.List.empty_list(label.label_type)
- else:
- self.available_labels_node = numba.typed.List(self.available_labels_node)
- if len(self.available_labels_edge)==0:
- self.available_labels_edge = numba.typed.List.empty_list(label.label_type)
- else:
- self.available_labels_edge = numba.typed.List(self.available_labels_edge)
-
- self.interpretations_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels)
- self.interpretations_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels)
+ self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.specific_node_labels, self.num_ga)
+ self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.specific_edge_labels, self.num_ga)
# Setup graph neighbors and reverse neighbors
self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type))
@@ -124,32 +130,46 @@ def _init_reverse_neighbors(neighbors):
@staticmethod
@numba.njit(cache=True)
- def _init_interpretations_node(nodes, available_labels, specific_labels):
+ def _init_interpretations_node(nodes, specific_labels, num_ga):
interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type)
- # General labels
+ predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes)
+
+ # Initialize nodes
for n in nodes:
- interpretations[n] = world.World(available_labels)
+ interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type))
+
# Specific labels
for l, ns in specific_labels.items():
for n in ns:
interpretations[n].world[l] = interval.closed(0.0, 1.0)
+ num_ga[0] += 1
+
+ for l, ns in specific_labels.items():
+ predicate_map[l] = numba.typed.List(ns)
+
+ return interpretations, predicate_map
- return interpretations
-
@staticmethod
@numba.njit(cache=True)
- def _init_interpretations_edge(edges, available_labels, specific_labels):
+ def _init_interpretations_edge(edges, specific_labels, num_ga):
interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type)
- # General labels
- for e in edges:
- interpretations[e] = world.World(available_labels)
+ predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges)
+
+ # Initialize edges
+ for n in edges:
+ interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type))
+
# Specific labels
for l, es in specific_labels.items():
for e in es:
interpretations[e].world[l] = interval.closed(0.0, 1.0)
+ num_ga[0] += 1
+
+ for l, es in specific_labels.items():
+ predicate_map[l] = numba.typed.List(es)
+
+ return interpretations, predicate_map
- return interpretations
-
@staticmethod
@numba.njit(cache=True)
def _init_convergence(convergence_bound_threshold, convergence_threshold):
@@ -193,7 +213,7 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
return max_time
def _start_fp(self, rules, max_facts_time, verbose, again):
- fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again)
+ fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again)
self.time = t - 1
# If we need to reason again, store the next timestep to start from
self.prev_reasoning_data[0] = t
@@ -202,23 +222,24 @@ def _start_fp(self, rules, max_facts_time, verbose, again):
print('Fixed Point iterations:', fp_cnt)
@staticmethod
- @numba.njit(cache=True)
- def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again):
+ @numba.njit(cache=True, parallel=False)
+ def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, num_ga, verbose, again):
t = prev_reasoning_data[0]
fp_cnt = prev_reasoning_data[1]
max_rules_time = 0
timestep_loop = True
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
- rules_to_remove_idx = numba.typed.List.empty_list(numba.types.int64)
+ rules_to_remove_idx = set()
+ rules_to_remove_idx.add(-1)
while timestep_loop:
if t==tmax:
timestep_loop = False
if verbose:
with objmode():
print('Timestep:', t, flush=True)
- # Reset Interpretation at beginning of timestep if non-canonical
- if t>0 and not canonical:
+ # Reset Interpretation at beginning of timestep if non-persistent
+ if t>0 and not persistent:
# Reset nodes (only if not static)
for n in nodes:
w = interpretations_node[n].world
@@ -238,24 +259,18 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
bound_delta = 0
update = False
- # Parameters for immediate rules
- immediate_node_rule_fire = False
- immediate_edge_rule_fire = False
- immediate_rule_applied = False
- # When delta_t = 0, we don't want to check the same rule with the same node/edge after coming back to the fp operator
- nodes_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_nodes)
- edges_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_edges)
- # Initialize the above
- for i in range(len(rules)):
- nodes_to_skip[i] = numba.typed.List.empty_list(node_type)
- edges_to_skip[i] = numba.typed.List.empty_list(edge_type)
-
# Start by applying facts
# Nodes
facts_to_be_applied_node_new.clear()
+ nodes_set = set(nodes)
for i in range(len(facts_to_be_applied_node)):
- if facts_to_be_applied_node[i][0]==t:
+ if facts_to_be_applied_node[i][0] == t:
comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5]
+ # If the component is not in the graph, add it
+ if comp not in nodes_set:
+ _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node)
+ nodes_set.add(comp)
+
# Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well
if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static():
# Check if we should even store any of the changes to the rule trace etc.
@@ -273,13 +288,13 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1]))
if atom_trace:
_update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i])
-
+
else:
# Check for inconsistencies (multiple facts)
if check_consistent_node(interpretations_node, comp, (l, bnd)):
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
override = True if update_mode == 'override' else False
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override)
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override)
update = u or update
# Update convergence params
@@ -289,11 +304,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
changes_cnt += changes
# Resolve inconsistency if necessary otherwise override bounds
else:
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
if inconsistency_check:
- resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes)
+ resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode)
else:
- mode = 'graph-attribute-fact' if graph_attribute else 'fact'
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True)
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True)
update = u or update
# Update convergence params
@@ -315,9 +330,15 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Edges
facts_to_be_applied_edge_new.clear()
+ edges_set = set(edges)
for i in range(len(facts_to_be_applied_edge)):
if facts_to_be_applied_edge[i][0]==t:
comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5]
+ # If the component is not in the graph, add it
+ if comp not in edges_set:
+ _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
+ edges_set.add(comp)
+
# Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well
if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static():
# Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute
@@ -339,7 +360,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override)
update = u or update
# Update convergence params
@@ -349,11 +370,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
changes_cnt += changes
# Resolve inconsistency
else:
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
+ resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode)
else:
- mode = 'graph-attribute-fact' if graph_attribute else 'fact'
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True)
update = u or update
# Update convergence params
@@ -382,50 +403,25 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Nodes
rules_to_remove_idx.clear()
for idx, i in enumerate(rules_to_be_applied_node):
- # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied
- if immediate_node_rule_fire and rules_to_be_applied_node[-1][4]:
- i = rules_to_be_applied_node[-1]
- idx = len(rules_to_be_applied_node) - 1
-
- if i[0]==t:
- comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5]
- sources, targets, edge_l = edges_to_be_added_node_rule[idx]
- edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge)
- changes_cnt += changes
-
- # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
- if edge_l.value!='':
- for e in edges_added:
- if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
- override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
-
- update = u or update
-
- # Update convergence params
- if convergence_mode=='delta_bound':
- bound_delta = max(bound_delta, changes)
- else:
- changes_cnt += changes
- # Resolve inconsistency
- else:
- if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
- else:
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
-
- update = u or update
+ if i[0] == t:
+ comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
+ # Check for inconsistencies
+ if check_consistent_node(interpretations_node, comp, (l, bnd)):
+ override = True if update_mode == 'override' else False
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
- # Update convergence params
- if convergence_mode=='delta_bound':
- bound_delta = max(bound_delta, changes)
- else:
- changes_cnt += changes
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency
else:
- # Check for inconsistencies
- if check_consistent_node(interpretations_node, comp, (l, bnd)):
- override = True if update_mode == 'override' else False
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override)
+ if inconsistency_check:
+ resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule')
+ else:
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
@@ -433,32 +429,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
- # Resolve inconsistency
- else:
- if inconsistency_check:
- resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes)
- else:
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True)
-
- update = u or update
- # Update convergence params
- if convergence_mode=='delta_bound':
- bound_delta = max(bound_delta, changes)
- else:
- changes_cnt += changes
# Delete rules that have been applied from list by adding index to list
- rules_to_remove_idx.append(idx)
-
- # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back
- if immediate:
- # If delta_t=0 we want to apply one rule and go back to the fp operator
- # If delta_t>0 we want to come back here and apply the rest of the rules
- if immediate_edge_rule_fire:
- break
- elif not immediate_edge_rule_fire and u:
- immediate_rule_applied = True
- break
+ rules_to_remove_idx.add(idx)
# Remove from rules to be applied and edges to be applied lists after coming out from loop
rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx])
@@ -469,26 +442,20 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Edges
rules_to_remove_idx.clear()
for idx, i in enumerate(rules_to_be_applied_edge):
- # If we broke from above loop to apply more rules, then break from here
- if immediate_rule_applied and not immediate_edge_rule_fire:
- break
- # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied
- if immediate_edge_rule_fire and rules_to_be_applied_edge[-1][4]:
- i = rules_to_be_applied_edge[-1]
- idx = len(rules_to_be_applied_edge) - 1
-
- if i[0]==t:
- comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5]
+ if i[0] == t:
+ comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
sources, targets, edge_l = edges_to_be_added_edge_rule[idx]
- edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge)
+ edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
changes_cnt += changes
# Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
- if edge_l.value!='':
+ if edge_l.value != '':
for e in edges_added:
+ if interpretations_edge[e].world[edge_l].is_static():
+ continue
if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
@@ -500,9 +467,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Resolve inconsistency
else:
if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
+ resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
else:
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
@@ -516,7 +483,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Check for inconsistencies
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
# Update convergence params
@@ -527,9 +494,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Resolve inconsistency
else:
if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
+ resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
else:
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
@@ -539,17 +506,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
changes_cnt += changes
# Delete rules that have been applied from list by adding the index to list
- rules_to_remove_idx.append(idx)
-
- # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back
- if immediate:
- # If t=0 we want to apply one rule and go back to the fp operator
- # If t>0 we want to come back here and apply the rest of the rules
- if immediate_edge_rule_fire:
- break
- elif not immediate_edge_rule_fire and u:
- immediate_rule_applied = True
- break
+ rules_to_remove_idx.add(idx)
# Remove from rules to be applied and edges to be applied lists after coming out from loop
rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx])
@@ -558,61 +515,45 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
rules_to_be_applied_edge_trace[:] = numba.typed.List([rules_to_be_applied_edge_trace[i] for i in range(len(rules_to_be_applied_edge_trace)) if i not in rules_to_remove_idx])
# Fixed point
- # if update or immediate_node_rule_fire or immediate_edge_rule_fire or immediate_rule_applied:
if update:
- # Increase fp operator count only if not an immediate rule
- if not (immediate_node_rule_fire or immediate_edge_rule_fire):
- fp_cnt += 1
+ # Increase fp operator count
+ fp_cnt += 1
- for i in range(len(rules)):
+ # Lists or threadsafe operations (when parallel is on)
+ rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))])
+ rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))])
+ if atom_trace:
+ rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
+ rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
+ edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))])
+
+ for i in prange(len(rules)):
rule = rules[i]
- immediate_rule = rule.is_immediate_rule()
- immediate_node_rule_fire = False
- immediate_edge_rule_fire = False
# Only go through if the rule can be applied within the given timesteps, or we're running until convergence
delta_t = rule.get_delta()
if t + delta_t <= tmax or tmax == -1 or again:
- applicable_node_rules = _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip[i])
- applicable_edge_rules = _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip[i])
+ applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t)
# Loop through applicable rules and add them to the rules to be applied for later or next fp operation
for applicable_rule in applicable_node_rules:
- n, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
+ n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule
# If there is an edge to add or the predicate doesn't exist or the interpretation is not static
- if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static():
+ if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static():
bnd = annotate(annotation_functions, rule, annotations, rule.get_weights())
# Bound annotations in between 0 and 1
bnd_l = min(max(bnd[0], 0), 1)
bnd_u = min(max(bnd[1], 0), 1)
bnd = interval.closed(bnd_l, bnd_u)
max_rules_time = max(max_rules_time, t + delta_t)
- edges_to_be_added_node_rule.append(edges_to_add)
- rules_to_be_applied_node.append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, rule.is_static_rule()))
if atom_trace:
- rules_to_be_applied_node_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
+ rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
- # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance
- # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0
- if delta_t != 0:
- nodes_to_skip[i].append(n)
-
- # Handle loop parameters for the next (maybe) fp operation
- # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire
- # Next fp operation we will skip this rule on this node because anyway there won't be an update
+ # If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop = True
update = False
- if immediate_rule and delta_t == 0:
- # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done.
- in_loop = True
- update = True
- immediate_node_rule_fire = True
- break
-
- # Break, apply immediate rule then come back to check for more applicable rules
- if immediate_node_rule_fire:
- break
for applicable_rule in applicable_edge_rules:
e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
@@ -624,51 +565,43 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
bnd_u = min(max(bnd[1], 0), 1)
bnd = interval.closed(bnd_l, bnd_u)
max_rules_time = max(max_rules_time, t+delta_t)
- edges_to_be_added_edge_rule.append(edges_to_add)
- rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ # edges_to_be_added_edge_rule.append(edges_to_add)
+ edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add)
+ rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, rule.is_static_rule()))
if atom_trace:
- rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
-
- # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance
- # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0
- if delta_t != 0:
- edges_to_skip[i].append(e)
+ # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
+ rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
- # Handle loop parameters for the next (maybe) fp operation
- # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire
- # Next fp operation we will skip this rule on this node because anyway there won't be an update
+ # If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop = True
update = False
- if immediate_rule and delta_t == 0:
- # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done.
- in_loop = True
- update = True
- immediate_edge_rule_fire = True
- break
-
- # Break, apply immediate rule then come back to check for more applicable rules
- if immediate_edge_rule_fire:
- break
-
- # Go through all the rules and go back to applying the rules if we came here because of an immediate rule where delta_t>0
- if immediate_rule_applied and not (immediate_node_rule_fire or immediate_edge_rule_fire):
- immediate_rule_applied = False
- in_loop = True
- update = False
- continue
-
+
+ # Update lists after parallel run
+ for i in range(len(rules)):
+ if len(rules_to_be_applied_node_threadsafe[i]) > 0:
+ rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i])
+ if len(rules_to_be_applied_edge_threadsafe[i]) > 0:
+ rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i])
+ if atom_trace:
+ if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0:
+ rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i])
+ if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0:
+ rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i])
+ if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0:
+ edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i])
+
# Check for convergence after each timestep (perfect convergence or convergence specified by user)
# Check number of changed interpretations or max bound change
# User specified convergence
- if convergence_mode=='delta_interpretation':
+ if convergence_mode == 'delta_interpretation':
if changes_cnt <= convergence_delta:
if verbose:
print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation')
# Be consistent with time returned when we don't converge
t += 1
break
- elif convergence_mode=='delta_bound':
+ elif convergence_mode == 'delta_bound':
if bound_delta <= convergence_delta:
if verbose:
print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation')
@@ -678,22 +611,23 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Perfect convergence
# Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable
# If no more rules/facts to be applied
- elif convergence_mode=='perfect_convergence':
- if t>=max_facts_time and t>=max_rules_time:
+ elif convergence_mode == 'perfect_convergence':
+ if t>=max_facts_time and t >= max_rules_time:
if verbose:
print(f'\nConverged at time: {t}')
# Be consistent with time returned when we don't converge
t += 1
break
- # Increment t
+ # Increment t, update number of ground atoms
t += 1
+ num_ga.append(num_ga[-1])
return fp_cnt, t
def add_edge(self, edge, l):
# This function is useful for pyreason gym, called externally
- _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge)
+ _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1)
def add_node(self, node, labels):
# This function is useful for pyreason gym, called externally
@@ -704,19 +638,19 @@ def add_node(self, node, labels):
def delete_edge(self, edge):
# This function is useful for pyreason gym, called externally
- _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge)
+ _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge, self.num_ga)
def delete_node(self, node):
# This function is useful for pyreason gym, called externally
- _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node)
+ _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node, self.num_ga)
- def get_interpretation_dict(self):
+ def get_dict(self):
# This function can be called externally to retrieve a dict of the interpretation values
# Only values in the rule trace will be added
# Initialize interpretations for each time and node and edge
interpretations = {}
- for t in range(self.tmax+1):
+ for t in range(self.time+1):
interpretations[t] = {}
for node in self.nodes:
interpretations[t][node] = InterpretationDict()
@@ -728,9 +662,9 @@ def get_interpretation_dict(self):
time, _, node, l, bnd = change
interpretations[time][node][l._value] = (bnd.lower, bnd.upper)
- # If canonical, update all following timesteps as well
- if self. canonical:
- for t in range(time+1, self.tmax+1):
+ # If persistent, update all following timesteps as well
+ if self. persistent:
+ for t in range(time+1, self.time+1):
interpretations[t][node][l._value] = (bnd.lower, bnd.upper)
# Update interpretation edges
@@ -738,747 +672,694 @@ def get_interpretation_dict(self):
time, _, edge, l, bnd, = change
interpretations[time][edge][l._value] = (bnd.lower, bnd.upper)
- # If canonical, update all following timesteps as well
- if self. canonical:
- for t in range(time+1, self.tmax+1):
+ # If persistent, update all following timesteps as well
+ if self. persistent:
+ for t in range(time+1, self.time+1):
interpretations[t][edge][l._value] = (bnd.lower, bnd.upper)
return interpretations
+ def get_final_num_ground_atoms(self):
+ """
+ This function returns the number of ground atoms after the reasoning process, for the final timestep
+ :return: int: Number of ground atoms in the interpretation after reasoning
+ """
+ ga_cnt = 0
+
+ for node in self.nodes:
+ for l in self.interpretations_node[node].world:
+ ga_cnt += 1
+ for edge in self.edges:
+ for l in self.interpretations_edge[edge].world:
+ ga_cnt += 1
+
+ return ga_cnt
+
+ def get_num_ground_atoms(self):
+ """
+ This function returns the number of ground atoms after the reasoning process, for each timestep
+ :return: list: Number of ground atoms in the interpretation after reasoning for each timestep
+ """
+ if self.num_ga[-1] == 0:
+ self.num_ga.pop()
+ return self.num_ga
+
+ def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]:
+ """
+ This function is used to query the graph after reasoning
+ :param query: A PyReason query object
+ :param return_bool: If True, returns boolean of query, else the bounds associated with it
+ :return: bool, or bounds
+ """
+
+ comp_type = query.get_component_type()
+ component = query.get_component()
+ pred = query.get_predicate()
+ bnd = query.get_bounds()
+
+ # Check if the component exists
+ if comp_type == 'node':
+ if component not in self.nodes:
+ return False if return_bool else (0, 0)
+ else:
+ if component not in self.edges:
+ return False if return_bool else (0, 0)
-@numba.njit(cache=True)
-def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip):
- # Extract rule params
- rule_type = rule.get_type()
- clauses = rule.get_clauses()
- thresholds = rule.get_thresholds()
- ann_fn = rule.get_annotation_function()
- rule_edges = rule.get_edges()
-
- # We return a list of tuples which specify the target nodes/edges that have made the rule body true
- applicable_rules = numba.typed.List.empty_list(node_applicable_rule_type)
-
- # Return empty list if rule is not node rule and if we are not inferring edges
- if rule_type != 'node' and rule_edges[0] == '':
- return applicable_rules
-
- # Steps
- # 1. Loop through all nodes and evaluate each clause with that node and check the truth with the thresholds
- # 2. Inside the clause loop it may be necessary to loop through all nodes/edges while grounding the variables
- # 3. If the clause is true add the qualified nodes and qualified edges to the atom trace, if on. Break otherwise
- # 4. After going through all clauses, add to the annotations list all the annotations of the specified subset. These will be passed to the annotation function
- # 5. Finally, if there are any edges to be added, place them in the list
-
- for piter in prange(len(nodes)):
- target_node = nodes[piter]
- if target_node in nodes_to_skip:
- continue
- # Initialize dictionary where keys are strings (x1, x2 etc.) and values are lists of qualified neighbors
- # Keep track of qualified nodes and qualified edges
- # If it's a node clause update (x1 or x2 etc.) qualified neighbors, if it's an edge clause update the qualified neighbors for the source and target (x1, x2)
- subsets = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
- qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
- qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
- annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
- edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
-
- satisfaction = True
- for i, clause in enumerate(clauses):
- # Unpack clause variables
- clause_type = clause[0]
- clause_label = clause[1]
- clause_variables = clause[2]
- clause_bnd = clause[3]
- clause_operator = clause[4]
-
- # Unpack thresholds
- # This value is total/available
- threshold_quantifier_type = thresholds[i][1][1]
-
- # This is a node clause
- # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes
- if clause_type == 'node':
- clause_var_1 = clause_variables[0]
- subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors)
-
- subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd)
-
- if atom_trace:
- qualified_nodes.append(numba.typed.List(subsets[clause_var_1]))
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in subsets[clause_var_1]:
- a.append(interpretations_node[qn].world[clause_label])
- annotations.append(a)
-
- # This is an edge clause
- elif clause_type == 'edge':
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_source, subset_target = get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes)
-
- # Get qualified edges
- qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, clause_bnd, reverse_graph)
- subsets[clause_var_1] = qe[0]
- subsets[clause_var_2] = qe[1]
+ # Check if the predicate exists
+ if comp_type == 'node':
+ if pred not in self.interpretations_node[component].world:
+ return False if return_bool else (0, 0)
+ else:
+ if pred not in self.interpretations_edge[component].world:
+ return False if return_bool else (0, 0)
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])):
- a.append(interpretations_edge[qe].world[clause_label])
- annotations.append(a)
+ # Check if the bounds are satisfied
+ if comp_type == 'node':
+ if self.interpretations_node[component].world[pred] in bnd:
+ return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper)
else:
- # This is a comparison clause
- # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1]
- # Remember that the predicate in the clause will not contain the "-num" where num is some number.
- # We have to remove that manually while checking
- # Steps:
- # 1. get qualified nodes/edges as well as number associated for first predicate
- # 2. get qualified nodes/edges as well as number associated for second predicate
- # 3. if there's no number in steps 1 or 2 return false clause
- # 4. do comparison with each qualified component from step 1 with each qualified component in step 2
-
- # It's a node comparison
- if len(clause_variables) == 2:
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors)
- subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, neighbors)
-
- # 1, 2
- qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd)
- qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd)
-
- # It's an edge comparison
- elif len(clause_variables) == 4:
- clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3]
- subset_1_source, subset_1_target = get_node_rule_edge_clause_subset(clause_var_1_source, clause_var_1_target, target_node, subsets, neighbors, reverse_neighbors, nodes)
- subset_2_source, subset_2_target = get_node_rule_edge_clause_subset(clause_var_2_source, clause_var_2_target, target_node, subsets, neighbors, reverse_neighbors, nodes)
-
- # 1, 2
- qualified_nodes_for_comparison_1_source, qualified_nodes_for_comparison_1_target, numbers_1 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_1_source, subset_1_target, clause_label, clause_bnd, reverse_graph)
- qualified_nodes_for_comparison_2_source, qualified_nodes_for_comparison_2_target, numbers_2 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_2_source, subset_2_target, clause_label, clause_bnd, reverse_graph)
-
- # Check if thresholds are satisfied
- # If it's a comparison clause we just need to check if the numbers list is not empty (no threshold support)
- if clause_type == 'comparison':
- if len(numbers_1) == 0 or len(numbers_2) == 0:
- satisfaction = False
- # Node comparison. Compare stage
- elif len(clause_variables) == 2:
- satisfaction, qualified_nodes_1, qualified_nodes_2 = compare_numbers_node_predicate(numbers_1, numbers_2, clause_operator, qualified_nodes_for_comparison_1, qualified_nodes_for_comparison_2)
-
- # Update subsets with final qualified nodes
- subsets[clause_var_1] = qualified_nodes_1
- subsets[clause_var_2] = qualified_nodes_2
- qualified_comparison_nodes = numba.typed.List(qualified_nodes_1)
- qualified_comparison_nodes.extend(qualified_nodes_2)
-
- if atom_trace:
- qualified_nodes.append(qualified_comparison_nodes)
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
- # Edge comparison. Compare stage
- else:
- satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator,
- qualified_nodes_for_comparison_1_source,
- qualified_nodes_for_comparison_1_target,
- qualified_nodes_for_comparison_2_source,
- qualified_nodes_for_comparison_2_target)
- # Update subsets with final qualified nodes
- subsets[clause_var_1_source] = qualified_nodes_1_source
- subsets[clause_var_1_target] = qualified_nodes_1_target
- subsets[clause_var_2_source] = qualified_nodes_2_source
- subsets[clause_var_2_target] = qualified_nodes_2_target
-
- qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target))
- qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target))
- qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1)
- qualified_comparison_nodes.extend(qualified_comparison_nodes_2)
-
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(qualified_comparison_nodes)
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
-
- # Non comparison clause
+ return False if return_bool else (0, 0)
+ else:
+ if self.interpretations_edge[component].world[pred] in bnd:
+ return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper)
else:
- if threshold_quantifier_type == 'total':
- if clause_type == 'node':
- neigh_len = len(subset)
- else:
- neigh_len = sum([len(l) for l in subset_target])
-
- # Available is all neighbors that have a particular label with bound inside [0,1]
- elif threshold_quantifier_type == 'available':
- if clause_type == 'node':
- neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0,1)))
- else:
- neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0,1), reverse_graph)[0])
-
- qualified_neigh_len = len(subsets[clause_var_1])
- satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction
-
- # Exit loop if even one clause is not satisfied
- if not satisfaction:
- break
-
- if satisfaction:
- # Collect edges to be added
- source, target, _ = rule_edges
-
- # Edges to be added
- if source != '' and target != '':
- # Check if edge nodes are target
- if source == '__target':
- edges_to_be_added[0].append(target_node)
- elif source in subsets:
- edges_to_be_added[0].extend(subsets[source])
- else:
- edges_to_be_added[0].append(source)
-
- if target == '__target':
- edges_to_be_added[1].append(target_node)
- elif target in subsets:
- edges_to_be_added[1].extend(subsets[target])
- else:
- edges_to_be_added[1].append(target)
-
- # node/edge, annotations, qualified nodes, qualified edges, edges to be added
- applicable_rules.append((target_node, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
-
- return applicable_rules
+ return False if return_bool else (0, 0)
@numba.njit(cache=True)
-def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip):
+def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t):
# Extract rule params
rule_type = rule.get_type()
+ head_variables = rule.get_head_variables()
clauses = rule.get_clauses()
thresholds = rule.get_thresholds()
ann_fn = rule.get_annotation_function()
rule_edges = rule.get_edges()
- # We return a list of tuples which specify the target nodes/edges that have made the rule body true
- applicable_rules = numba.typed.List.empty_list(edge_applicable_rule_type)
-
- # Return empty list if rule is not node rule
- if rule_type != 'edge':
- return applicable_rules
-
- # Steps
- # 1. Loop through all nodes and evaluate each clause with that node and check the truth with the thresholds
- # 2. Inside the clause loop it may be necessary to loop through all nodes/edges while grounding the variables
- # 3. If the clause is true add the qualified nodes and qualified edges to the atom trace, if on. Break otherwise
- # 4. After going through all clauses, add to the annotations list all the annotations of the specified subset. These will be passed to the annotation function
- # 5. Finally, if there are any edges to be added, place them in the list
-
- for piter in prange(len(edges)):
- target_edge = edges[piter]
- if target_edge in edges_to_skip:
- continue
- # Initialize dictionary where keys are strings (x1, x2 etc.) and values are lists of qualified neighbors
- # Keep track of qualified nodes and qualified edges
- # If it's a node clause update (x1 or x2 etc.) qualified neighbors, if it's an edge clause update the qualified neighbors for the source and target (x1, x2)
- subsets = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
- qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
- qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
- annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
- edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
-
- satisfaction = True
- for i, clause in enumerate(clauses):
- # Unpack clause variables
- clause_type = clause[0]
- clause_label = clause[1]
- clause_variables = clause[2]
- clause_bnd = clause[3]
- clause_operator = clause[4]
-
- # Unpack thresholds
- # This value is total/available
- threshold_quantifier_type = thresholds[i][1][1]
-
- # This is a node clause
- # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes
- if clause_type == 'node':
- clause_var_1 = clause_variables[0]
- subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors)
-
- subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd)
- if atom_trace:
- qualified_nodes.append(numba.typed.List(subsets[clause_var_1]))
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in subsets[clause_var_1]:
- a.append(interpretations_node[qn].world[clause_label])
- annotations.append(a)
-
- # This is an edge clause
- elif clause_type == 'edge':
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_source, subset_target = get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes)
-
- # Get qualified edges
- qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, clause_bnd, reverse_graph)
- subsets[clause_var_1] = qe[0]
- subsets[clause_var_2] = qe[1]
+ if rule_type == 'node':
+ head_var_1 = head_variables[0]
+ else:
+ head_var_1, head_var_2 = head_variables[0], head_variables[1]
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])):
- a.append(interpretations_edge[qe].world[clause_label])
- annotations.append(a)
-
+ # We return a list of tuples which specify the target nodes/edges that have made the rule body true
+ applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type)
+ applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type)
+
+ # Grounding procedure
+ # 1. Go through each clause and check which variables have not been initialized in groundings
+ # 2. Check satisfaction of variables based on the predicate in the clause
+
+ # Grounding variable that maps variables in the body to a list of grounded nodes
+ # Grounding edges that maps edge variables to a list of edges
+ groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
+ groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges)
+
+ # Dependency graph that keeps track of the connections between the variables in the body
+ dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+ dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+
+ nodes_set = set(nodes)
+ edges_set = set(edges)
+
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # This is a node clause
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+
+ # Get subset of nodes that can be used to ground the variable
+ # If we allow ground atoms, we can use the nodes directly
+ if allow_ground_rules and clause_var_1 in nodes_set:
+ grounding = numba.typed.List([clause_var_1])
+ else:
+ grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes)
+
+ # Narrow subset based on predicate
+ qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd)
+ groundings[clause_var_1] = qualified_groundings
+ qualified_groundings_set = set(qualified_groundings)
+ for c1, c2 in groundings_edges:
+ if c1 == clause_var_1:
+ groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set])
+ if c2 == clause_var_1:
+ groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set])
+
+ # Check satisfaction of those nodes wrt the threshold
+ # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
+ # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
+ # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
+ satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction
+
+ # This is an edge clause
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+
+ # Get subset of edges that can be used to ground the variables
+ # If we allow ground atoms, we can use the nodes directly
+ if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
+ grounding = numba.typed.List([(clause_var_1, clause_var_2)])
else:
- # This is a comparison clause
- # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1]
- # Remember that the predicate in the clause will not contain the "-num" where num is some number.
- # We have to remove that manually while checking
- # Steps:
- # 1. get qualified nodes/edges as well as number associated for first predicate
- # 2. get qualified nodes/edges as well as number associated for second predicate
- # 3. if there's no number in steps 1 or 2 return false clause
- # 4. do comparison with each qualified component from step 1 with each qualified component in step 2
-
- # It's a node comparison
- if len(clause_variables) == 2:
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors)
- subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, neighbors)
-
- # 1, 2
- qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd)
- qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd)
-
- # It's an edge comparison
- elif len(clause_variables) == 4:
- clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3]
- subset_1_source, subset_1_target = get_edge_rule_edge_clause_subset(clause_var_1_source, clause_var_1_target, target_edge, subsets, neighbors, reverse_neighbors, nodes)
- subset_2_source, subset_2_target = get_edge_rule_edge_clause_subset(clause_var_2_source, clause_var_2_target, target_edge, subsets, neighbors, reverse_neighbors, nodes)
-
- # 1, 2
- qualified_nodes_for_comparison_1_source, qualified_nodes_for_comparison_1_target, numbers_1 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_1_source, subset_1_target, clause_label, clause_bnd, reverse_graph)
- qualified_nodes_for_comparison_2_source, qualified_nodes_for_comparison_2_target, numbers_2 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_2_source, subset_2_target, clause_label, clause_bnd, reverse_graph)
-
- # Check if thresholds are satisfied
- # If it's a comparison clause we just need to check if the numbers list is not empty (no threshold support)
- if clause_type == 'comparison':
- if len(numbers_1) == 0 or len(numbers_2) == 0:
- satisfaction = False
- # Node comparison. Compare stage
- elif len(clause_variables) == 2:
- satisfaction, qualified_nodes_1, qualified_nodes_2 = compare_numbers_node_predicate(numbers_1, numbers_2, clause_operator, qualified_nodes_for_comparison_1, qualified_nodes_for_comparison_2)
-
- # Update subsets with final qualified nodes
- subsets[clause_var_1] = qualified_nodes_1
- subsets[clause_var_2] = qualified_nodes_2
- qualified_comparison_nodes = numba.typed.List(qualified_nodes_1)
- qualified_comparison_nodes.extend(qualified_nodes_2)
+ grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges)
+
+ # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)
+ qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd)
+
+ # Check satisfaction of those edges wrt the threshold
+ # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
+ # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
+ # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
+ satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction
+
+ # Update the groundings
+ groundings[clause_var_1] = numba.typed.List.empty_list(node_type)
+ groundings[clause_var_2] = numba.typed.List.empty_list(node_type)
+ groundings_clause_1_set = set(groundings[clause_var_1])
+ groundings_clause_2_set = set(groundings[clause_var_2])
+ for e in qualified_groundings:
+ if e[0] not in groundings_clause_1_set:
+ groundings[clause_var_1].append(e[0])
+ groundings_clause_1_set.add(e[0])
+ if e[1] not in groundings_clause_2_set:
+ groundings[clause_var_2].append(e[1])
+ groundings_clause_2_set.add(e[1])
+
+ # Update the edge groundings (to use later for grounding other clauses with the same variables)
+ groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings
+
+ # Update dependency graph
+ # Add a connection between clause_var_1 -> clause_var_2 and vice versa
+ if clause_var_1 not in dependency_graph_neighbors:
+ dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2])
+ elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]:
+ dependency_graph_neighbors[clause_var_1].append(clause_var_2)
+ if clause_var_2 not in dependency_graph_reverse_neighbors:
+ dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1])
+ elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]:
+ dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1)
+
+ # This is a comparison clause
+ else:
+ pass
- if atom_trace:
- qualified_nodes.append(qualified_comparison_nodes)
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
- # Edge comparison. Compare stage
- else:
- satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator,
- qualified_nodes_for_comparison_1_source,
- qualified_nodes_for_comparison_1_target,
- qualified_nodes_for_comparison_2_source,
- qualified_nodes_for_comparison_2_target)
- # Update subsets with final qualified nodes
- subsets[clause_var_1_source] = qualified_nodes_1_source
- subsets[clause_var_1_target] = qualified_nodes_1_target
- subsets[clause_var_2_source] = qualified_nodes_2_source
- subsets[clause_var_2_target] = qualified_nodes_2_target
-
- qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target))
- qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target))
- qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1)
- qualified_comparison_nodes.extend(qualified_comparison_nodes_2)
+ # Refine the subsets based on any updates
+ if satisfaction:
+ refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
+
+ # If satisfaction is false, break
+ if not satisfaction:
+ break
+
+ # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules
+ # Then continue to setup any edges to be added and annotations
+ # Fill out the rules to be applied lists
+ if satisfaction:
+ # Create temp grounding containers to verify if the head groundings are valid (only for edge rules)
+ # Setup edges to be added and fill rules to be applied
+ # Setup traces and inputs for annotation function
+ # Loop through the clause data and setup final annotations and trace variables
+ # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule
+ if rule_type == 'node':
+ # Loop through all the head variable groundings and add it to the rules to be applied
+ # Loop through the clauses and add appropriate trace data and annotations
+
+ # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph
+ head_var_1_in_nodes = head_var_1 in nodes
+ add_head_var_node_to_graph = False
+ if allow_ground_rules and head_var_1_in_nodes:
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ elif head_var_1 not in groundings:
+ if not head_var_1_in_nodes:
+ add_head_var_node_to_graph = True
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+
+ for head_grounding in groundings[head_var_1]:
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+
+ # Check for satisfaction one more time in case the refining process has changed the groundings
+ satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges)
+ if not satisfaction:
+ continue
+
+ for i, clause in enumerate(clauses):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(qualified_comparison_nodes)
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
-
- # Non comparison clause
- else:
- if threshold_quantifier_type == 'total':
if clause_type == 'node':
- neigh_len = len(subset)
- else:
- neigh_len = sum([len(l) for l in subset_target])
-
- # Available is all neighbors that have a particular label with bound inside [0,1]
- elif threshold_quantifier_type == 'available':
- if clause_type == 'node':
- neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1)))
+ clause_var_1 = clause_variables[0]
+
+ # 1.
+ if atom_trace:
+ if clause_var_1 == head_var_1:
+ qualified_nodes.append(numba.typed.List([head_grounding]))
+ else:
+ qualified_nodes.append(numba.typed.List(groundings[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ a.append(interpretations_node[head_grounding].world[clause_label])
+ else:
+ for qn in groundings[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ # 1.
+ if atom_trace:
+ # Cases: Both equal, one equal, none equal
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ if clause_var_1 == head_var_1:
+ es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding])
+ qualified_edges.append(es)
+ else:
+ qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)]))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ for e in groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_1:
+ for e in groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ else:
+ for qe in groundings_edges[(clause_var_1, clause_var_2)]:
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
else:
- neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0])
-
- qualified_neigh_len = len(subsets[clause_var_1])
- satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction
-
- # Exit loop if even one clause is not satisfied
- if not satisfaction:
- break
-
- # Here we are done going through each clause of the rule
- # If all clauses we're satisfied, proceed to collect annotations and prepare edges to be added
- if satisfaction:
- # Collect edges to be added
+ # Comparison clause (we do not handle for now)
+ pass
+
+ # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
+ if add_head_var_node_to_graph:
+ _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
+
+ # For each grounding add a rule to be applied
+ applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
+
+ elif rule_type == 'edge':
+ head_var_1 = head_variables[0]
+ head_var_2 = head_variables[1]
+
+ # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph
+ head_var_1_in_nodes = head_var_1 in nodes
+ head_var_2_in_nodes = head_var_2 in nodes
+ add_head_var_1_node_to_graph = False
+ add_head_var_2_node_to_graph = False
+ add_head_edge_to_graph = False
+ if allow_ground_rules and head_var_1_in_nodes:
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ if allow_ground_rules and head_var_2_in_nodes:
+ groundings[head_var_2] = numba.typed.List([head_var_2])
+
+ if head_var_1 not in groundings:
+ if not head_var_1_in_nodes:
+ add_head_var_1_node_to_graph = True
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ if head_var_2 not in groundings:
+ if not head_var_2_in_nodes:
+ add_head_var_2_node_to_graph = True
+ groundings[head_var_2] = numba.typed.List([head_var_2])
+
+ # Artificially connect the head variables with an edge if both of them were not in the graph
+ if not head_var_1_in_nodes and not head_var_2_in_nodes:
+ add_head_edge_to_graph = True
+
+ head_var_1_groundings = groundings[head_var_1]
+ head_var_2_groundings = groundings[head_var_2]
+
source, target, _ = rule_edges
+ infer_edges = True if source != '' and target != '' else False
+
+ # Prepare the edges that we will loop over.
+ # For infer edges we loop over each combination pair
+ # Else we loop over the valid edges in the graph
+ valid_edge_groundings = numba.typed.List.empty_list(edge_type)
+ for g1 in head_var_1_groundings:
+ for g2 in head_var_2_groundings:
+ if infer_edges:
+ valid_edge_groundings.append((g1, g2))
+ else:
+ if (g1, g2) in edges_set:
+ valid_edge_groundings.append((g1, g2))
+
+ # Loop through the head variable groundings
+ for valid_e in valid_edge_groundings:
+ head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+
+ # Containers to keep track of groundings to make sure that the edge pair is valid
+ # We do this because we cannot know beforehand the edge matches from source groundings to target groundings
+ temp_groundings = groundings.copy()
+ temp_groundings_edges = groundings_edges.copy()
+
+ # Refine the temp groundings for the specific edge head grounding
+ # We update the edge collection as well depending on if there's a match between the clause variables and head variables
+ temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding])
+ temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding])
+ for c1, c2 in temp_groundings_edges.keys():
+ if c1 == head_var_1 and c2 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)])
+ elif c1 == head_var_2 and c2 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)])
+ elif c1 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding])
+ elif c2 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding])
+ elif c1 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding])
+ elif c2 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding])
+
+ refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
+
+ # Check if the thresholds are still satisfied
+ # Check if all clauses are satisfied again in case the refining process changed anything
+ satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges)
+
+ if not satisfaction:
+ continue
+
+ if infer_edges:
+ # Prevent self loops while inferring edges if the clause variables are not the same
+ if source != target and head_var_1_grounding == head_var_2_grounding:
+ continue
+ edges_to_be_added[0].append(head_var_1_grounding)
+ edges_to_be_added[1].append(head_var_2_grounding)
- # Edges to be added
- if source != '' and target != '':
- # Check if edge nodes are source/target
- if source == '__source':
- edges_to_be_added[0].append(target_edge[0])
- elif source == '__target':
- edges_to_be_added[0].append(target_edge[1])
- elif source in subsets:
- edges_to_be_added[0].extend(subsets[source])
- else:
- edges_to_be_added[0].append(source)
-
- if target == '__source':
- edges_to_be_added[1].append(target_edge[0])
- elif target == '__target':
- edges_to_be_added[1].append(target_edge[1])
- elif target in subsets:
- edges_to_be_added[1].extend(subsets[target])
- else:
- edges_to_be_added[1].append(target)
+ for i, clause in enumerate(clauses):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ # 1.
+ if atom_trace:
+ if clause_var_1 == head_var_1:
+ qualified_nodes.append(numba.typed.List([head_var_1_grounding]))
+ elif clause_var_1 == head_var_2:
+ qualified_nodes.append(numba.typed.List([head_var_2_grounding]))
+ else:
+ qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ a.append(interpretations_node[head_var_1_grounding].world[clause_label])
+ elif clause_var_1 == head_var_2:
+ a.append(interpretations_node[head_var_2_grounding].world[clause_label])
+ else:
+ for qn in temp_groundings[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ # 1.
+ if atom_trace:
+ # Cases:
+ # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1)
+ # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2)
+ # 3. None equal
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding])
+ qualified_edges.append(es)
+ else:
+ qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)]))
+
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ else:
+ for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+
+ # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
+ if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1:
+ _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
+ if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2:
+ _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node)
+ if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding):
+ _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
- # node/edge, annotations, qualified nodes, qualified edges, edges to be added
- applicable_rules.append((target_edge, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
+ # For each grounding combination add a rule to be applied
+ # Only if all the clauses have valid groundings
+ # if satisfaction:
+ e = (head_var_1_grounding, head_var_2_grounding)
+ applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
- return applicable_rules
+ # Return the applicable rules
+ return applicable_rules_node, applicable_rules_edge
@numba.njit(cache=True)
-def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors):
- # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes
- if clause_var_1 == '__target':
- subset = numba.typed.List([target_node])
- else:
- subset = neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1]
- return subset
+def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges):
+ # Check if the thresholds are satisfied for each clause
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction
+ return satisfaction
@numba.njit(cache=True)
-def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes):
- # There are 5 cases for predicate(Y,Z):
- # 1. Either one or both of Y, Z are the target node
- # 2. Both predicate variables Y and Z have not been encountered before
- # 3. The source variable Y has not been encountered before but the target variable Z has
- # 4. The target variable Z has not been encountered before but the source variable Y has
- # 5. Both predicate variables Y and Z have been encountered before
+def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors):
+ # Loop through the dependency graph and refine the groundings that have connections
+ all_variables_refined = numba.typed.List(clause_variables)
+ variables_just_refined = numba.typed.List(clause_variables)
+ new_variables_refined = numba.typed.List.empty_list(numba.types.string)
+ while len(variables_just_refined) > 0:
+ for refined_variable in variables_just_refined:
+ # Refine all the neighbors of the refined variable
+ if refined_variable in dependency_graph_neighbors:
+ for neighbor in dependency_graph_neighbors[refined_variable]:
+ old_edge_groundings = groundings_edges[(refined_variable, neighbor)]
+ new_node_groundings = groundings[refined_variable]
+
+ # Delete old groundings for the variable being refined
+ del groundings[neighbor]
+ groundings[neighbor] = numba.typed.List.empty_list(node_type)
+
+ # Update the edge groundings and node groundings
+ qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings])
+ groundings_neighbor_set = set(groundings[neighbor])
+ for e in qualified_groundings:
+ if e[1] not in groundings_neighbor_set:
+ groundings[neighbor].append(e[1])
+ groundings_neighbor_set.add(e[1])
+ groundings_edges[(refined_variable, neighbor)] = qualified_groundings
+
+ # Add the neighbor to the list of refined variables so that we can refine for all its neighbors
+ if neighbor not in all_variables_refined:
+ new_variables_refined.append(neighbor)
+
+ if refined_variable in dependency_graph_reverse_neighbors:
+ for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]:
+ old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)]
+ new_node_groundings = groundings[refined_variable]
+
+ # Delete old groundings for the variable being refined
+ del groundings[reverse_neighbor]
+ groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type)
+
+ # Update the edge groundings and node groundings
+ qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings])
+ groundings_reverse_neighbor_set = set(groundings[reverse_neighbor])
+ for e in qualified_groundings:
+ if e[0] not in groundings_reverse_neighbor_set:
+ groundings[reverse_neighbor].append(e[0])
+ groundings_reverse_neighbor_set.add(e[0])
+ groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings
+
+ # Add the neighbor to the list of refined variables so that we can refine for all its neighbors
+ if reverse_neighbor not in all_variables_refined:
+ new_variables_refined.append(reverse_neighbor)
+
+ variables_just_refined = numba.typed.List(new_variables_refined)
+ all_variables_refined.extend(new_variables_refined)
+ new_variables_refined.clear()
- # Case 1:
- # Check if 1st variable or 1st and 2nd variables are the target
- if clause_var_1 == '__target':
- subset_source = numba.typed.List([target_node])
-
- # If both variables are the same
- if clause_var_2 == '__target':
- subset_target = numba.typed.List([numba.typed.List([target_node])])
- elif clause_var_2 in subsets:
- subset_target = numba.typed.List([subsets[clause_var_2]])
- else:
- subset_target = numba.typed.List([neighbors[target_node]])
- # Check if 2nd variable is the target (this means 1st variable isn't the target)
- elif clause_var_2 == '__target':
- subset_source = reverse_neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1]
- subset_target = numba.typed.List([numba.typed.List([target_node]) for _ in subset_source])
+@numba.njit(cache=True)
+def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold):
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(grounding)
- # Case 2:
- # We replace Y by all nodes and Z by the neighbors of each of these nodes
- elif clause_var_1 not in subsets and clause_var_2 not in subsets:
- subset_source = numba.typed.List(nodes)
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1)))
- # Case 3:
- # We replace Y by the sources of Z
- elif clause_var_1 not in subsets and clause_var_2 in subsets:
- subset_source = numba.typed.List.empty_list(node_type)
- subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_neigh_len = len(qualified_grounding)
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
- for n in subsets[clause_var_2]:
- sources = reverse_neighbors[n]
- for source in sources:
- subset_source.append(source)
- subset_target.append(numba.typed.List([n]))
- # Case 4:
- # We replace Z by the neighbors of Y
- elif clause_var_1 in subsets and clause_var_2 not in subsets:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
+@numba.njit(cache=True)
+def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold):
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(grounding)
- # Case 5:
- else:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source])
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1)))
- return subset_source, subset_target
+ qualified_neigh_len = len(qualified_grounding)
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
@numba.njit(cache=True)
-def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors):
- # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes
- if clause_var_1 == '__source':
- subset = numba.typed.List([target_edge[0]])
- elif clause_var_1 == '__target':
- subset = numba.typed.List([target_edge[1]])
+def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes):
+ # The groundings for a node clause can be either a previous grounding or all possible nodes
+ if l in predicate_map:
+ grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1]
else:
- subset = neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1]
- return subset
+ grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1]
+ return grounding
@numba.njit(cache=True)
-def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes):
- # There are 5 cases for predicate(Y,Z):
- # 1. Either one or both of Y, Z are the source or target node
- # 2. Both predicate variables Y and Z have not been encountered before
- # 3. The source variable Y has not been encountered before but the target variable Z has
- # 4. The target variable Z has not been encountered before but the source variable Y has
- # 5. Both predicate variables Y and Z have been encountered before
+def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges):
+ # There are 4 cases for predicate(Y,Z):
+ # 1. Both predicate variables Y and Z have not been encountered before
+ # 2. The source variable Y has not been encountered before but the target variable Z has
+ # 3. The target variable Z has not been encountered before but the source variable Y has
+ # 4. Both predicate variables Y and Z have been encountered before
+ edge_groundings = numba.typed.List.empty_list(edge_type)
+
# Case 1:
- # Check if 1st variable is the source
- if clause_var_1 == '__source':
- subset_source = numba.typed.List([target_edge[0]])
-
- # If 2nd variable is source/target/something else
- if clause_var_2 == '__source':
- subset_target = numba.typed.List([numba.typed.List([target_edge[0]])])
- elif clause_var_2 == '__target':
- subset_target = numba.typed.List([numba.typed.List([target_edge[1]])])
- elif clause_var_2 in subsets:
- subset_target = numba.typed.List([subsets[clause_var_2]])
- else:
- subset_target = numba.typed.List([neighbors[target_edge[0]]])
-
- # if 1st variable is the target
- elif clause_var_1 == '__target':
- subset_source = numba.typed.List([target_edge[1]])
-
- # if 2nd variable is source/target/something else
- if clause_var_2 == '__source':
- subset_target = numba.typed.List([numba.typed.List([target_edge[0]])])
- elif clause_var_2 == '__target':
- subset_target = numba.typed.List([numba.typed.List([target_edge[1]])])
- elif clause_var_2 in subsets:
- subset_target = numba.typed.List([subsets[clause_var_2]])
+ # We replace Y by all nodes and Z by the neighbors of each of these nodes
+ if clause_var_1 not in groundings and clause_var_2 not in groundings:
+ if l in predicate_map:
+ edge_groundings = predicate_map[l]
else:
- subset_target = numba.typed.List([neighbors[target_edge[1]]])
-
- # Handle the cases where the 2nd variable is source/target but the 1st is something else (cannot be source/target)
- elif clause_var_2 == '__source':
- subset_source = reverse_neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1]
- subset_target = numba.typed.List([numba.typed.List([target_edge[0]]) for _ in subset_source])
-
- elif clause_var_2 == '__target':
- subset_source = reverse_neighbors[target_edge[1]] if clause_var_1 not in subsets else subsets[clause_var_1]
- subset_target = numba.typed.List([numba.typed.List([target_edge[1]]) for _ in subset_source])
+ edge_groundings = edges
# Case 2:
- # We replace Y by all nodes and Z by the neighbors of each of these nodes
- elif clause_var_1 not in subsets and clause_var_2 not in subsets:
- subset_source = numba.typed.List(nodes)
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
-
- # Case 3:
# We replace Y by the sources of Z
- elif clause_var_1 not in subsets and clause_var_2 in subsets:
- subset_source = numba.typed.List.empty_list(node_type)
- subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ elif clause_var_1 not in groundings and clause_var_2 in groundings:
+ for n in groundings[clause_var_2]:
+ es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]])
+ edge_groundings.extend(es)
- for n in subsets[clause_var_2]:
- sources = reverse_neighbors[n]
- for source in sources:
- subset_source.append(source)
- subset_target.append(numba.typed.List([n]))
-
- # Case 4:
+ # Case 3:
# We replace Z by the neighbors of Y
- elif clause_var_1 in subsets and clause_var_2 not in subsets:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
+ elif clause_var_1 in groundings and clause_var_2 not in groundings:
+ for n in groundings[clause_var_1]:
+ es = numba.typed.List([(n, nn) for nn in neighbors[n]])
+ edge_groundings.extend(es)
- # Case 5:
+ # Case 4:
+ # We have seen both variables before
else:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source])
-
- return subset_source, subset_target
-
-
-@numba.njit(cache=True)
-def get_qualified_components_node_clause(interpretations_node, candidates, l, bnd):
- # Get all the qualified neighbors for a particular clause
- qualified_nodes = numba.typed.List.empty_list(node_type)
- for n in candidates:
- if is_satisfied_node(interpretations_node, n, (l, bnd)):
- qualified_nodes.append(n)
-
- return qualified_nodes
-
-
-@numba.njit(cache=True)
-def get_qualified_components_node_comparison_clause(interpretations_node, candidates, l, bnd):
- # Get all the qualified neighbors for a particular comparison clause and return them along with the number associated
- qualified_nodes = numba.typed.List.empty_list(node_type)
- qualified_nodes_numbers = numba.typed.List.empty_list(numba.types.float64)
- for n in candidates:
- result, number = is_satisfied_node_comparison(interpretations_node, n, (l, bnd))
- if result:
- qualified_nodes.append(n)
- qualified_nodes_numbers.append(number)
+ # We have already seen these two variables in an edge clause
+ if (clause_var_1, clause_var_2) in groundings_edges:
+ edge_groundings = groundings_edges[(clause_var_1, clause_var_2)]
+ # We have seen both these variables but not in an edge clause together
+ else:
+ groundings_clause_var_2_set = set(groundings[clause_var_2])
+ for n in groundings[clause_var_1]:
+ es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set])
+ edge_groundings.extend(es)
- return qualified_nodes, qualified_nodes_numbers
+ return edge_groundings
@numba.njit(cache=True)
-def get_qualified_components_edge_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph):
- # Get all the qualified sources and targets for a particular clause
- qualified_nodes_source = numba.typed.List.empty_list(node_type)
- qualified_nodes_target = numba.typed.List.empty_list(node_type)
- for i, source in enumerate(candidates_source):
- for target in candidates_target[i]:
- edge = (source, target) if not reverse_graph else (target, source)
- if is_satisfied_edge(interpretations_edge, edge, (l, bnd)):
- qualified_nodes_source.append(source)
- qualified_nodes_target.append(target)
+def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd):
+ # Filter the grounding by the predicate and bound of the clause
+ qualified_groundings = numba.typed.List.empty_list(node_type)
+ for n in grounding:
+ if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)):
+ qualified_groundings.append(n)
- return qualified_nodes_source, qualified_nodes_target
+ return qualified_groundings
@numba.njit(cache=True)
-def get_qualified_components_edge_comparison_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph):
- # Get all the qualified sources and targets for a particular clause
- qualified_nodes_source = numba.typed.List.empty_list(node_type)
- qualified_nodes_target = numba.typed.List.empty_list(node_type)
- qualified_edges_numbers = numba.typed.List.empty_list(numba.types.float64)
- for i, source in enumerate(candidates_source):
- for target in candidates_target[i]:
- edge = (source, target) if not reverse_graph else (target, source)
- result, number = is_satisfied_edge_comparison(interpretations_edge, edge, (l, bnd))
- if result:
- qualified_nodes_source.append(source)
- qualified_nodes_target.append(target)
- qualified_edges_numbers.append(number)
-
- return qualified_nodes_source, qualified_nodes_target, qualified_edges_numbers
+def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd):
+ # Filter the grounding by the predicate and bound of the clause
+ qualified_groundings = numba.typed.List.empty_list(edge_type)
+ for e in grounding:
+ if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)):
+ qualified_groundings.append(e)
-
-@numba.njit(cache=True)
-def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, qualified_nodes_2):
- result = False
- final_qualified_nodes_1 = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_2 = numba.typed.List.empty_list(node_type)
- for i in range(len(numbers_1)):
- for j in range(len(numbers_2)):
- if op == '<':
- if numbers_1[i] < numbers_2[j]:
- result = True
- elif op == '<=':
- if numbers_1[i] <= numbers_2[j]:
- result = True
- elif op == '>':
- if numbers_1[i] > numbers_2[j]:
- result = True
- elif op == '>=':
- if numbers_1[i] >= numbers_2[j]:
- result = True
- elif op == '==':
- if numbers_1[i] == numbers_2[j]:
- result = True
- elif op == '!=':
- if numbers_1[i] != numbers_2[j]:
- result = True
-
- if result:
- final_qualified_nodes_1.append(qualified_nodes_1[i])
- final_qualified_nodes_2.append(qualified_nodes_2[j])
- return result, final_qualified_nodes_1, final_qualified_nodes_2
-
-
-@numba.njit(cache=True)
-def compare_numbers_edge_predicate(numbers_1, numbers_2, op, qualified_nodes_1a, qualified_nodes_1b, qualified_nodes_2a, qualified_nodes_2b):
- result = False
- final_qualified_nodes_1a = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_1b = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_2a = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_2b = numba.typed.List.empty_list(node_type)
- for i in range(len(numbers_1)):
- for j in range(len(numbers_2)):
- if op == '<':
- if numbers_1[i] < numbers_2[j]:
- result = True
- elif op == '<=':
- if numbers_1[i] <= numbers_2[j]:
- result = True
- elif op == '>':
- if numbers_1[i] > numbers_2[j]:
- result = True
- elif op == '>=':
- if numbers_1[i] >= numbers_2[j]:
- result = True
- elif op == '==':
- if numbers_1[i] == numbers_2[j]:
- result = True
- elif op == '!=':
- if numbers_1[i] != numbers_2[j]:
- result = True
-
- if result:
- final_qualified_nodes_1a.append(qualified_nodes_1a[i])
- final_qualified_nodes_1b.append(qualified_nodes_1b[i])
- final_qualified_nodes_2a.append(qualified_nodes_2a[j])
- final_qualified_nodes_2b.append(qualified_nodes_2b[j])
- return result, final_qualified_nodes_1a, final_qualified_nodes_1b, final_qualified_nodes_2a, final_qualified_nodes_2b
+ return qualified_groundings
@numba.njit(cache=True)
@@ -1514,7 +1395,7 @@ def _satisfies_threshold(num_neigh, num_qualified_component, threshold):
@numba.njit(cache=True)
-def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False):
+def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False):
updated = False
# This is to prevent a key error in case the label is a specific label
try:
@@ -1525,6 +1406,11 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
# Add label to world if it is not there
if l not in world.world:
world.world[l] = interval.closed(0, 1)
+ num_ga[t_cnt] += 1
+ if l in predicate_map:
+ predicate_map[l].append(comp)
+ else:
+ predicate_map[l] = numba.typed.List([comp])
# Check if update is necessary with previous bnd
prev_bnd = world.world[l].copy()
@@ -1557,7 +1443,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
if updated:
ip_update_cnt = 0
for p1, p2 in ipl:
- if p1==l:
+ if p1 == l:
+ if p2 not in world.world:
+ world.world[p2] = interval.closed(0, 1)
+ if p2 in predicate_map:
+ predicate_map[p2].append(comp)
+ else:
+ predicate_map[p2] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
@@ -1568,7 +1460,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper)))
- if p2==l:
+ if p2 == l:
+ if p1 not in world.world:
+ world.world[p1] = interval.closed(0, 1)
+ if p1 in predicate_map:
+ predicate_map[p1].append(comp)
+ else:
+ predicate_map[p1] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
@@ -1603,7 +1501,7 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
@numba.njit(cache=True)
-def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False):
+def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False):
updated = False
# This is to prevent a key error in case the label is a specific label
try:
@@ -1614,6 +1512,11 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
# Add label to world if it is not there
if l not in world.world:
world.world[l] = interval.closed(0, 1)
+ num_ga[t_cnt] += 1
+ if l in predicate_map:
+ predicate_map[l].append(comp)
+ else:
+ predicate_map[l] = numba.typed.List([comp])
# Check if update is necessary with previous bnd
prev_bnd = world.world[l].copy()
@@ -1646,7 +1549,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
if updated:
ip_update_cnt = 0
for p1, p2 in ipl:
- if p1==l:
+ if p1 == l:
+ if p2 not in world.world:
+ world.world[p2] = interval.closed(0, 1)
+ if p2 in predicate_map:
+ predicate_map[p2].append(comp)
+ else:
+ predicate_map[p2] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
@@ -1657,7 +1566,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper)))
- if p2==l:
+ if p2 == l:
+ if p1 not in world.world:
+ world.world[p1] = interval.closed(0, 1)
+ if p1 in predicate_map:
+ predicate_map[p1].append(comp)
+ else:
+ predicate_map[p1] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
@@ -1668,7 +1583,7 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper)))
-
+
# Gather convergence data
change = 0
if updated:
@@ -1684,7 +1599,7 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
change = max(change, max_delta)
else:
change = 1 + ip_update_cnt
-
+
return (updated, change)
except:
return (False, 0)
@@ -1693,20 +1608,20 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
@numba.njit(cache=True)
def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name):
rule_trace.append((qn, qe, prev_bnd.copy(), name))
-
+
@numba.njit(cache=True)
def are_satisfied_node(interpretations, comp, nas):
result = True
- for (label, interval) in nas:
- result = result and is_satisfied_node(interpretations, comp, (label, interval))
+ for (l, bnd) in nas:
+ result = result and is_satisfied_node(interpretations, comp, (l, bnd))
return result
@numba.njit(cache=True)
def is_satisfied_node(interpretations, comp, na):
result = False
- if (not (na[0] is None or na[1] is None)):
+ if not (na[0] is None or na[1] is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
@@ -1748,15 +1663,15 @@ def is_satisfied_node_comparison(interpretations, comp, na):
@numba.njit(cache=True)
def are_satisfied_edge(interpretations, comp, nas):
result = True
- for (label, interval) in nas:
- result = result and is_satisfied_edge(interpretations, comp, (label, interval))
+ for (l, bnd) in nas:
+ result = result and is_satisfied_edge(interpretations, comp, (l, bnd))
return result
@numba.njit(cache=True)
def is_satisfied_edge(interpretations, comp, na):
result = False
- if (not (na[0] is None or na[1] is None)):
+ if not (na[0] is None or na[1] is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
@@ -1835,19 +1750,25 @@ def check_consistent_edge(interpretations, comp, na):
@numba.njit(cache=True)
-def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes):
+def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
world = interpretations[comp]
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1)))
+ if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace:
+ name = facts_to_be_applied_trace[idx]
+ elif mode == 'rule' and atom_trace:
+ name = rules_to_be_applied_trace[idx][2]
+ else:
+ name = '-'
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}')
# Resolve inconsistency and set static
world.world[na[0]].set_lower_upper(0, 1)
world.world[na[0]].set_static(True)
for p1, p2 in ipl:
if p1==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}')
world.world[p2].set_lower_upper(0, 1)
world.world[p2].set_static(True)
if store_interpretation_changes:
@@ -1855,28 +1776,34 @@ def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, at
if p2==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}')
world.world[p1].set_lower_upper(0, 1)
world.world[p1].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1)))
- # Add inconsistent predicates to a list
+ # Add inconsistent predicates to a list
@numba.njit(cache=True)
-def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes):
+def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
w = interpretations[comp]
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1)))
+ if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace:
+ name = facts_to_be_applied_trace[idx]
+ elif mode == 'rule' and atom_trace:
+ name = rules_to_be_applied_trace[idx][2]
+ else:
+ name = '-'
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}')
# Resolve inconsistency and set static
w.world[na[0]].set_lower_upper(0, 1)
w.world[na[0]].set_static(True)
for p1, p2 in ipl:
if p1==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}')
w.world[p2].set_lower_upper(0, 1)
w.world[p2].set_static(True)
if store_interpretation_changes:
@@ -1884,7 +1811,7 @@ def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, at
if p2==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}')
w.world[p1].set_lower_upper(0, 1)
w.world[p1].set_static(True)
if store_interpretation_changes:
@@ -1900,7 +1827,7 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
@numba.njit(cache=True)
-def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge):
+def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t):
# If not a node, add to list of nodes and initialize neighbors
if source not in nodes:
_add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node)
@@ -1920,43 +1847,57 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int
reverse_neighbors[target].append(source)
if l.value!='':
interpretations_edge[edge] = world.World(numba.typed.List([l]))
+ num_ga[t] += 1
+ if l in predicate_map:
+ predicate_map[l].append(edge)
+ else:
+ predicate_map[l] = numba.typed.List([edge])
else:
interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type))
else:
if l not in interpretations_edge[edge].world and l.value!='':
new_edge = True
interpretations_edge[edge].world[l] = interval.closed(0, 1)
+ num_ga[t] += 1
return edge, new_edge
@numba.njit(cache=True)
-def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge):
+def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t):
changes = 0
edges_added = numba.typed.List.empty_list(edge_type)
for source in sources:
for target in targets:
- edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge)
+ edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t)
edges_added.append(edge)
changes = changes+1 if new_edge else changes
return edges_added, changes
@numba.njit(cache=True)
-def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge):
+def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map, num_ga):
source, target = edge
edges.remove(edge)
+ num_ga[-1] -= len(interpretations_edge[edge].world)
del interpretations_edge[edge]
+ for l in predicate_map:
+ if edge in predicate_map[l]:
+ predicate_map[l].remove(edge)
neighbors[source].remove(target)
reverse_neighbors[target].remove(source)
@numba.njit(cache=True)
-def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
+def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map, num_ga):
nodes.remove(node)
+ num_ga[-1] -= len(interpretations_node[node].world)
del interpretations_node[node]
del neighbors[node]
del reverse_neighbors[node]
+ for l in predicate_map:
+ if node in predicate_map[l]:
+ predicate_map[l].remove(node)
# Remove all occurrences of node in neighbors
for n in neighbors.keys():
diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py
index 230b6415..77ac6060 100644
--- a/pyreason/scripts/interpretation/interpretation_parallel.py
+++ b/pyreason/scripts/interpretation/interpretation_parallel.py
@@ -1,3 +1,5 @@
+from typing import Union, Tuple
+
import pyreason.scripts.numba_wrapper.numba_types.world_type as world
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
@@ -15,6 +17,12 @@
list_of_nodes = numba.types.ListType(node_type)
list_of_edges = numba.types.ListType(edge_type)
+# Type for storing clause data
+clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string)))
+
+# Type for storing refine clause data
+refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8))
+
# Type for facts to be applied
facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
@@ -37,36 +45,44 @@
numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
))
+rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean))
+rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean))
+rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))
+edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
+
class Interpretation:
- available_labels_node = []
- available_labels_edge = []
specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type))
specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type))
- def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode):
+ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules):
self.graph = graph
self.ipl = ipl
self.annotation_functions = annotation_functions
self.reverse_graph = reverse_graph
self.atom_trace = atom_trace
self.save_graph_attributes_to_rule_trace = save_graph_attributes_to_rule_trace
- self.canonical = canonical
+ self.persistent = persistent
self.inconsistency_check = inconsistency_check
self.store_interpretation_changes = store_interpretation_changes
self.update_mode = update_mode
+ self.allow_ground_rules = allow_ground_rules
+
+ # Counter for number of ground atoms for each timestep, start with zero for the zeroth timestep
+ self.num_ga = numba.typed.List.empty_list(numba.types.int64)
+ self.num_ga.append(0)
# For reasoning and reasoning again (contains previous time and previous fp operation cnt)
self.time = 0
self.prev_reasoning_data = numba.typed.List([0, 0])
# Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true
- self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)))
- self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)))
+ self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
+ self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string)
self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string)
- self.rules_to_be_applied_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)))
- self.rules_to_be_applied_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)))
+ self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type)
+ self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type)
self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type)
self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)))
@@ -84,18 +100,8 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace,
self.nodes.extend(numba.typed.List(self.graph.nodes()))
self.edges.extend(numba.typed.List(self.graph.edges()))
- # Make sure they are correct type
- if len(self.available_labels_node)==0:
- self.available_labels_node = numba.typed.List.empty_list(label.label_type)
- else:
- self.available_labels_node = numba.typed.List(self.available_labels_node)
- if len(self.available_labels_edge)==0:
- self.available_labels_edge = numba.typed.List.empty_list(label.label_type)
- else:
- self.available_labels_edge = numba.typed.List(self.available_labels_edge)
-
- self.interpretations_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels)
- self.interpretations_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels)
+ self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.specific_node_labels, self.num_ga)
+ self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.specific_edge_labels, self.num_ga)
# Setup graph neighbors and reverse neighbors
self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type))
@@ -107,7 +113,7 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace,
self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors)
@staticmethod
- @numba.njit(cache=False)
+ @numba.njit(cache=True)
def _init_reverse_neighbors(neighbors):
reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
for n, neighbor_nodes in neighbors.items():
@@ -123,35 +129,49 @@ def _init_reverse_neighbors(neighbors):
return reverse_neighbors
@staticmethod
- @numba.njit(cache=False)
- def _init_interpretations_node(nodes, available_labels, specific_labels):
+ @numba.njit(cache=True)
+ def _init_interpretations_node(nodes, specific_labels, num_ga):
interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type)
- # General labels
+ predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes)
+
+ # Initialize nodes
for n in nodes:
- interpretations[n] = world.World(available_labels)
+ interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type))
+
# Specific labels
for l, ns in specific_labels.items():
for n in ns:
interpretations[n].world[l] = interval.closed(0.0, 1.0)
+ num_ga[0] += 1
+
+ for l, ns in specific_labels.items():
+ predicate_map[l] = numba.typed.List(ns)
+
+ return interpretations, predicate_map
- return interpretations
-
@staticmethod
- @numba.njit(cache=False)
- def _init_interpretations_edge(edges, available_labels, specific_labels):
+ @numba.njit(cache=True)
+ def _init_interpretations_edge(edges, specific_labels, num_ga):
interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type)
- # General labels
- for e in edges:
- interpretations[e] = world.World(available_labels)
+ predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges)
+
+ # Initialize edges
+ for n in edges:
+ interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type))
+
# Specific labels
for l, es in specific_labels.items():
for e in es:
interpretations[e].world[l] = interval.closed(0.0, 1.0)
+ num_ga[0] += 1
+
+ for l, es in specific_labels.items():
+ predicate_map[l] = numba.typed.List(es)
+
+ return interpretations, predicate_map
- return interpretations
-
@staticmethod
- @numba.njit(cache=False)
+ @numba.njit(cache=True)
def _init_convergence(convergence_bound_threshold, convergence_threshold):
if convergence_bound_threshold==-1 and convergence_threshold==-1:
convergence_mode = 'perfect_convergence'
@@ -171,7 +191,7 @@ def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_thr
self._start_fp(rules, max_facts_time, verbose, again)
@staticmethod
- @numba.njit(cache=False)
+ @numba.njit(cache=True)
def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, atom_trace):
max_time = 0
for fact in facts_node:
@@ -193,7 +213,7 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
return max_time
def _start_fp(self, rules, max_facts_time, verbose, again):
- fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again)
+ fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again)
self.time = t - 1
# If we need to reason again, store the next timestep to start from
self.prev_reasoning_data[0] = t
@@ -202,23 +222,24 @@ def _start_fp(self, rules, max_facts_time, verbose, again):
print('Fixed Point iterations:', fp_cnt)
@staticmethod
- @numba.njit(cache=False)
- def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again):
+ @numba.njit(cache=True, parallel=True)
+ def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, num_ga, verbose, again):
t = prev_reasoning_data[0]
fp_cnt = prev_reasoning_data[1]
max_rules_time = 0
timestep_loop = True
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
- rules_to_remove_idx = numba.typed.List.empty_list(numba.types.int64)
+ rules_to_remove_idx = set()
+ rules_to_remove_idx.add(-1)
while timestep_loop:
if t==tmax:
timestep_loop = False
if verbose:
with objmode():
print('Timestep:', t, flush=True)
- # Reset Interpretation at beginning of timestep if non-canonical
- if t>0 and not canonical:
+ # Reset Interpretation at beginning of timestep if non-persistent
+ if t>0 and not persistent:
# Reset nodes (only if not static)
for n in nodes:
w = interpretations_node[n].world
@@ -238,24 +259,18 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
bound_delta = 0
update = False
- # Parameters for immediate rules
- immediate_node_rule_fire = False
- immediate_edge_rule_fire = False
- immediate_rule_applied = False
- # When delta_t = 0, we don't want to check the same rule with the same node/edge after coming back to the fp operator
- nodes_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_nodes)
- edges_to_skip = numba.typed.Dict.empty(key_type=numba.types.int64, value_type=list_of_edges)
- # Initialize the above
- for i in range(len(rules)):
- nodes_to_skip[i] = numba.typed.List.empty_list(node_type)
- edges_to_skip[i] = numba.typed.List.empty_list(edge_type)
-
# Start by applying facts
# Nodes
facts_to_be_applied_node_new.clear()
+ nodes_set = set(nodes)
for i in range(len(facts_to_be_applied_node)):
- if facts_to_be_applied_node[i][0]==t:
+ if facts_to_be_applied_node[i][0] == t:
comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5]
+ # If the component is not in the graph, add it
+ if comp not in nodes_set:
+ _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node)
+ nodes_set.add(comp)
+
# Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well
if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static():
# Check if we should even store any of the changes to the rule trace etc.
@@ -273,13 +288,13 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1]))
if atom_trace:
_update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i])
-
+
else:
# Check for inconsistencies (multiple facts)
if check_consistent_node(interpretations_node, comp, (l, bnd)):
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
override = True if update_mode == 'override' else False
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override)
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override)
update = u or update
# Update convergence params
@@ -289,11 +304,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
changes_cnt += changes
# Resolve inconsistency if necessary otherwise override bounds
else:
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
if inconsistency_check:
- resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes)
+ resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode)
else:
- mode = 'graph-attribute-fact' if graph_attribute else 'fact'
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True)
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True)
update = u or update
# Update convergence params
@@ -315,9 +330,15 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Edges
facts_to_be_applied_edge_new.clear()
+ edges_set = set(edges)
for i in range(len(facts_to_be_applied_edge)):
if facts_to_be_applied_edge[i][0]==t:
comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5]
+ # If the component is not in the graph, add it
+ if comp not in edges_set:
+ _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
+ edges_set.add(comp)
+
# Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well
if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static():
# Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute
@@ -339,7 +360,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
mode = 'graph-attribute-fact' if graph_attribute else 'fact'
override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override)
update = u or update
# Update convergence params
@@ -349,11 +370,11 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
changes_cnt += changes
# Resolve inconsistency
else:
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
+ resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode)
else:
- mode = 'graph-attribute-fact' if graph_attribute else 'fact'
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True)
update = u or update
# Update convergence params
@@ -382,50 +403,25 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Nodes
rules_to_remove_idx.clear()
for idx, i in enumerate(rules_to_be_applied_node):
- # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied
- if immediate_node_rule_fire and rules_to_be_applied_node[-1][4]:
- i = rules_to_be_applied_node[-1]
- idx = len(rules_to_be_applied_node) - 1
-
- if i[0]==t:
- comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5]
- sources, targets, edge_l = edges_to_be_added_node_rule[idx]
- edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge)
- changes_cnt += changes
-
- # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
- if edge_l.value!='':
- for e in edges_added:
- if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
- override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
-
- update = u or update
-
- # Update convergence params
- if convergence_mode=='delta_bound':
- bound_delta = max(bound_delta, changes)
- else:
- changes_cnt += changes
- # Resolve inconsistency
- else:
- if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
- else:
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
-
- update = u or update
+ if i[0] == t:
+ comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
+ # Check for inconsistencies
+ if check_consistent_node(interpretations_node, comp, (l, bnd)):
+ override = True if update_mode == 'override' else False
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
- # Update convergence params
- if convergence_mode=='delta_bound':
- bound_delta = max(bound_delta, changes)
- else:
- changes_cnt += changes
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency
else:
- # Check for inconsistencies
- if check_consistent_node(interpretations_node, comp, (l, bnd)):
- override = True if update_mode == 'override' else False
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override)
+ if inconsistency_check:
+ resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule')
+ else:
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
@@ -433,32 +429,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
bound_delta = max(bound_delta, changes)
else:
changes_cnt += changes
- # Resolve inconsistency
- else:
- if inconsistency_check:
- resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_node, rule_trace_node_atoms, store_interpretation_changes)
- else:
- u, changes = _update_node(interpretations_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True)
-
- update = u or update
- # Update convergence params
- if convergence_mode=='delta_bound':
- bound_delta = max(bound_delta, changes)
- else:
- changes_cnt += changes
# Delete rules that have been applied from list by adding index to list
- rules_to_remove_idx.append(idx)
-
- # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back
- if immediate:
- # If delta_t=0 we want to apply one rule and go back to the fp operator
- # If delta_t>0 we want to come back here and apply the rest of the rules
- if immediate_edge_rule_fire:
- break
- elif not immediate_edge_rule_fire and u:
- immediate_rule_applied = True
- break
+ rules_to_remove_idx.add(idx)
# Remove from rules to be applied and edges to be applied lists after coming out from loop
rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx])
@@ -469,26 +442,20 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Edges
rules_to_remove_idx.clear()
for idx, i in enumerate(rules_to_be_applied_edge):
- # If we broke from above loop to apply more rules, then break from here
- if immediate_rule_applied and not immediate_edge_rule_fire:
- break
- # If we are coming here from an immediate rule firing with delta_t=0 we have to apply that one rule. Which was just added to the list to_be_applied
- if immediate_edge_rule_fire and rules_to_be_applied_edge[-1][4]:
- i = rules_to_be_applied_edge[-1]
- idx = len(rules_to_be_applied_edge) - 1
-
- if i[0]==t:
- comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5]
+ if i[0] == t:
+ comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
sources, targets, edge_l = edges_to_be_added_edge_rule[idx]
- edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge)
+ edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
changes_cnt += changes
# Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
- if edge_l.value!='':
+ if edge_l.value != '':
for e in edges_added:
+ if interpretations_edge[e].world[edge_l].is_static():
+ continue
if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
@@ -500,9 +467,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Resolve inconsistency
else:
if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
+ resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
else:
- u, changes = _update_edge(interpretations_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
@@ -516,7 +483,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Check for inconsistencies
if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
override = True if update_mode == 'override' else False
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
update = u or update
# Update convergence params
@@ -527,9 +494,9 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Resolve inconsistency
else:
if inconsistency_check:
- resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, atom_trace, rule_trace_edge, rule_trace_edge_atoms, store_interpretation_changes)
+ resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
else:
- u, changes = _update_edge(interpretations_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True)
update = u or update
# Update convergence params
@@ -539,17 +506,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
changes_cnt += changes
# Delete rules that have been applied from list by adding the index to list
- rules_to_remove_idx.append(idx)
-
- # Break out of the apply rules loop if a rule is immediate. Then we go to the fp operator and check for other applicable rules then come back
- if immediate:
- # If t=0 we want to apply one rule and go back to the fp operator
- # If t>0 we want to come back here and apply the rest of the rules
- if immediate_edge_rule_fire:
- break
- elif not immediate_edge_rule_fire and u:
- immediate_rule_applied = True
- break
+ rules_to_remove_idx.add(idx)
# Remove from rules to be applied and edges to be applied lists after coming out from loop
rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx])
@@ -558,61 +515,45 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
rules_to_be_applied_edge_trace[:] = numba.typed.List([rules_to_be_applied_edge_trace[i] for i in range(len(rules_to_be_applied_edge_trace)) if i not in rules_to_remove_idx])
# Fixed point
- # if update or immediate_node_rule_fire or immediate_edge_rule_fire or immediate_rule_applied:
if update:
- # Increase fp operator count only if not an immediate rule
- if not (immediate_node_rule_fire or immediate_edge_rule_fire):
- fp_cnt += 1
+ # Increase fp operator count
+ fp_cnt += 1
- for i in range(len(rules)):
+ # Lists or threadsafe operations (when parallel is on)
+ rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))])
+ rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))])
+ if atom_trace:
+ rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
+ rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
+ edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))])
+
+ for i in prange(len(rules)):
rule = rules[i]
- immediate_rule = rule.is_immediate_rule()
- immediate_node_rule_fire = False
- immediate_edge_rule_fire = False
# Only go through if the rule can be applied within the given timesteps, or we're running until convergence
delta_t = rule.get_delta()
if t + delta_t <= tmax or tmax == -1 or again:
- applicable_node_rules = _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip[i])
- applicable_edge_rules = _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip[i])
+ applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t)
# Loop through applicable rules and add them to the rules to be applied for later or next fp operation
for applicable_rule in applicable_node_rules:
- n, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
+ n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule
# If there is an edge to add or the predicate doesn't exist or the interpretation is not static
- if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static():
+ if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static():
bnd = annotate(annotation_functions, rule, annotations, rule.get_weights())
# Bound annotations in between 0 and 1
bnd_l = min(max(bnd[0], 0), 1)
bnd_u = min(max(bnd[1], 0), 1)
bnd = interval.closed(bnd_l, bnd_u)
max_rules_time = max(max_rules_time, t + delta_t)
- edges_to_be_added_node_rule.append(edges_to_add)
- rules_to_be_applied_node.append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, rule.is_static_rule()))
if atom_trace:
- rules_to_be_applied_node_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
-
- # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance
- # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0
- if delta_t != 0:
- nodes_to_skip[i].append(n)
+ rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
- # Handle loop parameters for the next (maybe) fp operation
- # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire
- # Next fp operation we will skip this rule on this node because anyway there won't be an update
+ # If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop = True
update = False
- if immediate_rule and delta_t == 0:
- # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done.
- in_loop = True
- update = True
- immediate_node_rule_fire = True
- break
-
- # Break, apply immediate rule then come back to check for more applicable rules
- if immediate_node_rule_fire:
- break
for applicable_rule in applicable_edge_rules:
e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
@@ -624,51 +565,43 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
bnd_u = min(max(bnd[1], 0), 1)
bnd = interval.closed(bnd_l, bnd_u)
max_rules_time = max(max_rules_time, t+delta_t)
- edges_to_be_added_edge_rule.append(edges_to_add)
- rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ # edges_to_be_added_edge_rule.append(edges_to_add)
+ edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add)
+ rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, rule.is_static_rule()))
if atom_trace:
- rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
-
- # We apply a rule on a node/edge only once in each timestep to prevent it from being added to the to_be_added list continuously (this will improve performance
- # It's possible to have an annotation function that keeps changing the value of a node/edge. Do this only for delta_t>0
- if delta_t != 0:
- edges_to_skip[i].append(e)
+ # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
+ rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
- # Handle loop parameters for the next (maybe) fp operation
- # If it is a t=0 rule or an immediate rule we want to go back for another fp operation to check for new rules that may fire
- # Next fp operation we will skip this rule on this node because anyway there won't be an update
+ # If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop = True
update = False
- if immediate_rule and delta_t == 0:
- # immediate_rule_fire becomes True because we still need to check for more eligible rules, we're not done.
- in_loop = True
- update = True
- immediate_edge_rule_fire = True
- break
-
- # Break, apply immediate rule then come back to check for more applicable rules
- if immediate_edge_rule_fire:
- break
-
- # Go through all the rules and go back to applying the rules if we came here because of an immediate rule where delta_t>0
- if immediate_rule_applied and not (immediate_node_rule_fire or immediate_edge_rule_fire):
- immediate_rule_applied = False
- in_loop = True
- update = False
- continue
-
+
+ # Update lists after parallel run
+ for i in range(len(rules)):
+ if len(rules_to_be_applied_node_threadsafe[i]) > 0:
+ rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i])
+ if len(rules_to_be_applied_edge_threadsafe[i]) > 0:
+ rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i])
+ if atom_trace:
+ if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0:
+ rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i])
+ if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0:
+ rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i])
+ if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0:
+ edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i])
+
# Check for convergence after each timestep (perfect convergence or convergence specified by user)
# Check number of changed interpretations or max bound change
# User specified convergence
- if convergence_mode=='delta_interpretation':
+ if convergence_mode == 'delta_interpretation':
if changes_cnt <= convergence_delta:
if verbose:
print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation')
# Be consistent with time returned when we don't converge
t += 1
break
- elif convergence_mode=='delta_bound':
+ elif convergence_mode == 'delta_bound':
if bound_delta <= convergence_delta:
if verbose:
print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation')
@@ -678,22 +611,23 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Perfect convergence
# Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable
# If no more rules/facts to be applied
- elif convergence_mode=='perfect_convergence':
- if t>=max_facts_time and t>=max_rules_time:
+ elif convergence_mode == 'perfect_convergence':
+ if t>=max_facts_time and t >= max_rules_time:
if verbose:
print(f'\nConverged at time: {t}')
# Be consistent with time returned when we don't converge
t += 1
break
- # Increment t
+ # Increment t, update number of ground atoms
t += 1
+ num_ga.append(num_ga[-1])
return fp_cnt, t
def add_edge(self, edge, l):
# This function is useful for pyreason gym, called externally
- _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge)
+ _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1)
def add_node(self, node, labels):
# This function is useful for pyreason gym, called externally
@@ -704,19 +638,19 @@ def add_node(self, node, labels):
def delete_edge(self, edge):
# This function is useful for pyreason gym, called externally
- _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge)
+ _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge, self.num_ga)
def delete_node(self, node):
# This function is useful for pyreason gym, called externally
- _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node)
+ _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node, self.num_ga)
- def get_interpretation_dict(self):
+ def get_dict(self):
# This function can be called externally to retrieve a dict of the interpretation values
# Only values in the rule trace will be added
# Initialize interpretations for each time and node and edge
interpretations = {}
- for t in range(self.tmax+1):
+ for t in range(self.time+1):
interpretations[t] = {}
for node in self.nodes:
interpretations[t][node] = InterpretationDict()
@@ -728,9 +662,9 @@ def get_interpretation_dict(self):
time, _, node, l, bnd = change
interpretations[time][node][l._value] = (bnd.lower, bnd.upper)
- # If canonical, update all following timesteps as well
- if self. canonical:
- for t in range(time+1, self.tmax+1):
+ # If persistent, update all following timesteps as well
+ if self. persistent:
+ for t in range(time+1, self.time+1):
interpretations[t][node][l._value] = (bnd.lower, bnd.upper)
# Update interpretation edges
@@ -738,768 +672,697 @@ def get_interpretation_dict(self):
time, _, edge, l, bnd, = change
interpretations[time][edge][l._value] = (bnd.lower, bnd.upper)
- # If canonical, update all following timesteps as well
- if self. canonical:
- for t in range(time+1, self.tmax+1):
+ # If persistent, update all following timesteps as well
+ if self. persistent:
+ for t in range(time+1, self.time+1):
interpretations[t][edge][l._value] = (bnd.lower, bnd.upper)
return interpretations
+ def get_final_num_ground_atoms(self):
+ """
+ This function returns the number of ground atoms after the reasoning process, for the final timestep
+ :return: int: Number of ground atoms in the interpretation after reasoning
+ """
+ ga_cnt = 0
+
+ for node in self.nodes:
+ for l in self.interpretations_node[node].world:
+ ga_cnt += 1
+ for edge in self.edges:
+ for l in self.interpretations_edge[edge].world:
+ ga_cnt += 1
+
+ return ga_cnt
+
+ def get_num_ground_atoms(self):
+ """
+ This function returns the number of ground atoms after the reasoning process, for each timestep
+ :return: list: Number of ground atoms in the interpretation after reasoning for each timestep
+ """
+ if self.num_ga[-1] == 0:
+ self.num_ga.pop()
+ return self.num_ga
+
+ def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]:
+ """
+ This function is used to query the graph after reasoning
+ :param query: A PyReason query object
+ :param return_bool: If True, returns boolean of query, else the bounds associated with it
+ :return: bool, or bounds
+ """
+
+ comp_type = query.get_component_type()
+ component = query.get_component()
+ pred = query.get_predicate()
+ bnd = query.get_bounds()
+
+ # Check if the component exists
+ if comp_type == 'node':
+ if component not in self.nodes:
+ return False if return_bool else (0, 0)
+ else:
+ if component not in self.edges:
+ return False if return_bool else (0, 0)
-@numba.njit(cache=False, parallel=True)
-def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip):
- # Extract rule params
- rule_type = rule.get_type()
- clauses = rule.get_clauses()
- thresholds = rule.get_thresholds()
- ann_fn = rule.get_annotation_function()
- rule_edges = rule.get_edges()
-
- # We return a list of tuples which specify the target nodes/edges that have made the rule body true
- applicable_rules = numba.typed.List.empty_list(node_applicable_rule_type)
-
- # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe
- # One array for each node, then condense into a single list later
- applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(node_applicable_rule_type) for _ in nodes])
-
- # Return empty list if rule is not node rule and if we are not inferring edges
- if rule_type != 'node' and rule_edges[0] == '':
- return applicable_rules
-
- # Steps
- # 1. Loop through all nodes and evaluate each clause with that node and check the truth with the thresholds
- # 2. Inside the clause loop it may be necessary to loop through all nodes/edges while grounding the variables
- # 3. If the clause is true add the qualified nodes and qualified edges to the atom trace, if on. Break otherwise
- # 4. After going through all clauses, add to the annotations list all the annotations of the specified subset. These will be passed to the annotation function
- # 5. Finally, if there are any edges to be added, place them in the list
-
- for piter in prange(len(nodes)):
- target_node = nodes[piter]
- if target_node in nodes_to_skip:
- continue
- # Initialize dictionary where keys are strings (x1, x2 etc.) and values are lists of qualified neighbors
- # Keep track of qualified nodes and qualified edges
- # If it's a node clause update (x1 or x2 etc.) qualified neighbors, if it's an edge clause update the qualified neighbors for the source and target (x1, x2)
- subsets = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
- qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
- qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
- annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
- edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
-
- satisfaction = True
- for i, clause in enumerate(clauses):
- # Unpack clause variables
- clause_type = clause[0]
- clause_label = clause[1]
- clause_variables = clause[2]
- clause_bnd = clause[3]
- clause_operator = clause[4]
-
- # Unpack thresholds
- # This value is total/available
- threshold_quantifier_type = thresholds[i][1][1]
-
- # This is a node clause
- # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes
- if clause_type == 'node':
- clause_var_1 = clause_variables[0]
- subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors)
-
- subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd)
-
- if atom_trace:
- qualified_nodes.append(numba.typed.List(subsets[clause_var_1]))
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in subsets[clause_var_1]:
- a.append(interpretations_node[qn].world[clause_label])
- annotations.append(a)
-
- # This is an edge clause
- elif clause_type == 'edge':
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_source, subset_target = get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes)
-
- # Get qualified edges
- qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, clause_bnd, reverse_graph)
- subsets[clause_var_1] = qe[0]
- subsets[clause_var_2] = qe[1]
+ # Check if the predicate exists
+ if comp_type == 'node':
+ if pred not in self.interpretations_node[component].world:
+ return False if return_bool else (0, 0)
+ else:
+ if pred not in self.interpretations_edge[component].world:
+ return False if return_bool else (0, 0)
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])):
- a.append(interpretations_edge[qe].world[clause_label])
- annotations.append(a)
+ # Check if the bounds are satisfied
+ if comp_type == 'node':
+ if self.interpretations_node[component].world[pred] in bnd:
+ return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper)
else:
- # This is a comparison clause
- # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1]
- # Remember that the predicate in the clause will not contain the "-num" where num is some number.
- # We have to remove that manually while checking
- # Steps:
- # 1. get qualified nodes/edges as well as number associated for first predicate
- # 2. get qualified nodes/edges as well as number associated for second predicate
- # 3. if there's no number in steps 1 or 2 return false clause
- # 4. do comparison with each qualified component from step 1 with each qualified component in step 2
-
- # It's a node comparison
- if len(clause_variables) == 2:
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors)
- subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, neighbors)
-
- # 1, 2
- qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd)
- qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd)
-
- # It's an edge comparison
- elif len(clause_variables) == 4:
- clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3]
- subset_1_source, subset_1_target = get_node_rule_edge_clause_subset(clause_var_1_source, clause_var_1_target, target_node, subsets, neighbors, reverse_neighbors, nodes)
- subset_2_source, subset_2_target = get_node_rule_edge_clause_subset(clause_var_2_source, clause_var_2_target, target_node, subsets, neighbors, reverse_neighbors, nodes)
-
- # 1, 2
- qualified_nodes_for_comparison_1_source, qualified_nodes_for_comparison_1_target, numbers_1 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_1_source, subset_1_target, clause_label, clause_bnd, reverse_graph)
- qualified_nodes_for_comparison_2_source, qualified_nodes_for_comparison_2_target, numbers_2 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_2_source, subset_2_target, clause_label, clause_bnd, reverse_graph)
-
- # Check if thresholds are satisfied
- # If it's a comparison clause we just need to check if the numbers list is not empty (no threshold support)
- if clause_type == 'comparison':
- if len(numbers_1) == 0 or len(numbers_2) == 0:
- satisfaction = False
- # Node comparison. Compare stage
- elif len(clause_variables) == 2:
- satisfaction, qualified_nodes_1, qualified_nodes_2 = compare_numbers_node_predicate(numbers_1, numbers_2, clause_operator, qualified_nodes_for_comparison_1, qualified_nodes_for_comparison_2)
-
- # Update subsets with final qualified nodes
- subsets[clause_var_1] = qualified_nodes_1
- subsets[clause_var_2] = qualified_nodes_2
- qualified_comparison_nodes = numba.typed.List(qualified_nodes_1)
- qualified_comparison_nodes.extend(qualified_nodes_2)
-
- if atom_trace:
- qualified_nodes.append(qualified_comparison_nodes)
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
- # Edge comparison. Compare stage
- else:
- satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator,
- qualified_nodes_for_comparison_1_source,
- qualified_nodes_for_comparison_1_target,
- qualified_nodes_for_comparison_2_source,
- qualified_nodes_for_comparison_2_target)
- # Update subsets with final qualified nodes
- subsets[clause_var_1_source] = qualified_nodes_1_source
- subsets[clause_var_1_target] = qualified_nodes_1_target
- subsets[clause_var_2_source] = qualified_nodes_2_source
- subsets[clause_var_2_target] = qualified_nodes_2_target
-
- qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target))
- qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target))
- qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1)
- qualified_comparison_nodes.extend(qualified_comparison_nodes_2)
-
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(qualified_comparison_nodes)
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
-
- # Non comparison clause
+ return False if return_bool else (0, 0)
+ else:
+ if self.interpretations_edge[component].world[pred] in bnd:
+ return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper)
else:
- if threshold_quantifier_type == 'total':
- if clause_type == 'node':
- neigh_len = len(subset)
- else:
- neigh_len = sum([len(l) for l in subset_target])
-
- # Available is all neighbors that have a particular label with bound inside [0,1]
- elif threshold_quantifier_type == 'available':
- if clause_type == 'node':
- neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0,1)))
- else:
- neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0,1), reverse_graph)[0])
-
- qualified_neigh_len = len(subsets[clause_var_1])
- satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction
-
- # Exit loop if even one clause is not satisfied
- if not satisfaction:
- break
-
- if satisfaction:
- # Collect edges to be added
- source, target, _ = rule_edges
-
- # Edges to be added
- if source != '' and target != '':
- # Check if edge nodes are target
- if source == '__target':
- edges_to_be_added[0].append(target_node)
- elif source in subsets:
- edges_to_be_added[0].extend(subsets[source])
- else:
- edges_to_be_added[0].append(source)
-
- if target == '__target':
- edges_to_be_added[1].append(target_node)
- elif target in subsets:
- edges_to_be_added[1].extend(subsets[target])
- else:
- edges_to_be_added[1].append(target)
+ return False if return_bool else (0, 0)
- # node/edge, annotations, qualified nodes, qualified edges, edges to be added
- applicable_rules_threadsafe[piter] = numba.typed.List([(target_node, annotations, qualified_nodes, qualified_edges, edges_to_be_added)])
- # Merge all threadsafe rules into one single array
- for applicable_rule in applicable_rules_threadsafe:
- if len(applicable_rule) > 0:
- applicable_rules.append(applicable_rule[0])
-
- return applicable_rules
-
-
-@numba.njit(cache=False, parallel=True)
-def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip):
+@numba.njit(cache=True)
+def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t):
# Extract rule params
rule_type = rule.get_type()
+ head_variables = rule.get_head_variables()
clauses = rule.get_clauses()
thresholds = rule.get_thresholds()
ann_fn = rule.get_annotation_function()
rule_edges = rule.get_edges()
- # We return a list of tuples which specify the target nodes/edges that have made the rule body true
- applicable_rules = numba.typed.List.empty_list(edge_applicable_rule_type)
-
- # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe
- # One array for each node, then condense into a single list later
- applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(edge_applicable_rule_type) for _ in edges])
-
- # Return empty list if rule is not node rule
- if rule_type != 'edge':
- return applicable_rules
-
- # Steps
- # 1. Loop through all nodes and evaluate each clause with that node and check the truth with the thresholds
- # 2. Inside the clause loop it may be necessary to loop through all nodes/edges while grounding the variables
- # 3. If the clause is true add the qualified nodes and qualified edges to the atom trace, if on. Break otherwise
- # 4. After going through all clauses, add to the annotations list all the annotations of the specified subset. These will be passed to the annotation function
- # 5. Finally, if there are any edges to be added, place them in the list
-
- for piter in prange(len(edges)):
- target_edge = edges[piter]
- if target_edge in edges_to_skip:
- continue
- # Initialize dictionary where keys are strings (x1, x2 etc.) and values are lists of qualified neighbors
- # Keep track of qualified nodes and qualified edges
- # If it's a node clause update (x1 or x2 etc.) qualified neighbors, if it's an edge clause update the qualified neighbors for the source and target (x1, x2)
- subsets = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
- qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
- qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
- annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
- edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
-
- satisfaction = True
- for i, clause in enumerate(clauses):
- # Unpack clause variables
- clause_type = clause[0]
- clause_label = clause[1]
- clause_variables = clause[2]
- clause_bnd = clause[3]
- clause_operator = clause[4]
-
- # Unpack thresholds
- # This value is total/available
- threshold_quantifier_type = thresholds[i][1][1]
-
- # This is a node clause
- # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes
- if clause_type == 'node':
- clause_var_1 = clause_variables[0]
- subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors)
-
- subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd)
- if atom_trace:
- qualified_nodes.append(numba.typed.List(subsets[clause_var_1]))
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in subsets[clause_var_1]:
- a.append(interpretations_node[qn].world[clause_label])
- annotations.append(a)
-
- # This is an edge clause
- elif clause_type == 'edge':
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_source, subset_target = get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes)
-
- # Get qualified edges
- qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, clause_bnd, reverse_graph)
- subsets[clause_var_1] = qe[0]
- subsets[clause_var_2] = qe[1]
+ if rule_type == 'node':
+ head_var_1 = head_variables[0]
+ else:
+ head_var_1, head_var_2 = head_variables[0], head_variables[1]
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])))
-
- # Add annotations if necessary
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])):
- a.append(interpretations_edge[qe].world[clause_label])
- annotations.append(a)
-
+ # We return a list of tuples which specify the target nodes/edges that have made the rule body true
+ applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type)
+ applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type)
+
+ # Grounding procedure
+ # 1. Go through each clause and check which variables have not been initialized in groundings
+ # 2. Check satisfaction of variables based on the predicate in the clause
+
+ # Grounding variable that maps variables in the body to a list of grounded nodes
+ # Grounding edges that maps edge variables to a list of edges
+ groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
+ groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges)
+
+ # Dependency graph that keeps track of the connections between the variables in the body
+ dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+ dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+
+ nodes_set = set(nodes)
+ edges_set = set(edges)
+
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # This is a node clause
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+
+ # Get subset of nodes that can be used to ground the variable
+ # If we allow ground atoms, we can use the nodes directly
+ if allow_ground_rules and clause_var_1 in nodes_set:
+ grounding = numba.typed.List([clause_var_1])
else:
- # This is a comparison clause
- # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1]
- # Remember that the predicate in the clause will not contain the "-num" where num is some number.
- # We have to remove that manually while checking
- # Steps:
- # 1. get qualified nodes/edges as well as number associated for first predicate
- # 2. get qualified nodes/edges as well as number associated for second predicate
- # 3. if there's no number in steps 1 or 2 return false clause
- # 4. do comparison with each qualified component from step 1 with each qualified component in step 2
-
- # It's a node comparison
- if len(clause_variables) == 2:
- clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
- subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors)
- subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, neighbors)
-
- # 1, 2
- qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd)
- qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd)
-
- # It's an edge comparison
- elif len(clause_variables) == 4:
- clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3]
- subset_1_source, subset_1_target = get_edge_rule_edge_clause_subset(clause_var_1_source, clause_var_1_target, target_edge, subsets, neighbors, reverse_neighbors, nodes)
- subset_2_source, subset_2_target = get_edge_rule_edge_clause_subset(clause_var_2_source, clause_var_2_target, target_edge, subsets, neighbors, reverse_neighbors, nodes)
-
- # 1, 2
- qualified_nodes_for_comparison_1_source, qualified_nodes_for_comparison_1_target, numbers_1 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_1_source, subset_1_target, clause_label, clause_bnd, reverse_graph)
- qualified_nodes_for_comparison_2_source, qualified_nodes_for_comparison_2_target, numbers_2 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_2_source, subset_2_target, clause_label, clause_bnd, reverse_graph)
-
- # Check if thresholds are satisfied
- # If it's a comparison clause we just need to check if the numbers list is not empty (no threshold support)
- if clause_type == 'comparison':
- if len(numbers_1) == 0 or len(numbers_2) == 0:
- satisfaction = False
- # Node comparison. Compare stage
- elif len(clause_variables) == 2:
- satisfaction, qualified_nodes_1, qualified_nodes_2 = compare_numbers_node_predicate(numbers_1, numbers_2, clause_operator, qualified_nodes_for_comparison_1, qualified_nodes_for_comparison_2)
-
- # Update subsets with final qualified nodes
- subsets[clause_var_1] = qualified_nodes_1
- subsets[clause_var_2] = qualified_nodes_2
- qualified_comparison_nodes = numba.typed.List(qualified_nodes_1)
- qualified_comparison_nodes.extend(qualified_nodes_2)
-
- if atom_trace:
- qualified_nodes.append(qualified_comparison_nodes)
- qualified_edges.append(numba.typed.List.empty_list(edge_type))
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qn in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
- # Edge comparison. Compare stage
- else:
- satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator,
- qualified_nodes_for_comparison_1_source,
- qualified_nodes_for_comparison_1_target,
- qualified_nodes_for_comparison_2_source,
- qualified_nodes_for_comparison_2_target)
- # Update subsets with final qualified nodes
- subsets[clause_var_1_source] = qualified_nodes_1_source
- subsets[clause_var_1_target] = qualified_nodes_1_target
- subsets[clause_var_2_source] = qualified_nodes_2_source
- subsets[clause_var_2_target] = qualified_nodes_2_target
-
- qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target))
- qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target))
- qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1)
- qualified_comparison_nodes.extend(qualified_comparison_nodes_2)
-
- if atom_trace:
- qualified_nodes.append(numba.typed.List.empty_list(node_type))
- qualified_edges.append(qualified_comparison_nodes)
-
- # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
- if ann_fn != '':
- a = numba.typed.List.empty_list(interval.interval_type)
- for qe in qualified_comparison_nodes:
- a.append(interval.closed(1, 1))
- annotations.append(a)
-
- # Non comparison clause
+ grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes)
+
+ # Narrow subset based on predicate
+ qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd)
+ groundings[clause_var_1] = qualified_groundings
+ qualified_groundings_set = set(qualified_groundings)
+ for c1, c2 in groundings_edges:
+ if c1 == clause_var_1:
+ groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set])
+ if c2 == clause_var_1:
+ groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set])
+
+ # Check satisfaction of those nodes wrt the threshold
+ # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
+ # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
+ # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
+ satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction
+
+ # This is an edge clause
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+
+ # Get subset of edges that can be used to ground the variables
+ # If we allow ground atoms, we can use the nodes directly
+ if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
+ grounding = numba.typed.List([(clause_var_1, clause_var_2)])
else:
- if threshold_quantifier_type == 'total':
- if clause_type == 'node':
- neigh_len = len(subset)
- else:
- neigh_len = sum([len(l) for l in subset_target])
-
- # Available is all neighbors that have a particular label with bound inside [0,1]
- elif threshold_quantifier_type == 'available':
- if clause_type == 'node':
- neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1)))
- else:
- neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0])
-
- qualified_neigh_len = len(subsets[clause_var_1])
- satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, thresholds[i]) and satisfaction
-
- # Exit loop if even one clause is not satisfied
- if not satisfaction:
- break
-
- # Here we are done going through each clause of the rule
- # If all clauses we're satisfied, proceed to collect annotations and prepare edges to be added
- if satisfaction:
- # Collect edges to be added
- source, target, _ = rule_edges
+ grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges)
+
+ # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)
+ qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd)
+
+ # Check satisfaction of those edges wrt the threshold
+ # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
+ # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
+ # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
+ satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction
+
+ # Update the groundings
+ groundings[clause_var_1] = numba.typed.List.empty_list(node_type)
+ groundings[clause_var_2] = numba.typed.List.empty_list(node_type)
+ groundings_clause_1_set = set(groundings[clause_var_1])
+ groundings_clause_2_set = set(groundings[clause_var_2])
+ for e in qualified_groundings:
+ if e[0] not in groundings_clause_1_set:
+ groundings[clause_var_1].append(e[0])
+ groundings_clause_1_set.add(e[0])
+ if e[1] not in groundings_clause_2_set:
+ groundings[clause_var_2].append(e[1])
+ groundings_clause_2_set.add(e[1])
+
+ # Update the edge groundings (to use later for grounding other clauses with the same variables)
+ groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings
+
+ # Update dependency graph
+ # Add a connection between clause_var_1 -> clause_var_2 and vice versa
+ if clause_var_1 not in dependency_graph_neighbors:
+ dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2])
+ elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]:
+ dependency_graph_neighbors[clause_var_1].append(clause_var_2)
+ if clause_var_2 not in dependency_graph_reverse_neighbors:
+ dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1])
+ elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]:
+ dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1)
+
+ # This is a comparison clause
+ else:
+ pass
- # Edges to be added
- if source != '' and target != '':
- # Check if edge nodes are source/target
- if source == '__source':
- edges_to_be_added[0].append(target_edge[0])
- elif source == '__target':
- edges_to_be_added[0].append(target_edge[1])
- elif source in subsets:
- edges_to_be_added[0].extend(subsets[source])
- else:
- edges_to_be_added[0].append(source)
-
- if target == '__source':
- edges_to_be_added[1].append(target_edge[0])
- elif target == '__target':
- edges_to_be_added[1].append(target_edge[1])
- elif target in subsets:
- edges_to_be_added[1].extend(subsets[target])
- else:
- edges_to_be_added[1].append(target)
+ # Refine the subsets based on any updates
+ if satisfaction:
+ refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
+
+ # If satisfaction is false, break
+ if not satisfaction:
+ break
+
+ # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules
+ # Then continue to setup any edges to be added and annotations
+ # Fill out the rules to be applied lists
+ if satisfaction:
+ # Create temp grounding containers to verify if the head groundings are valid (only for edge rules)
+ # Setup edges to be added and fill rules to be applied
+ # Setup traces and inputs for annotation function
+ # Loop through the clause data and setup final annotations and trace variables
+ # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule
+ if rule_type == 'node':
+ # Loop through all the head variable groundings and add it to the rules to be applied
+ # Loop through the clauses and add appropriate trace data and annotations
+
+ # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph
+ head_var_1_in_nodes = head_var_1 in nodes
+ add_head_var_node_to_graph = False
+ if allow_ground_rules and head_var_1_in_nodes:
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ elif head_var_1 not in groundings:
+ if not head_var_1_in_nodes:
+ add_head_var_node_to_graph = True
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+
+ for head_grounding in groundings[head_var_1]:
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+
+ # Check for satisfaction one more time in case the refining process has changed the groundings
+ satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges)
+ if not satisfaction:
+ continue
+
+ for i, clause in enumerate(clauses):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
- # node/edge, annotations, qualified nodes, qualified edges, edges to be added
- applicable_rules_threadsafe[piter] = numba.typed.List([(target_edge, annotations, qualified_nodes, qualified_edges, edges_to_be_added)])
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
- # Merge all threadsafe rules into one single array
- for applicable_rule in applicable_rules_threadsafe:
- if len(applicable_rule) > 0:
- applicable_rules.append(applicable_rule[0])
+ # 1.
+ if atom_trace:
+ if clause_var_1 == head_var_1:
+ qualified_nodes.append(numba.typed.List([head_grounding]))
+ else:
+ qualified_nodes.append(numba.typed.List(groundings[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ a.append(interpretations_node[head_grounding].world[clause_label])
+ else:
+ for qn in groundings[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ # 1.
+ if atom_trace:
+ # Cases: Both equal, one equal, none equal
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ if clause_var_1 == head_var_1:
+ es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding])
+ qualified_edges.append(es)
+ else:
+ qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)]))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ for e in groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_1:
+ for e in groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ else:
+ for qe in groundings_edges[(clause_var_1, clause_var_2)]:
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+ else:
+ # Comparison clause (we do not handle for now)
+ pass
+
+ # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
+ if add_head_var_node_to_graph:
+ _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
+
+ # For each grounding add a rule to be applied
+ applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
+
+ elif rule_type == 'edge':
+ head_var_1 = head_variables[0]
+ head_var_2 = head_variables[1]
+
+ # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph
+ head_var_1_in_nodes = head_var_1 in nodes
+ head_var_2_in_nodes = head_var_2 in nodes
+ add_head_var_1_node_to_graph = False
+ add_head_var_2_node_to_graph = False
+ add_head_edge_to_graph = False
+ if allow_ground_rules and head_var_1_in_nodes:
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ if allow_ground_rules and head_var_2_in_nodes:
+ groundings[head_var_2] = numba.typed.List([head_var_2])
+
+ if head_var_1 not in groundings:
+ if not head_var_1_in_nodes:
+ add_head_var_1_node_to_graph = True
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ if head_var_2 not in groundings:
+ if not head_var_2_in_nodes:
+ add_head_var_2_node_to_graph = True
+ groundings[head_var_2] = numba.typed.List([head_var_2])
+
+ # Artificially connect the head variables with an edge if both of them were not in the graph
+ if not head_var_1_in_nodes and not head_var_2_in_nodes:
+ add_head_edge_to_graph = True
+
+ head_var_1_groundings = groundings[head_var_1]
+ head_var_2_groundings = groundings[head_var_2]
- return applicable_rules
+ source, target, _ = rule_edges
+ infer_edges = True if source != '' and target != '' else False
+
+ # Prepare the edges that we will loop over.
+ # For infer edges we loop over each combination pair
+ # Else we loop over the valid edges in the graph
+ valid_edge_groundings = numba.typed.List.empty_list(edge_type)
+ for g1 in head_var_1_groundings:
+ for g2 in head_var_2_groundings:
+ if infer_edges:
+ valid_edge_groundings.append((g1, g2))
+ else:
+ if (g1, g2) in edges_set:
+ valid_edge_groundings.append((g1, g2))
+
+ # Loop through the head variable groundings
+ for valid_e in valid_edge_groundings:
+ head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+
+ # Containers to keep track of groundings to make sure that the edge pair is valid
+ # We do this because we cannot know beforehand the edge matches from source groundings to target groundings
+ temp_groundings = groundings.copy()
+ temp_groundings_edges = groundings_edges.copy()
+
+ # Refine the temp groundings for the specific edge head grounding
+ # We update the edge collection as well depending on if there's a match between the clause variables and head variables
+ temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding])
+ temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding])
+ for c1, c2 in temp_groundings_edges.keys():
+ if c1 == head_var_1 and c2 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)])
+ elif c1 == head_var_2 and c2 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)])
+ elif c1 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding])
+ elif c2 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding])
+ elif c1 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding])
+ elif c2 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding])
+
+ refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
+
+ # Check if the thresholds are still satisfied
+ # Check if all clauses are satisfied again in case the refining process changed anything
+ satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges)
+
+ if not satisfaction:
+ continue
+
+ if infer_edges:
+ # Prevent self loops while inferring edges if the clause variables are not the same
+ if source != target and head_var_1_grounding == head_var_2_grounding:
+ continue
+ edges_to_be_added[0].append(head_var_1_grounding)
+ edges_to_be_added[1].append(head_var_2_grounding)
+ for i, clause in enumerate(clauses):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
-@numba.njit(cache=False)
-def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, neighbors):
- # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes
- if clause_var_1 == '__target':
- subset = numba.typed.List([target_node])
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ # 1.
+ if atom_trace:
+ if clause_var_1 == head_var_1:
+ qualified_nodes.append(numba.typed.List([head_var_1_grounding]))
+ elif clause_var_1 == head_var_2:
+ qualified_nodes.append(numba.typed.List([head_var_2_grounding]))
+ else:
+ qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ a.append(interpretations_node[head_var_1_grounding].world[clause_label])
+ elif clause_var_1 == head_var_2:
+ a.append(interpretations_node[head_var_2_grounding].world[clause_label])
+ else:
+ for qn in temp_groundings[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ # 1.
+ if atom_trace:
+ # Cases:
+ # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1)
+ # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2)
+ # 3. None equal
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding])
+ qualified_edges.append(es)
+ else:
+ qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)]))
+
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ else:
+ for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+
+ # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
+ if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1:
+ _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
+ if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2:
+ _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node)
+ if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding):
+ _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
+
+ # For each grounding combination add a rule to be applied
+ # Only if all the clauses have valid groundings
+ # if satisfaction:
+ e = (head_var_1_grounding, head_var_2_grounding)
+ applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
+
+ # Return the applicable rules
+ return applicable_rules_node, applicable_rules_edge
+
+
+@numba.njit(cache=True)
+def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges):
+ # Check if the thresholds are satisfied for each clause
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors):
+ # Loop through the dependency graph and refine the groundings that have connections
+ all_variables_refined = numba.typed.List(clause_variables)
+ variables_just_refined = numba.typed.List(clause_variables)
+ new_variables_refined = numba.typed.List.empty_list(numba.types.string)
+ while len(variables_just_refined) > 0:
+ for refined_variable in variables_just_refined:
+ # Refine all the neighbors of the refined variable
+ if refined_variable in dependency_graph_neighbors:
+ for neighbor in dependency_graph_neighbors[refined_variable]:
+ old_edge_groundings = groundings_edges[(refined_variable, neighbor)]
+ new_node_groundings = groundings[refined_variable]
+
+ # Delete old groundings for the variable being refined
+ del groundings[neighbor]
+ groundings[neighbor] = numba.typed.List.empty_list(node_type)
+
+ # Update the edge groundings and node groundings
+ qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings])
+ groundings_neighbor_set = set(groundings[neighbor])
+ for e in qualified_groundings:
+ if e[1] not in groundings_neighbor_set:
+ groundings[neighbor].append(e[1])
+ groundings_neighbor_set.add(e[1])
+ groundings_edges[(refined_variable, neighbor)] = qualified_groundings
+
+ # Add the neighbor to the list of refined variables so that we can refine for all its neighbors
+ if neighbor not in all_variables_refined:
+ new_variables_refined.append(neighbor)
+
+ if refined_variable in dependency_graph_reverse_neighbors:
+ for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]:
+ old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)]
+ new_node_groundings = groundings[refined_variable]
+
+ # Delete old groundings for the variable being refined
+ del groundings[reverse_neighbor]
+ groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type)
+
+ # Update the edge groundings and node groundings
+ qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings])
+ groundings_reverse_neighbor_set = set(groundings[reverse_neighbor])
+ for e in qualified_groundings:
+ if e[0] not in groundings_reverse_neighbor_set:
+ groundings[reverse_neighbor].append(e[0])
+ groundings_reverse_neighbor_set.add(e[0])
+ groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings
+
+ # Add the neighbor to the list of refined variables so that we can refine for all its neighbors
+ if reverse_neighbor not in all_variables_refined:
+ new_variables_refined.append(reverse_neighbor)
+
+ variables_just_refined = numba.typed.List(new_variables_refined)
+ all_variables_refined.extend(new_variables_refined)
+ new_variables_refined.clear()
+
+
+@numba.njit(cache=True)
+def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold):
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(grounding)
+
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1)))
+
+ qualified_neigh_len = len(qualified_grounding)
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold):
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(grounding)
+
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1)))
+
+ qualified_neigh_len = len(qualified_grounding)
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes):
+ # The groundings for a node clause can be either a previous grounding or all possible nodes
+ if l in predicate_map:
+ grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1]
else:
- subset = neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1]
- return subset
+ grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1]
+ return grounding
-@numba.njit(cache=False)
-def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes):
- # There are 5 cases for predicate(Y,Z):
- # 1. Either one or both of Y, Z are the target node
- # 2. Both predicate variables Y and Z have not been encountered before
- # 3. The source variable Y has not been encountered before but the target variable Z has
- # 4. The target variable Z has not been encountered before but the source variable Y has
- # 5. Both predicate variables Y and Z have been encountered before
+@numba.njit(cache=True)
+def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges):
+ # There are 4 cases for predicate(Y,Z):
+ # 1. Both predicate variables Y and Z have not been encountered before
+ # 2. The source variable Y has not been encountered before but the target variable Z has
+ # 3. The target variable Z has not been encountered before but the source variable Y has
+ # 4. Both predicate variables Y and Z have been encountered before
+ edge_groundings = numba.typed.List.empty_list(edge_type)
# Case 1:
- # Check if 1st variable or 1st and 2nd variables are the target
- if clause_var_1 == '__target':
- subset_source = numba.typed.List([target_node])
-
- # If both variables are the same
- if clause_var_2 == '__target':
- subset_target = numba.typed.List([numba.typed.List([target_node])])
- elif clause_var_2 in subsets:
- subset_target = numba.typed.List([subsets[clause_var_2]])
+ # We replace Y by all nodes and Z by the neighbors of each of these nodes
+ if clause_var_1 not in groundings and clause_var_2 not in groundings:
+ if l in predicate_map:
+ edge_groundings = predicate_map[l]
else:
- subset_target = numba.typed.List([neighbors[target_node]])
-
- # Check if 2nd variable is the target (this means 1st variable isn't the target)
- elif clause_var_2 == '__target':
- subset_source = reverse_neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1]
- subset_target = numba.typed.List([numba.typed.List([target_node]) for _ in subset_source])
+ edge_groundings = edges
# Case 2:
- # We replace Y by all nodes and Z by the neighbors of each of these nodes
- elif clause_var_1 not in subsets and clause_var_2 not in subsets:
- subset_source = numba.typed.List(nodes)
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
-
- # Case 3:
# We replace Y by the sources of Z
- elif clause_var_1 not in subsets and clause_var_2 in subsets:
- subset_source = numba.typed.List.empty_list(node_type)
- subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
-
- for n in subsets[clause_var_2]:
- sources = reverse_neighbors[n]
- for source in sources:
- subset_source.append(source)
- subset_target.append(numba.typed.List([n]))
+ elif clause_var_1 not in groundings and clause_var_2 in groundings:
+ for n in groundings[clause_var_2]:
+ es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]])
+ edge_groundings.extend(es)
- # Case 4:
+ # Case 3:
# We replace Z by the neighbors of Y
- elif clause_var_1 in subsets and clause_var_2 not in subsets:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
+ elif clause_var_1 in groundings and clause_var_2 not in groundings:
+ for n in groundings[clause_var_1]:
+ es = numba.typed.List([(n, nn) for nn in neighbors[n]])
+ edge_groundings.extend(es)
- # Case 5:
+ # Case 4:
+ # We have seen both variables before
else:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source])
+ # We have already seen these two variables in an edge clause
+ if (clause_var_1, clause_var_2) in groundings_edges:
+ edge_groundings = groundings_edges[(clause_var_1, clause_var_2)]
+ # We have seen both these variables but not in an edge clause together
+ else:
+ groundings_clause_var_2_set = set(groundings[clause_var_2])
+ for n in groundings[clause_var_1]:
+ es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set])
+ edge_groundings.extend(es)
- return subset_source, subset_target
+ return edge_groundings
-@numba.njit(cache=False)
-def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, neighbors):
- # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes
- if clause_var_1 == '__source':
- subset = numba.typed.List([target_edge[0]])
- elif clause_var_1 == '__target':
- subset = numba.typed.List([target_edge[1]])
- else:
- subset = neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1]
- return subset
-
-
-@numba.njit(cache=False)
-def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes):
- # There are 5 cases for predicate(Y,Z):
- # 1. Either one or both of Y, Z are the source or target node
- # 2. Both predicate variables Y and Z have not been encountered before
- # 3. The source variable Y has not been encountered before but the target variable Z has
- # 4. The target variable Z has not been encountered before but the source variable Y has
- # 5. Both predicate variables Y and Z have been encountered before
- # Case 1:
- # Check if 1st variable is the source
- if clause_var_1 == '__source':
- subset_source = numba.typed.List([target_edge[0]])
-
- # If 2nd variable is source/target/something else
- if clause_var_2 == '__source':
- subset_target = numba.typed.List([numba.typed.List([target_edge[0]])])
- elif clause_var_2 == '__target':
- subset_target = numba.typed.List([numba.typed.List([target_edge[1]])])
- elif clause_var_2 in subsets:
- subset_target = numba.typed.List([subsets[clause_var_2]])
- else:
- subset_target = numba.typed.List([neighbors[target_edge[0]]])
-
- # if 1st variable is the target
- elif clause_var_1 == '__target':
- subset_source = numba.typed.List([target_edge[1]])
-
- # if 2nd variable is source/target/something else
- if clause_var_2 == '__source':
- subset_target = numba.typed.List([numba.typed.List([target_edge[0]])])
- elif clause_var_2 == '__target':
- subset_target = numba.typed.List([numba.typed.List([target_edge[1]])])
- elif clause_var_2 in subsets:
- subset_target = numba.typed.List([subsets[clause_var_2]])
- else:
- subset_target = numba.typed.List([neighbors[target_edge[1]]])
-
- # Handle the cases where the 2nd variable is source/target but the 1st is something else (cannot be source/target)
- elif clause_var_2 == '__source':
- subset_source = reverse_neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1]
- subset_target = numba.typed.List([numba.typed.List([target_edge[0]]) for _ in subset_source])
+@numba.njit(cache=True)
+def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd):
+ # Filter the grounding by the predicate and bound of the clause
+ qualified_groundings = numba.typed.List.empty_list(node_type)
+ for n in grounding:
+ if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)):
+ qualified_groundings.append(n)
- elif clause_var_2 == '__target':
- subset_source = reverse_neighbors[target_edge[1]] if clause_var_1 not in subsets else subsets[clause_var_1]
- subset_target = numba.typed.List([numba.typed.List([target_edge[1]]) for _ in subset_source])
+ return qualified_groundings
- # Case 2:
- # We replace Y by all nodes and Z by the neighbors of each of these nodes
- elif clause_var_1 not in subsets and clause_var_2 not in subsets:
- subset_source = numba.typed.List(nodes)
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
- # Case 3:
- # We replace Y by the sources of Z
- elif clause_var_1 not in subsets and clause_var_2 in subsets:
- subset_source = numba.typed.List.empty_list(node_type)
- subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+@numba.njit(cache=True)
+def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd):
+ # Filter the grounding by the predicate and bound of the clause
+ qualified_groundings = numba.typed.List.empty_list(edge_type)
+ for e in grounding:
+ if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)):
+ qualified_groundings.append(e)
- for n in subsets[clause_var_2]:
- sources = reverse_neighbors[n]
- for source in sources:
- subset_source.append(source)
- subset_target.append(numba.typed.List([n]))
+ return qualified_groundings
- # Case 4:
- # We replace Z by the neighbors of Y
- elif clause_var_1 in subsets and clause_var_2 not in subsets:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([neighbors[n] for n in subset_source])
- # Case 5:
- else:
- subset_source = subsets[clause_var_1]
- subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source])
-
- return subset_source, subset_target
-
-
-@numba.njit(cache=False)
-def get_qualified_components_node_clause(interpretations_node, candidates, l, bnd):
- # Get all the qualified neighbors for a particular clause
- qualified_nodes = numba.typed.List.empty_list(node_type)
- for n in candidates:
- if is_satisfied_node(interpretations_node, n, (l, bnd)):
- qualified_nodes.append(n)
-
- return qualified_nodes
-
-
-@numba.njit(cache=False)
-def get_qualified_components_node_comparison_clause(interpretations_node, candidates, l, bnd):
- # Get all the qualified neighbors for a particular comparison clause and return them along with the number associated
- qualified_nodes = numba.typed.List.empty_list(node_type)
- qualified_nodes_numbers = numba.typed.List.empty_list(numba.types.float64)
- for n in candidates:
- result, number = is_satisfied_node_comparison(interpretations_node, n, (l, bnd))
- if result:
- qualified_nodes.append(n)
- qualified_nodes_numbers.append(number)
-
- return qualified_nodes, qualified_nodes_numbers
-
-
-@numba.njit(cache=False)
-def get_qualified_components_edge_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph):
- # Get all the qualified sources and targets for a particular clause
- qualified_nodes_source = numba.typed.List.empty_list(node_type)
- qualified_nodes_target = numba.typed.List.empty_list(node_type)
- for i, source in enumerate(candidates_source):
- for target in candidates_target[i]:
- edge = (source, target) if not reverse_graph else (target, source)
- if is_satisfied_edge(interpretations_edge, edge, (l, bnd)):
- qualified_nodes_source.append(source)
- qualified_nodes_target.append(target)
-
- return qualified_nodes_source, qualified_nodes_target
-
-
-@numba.njit(cache=False)
-def get_qualified_components_edge_comparison_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph):
- # Get all the qualified sources and targets for a particular clause
- qualified_nodes_source = numba.typed.List.empty_list(node_type)
- qualified_nodes_target = numba.typed.List.empty_list(node_type)
- qualified_edges_numbers = numba.typed.List.empty_list(numba.types.float64)
- for i, source in enumerate(candidates_source):
- for target in candidates_target[i]:
- edge = (source, target) if not reverse_graph else (target, source)
- result, number = is_satisfied_edge_comparison(interpretations_edge, edge, (l, bnd))
- if result:
- qualified_nodes_source.append(source)
- qualified_nodes_target.append(target)
- qualified_edges_numbers.append(number)
-
- return qualified_nodes_source, qualified_nodes_target, qualified_edges_numbers
-
-
-@numba.njit(cache=False)
-def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, qualified_nodes_2):
- result = False
- final_qualified_nodes_1 = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_2 = numba.typed.List.empty_list(node_type)
- for i in range(len(numbers_1)):
- for j in range(len(numbers_2)):
- if op == '<':
- if numbers_1[i] < numbers_2[j]:
- result = True
- elif op == '<=':
- if numbers_1[i] <= numbers_2[j]:
- result = True
- elif op == '>':
- if numbers_1[i] > numbers_2[j]:
- result = True
- elif op == '>=':
- if numbers_1[i] >= numbers_2[j]:
- result = True
- elif op == '==':
- if numbers_1[i] == numbers_2[j]:
- result = True
- elif op == '!=':
- if numbers_1[i] != numbers_2[j]:
- result = True
-
- if result:
- final_qualified_nodes_1.append(qualified_nodes_1[i])
- final_qualified_nodes_2.append(qualified_nodes_2[j])
- return result, final_qualified_nodes_1, final_qualified_nodes_2
-
-
-@numba.njit(cache=False)
-def compare_numbers_edge_predicate(numbers_1, numbers_2, op, qualified_nodes_1a, qualified_nodes_1b, qualified_nodes_2a, qualified_nodes_2b):
- result = False
- final_qualified_nodes_1a = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_1b = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_2a = numba.typed.List.empty_list(node_type)
- final_qualified_nodes_2b = numba.typed.List.empty_list(node_type)
- for i in range(len(numbers_1)):
- for j in range(len(numbers_2)):
- if op == '<':
- if numbers_1[i] < numbers_2[j]:
- result = True
- elif op == '<=':
- if numbers_1[i] <= numbers_2[j]:
- result = True
- elif op == '>':
- if numbers_1[i] > numbers_2[j]:
- result = True
- elif op == '>=':
- if numbers_1[i] >= numbers_2[j]:
- result = True
- elif op == '==':
- if numbers_1[i] == numbers_2[j]:
- result = True
- elif op == '!=':
- if numbers_1[i] != numbers_2[j]:
- result = True
-
- if result:
- final_qualified_nodes_1a.append(qualified_nodes_1a[i])
- final_qualified_nodes_1b.append(qualified_nodes_1b[i])
- final_qualified_nodes_2a.append(qualified_nodes_2a[j])
- final_qualified_nodes_2b.append(qualified_nodes_2b[j])
- return result, final_qualified_nodes_1a, final_qualified_nodes_1b, final_qualified_nodes_2a, final_qualified_nodes_2b
-
-
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def _satisfies_threshold(num_neigh, num_qualified_component, threshold):
# Checks if qualified neighbors satisfy threshold. This is for one clause
if threshold[1][0]=='number':
@@ -1531,8 +1394,8 @@ def _satisfies_threshold(num_neigh, num_qualified_component, threshold):
return result
-@numba.njit(cache=False)
-def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False):
+@numba.njit(cache=True)
+def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False):
updated = False
# This is to prevent a key error in case the label is a specific label
try:
@@ -1543,6 +1406,11 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
# Add label to world if it is not there
if l not in world.world:
world.world[l] = interval.closed(0, 1)
+ num_ga[t_cnt] += 1
+ if l in predicate_map:
+ predicate_map[l].append(comp)
+ else:
+ predicate_map[l] = numba.typed.List([comp])
# Check if update is necessary with previous bnd
prev_bnd = world.world[l].copy()
@@ -1575,7 +1443,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
if updated:
ip_update_cnt = 0
for p1, p2 in ipl:
- if p1==l:
+ if p1 == l:
+ if p2 not in world.world:
+ world.world[p2] = interval.closed(0, 1)
+ if p2 in predicate_map:
+ predicate_map[p2].append(comp)
+ else:
+ predicate_map[p2] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
@@ -1586,7 +1460,13 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper)))
- if p2==l:
+ if p2 == l:
+ if p1 not in world.world:
+ world.world[p1] = interval.closed(0, 1)
+ if p1 in predicate_map:
+ predicate_map[p1].append(comp)
+ else:
+ predicate_map[p1] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
@@ -1620,8 +1500,8 @@ def _update_node(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
return (False, 0)
-@numba.njit(cache=False)
-def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False):
+@numba.njit(cache=True)
+def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False):
updated = False
# This is to prevent a key error in case the label is a specific label
try:
@@ -1632,6 +1512,11 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
# Add label to world if it is not there
if l not in world.world:
world.world[l] = interval.closed(0, 1)
+ num_ga[t_cnt] += 1
+ if l in predicate_map:
+ predicate_map[l].append(comp)
+ else:
+ predicate_map[l] = numba.typed.List([comp])
# Check if update is necessary with previous bnd
prev_bnd = world.world[l].copy()
@@ -1664,7 +1549,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
if updated:
ip_update_cnt = 0
for p1, p2 in ipl:
- if p1==l:
+ if p1 == l:
+ if p2 not in world.world:
+ world.world[p2] = interval.closed(0, 1)
+ if p2 in predicate_map:
+ predicate_map[p2].append(comp)
+ else:
+ predicate_map[p2] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
@@ -1675,7 +1566,13 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper)))
- if p2==l:
+ if p2 == l:
+ if p1 not in world.world:
+ world.world[p1] = interval.closed(0, 1)
+ if p1 in predicate_map:
+ predicate_map[p1].append(comp)
+ else:
+ predicate_map[p1] = numba.typed.List([comp])
if atom_trace:
_update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
@@ -1686,7 +1583,7 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
updated_bnds.append(world.world[p2])
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper)))
-
+
# Gather convergence data
change = 0
if updated:
@@ -1702,29 +1599,29 @@ def _update_edge(interpretations, comp, na, ipl, rule_trace, fp_cnt, t_cnt, stat
change = max(change, max_delta)
else:
change = 1 + ip_update_cnt
-
+
return (updated, change)
except:
return (False, 0)
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name):
rule_trace.append((qn, qe, prev_bnd.copy(), name))
-
-@numba.njit(cache=False)
+
+@numba.njit(cache=True)
def are_satisfied_node(interpretations, comp, nas):
result = True
- for (label, interval) in nas:
- result = result and is_satisfied_node(interpretations, comp, (label, interval))
+ for (l, bnd) in nas:
+ result = result and is_satisfied_node(interpretations, comp, (l, bnd))
return result
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def is_satisfied_node(interpretations, comp, na):
result = False
- if (not (na[0] is None or na[1] is None)):
+ if not (na[0] is None or na[1] is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
@@ -1736,7 +1633,7 @@ def is_satisfied_node(interpretations, comp, na):
return result
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def is_satisfied_node_comparison(interpretations, comp, na):
result = False
number = 0
@@ -1763,18 +1660,18 @@ def is_satisfied_node_comparison(interpretations, comp, na):
return result, number
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def are_satisfied_edge(interpretations, comp, nas):
result = True
- for (label, interval) in nas:
- result = result and is_satisfied_edge(interpretations, comp, (label, interval))
+ for (l, bnd) in nas:
+ result = result and is_satisfied_edge(interpretations, comp, (l, bnd))
return result
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def is_satisfied_edge(interpretations, comp, na):
result = False
- if (not (na[0] is None or na[1] is None)):
+ if not (na[0] is None or na[1] is None):
# This is to prevent a key error in case the label is a specific label
try:
world = interpretations[comp]
@@ -1786,7 +1683,7 @@ def is_satisfied_edge(interpretations, comp, na):
return result
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def is_satisfied_edge_comparison(interpretations, comp, na):
result = False
number = 0
@@ -1813,7 +1710,7 @@ def is_satisfied_edge_comparison(interpretations, comp, na):
return result, number
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def annotate(annotation_functions, rule, annotations, weights):
func_name = rule.get_annotation_function()
if func_name == '':
@@ -1826,7 +1723,7 @@ def annotate(annotation_functions, rule, annotations, weights):
return annotation
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def check_consistent_node(interpretations, comp, na):
world = interpretations[comp]
if na[0] in world.world:
@@ -1839,7 +1736,7 @@ def check_consistent_node(interpretations, comp, na):
return True
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def check_consistent_edge(interpretations, comp, na):
world = interpretations[comp]
if na[0] in world.world:
@@ -1852,20 +1749,26 @@ def check_consistent_edge(interpretations, comp, na):
return True
-@numba.njit(cache=False)
-def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes):
+@numba.njit(cache=True)
+def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
world = interpretations[comp]
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1)))
+ if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace:
+ name = facts_to_be_applied_trace[idx]
+ elif mode == 'rule' and atom_trace:
+ name = rules_to_be_applied_trace[idx][2]
+ else:
+ name = '-'
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}')
# Resolve inconsistency and set static
world.world[na[0]].set_lower_upper(0, 1)
world.world[na[0]].set_static(True)
for p1, p2 in ipl:
if p1==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}')
world.world[p2].set_lower_upper(0, 1)
world.world[p2].set_static(True)
if store_interpretation_changes:
@@ -1873,28 +1776,34 @@ def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, at
if p2==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}')
world.world[p1].set_lower_upper(0, 1)
world.world[p1].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1)))
- # Add inconsistent predicates to a list
+ # Add inconsistent predicates to a list
-@numba.njit(cache=False)
-def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, atom_trace, rule_trace, rule_trace_atoms, store_interpretation_changes):
+@numba.njit(cache=True)
+def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
w = interpretations[comp]
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1)))
+ if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace:
+ name = facts_to_be_applied_trace[idx]
+ elif mode == 'rule' and atom_trace:
+ name = rules_to_be_applied_trace[idx][2]
+ else:
+ name = '-'
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}')
# Resolve inconsistency and set static
w.world[na[0]].set_lower_upper(0, 1)
w.world[na[0]].set_static(True)
for p1, p2 in ipl:
if p1==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}')
w.world[p2].set_lower_upper(0, 1)
w.world[p2].set_static(True)
if store_interpretation_changes:
@@ -1902,14 +1811,14 @@ def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, at
if p2==na[0]:
if atom_trace:
- _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], 'Inconsistency')
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}')
w.world[p1].set_lower_upper(0, 1)
w.world[p1].set_static(True)
if store_interpretation_changes:
rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1)))
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
nodes.append(node)
neighbors[node] = numba.typed.List.empty_list(node_type)
@@ -1917,8 +1826,8 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
interpretations_node[node] = world.World(numba.typed.List.empty_list(label.label_type))
-@numba.njit(cache=False)
-def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge):
+@numba.njit(cache=True)
+def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t):
# If not a node, add to list of nodes and initialize neighbors
if source not in nodes:
_add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node)
@@ -1938,43 +1847,57 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int
reverse_neighbors[target].append(source)
if l.value!='':
interpretations_edge[edge] = world.World(numba.typed.List([l]))
+ num_ga[t] += 1
+ if l in predicate_map:
+ predicate_map[l].append(edge)
+ else:
+ predicate_map[l] = numba.typed.List([edge])
else:
interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type))
else:
if l not in interpretations_edge[edge].world and l.value!='':
new_edge = True
interpretations_edge[edge].world[l] = interval.closed(0, 1)
+ num_ga[t] += 1
return edge, new_edge
-@numba.njit(cache=False)
-def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge):
+@numba.njit(cache=True)
+def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t):
changes = 0
edges_added = numba.typed.List.empty_list(edge_type)
for source in sources:
for target in targets:
- edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge)
+ edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t)
edges_added.append(edge)
changes = changes+1 if new_edge else changes
return edges_added, changes
-@numba.njit(cache=False)
-def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge):
+@numba.njit(cache=True)
+def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map, num_ga):
source, target = edge
edges.remove(edge)
+ num_ga[-1] -= len(interpretations_edge[edge].world)
del interpretations_edge[edge]
+ for l in predicate_map:
+ if edge in predicate_map[l]:
+ predicate_map[l].remove(edge)
neighbors[source].remove(target)
reverse_neighbors[target].remove(source)
-@numba.njit(cache=False)
-def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
+@numba.njit(cache=True)
+def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map, num_ga):
nodes.remove(node)
+ num_ga[-1] -= len(interpretations_node[node].world)
del interpretations_node[node]
del neighbors[node]
del reverse_neighbors[node]
+ for l in predicate_map:
+ if node in predicate_map[l]:
+ predicate_map[l].remove(node)
# Remove all occurrences of node in neighbors
for n in neighbors.keys():
@@ -1985,7 +1908,7 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node
reverse_neighbors[n].remove(node)
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def float_to_str(value):
number = int(value)
decimal = int(value % 1 * 1000)
@@ -1993,7 +1916,7 @@ def float_to_str(value):
return float_str
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def str_to_float(value):
decimal_pos = value.find('.')
if decimal_pos != -1:
@@ -2006,7 +1929,7 @@ def str_to_float(value):
return value
-@numba.njit(cache=False)
+@numba.njit(cache=True)
def str_to_int(value):
if value[0] == '-':
negative = True
diff --git a/pyreason/scripts/interpretation/temp.py b/pyreason/scripts/interpretation/temp.py
new file mode 100644
index 00000000..9f12fe20
--- /dev/null
+++ b/pyreason/scripts/interpretation/temp.py
@@ -0,0 +1,3012 @@
+from networkx.classes import edges
+
+import pyreason.scripts.numba_wrapper.numba_types.world_type as world
+import pyreason.scripts.numba_wrapper.numba_types.label_type as label
+import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+from pyreason.scripts.interpretation.interpretation_dict import InterpretationDict
+
+import numba
+from numba import objmode, prange
+import time
+
+
+# Types for the dictionaries
+node_type = numba.types.string
+edge_type = numba.types.UniTuple(numba.types.string, 2)
+
+# Type for storing list of qualified nodes/edges
+list_of_nodes = numba.types.ListType(node_type)
+list_of_edges = numba.types.ListType(edge_type)
+
+# Type for storing clause data
+clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string)))
+
+# Type for storing refine clause data
+refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8))
+
+# Type for facts to be applied
+facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
+facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
+
+# Type for returning list of applicable rules for a certain rule
+# node/edge, annotations, qualified nodes, qualified edges, edges to be added
+node_applicable_rule_type = numba.types.Tuple((
+ node_type,
+ numba.types.ListType(numba.types.ListType(interval.interval_type)),
+ numba.types.ListType(numba.types.ListType(node_type)),
+ numba.types.ListType(numba.types.ListType(edge_type)),
+ numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
+))
+
+edge_applicable_rule_type = numba.types.Tuple((
+ edge_type,
+ numba.types.ListType(numba.types.ListType(interval.interval_type)),
+ numba.types.ListType(numba.types.ListType(node_type)),
+ numba.types.ListType(numba.types.ListType(edge_type)),
+ numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
+))
+
+rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
+rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean))
+rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string))
+edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))
+
+
+class Interpretation:
+ available_labels_node = []
+ available_labels_edge = []
+ specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type))
+ specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type))
+
+ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules):
+ self.graph = graph
+ self.ipl = ipl
+ self.annotation_functions = annotation_functions
+ self.reverse_graph = reverse_graph
+ self.atom_trace = atom_trace
+ self.save_graph_attributes_to_rule_trace = save_graph_attributes_to_rule_trace
+ self.canonical = canonical
+ self.inconsistency_check = inconsistency_check
+ self.store_interpretation_changes = store_interpretation_changes
+ self.update_mode = update_mode
+ self.allow_ground_rules = allow_ground_rules
+
+ # For reasoning and reasoning again (contains previous time and previous fp operation cnt)
+ self.time = 0
+ self.prev_reasoning_data = numba.typed.List([0, 0])
+
+ # Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true
+ self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
+ self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type)
+ self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string)
+ self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string)
+ self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type)
+ self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type)
+ self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type)
+ self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
+ self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)))
+ self.edges_to_be_added_edge_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)))
+
+ # Keep track of all the rules that have affected each node/edge at each timestep/fp operation, and all ground atoms that have affected the rules as well. Keep track of previous bounds and name of the rule/fact here
+ self.rule_trace_node_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string)))
+ self.rule_trace_edge_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string)))
+ self.rule_trace_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, node_type, label.label_type, interval.interval_type)))
+ self.rule_trace_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, edge_type, label.label_type, interval.interval_type)))
+
+ # Nodes and edges of the graph
+ self.nodes = numba.typed.List.empty_list(node_type)
+ self.edges = numba.typed.List.empty_list(edge_type)
+ self.nodes.extend(numba.typed.List(self.graph.nodes()))
+ self.edges.extend(numba.typed.List(self.graph.edges()))
+
+ # Make sure they are correct type
+ if len(self.available_labels_node)==0:
+ self.available_labels_node = numba.typed.List.empty_list(label.label_type)
+ else:
+ self.available_labels_node = numba.typed.List(self.available_labels_node)
+ if len(self.available_labels_edge)==0:
+ self.available_labels_edge = numba.typed.List.empty_list(label.label_type)
+ else:
+ self.available_labels_edge = numba.typed.List(self.available_labels_edge)
+
+ self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.available_labels_node, self.specific_node_labels)
+ self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.available_labels_edge, self.specific_edge_labels)
+
+ # Setup graph neighbors and reverse neighbors
+ self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type))
+ for n in self.graph.nodes():
+ l = numba.typed.List.empty_list(node_type)
+ [l.append(neigh) for neigh in self.graph.neighbors(n)]
+ self.neighbors[n] = l
+
+ self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors)
+
+ @staticmethod
+ @numba.njit(cache=True)
+ def _init_reverse_neighbors(neighbors):
+ reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+ for n, neighbor_nodes in neighbors.items():
+ for neighbor_node in neighbor_nodes:
+ if neighbor_node in reverse_neighbors and n not in reverse_neighbors[neighbor_node]:
+ reverse_neighbors[neighbor_node].append(n)
+ else:
+ reverse_neighbors[neighbor_node] = numba.typed.List([n])
+ # This makes sure each node has a value
+ if n not in reverse_neighbors:
+ reverse_neighbors[n] = numba.typed.List.empty_list(node_type)
+
+ return reverse_neighbors
+
+ @staticmethod
+ @numba.njit(cache=True)
+ def _init_interpretations_node(nodes, available_labels, specific_labels):
+ interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type)
+ predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes)
+ # General labels
+ for n in nodes:
+ interpretations[n] = world.World(available_labels)
+ # Specific labels
+ for l, ns in specific_labels.items():
+ for n in ns:
+ interpretations[n].world[l] = interval.closed(0.0, 1.0)
+
+ for l in available_labels:
+ predicate_map[l] = numba.typed.List(nodes)
+
+ for l, ns in specific_labels.items():
+ predicate_map[l] = numba.typed.List(ns)
+
+ return interpretations, predicate_map
+
+ @staticmethod
+ @numba.njit(cache=True)
+ def _init_interpretations_edge(edges, available_labels, specific_labels):
+ interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type)
+ predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges)
+ # General labels
+ for e in edges:
+ interpretations[e] = world.World(available_labels)
+ # Specific labels
+ for l, es in specific_labels.items():
+ for e in es:
+ interpretations[e].world[l] = interval.closed(0.0, 1.0)
+
+ for l in available_labels:
+ predicate_map[l] = numba.typed.List(edges)
+
+ for l, es in specific_labels.items():
+ predicate_map[l] = numba.typed.List(es)
+
+ return interpretations, predicate_map
+
+ @staticmethod
+ @numba.njit(cache=True)
+ def _init_convergence(convergence_bound_threshold, convergence_threshold):
+ if convergence_bound_threshold==-1 and convergence_threshold==-1:
+ convergence_mode = 'perfect_convergence'
+ convergence_delta = 0
+ elif convergence_bound_threshold==-1:
+ convergence_mode = 'delta_interpretation'
+ convergence_delta = convergence_threshold
+ else:
+ convergence_mode = 'delta_bound'
+ convergence_delta = convergence_bound_threshold
+ return convergence_mode, convergence_delta
+
+ def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False):
+ self.tmax = tmax
+ self._convergence_mode, self._convergence_delta = self._init_convergence(convergence_bound_threshold, convergence_threshold)
+ max_facts_time = self._init_facts(facts_node, facts_edge, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.atom_trace)
+ self._start_fp(rules, max_facts_time, verbose, again)
+
+ @staticmethod
+ @numba.njit(cache=True)
+ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, atom_trace):
+ max_time = 0
+ for fact in facts_node:
+ for t in range(fact.get_time_lower(), fact.get_time_upper() + 1):
+ max_time = max(max_time, t)
+ name = fact.get_name()
+ graph_attribute = True if name=='graph-attribute-fact' else False
+ facts_to_be_applied_node.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute))
+ if atom_trace:
+ facts_to_be_applied_node_trace.append(fact.get_name())
+ for fact in facts_edge:
+ for t in range(fact.get_time_lower(), fact.get_time_upper() + 1):
+ max_time = max(max_time, t)
+ name = fact.get_name()
+ graph_attribute = True if name=='graph-attribute-fact' else False
+ facts_to_be_applied_edge.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute))
+ if atom_trace:
+ facts_to_be_applied_edge_trace.append(fact.get_name())
+ return max_time
+
+ def _start_fp(self, rules, max_facts_time, verbose, again):
+ fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again)
+ self.time = t - 1
+ # If we need to reason again, store the next timestep to start from
+ self.prev_reasoning_data[0] = t
+ self.prev_reasoning_data[1] = fp_cnt
+ if verbose:
+ print('Fixed Point iterations:', fp_cnt)
+
+ @staticmethod
+ @numba.njit(cache=True, parallel=False)
+ def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again):
+ t = prev_reasoning_data[0]
+ fp_cnt = prev_reasoning_data[1]
+ max_rules_time = 0
+ timestep_loop = True
+ facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
+ facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
+ rules_to_remove_idx = set()
+ rules_to_remove_idx.add(-1)
+ while timestep_loop:
+ if t==tmax:
+ timestep_loop = False
+ if verbose:
+ with objmode():
+ print('Timestep:', t, flush=True)
+ # Reset Interpretation at beginning of timestep if non-canonical
+ if t>0 and not canonical:
+ # Reset nodes (only if not static)
+ for n in nodes:
+ w = interpretations_node[n].world
+ for l in w:
+ if not w[l].is_static():
+ w[l].reset()
+
+ # Reset edges (only if not static)
+ for e in edges:
+ w = interpretations_edge[e].world
+ for l in w:
+ if not w[l].is_static():
+ w[l].reset()
+
+ # Convergence parameters
+ changes_cnt = 0
+ bound_delta = 0
+ update = False
+
+ # Start by applying facts
+ # Nodes
+ with objmode(start='f8'):
+ start=time.time()
+ facts_to_be_applied_node_new.clear()
+ nodes_set = set(nodes)
+ for i in range(len(facts_to_be_applied_node)):
+ if facts_to_be_applied_node[i][0] == t:
+ comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5]
+ # If the component is not in the graph, add it
+ if comp not in nodes_set:
+ _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node)
+ nodes_set.add(comp)
+
+ # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well
+ if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static():
+ # Check if we should even store any of the changes to the rule trace etc.
+ # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute
+ if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes:
+ rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd))
+ if atom_trace:
+ _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i])
+ for p1, p2 in ipl:
+ if p1==l:
+ rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2]))
+ if atom_trace:
+ _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i])
+ elif p2==l:
+ rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1]))
+ if atom_trace:
+ _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i])
+
+ else:
+ # Check for inconsistencies (multiple facts)
+ if check_consistent_node(interpretations_node, comp, (l, bnd)):
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
+ override = True if update_mode == 'override' else False
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency if necessary otherwise override bounds
+ else:
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
+ if inconsistency_check:
+ resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode)
+ else:
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+
+ if static:
+ facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute))
+
+ # If time doesn't match, fact to be applied later
+ else:
+ facts_to_be_applied_node_new.append(facts_to_be_applied_node[i])
+
+ # Update list of facts with ones that have not been applied yet (delete applied facts)
+ facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy()
+ facts_to_be_applied_node_new.clear()
+
+ with objmode():
+ print('Time taken for node facts:', time.time()-start, flush=True)
+
+
+ # Edges
+ with objmode(start='f8'):
+ start = time.time()
+ facts_to_be_applied_edge_new.clear()
+ edges_set = set(edges)
+ for i in range(len(facts_to_be_applied_edge)):
+ if facts_to_be_applied_edge[i][0]==t:
+ comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5]
+ # If the component is not in the graph, add it
+ if comp not in edges_set:
+ _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge)
+ edges_set.add(comp)
+
+ # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well
+ if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static():
+ # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute
+ if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes:
+ rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l]))
+ if atom_trace:
+ _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i])
+ for p1, p2 in ipl:
+ if p1==l:
+ rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2]))
+ if atom_trace:
+ _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i])
+ elif p2==l:
+ rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1]))
+ if atom_trace:
+ _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i])
+ else:
+ # Check for inconsistencies
+ if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
+ override = True if update_mode == 'override' else False
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency
+ else:
+ mode = 'graph-attribute-fact' if graph_attribute else 'fact'
+ if inconsistency_check:
+ resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode)
+ else:
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+
+ if static:
+ facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute))
+
+ # Time doesn't match, fact to be applied later
+ else:
+ facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i])
+
+ # Update list of facts with ones that have not been applied yet (delete applied facts)
+ facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy()
+ facts_to_be_applied_edge_new.clear()
+
+ with objmode():
+ print('Time taken for edge facts:', time.time()-start, flush=True)
+
+ in_loop = True
+ while in_loop:
+ with objmode():
+ print('FP Iteration:', fp_cnt, flush=True)
+ # This will become true only if delta_t = 0 for some rule, otherwise we go to the next timestep
+ in_loop = False
+
+ # Apply the rules that need to be applied at this timestep
+ # Nodes
+ with objmode(start='f8'):
+ start = time.time()
+ rules_to_remove_idx.clear()
+ for idx, i in enumerate(rules_to_be_applied_node):
+ if i[0] == t:
+ comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5]
+ # Check for inconsistencies
+ if check_consistent_node(interpretations_node, comp, (l, bnd)):
+ override = True if update_mode == 'override' else False
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency
+ else:
+ if inconsistency_check:
+ resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule')
+ else:
+ u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+
+ # Delete rules that have been applied from list by adding index to list
+ rules_to_remove_idx.add(idx)
+
+ # Remove from rules to be applied and edges to be applied lists after coming out from loop
+ rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx])
+ edges_to_be_added_node_rule[:] = numba.typed.List([edges_to_be_added_node_rule[i] for i in range(len(edges_to_be_added_node_rule)) if i not in rules_to_remove_idx])
+ if atom_trace:
+ rules_to_be_applied_node_trace[:] = numba.typed.List([rules_to_be_applied_node_trace[i] for i in range(len(rules_to_be_applied_node_trace)) if i not in rules_to_remove_idx])
+
+ with objmode():
+ print('Time taken for node rules:', time.time()-start, flush=True)
+
+ # Edges
+ with objmode(start='f8'):
+ start = time.time()
+ rules_to_remove_idx.clear()
+ for idx, i in enumerate(rules_to_be_applied_edge):
+ if i[0] == t:
+ comp, l, bnd, immediate, set_static = i[1], i[2], i[3], i[4], i[5]
+ sources, targets, edge_l = edges_to_be_added_edge_rule[idx]
+ edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge)
+ changes_cnt += changes
+
+ # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
+ if edge_l.value != '':
+ for e in edges_added:
+ if interpretations_edge[e].world[edge_l].is_static():
+ continue
+ if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
+ override = True if update_mode == 'override' else False
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
+
+ update = u or update
+
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency
+ else:
+ if inconsistency_check:
+ resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
+ else:
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
+
+ update = u or update
+
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+
+ else:
+ # Check for inconsistencies
+ if check_consistent_edge(interpretations_edge, comp, (l, bnd)):
+ override = True if update_mode == 'override' else False
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+ # Resolve inconsistency
+ else:
+ if inconsistency_check:
+ resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule')
+ else:
+ u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True)
+
+ update = u or update
+ # Update convergence params
+ if convergence_mode=='delta_bound':
+ bound_delta = max(bound_delta, changes)
+ else:
+ changes_cnt += changes
+
+ # Delete rules that have been applied from list by adding the index to list
+ rules_to_remove_idx.add(idx)
+
+ # Remove from rules to be applied and edges to be applied lists after coming out from loop
+ rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx])
+ edges_to_be_added_edge_rule[:] = numba.typed.List([edges_to_be_added_edge_rule[i] for i in range(len(edges_to_be_added_edge_rule)) if i not in rules_to_remove_idx])
+ if atom_trace:
+ rules_to_be_applied_edge_trace[:] = numba.typed.List([rules_to_be_applied_edge_trace[i] for i in range(len(rules_to_be_applied_edge_trace)) if i not in rules_to_remove_idx])
+
+ with objmode():
+ print('Time taken for edge rules:', time.time()-start, flush=True)
+
+ # Fixed point
+ # if update or immediate_node_rule_fire or immediate_edge_rule_fire or immediate_rule_applied:
+ if update:
+ # Increase fp operator count
+ fp_cnt += 1
+
+ # Lists or threadsafe operations (when parallel is on)
+ rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))])
+ rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))])
+ if atom_trace:
+ rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
+ rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
+ edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))])
+
+ with objmode(startrules='f8'):
+ startrules = time.time()
+ for i in prange(len(rules)):
+ rule = rules[i]
+ immediate_rule = rule.is_immediate_rule()
+
+ # Only go through if the rule can be applied within the given timesteps, or we're running until convergence
+ delta_t = rule.get_delta()
+ if t + delta_t <= tmax or tmax == -1 or again:
+ with objmode(start='f8'):
+ start = time.time()
+ applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules)
+ with objmode():
+ print('Time taken for grounding:', time.time()-start, flush=True)
+
+ # Loop through applicable rules and add them to the rules to be applied for later or next fp operation
+ for applicable_rule in applicable_node_rules:
+ n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule
+ # If there is an edge to add or the predicate doesn't exist or the interpretation is not static
+ if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static():
+ bnd = annotate(annotation_functions, rule, annotations, rule.get_weights())
+ # Bound annotations in between 0 and 1
+ bnd_l = min(max(bnd[0], 0), 1)
+ bnd_u = min(max(bnd[1], 0), 1)
+ bnd = interval.closed(bnd_l, bnd_u)
+ max_rules_time = max(max_rules_time, t + delta_t)
+ rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ if atom_trace:
+ rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
+
+ # If delta_t is zero we apply the rules and check if more are applicable
+ if delta_t == 0:
+ in_loop = True
+ update = False
+
+ for applicable_rule in applicable_edge_rules:
+ e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
+ # If there is an edge to add or the predicate doesn't exist or the interpretation is not static
+ if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_edge[e].world or not interpretations_edge[e].world[rule.get_target()].is_static():
+ bnd = annotate(annotation_functions, rule, annotations, rule.get_weights())
+ # Bound annotations in between 0 and 1
+ bnd_l = min(max(bnd[0], 0), 1)
+ bnd_u = min(max(bnd[1], 0), 1)
+ bnd = interval.closed(bnd_l, bnd_u)
+ max_rules_time = max(max_rules_time, t+delta_t)
+ # edges_to_be_added_edge_rule.append(edges_to_add)
+ edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add)
+ # rules_to_be_applied_edge.append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, immediate_rule, rule.is_static_rule()))
+ if atom_trace:
+ # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name()))
+ rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name()))
+
+ # If delta_t is zero we apply the rules and check if more are applicable
+ if delta_t == 0:
+ in_loop = True
+ update = False
+
+ with objmode():
+ print('Time taken for ALL rule groundings:', time.time()-startrules, flush=True)
+ # Update lists after parallel run
+ for i in range(len(rules)):
+ if len(rules_to_be_applied_node_threadsafe[i]) > 0:
+ rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i])
+ if len(rules_to_be_applied_edge_threadsafe[i]) > 0:
+ rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i])
+ if atom_trace:
+ if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0:
+ rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i])
+ if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0:
+ rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i])
+ if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0:
+ edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i])
+
+ # Check for convergence after each timestep (perfect convergence or convergence specified by user)
+ # Check number of changed interpretations or max bound change
+ # User specified convergence
+ if convergence_mode == 'delta_interpretation':
+ if changes_cnt <= convergence_delta:
+ if verbose:
+ print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation')
+ # Be consistent with time returned when we don't converge
+ t += 1
+ break
+ elif convergence_mode == 'delta_bound':
+ if bound_delta <= convergence_delta:
+ if verbose:
+ print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation')
+ # Be consistent with time returned when we don't converge
+ t += 1
+ break
+ # Perfect convergence
+ # Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable
+ # If no more rules/facts to be applied
+ elif convergence_mode == 'perfect_convergence':
+ if t>=max_facts_time and t >= max_rules_time:
+ if verbose:
+ print(f'\nConverged at time: {t}')
+ # Be consistent with time returned when we don't converge
+ t += 1
+ break
+
+ # Increment t
+ t += 1
+
+ return fp_cnt, t
+
+ def add_edge(self, edge, l):
+ # This function is useful for pyreason gym, called externally
+ _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge)
+
+ def add_node(self, node, labels):
+ # This function is useful for pyreason gym, called externally
+ if node not in self.nodes:
+ _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node)
+ for l in labels:
+ self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1)
+
+ def delete_edge(self, edge):
+ # This function is useful for pyreason gym, called externally
+ _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge)
+
+ def delete_node(self, node):
+ # This function is useful for pyreason gym, called externally
+ _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node)
+
+ def get_dict(self):
+ # This function can be called externally to retrieve a dict of the interpretation values
+ # Only values in the rule trace will be added
+
+ # Initialize interpretations for each time and node and edge
+ interpretations = {}
+ for t in range(self.time+1):
+ interpretations[t] = {}
+ for node in self.nodes:
+ interpretations[t][node] = InterpretationDict()
+ for edge in self.edges:
+ interpretations[t][edge] = InterpretationDict()
+
+ # Update interpretation nodes
+ for change in self.rule_trace_node:
+ time, _, node, l, bnd = change
+ interpretations[time][node][l._value] = (bnd.lower, bnd.upper)
+
+ # If canonical, update all following timesteps as well
+ if self. canonical:
+ for t in range(time+1, self.time+1):
+ interpretations[t][node][l._value] = (bnd.lower, bnd.upper)
+
+ # Update interpretation edges
+ for change in self.rule_trace_edge:
+ time, _, edge, l, bnd, = change
+ interpretations[time][edge][l._value] = (bnd.lower, bnd.upper)
+
+ # If canonical, update all following timesteps as well
+ if self. canonical:
+ for t in range(time+1, self.time+1):
+ interpretations[t][edge][l._value] = (bnd.lower, bnd.upper)
+
+ return interpretations
+
+ def query(self, query, return_bool=True):
+ """
+ This function is used to query the graph after reasoning
+ :param query: The query string of for `pred(node)` or `pred(edge)` or `pred(node) : [l, u]`
+ :param return_bool: If True, returns boolean of query, else the bounds associated with it
+ :return: bool, or bounds
+ """
+ # Parse the query
+ query = query.replace(' ', '')
+
+ if ':' in query:
+ pred_comp, bounds = query.split(':')
+ bounds = bounds.replace('[', '').replace(']', '')
+ l, u = bounds.split(',')
+ l, u = float(l), float(u)
+ else:
+ if query[0] == '~':
+ pred_comp = query[1:]
+ l, u = 0, 0
+ else:
+ pred_comp = query
+ l, u = 1, 1
+
+ bnd = interval.closed(l, u)
+
+ # Split predicate and component
+ idx = pred_comp.find('(')
+ pred = label.Label(pred_comp[:idx])
+ component = pred_comp[idx + 1:-1]
+
+ if ',' in component:
+ component = tuple(component.split(','))
+ comp_type = 'edge'
+ else:
+ comp_type = 'node'
+
+ # Check if the component exists
+ if comp_type == 'node':
+ if component not in self.nodes:
+ return False if return_bool else (0, 0)
+ else:
+ if component not in self.edges:
+ return False if return_bool else (0, 0)
+
+ # Check if the predicate exists
+ if comp_type == 'node':
+ if pred not in self.interpretations_node[component].world:
+ return False if return_bool else (0, 0)
+ else:
+ if pred not in self.interpretations_edge[component].world:
+ return False if return_bool else (0, 0)
+
+ # Check if the bounds are satisfied
+ if comp_type == 'node':
+ if self.interpretations_node[component].world[pred] in bnd:
+ return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper)
+ else:
+ return False if return_bool else (0, 0)
+ else:
+ if self.interpretations_edge[component].world[pred] in bnd:
+ return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper)
+ else:
+ return False if return_bool else (0, 0)
+
+
+@numba.njit(cache=True)
+def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules):
+ # Extract rule params
+ rule_type = rule.get_type()
+ head_variables = rule.get_head_variables()
+ clauses = rule.get_clauses()
+ thresholds = rule.get_thresholds()
+ ann_fn = rule.get_annotation_function()
+ rule_edges = rule.get_edges()
+
+ if rule_type == 'node':
+ head_var_1 = head_variables[0]
+ else:
+ head_var_1, head_var_2 = head_variables[0], head_variables[1]
+
+ # We return a list of tuples which specify the target nodes/edges that have made the rule body true
+ applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type)
+ applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type)
+
+ # Grounding procedure
+ # 1. Go through each clause and check which variables have not been initialized in groundings
+ # 2. Check satisfaction of variables based on the predicate in the clause
+
+ # Grounding variable that maps variables in the body to a list of grounded nodes
+ # Grounding edges that maps edge variables to a list of edges
+ groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
+ groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges)
+
+ # Dependency graph that keeps track of the connections between the variables in the body
+ dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+ dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes)
+
+ nodes_set = set(nodes)
+ edges_set = set(edges)
+
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # This is a node clause
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+
+ # Get subset of nodes that can be used to ground the variable
+ # If we allow ground atoms, we can use the nodes directly
+ with objmode(start='f8'):
+ start = time.time()
+ if allow_ground_rules and clause_var_1 in nodes_set:
+ grounding = numba.typed.List([clause_var_1])
+ else:
+ grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label)
+ with objmode():
+ print('__ get node groundings:', time.time()-start, flush=True)
+
+ # Narrow subset based on predicate
+ with objmode(start='f8'):
+ start = time.time()
+ qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd)
+ groundings[clause_var_1] = qualified_groundings
+ qualified_groundings_set = set(qualified_groundings)
+ for c1, c2 in groundings_edges:
+ if c1 == clause_var_1:
+ groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set])
+ if c2 == clause_var_1:
+ groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set])
+
+ with objmode():
+ print('__ get node qualified groundings:', time.time()-start, flush=True)
+
+ # Check satisfaction of those nodes wrt the threshold
+ # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
+ # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
+ # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
+ satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction
+
+ # This is an edge clause
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+
+ # Get subset of edges that can be used to ground the variables
+ # If we allow ground atoms, we can use the nodes directly
+ with objmode(start='f8'):
+ start = time.time()
+ if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
+ grounding = numba.typed.List([(clause_var_1, clause_var_2)])
+ else:
+ grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label)
+
+ with objmode():
+ print('__ get edge groundings:', time.time()-start, flush=True)
+
+ with objmode(start='f8'):
+ start = time.time()
+ # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)
+ qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd)
+
+ with objmode():
+ print('__ get edge qualified groundings:', time.time()-start, flush=True)
+
+ # Check satisfaction of those edges wrt the threshold
+ # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule
+ # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges
+ # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0:
+ satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction
+
+ # Update the groundings
+ with objmode(start='f8'):
+ start = time.time()
+ groundings[clause_var_1] = numba.typed.List.empty_list(node_type)
+ groundings[clause_var_2] = numba.typed.List.empty_list(node_type)
+ groundings_clause_1_set = set(groundings[clause_var_1])
+ groundings_clause_2_set = set(groundings[clause_var_2])
+ for e in qualified_groundings:
+ if e[0] not in groundings_clause_1_set:
+ groundings[clause_var_1].append(e[0])
+ groundings_clause_1_set.add(e[0])
+ if e[1] not in groundings_clause_2_set:
+ groundings[clause_var_2].append(e[1])
+ groundings_clause_2_set.add(e[1])
+
+ # Update the edge groundings (to use later for grounding other clauses with the same variables)
+ groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings
+
+ # Update dependency graph
+ # Add a connection between clause_var_1 -> clause_var_2 and vice versa
+ if clause_var_1 not in dependency_graph_neighbors:
+ dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2])
+ elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]:
+ dependency_graph_neighbors[clause_var_1].append(clause_var_2)
+ if clause_var_2 not in dependency_graph_reverse_neighbors:
+ dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1])
+ elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]:
+ dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1)
+
+ with objmode():
+ print('__ updated dependency graph and groundings:', time.time()-start, flush=True)
+
+ # This is a comparison clause
+ else:
+ pass
+
+ # Refine the subsets based on any updates
+ if satisfaction:
+ with objmode(start='f8'):
+ start = time.time()
+ refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
+ with objmode():
+ print('__ refined groundings:', time.time()-start, flush=True)
+
+ # If satisfaction is false, break
+ if not satisfaction:
+ break
+
+ # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules
+ # Then continue to setup any edges to be added and annotations
+ # Fill out the rules to be applied lists
+ if satisfaction:
+ # Create temp grounding containers to verify if the head groundings are valid (only for edge rules)
+ # Setup edges to be added and fill rules to be applied
+ # Setup traces and inputs for annotation function
+ # Loop through the clause data and setup final annotations and trace variables
+ # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule
+ if rule_type == 'node':
+ # Loop through all the head variable groundings and add it to the rules to be applied
+ # Loop through the clauses and add appropriate trace data and annotations
+ with objmode(start='f8'):
+ start = time.time()
+
+ # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph
+ head_var_1_in_nodes = head_var_1 in nodes
+ add_head_var_node_to_graph = False
+ if allow_ground_rules and head_var_1_in_nodes:
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ elif head_var_1 not in groundings:
+ if not head_var_1_in_nodes:
+ add_head_var_node_to_graph = True
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+
+ for head_grounding in groundings[head_var_1]:
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+
+ # Check for satisfaction one more time in case the refining process has changed the groundings
+ satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges)
+ if not satisfaction:
+ continue
+
+ for i, clause in enumerate(clauses):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+
+ # 1.
+ if atom_trace:
+ if clause_var_1 == head_var_1:
+ qualified_nodes.append(numba.typed.List([head_grounding]))
+ else:
+ qualified_nodes.append(numba.typed.List(groundings[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ a.append(interpretations_node[head_grounding].world[clause_label])
+ else:
+ for qn in groundings[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ # 1.
+ if atom_trace:
+ # Cases: Both equal, one equal, none equal
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ if clause_var_1 == head_var_1:
+ es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding])
+ qualified_edges.append(es)
+ else:
+ qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)]))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ for e in groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_1:
+ for e in groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ else:
+ for qe in groundings_edges[(clause_var_1, clause_var_2)]:
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+ else:
+ # Comparison clause (we do not handle for now)
+ pass
+
+ # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
+ if add_head_var_node_to_graph:
+ _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
+
+ # For each grounding add a rule to be applied
+ applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
+
+ with objmode():
+ print('__ rule fire node head grounding:', time.time()-start, flush=True)
+
+ elif rule_type == 'edge':
+ head_var_1 = head_variables[0]
+ head_var_2 = head_variables[1]
+ with objmode(start='f8'):
+ start = time.time()
+
+ # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph
+ head_var_1_in_nodes = head_var_1 in nodes
+ head_var_2_in_nodes = head_var_2 in nodes
+ add_head_var_1_node_to_graph = False
+ add_head_var_2_node_to_graph = False
+ add_head_edge_to_graph = False
+ if allow_ground_rules and head_var_1_in_nodes:
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ if allow_ground_rules and head_var_2_in_nodes:
+ groundings[head_var_2] = numba.typed.List([head_var_2])
+
+ if head_var_1 not in groundings:
+ if not head_var_1_in_nodes:
+ add_head_var_1_node_to_graph = True
+ groundings[head_var_1] = numba.typed.List([head_var_1])
+ if head_var_2 not in groundings:
+ if not head_var_2_in_nodes:
+ add_head_var_2_node_to_graph = True
+ groundings[head_var_2] = numba.typed.List([head_var_2])
+
+ # Artificially connect the head variables with an edge if both of them were not in the graph
+ if not head_var_1_in_nodes and not head_var_2_in_nodes:
+ add_head_edge_to_graph = True
+
+ head_var_1_groundings = groundings[head_var_1]
+ head_var_2_groundings = groundings[head_var_2]
+
+ source, target, _ = rule_edges
+ infer_edges = True if source != '' and target != '' else False
+
+ # Prepare the edges that we will loop over.
+ # For infer edges we loop over each combination pair
+ # Else we loop over the valid edges in the graph
+ valid_edge_groundings = numba.typed.List.empty_list(edge_type)
+ for g1 in head_var_1_groundings:
+ for g2 in head_var_2_groundings:
+ if infer_edges:
+ valid_edge_groundings.append((g1, g2))
+ else:
+ if (g1, g2) in edges_set:
+ valid_edge_groundings.append((g1, g2))
+
+ # Loop through the head variable groundings
+ for valid_e in valid_edge_groundings:
+ head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+
+ # Containers to keep track of groundings to make sure that the edge pair is valid
+ # We do this because we cannot know beforehand the edge matches from source groundings to target groundings
+ temp_groundings = groundings.copy()
+ temp_groundings_edges = groundings_edges.copy()
+
+ # Refine the temp groundings for the specific edge head grounding
+ # We update the edge collection as well depending on if there's a match between the clause variables and head variables
+ temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding])
+ temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding])
+ for c1, c2 in temp_groundings_edges.keys():
+ if c1 == head_var_1 and c2 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)])
+ elif c1 == head_var_2 and c2 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)])
+ elif c1 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding])
+ elif c2 == head_var_1:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding])
+ elif c1 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding])
+ elif c2 == head_var_2:
+ temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding])
+
+ refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors)
+
+ # Check if the thresholds are still satisfied
+ # Check if all clauses are satisfied again in case the refining process changed anything
+ satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges)
+
+ if not satisfaction:
+ continue
+
+ if infer_edges:
+ # Prevent self loops while inferring edges if the clause variables are not the same
+ if source != target and head_var_1_grounding == head_var_2_grounding:
+ continue
+ edges_to_be_added[0].append(head_var_1_grounding)
+ edges_to_be_added[1].append(head_var_2_grounding)
+
+ for i, clause in enumerate(clauses):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ # 1.
+ if atom_trace:
+ if clause_var_1 == head_var_1:
+ qualified_nodes.append(numba.typed.List([head_var_1_grounding]))
+ elif clause_var_1 == head_var_2:
+ qualified_nodes.append(numba.typed.List([head_var_2_grounding]))
+ else:
+ qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1:
+ a.append(interpretations_node[head_var_1_grounding].world[clause_label])
+ elif clause_var_1 == head_var_2:
+ a.append(interpretations_node[head_var_2_grounding].world[clause_label])
+ else:
+ for qn in temp_groundings[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ # 1.
+ if atom_trace:
+ # Cases:
+ # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1)
+ # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2)
+ # 3. None equal
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_1 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_1:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding])
+ qualified_edges.append(es)
+ elif clause_var_2 == head_var_2:
+ es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding])
+ qualified_edges.append(es)
+ else:
+ qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)]))
+
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ if clause_var_1 == head_var_1 and clause_var_2 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_1 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[0] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_1:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_var_1_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ elif clause_var_2 == head_var_2:
+ for e in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ if e[1] == head_var_2_grounding:
+ a.append(interpretations_edge[e].world[clause_label])
+ else:
+ for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]:
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+
+ # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
+ if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1:
+ _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
+ if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2:
+ _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node)
+ if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding):
+ _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge)
+
+ # For each grounding combination add a rule to be applied
+ # Only if all the clauses have valid groundings
+ # if satisfaction:
+ e = (head_var_1_grounding, head_var_2_grounding)
+ applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added))
+
+ with objmode():
+ print('__ rule fire edge head grounding:', time.time()-start, flush=True)
+
+ # Return the applicable rules
+ return applicable_rules_node, applicable_rules_edge
+
+
+@numba.njit(cache=True)
+def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges):
+ # Check if the thresholds are satisfied for each clause
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def _ground_node_rule(rule, interpretations_node, interpretations_edge, nodes, neighbors, reverse_neighbors, atom_trace, reverse_graph, nodes_to_skip):
+ # Extract rule params
+ rule_type = rule.get_type()
+ clauses = rule.get_clauses()
+ thresholds = rule.get_thresholds()
+ ann_fn = rule.get_annotation_function()
+ rule_edges = rule.get_edges()
+
+ # We return a list of tuples which specify the target nodes/edges that have made the rule body true
+ applicable_rules = numba.typed.List.empty_list(node_applicable_rule_type)
+
+ # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe
+ # One array for each node, then condense into a single list later
+ applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(node_applicable_rule_type) for _ in nodes])
+
+ # Return empty list if rule is not node rule and if we are not inferring edges
+ if rule_type != 'node' and rule_edges[0] == '':
+ return applicable_rules
+
+ # Steps
+ # 1. Loop through all nodes and evaluate each clause with that node and check the truth with the thresholds
+ # 2. Inside the clause loop it may be necessary to loop through all nodes/edges while grounding the variables
+ # 3. If the clause is true add the qualified nodes and qualified edges to the atom trace, if on. Break otherwise
+ # 4. After going through all clauses, add to the annotations list all the annotations of the specified subset. These will be passed to the annotation function
+ # 5. Finally, if there are any edges to be added, place them in the list
+
+ for piter in prange(len(nodes)):
+ target_node = nodes[piter]
+ if target_node in nodes_to_skip:
+ continue
+ # Initialize dictionary where keys are strings (x1, x2 etc.) and values are lists of qualified neighbors
+ # Keep track of qualified nodes and qualified edges
+ # If it's a node clause update (x1 or x2 etc.) qualified neighbors, if it's an edge clause update the qualified neighbors for the source and target (x1, x2)
+ subsets = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+ clause_type_and_variables = numba.typed.List.empty_list(clause_data)
+
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # This is a node clause
+ # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ subset = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes)
+
+ subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd)
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('node', clause_label, numba.typed.List([clause_var_1])))
+
+ # This is an edge clause
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ subset_source, subset_target = get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes)
+
+ # Get qualified edges
+ qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, clause_bnd, reverse_graph)
+ subsets[clause_var_1] = qe[0]
+ subsets[clause_var_2] = qe[1]
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('edge', clause_label, numba.typed.List([clause_var_1, clause_var_2])))
+
+ else:
+ # This is a comparison clause
+ # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1]
+ # Remember that the predicate in the clause will not contain the "-num" where num is some number.
+ # We have to remove that manually while checking
+ # Steps:
+ # 1. get qualified nodes/edges as well as number associated for first predicate
+ # 2. get qualified nodes/edges as well as number associated for second predicate
+ # 3. if there's no number in steps 1 or 2 return false clause
+ # 4. do comparison with each qualified component from step 1 with each qualified component in step 2
+
+ # It's a node comparison
+ if len(clause_variables) == 2:
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ subset_1 = get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes)
+ subset_2 = get_node_rule_node_clause_subset(clause_var_2, target_node, subsets, nodes)
+
+ # 1, 2
+ qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd)
+ qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd)
+
+ # It's an edge comparison
+ elif len(clause_variables) == 4:
+ clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3]
+ subset_1_source, subset_1_target = get_node_rule_edge_clause_subset(clause_var_1_source, clause_var_1_target, target_node, subsets, neighbors, reverse_neighbors, nodes)
+ subset_2_source, subset_2_target = get_node_rule_edge_clause_subset(clause_var_2_source, clause_var_2_target, target_node, subsets, neighbors, reverse_neighbors, nodes)
+
+ # 1, 2
+ qualified_nodes_for_comparison_1_source, qualified_nodes_for_comparison_1_target, numbers_1 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_1_source, subset_1_target, clause_label, clause_bnd, reverse_graph)
+ qualified_nodes_for_comparison_2_source, qualified_nodes_for_comparison_2_target, numbers_2 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_2_source, subset_2_target, clause_label, clause_bnd, reverse_graph)
+
+ # Check if thresholds are satisfied
+ # If it's a comparison clause we just need to check if the numbers list is not empty (no threshold support)
+ if clause_type == 'comparison':
+ if len(numbers_1) == 0 or len(numbers_2) == 0:
+ satisfaction = False
+ # Node comparison. Compare stage
+ elif len(clause_variables) == 2:
+ satisfaction, qualified_nodes_1, qualified_nodes_2 = compare_numbers_node_predicate(numbers_1, numbers_2, clause_operator, qualified_nodes_for_comparison_1, qualified_nodes_for_comparison_2)
+
+ # Update subsets with final qualified nodes
+ subsets[clause_var_1] = qualified_nodes_1
+ subsets[clause_var_2] = qualified_nodes_2
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('node-comparison', clause_label, numba.typed.List([clause_var_1, clause_var_2])))
+
+ # Edge comparison. Compare stage
+ else:
+ satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator,
+ qualified_nodes_for_comparison_1_source,
+ qualified_nodes_for_comparison_1_target,
+ qualified_nodes_for_comparison_2_source,
+ qualified_nodes_for_comparison_2_target)
+ # Update subsets with final qualified nodes
+ subsets[clause_var_1_source] = qualified_nodes_1_source
+ subsets[clause_var_1_target] = qualified_nodes_1_target
+ subsets[clause_var_2_source] = qualified_nodes_2_source
+ subsets[clause_var_2_target] = qualified_nodes_2_target
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('edge-comparison', clause_label, numba.typed.List([clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target])))
+
+ # Non comparison clause
+ else:
+ if clause_type == 'node':
+ satisfaction = check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, thresholds[i]) and satisfaction
+ else:
+ satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, thresholds[i], reverse_graph) and satisfaction
+
+ # Refine subsets based on any updates
+ if satisfaction:
+ satisfaction = refine_subsets_node_rule(interpretations_edge, clauses, i, subsets, target_node, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph) and satisfaction
+
+ # Exit loop if even one clause is not satisfied
+ if not satisfaction:
+ break
+
+ if satisfaction:
+ # Collect edges to be added
+ source, target, _ = rule_edges
+
+ # Edges to be added
+ if source != '' and target != '':
+ # Check if edge nodes are target
+ if source == '__target':
+ edges_to_be_added[0].append(target_node)
+ elif source in subsets:
+ edges_to_be_added[0].extend(subsets[source])
+ else:
+ edges_to_be_added[0].append(source)
+
+ if target == '__target':
+ edges_to_be_added[1].append(target_node)
+ elif target in subsets:
+ edges_to_be_added[1].extend(subsets[target])
+ else:
+ edges_to_be_added[1].append(target)
+
+ # Loop through the clause data and setup final annotations and trace variables
+ # 1. Add qualified nodes/edges to trace
+ # 2. Add annotations to annotation function variable
+ for i, clause in enumerate(clause_type_and_variables):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(numba.typed.List(subsets[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qn in subsets[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])):
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'node-comparison':
+ clause_var_1, clause_var_2 = clause_variables
+ qualified_nodes_1 = subsets[clause_var_1]
+ qualified_nodes_2 = subsets[clause_var_2]
+ qualified_comparison_nodes = numba.typed.List(qualified_nodes_1)
+ qualified_comparison_nodes.extend(qualified_nodes_2)
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(qualified_comparison_nodes)
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qn in qualified_comparison_nodes:
+ a.append(interval.closed(1, 1))
+ annotations.append(a)
+
+ elif clause_type == 'edge-comparison':
+ clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables
+ qualified_nodes_1_source = subsets[clause_var_1_source]
+ qualified_nodes_1_target = subsets[clause_var_1_target]
+ qualified_nodes_2_source = subsets[clause_var_2_source]
+ qualified_nodes_2_target = subsets[clause_var_2_target]
+ qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target))
+ qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target))
+ qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1)
+ qualified_comparison_nodes.extend(qualified_comparison_nodes_2)
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ qualified_edges.append(qualified_comparison_nodes)
+ # 2.
+ # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qe in qualified_comparison_nodes:
+ a.append(interval.closed(1, 1))
+ annotations.append(a)
+
+ # node/edge, annotations, qualified nodes, qualified edges, edges to be added
+ applicable_rules_threadsafe[piter] = numba.typed.List([(target_node, annotations, qualified_nodes, qualified_edges, edges_to_be_added)])
+
+ # Merge all threadsafe rules into one single array
+ for applicable_rule in applicable_rules_threadsafe:
+ if len(applicable_rule) > 0:
+ applicable_rules.append(applicable_rule[0])
+
+ return applicable_rules
+
+
+@numba.njit(cache=True)
+def _ground_edge_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, reverse_graph, edges_to_skip):
+ # Extract rule params
+ rule_type = rule.get_type()
+ clauses = rule.get_clauses()
+ thresholds = rule.get_thresholds()
+ ann_fn = rule.get_annotation_function()
+ rule_edges = rule.get_edges()
+
+ # We return a list of tuples which specify the target nodes/edges that have made the rule body true
+ applicable_rules = numba.typed.List.empty_list(edge_applicable_rule_type)
+
+ # Create pre-allocated data structure so that parallel code does not need to use "append" to be threadsafe
+ # One array for each node, then condense into a single list later
+ applicable_rules_threadsafe = numba.typed.List([numba.typed.List.empty_list(edge_applicable_rule_type) for _ in edges])
+
+ # Return empty list if rule is not node rule
+ if rule_type != 'edge':
+ return applicable_rules
+
+ # Steps
+ # 1. Loop through all nodes and evaluate each clause with that node and check the truth with the thresholds
+ # 2. Inside the clause loop it may be necessary to loop through all nodes/edges while grounding the variables
+ # 3. If the clause is true add the qualified nodes and qualified edges to the atom trace, if on. Break otherwise
+ # 4. After going through all clauses, add to the annotations list all the annotations of the specified subset. These will be passed to the annotation function
+ # 5. Finally, if there are any edges to be added, place them in the list
+
+ for piter in prange(len(edges)):
+ target_edge = edges[piter]
+ if target_edge in edges_to_skip:
+ continue
+ # Initialize dictionary where keys are strings (x1, x2 etc.) and values are lists of qualified neighbors
+ # Keep track of qualified nodes and qualified edges
+ # If it's a node clause update (x1 or x2 etc.) qualified neighbors, if it's an edge clause update the qualified neighbors for the source and target (x1, x2)
+ subsets = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes)
+ qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type))
+ edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1])
+ clause_type_and_variables = numba.typed.List.empty_list(clause_data)
+
+ satisfaction = True
+ for i, clause in enumerate(clauses):
+ # Unpack clause variables
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # This is a node clause
+ # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ subset = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes)
+
+ subsets[clause_var_1] = get_qualified_components_node_clause(interpretations_node, subset, clause_label, clause_bnd)
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('node', clause_label, numba.typed.List([clause_var_1])))
+
+ # This is an edge clause
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ subset_source, subset_target = get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes)
+
+ # Get qualified edges
+ qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, clause_bnd, reverse_graph)
+ subsets[clause_var_1] = qe[0]
+ subsets[clause_var_2] = qe[1]
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('edge', clause_label, numba.typed.List([clause_var_1, clause_var_2])))
+
+ else:
+ # This is a comparison clause
+ # Make sure there is at least one ground atom such that pred-num(x) : [1,1] or pred-num(x,y) : [1,1]
+ # Remember that the predicate in the clause will not contain the "-num" where num is some number.
+ # We have to remove that manually while checking
+ # Steps:
+ # 1. get qualified nodes/edges as well as number associated for first predicate
+ # 2. get qualified nodes/edges as well as number associated for second predicate
+ # 3. if there's no number in steps 1 or 2 return false clause
+ # 4. do comparison with each qualified component from step 1 with each qualified component in step 2
+
+ # It's a node comparison
+ if len(clause_variables) == 2:
+ clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1]
+ subset_1 = get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes)
+ subset_2 = get_edge_rule_node_clause_subset(clause_var_2, target_edge, subsets, nodes)
+
+ # 1, 2
+ qualified_nodes_for_comparison_1, numbers_1 = get_qualified_components_node_comparison_clause(interpretations_node, subset_1, clause_label, clause_bnd)
+ qualified_nodes_for_comparison_2, numbers_2 = get_qualified_components_node_comparison_clause(interpretations_node, subset_2, clause_label, clause_bnd)
+
+ # It's an edge comparison
+ elif len(clause_variables) == 4:
+ clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables[0], clause_variables[1], clause_variables[2], clause_variables[3]
+ subset_1_source, subset_1_target = get_edge_rule_edge_clause_subset(clause_var_1_source, clause_var_1_target, target_edge, subsets, neighbors, reverse_neighbors, nodes)
+ subset_2_source, subset_2_target = get_edge_rule_edge_clause_subset(clause_var_2_source, clause_var_2_target, target_edge, subsets, neighbors, reverse_neighbors, nodes)
+
+ # 1, 2
+ qualified_nodes_for_comparison_1_source, qualified_nodes_for_comparison_1_target, numbers_1 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_1_source, subset_1_target, clause_label, clause_bnd, reverse_graph)
+ qualified_nodes_for_comparison_2_source, qualified_nodes_for_comparison_2_target, numbers_2 = get_qualified_components_edge_comparison_clause(interpretations_edge, subset_2_source, subset_2_target, clause_label, clause_bnd, reverse_graph)
+
+ # Check if thresholds are satisfied
+ # If it's a comparison clause we just need to check if the numbers list is not empty (no threshold support)
+ if clause_type == 'comparison':
+ if len(numbers_1) == 0 or len(numbers_2) == 0:
+ satisfaction = False
+ # Node comparison. Compare stage
+ elif len(clause_variables) == 2:
+ satisfaction, qualified_nodes_1, qualified_nodes_2 = compare_numbers_node_predicate(numbers_1, numbers_2, clause_operator, qualified_nodes_for_comparison_1, qualified_nodes_for_comparison_2)
+
+ # Update subsets with final qualified nodes
+ subsets[clause_var_1] = qualified_nodes_1
+ subsets[clause_var_2] = qualified_nodes_2
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('node-comparison', clause_label, numba.typed.List([clause_var_1, clause_var_2])))
+
+ # Edge comparison. Compare stage
+ else:
+ satisfaction, qualified_nodes_1_source, qualified_nodes_1_target, qualified_nodes_2_source, qualified_nodes_2_target = compare_numbers_edge_predicate(numbers_1, numbers_2, clause_operator,
+ qualified_nodes_for_comparison_1_source,
+ qualified_nodes_for_comparison_1_target,
+ qualified_nodes_for_comparison_2_source,
+ qualified_nodes_for_comparison_2_target)
+ # Update subsets with final qualified nodes
+ subsets[clause_var_1_source] = qualified_nodes_1_source
+ subsets[clause_var_1_target] = qualified_nodes_1_target
+ subsets[clause_var_2_source] = qualified_nodes_2_source
+ subsets[clause_var_2_target] = qualified_nodes_2_target
+
+ # Save data for annotations and atom trace
+ clause_type_and_variables.append(('edge-comparison', clause_label, numba.typed.List([clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target])))
+
+ # Non comparison clause
+ else:
+ if clause_type == 'node':
+ satisfaction = check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, thresholds[i]) and satisfaction
+ else:
+ satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, thresholds[i], reverse_graph) and satisfaction
+
+ # Refine subsets based on any updates
+ if satisfaction:
+ satisfaction = refine_subsets_edge_rule(interpretations_edge, clauses, i, subsets, target_edge, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph) and satisfaction
+
+ # Exit loop if even one clause is not satisfied
+ if not satisfaction:
+ break
+
+ # Here we are done going through each clause of the rule
+ # If all clauses we're satisfied, proceed to collect annotations and prepare edges to be added
+ if satisfaction:
+ # Loop through the clause data and setup final annotations and trace variables
+ # 1. Add qualified nodes/edges to trace
+ # 2. Add annotations to annotation function variable
+ for i, clause in enumerate(clause_type_and_variables):
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+
+ if clause_type == 'node':
+ clause_var_1 = clause_variables[0]
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(numba.typed.List(subsets[clause_var_1]))
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qn in subsets[clause_var_1]:
+ a.append(interpretations_node[qn].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'edge':
+ clause_var_1, clause_var_2 = clause_variables
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ qualified_edges.append(numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])))
+ # 2.
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qe in numba.typed.List(zip(subsets[clause_var_1], subsets[clause_var_2])):
+ a.append(interpretations_edge[qe].world[clause_label])
+ annotations.append(a)
+
+ elif clause_type == 'node-comparison':
+ clause_var_1, clause_var_2 = clause_variables
+ qualified_nodes_1 = subsets[clause_var_1]
+ qualified_nodes_2 = subsets[clause_var_2]
+ qualified_comparison_nodes = numba.typed.List(qualified_nodes_1)
+ qualified_comparison_nodes.extend(qualified_nodes_2)
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(qualified_comparison_nodes)
+ qualified_edges.append(numba.typed.List.empty_list(edge_type))
+ # 2.
+ # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qn in qualified_comparison_nodes:
+ a.append(interval.closed(1, 1))
+ annotations.append(a)
+
+ elif clause_type == 'edge-comparison':
+ clause_var_1_source, clause_var_1_target, clause_var_2_source, clause_var_2_target = clause_variables
+ qualified_nodes_1_source = subsets[clause_var_1_source]
+ qualified_nodes_1_target = subsets[clause_var_1_target]
+ qualified_nodes_2_source = subsets[clause_var_2_source]
+ qualified_nodes_2_target = subsets[clause_var_2_target]
+ qualified_comparison_nodes_1 = numba.typed.List(zip(qualified_nodes_1_source, qualified_nodes_1_target))
+ qualified_comparison_nodes_2 = numba.typed.List(zip(qualified_nodes_2_source, qualified_nodes_2_target))
+ qualified_comparison_nodes = numba.typed.List(qualified_comparison_nodes_1)
+ qualified_comparison_nodes.extend(qualified_comparison_nodes_2)
+ # 1.
+ if atom_trace:
+ qualified_nodes.append(numba.typed.List.empty_list(node_type))
+ qualified_edges.append(qualified_comparison_nodes)
+ # 2.
+ # Add annotations for comparison clause. For now, we don't distinguish between LHS and RHS annotations
+ if ann_fn != '':
+ a = numba.typed.List.empty_list(interval.interval_type)
+ for qe in qualified_comparison_nodes:
+ a.append(interval.closed(1, 1))
+ annotations.append(a)
+ # node/edge, annotations, qualified nodes, qualified edges, edges to be added
+ applicable_rules_threadsafe[piter] = numba.typed.List([(target_edge, annotations, qualified_nodes, qualified_edges, edges_to_be_added)])
+
+ # Merge all threadsafe rules into one single array
+ for applicable_rule in applicable_rules_threadsafe:
+ if len(applicable_rule) > 0:
+ applicable_rules.append(applicable_rule[0])
+
+ return applicable_rules
+
+
+@numba.njit(cache=True)
+def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors):
+ # Loop through the dependency graph and refine the groundings that have connections
+ all_variables_refined = numba.typed.List(clause_variables)
+ variables_just_refined = numba.typed.List(clause_variables)
+ new_variables_refined = numba.typed.List.empty_list(numba.types.string)
+ while len(variables_just_refined) > 0:
+ for refined_variable in variables_just_refined:
+ # Refine all the neighbors of the refined variable
+ if refined_variable in dependency_graph_neighbors:
+ for neighbor in dependency_graph_neighbors[refined_variable]:
+ old_edge_groundings = groundings_edges[(refined_variable, neighbor)]
+ new_node_groundings = groundings[refined_variable]
+
+ # Delete old groundings for the variable being refined
+ del groundings[neighbor]
+ groundings[neighbor] = numba.typed.List.empty_list(node_type)
+
+ # Update the edge groundings and node groundings
+ qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings])
+ groundings_neighbor_set = set(groundings[neighbor])
+ for e in qualified_groundings:
+ if e[1] not in groundings_neighbor_set:
+ groundings[neighbor].append(e[1])
+ groundings_neighbor_set.add(e[1])
+ groundings_edges[(refined_variable, neighbor)] = qualified_groundings
+
+ # Add the neighbor to the list of refined variables so that we can refine for all its neighbors
+ if neighbor not in all_variables_refined:
+ new_variables_refined.append(neighbor)
+
+ if refined_variable in dependency_graph_reverse_neighbors:
+ for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]:
+ old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)]
+ new_node_groundings = groundings[refined_variable]
+
+ # Delete old groundings for the variable being refined
+ del groundings[reverse_neighbor]
+ groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type)
+
+ # Update the edge groundings and node groundings
+ qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings])
+ groundings_reverse_neighbor_set = set(groundings[reverse_neighbor])
+ for e in qualified_groundings:
+ if e[0] not in groundings_reverse_neighbor_set:
+ groundings[reverse_neighbor].append(e[0])
+ groundings_reverse_neighbor_set.add(e[0])
+ groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings
+
+ # Add the neighbor to the list of refined variables so that we can refine for all its neighbors
+ if reverse_neighbor not in all_variables_refined:
+ new_variables_refined.append(reverse_neighbor)
+
+ variables_just_refined = numba.typed.List(new_variables_refined)
+ all_variables_refined.extend(new_variables_refined)
+ new_variables_refined.clear()
+
+
+@numba.njit(cache=True)
+def refine_subsets_node_rule(interpretations_edge, clauses, i, subsets, target_node, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph):
+ """NOTE: DEPRECATED"""
+ # Loop through all clauses till clause i-1 and update subsets recursively
+ # Then check if the clause still satisfies the thresholds
+ clause = clauses[i]
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # Keep track of the variables that were refined (start with clause_variables) and variables that need refining
+ satisfaction = True
+ all_variables_refined = numba.typed.List(clause_variables)
+ variables_just_refined = numba.typed.List(clause_variables)
+ new_variables_refined = numba.typed.List.empty_list(numba.types.string)
+ while len(variables_just_refined) > 0:
+ for j in range(i):
+ c = clauses[j]
+ c_type = c[0]
+ c_label = c[1]
+ c_variables = c[2]
+ c_bnd = c[3]
+ c_operator = c[4]
+
+ # If it is an edge clause or edge comparison clause, check if any of clause_variables are in c_variables
+ # If yes, then update the variable that is with it in the clause
+ if c_type == 'edge' or (c_type == 'comparison' and len(c_variables) > 2):
+ for v in variables_just_refined:
+ for k, cv in enumerate(c_variables):
+ if cv == v:
+ # Find which variable needs to be refined, 1st or 2nd.
+ # 2nd variable needs refining
+ if k == 0:
+ refine_idx = 1
+ refine_v = c_variables[1]
+ # 1st variable needs refining
+ elif k == 1:
+ refine_idx = 0
+ refine_v = c_variables[0]
+ # 2nd variable needs refining
+ elif k == 2:
+ refine_idx = 1
+ refine_v = c_variables[3]
+ # 1st variable needs refining
+ else:
+ refine_idx = 0
+ refine_v = c_variables[2]
+
+ # Refine the variable
+ if refine_v not in all_variables_refined:
+ new_variables_refined.append(refine_v)
+
+ if c_type == 'edge':
+ clause_var_1, clause_var_2 = (refine_v, cv) if refine_idx == 0 else (cv, refine_v)
+ del subsets[refine_v]
+ subset_source, subset_target = get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes)
+
+ # Get qualified edges
+ qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, c_label, c_bnd, reverse_graph)
+ subsets[clause_var_1] = qe[0]
+ subsets[clause_var_2] = qe[1]
+
+ # Check if we still satisfy the clause
+ satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, c_label, thresholds[j], reverse_graph) and satisfaction
+ else:
+ # We do not support refinement for comparison clauses
+ pass
+
+ if not satisfaction:
+ return satisfaction
+
+ variables_just_refined = numba.typed.List(new_variables_refined)
+ all_variables_refined.extend(new_variables_refined)
+ new_variables_refined.clear()
+
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def refine_subsets_edge_rule(interpretations_edge, clauses, i, subsets, target_edge, neighbors, reverse_neighbors, nodes, thresholds, reverse_graph):
+ """NOTE: DEPRECATED"""
+ # Loop through all clauses till clause i-1 and update subsets recursively
+ # Then check if the clause still satisfies the thresholds
+ clause = clauses[i]
+ clause_type = clause[0]
+ clause_label = clause[1]
+ clause_variables = clause[2]
+ clause_bnd = clause[3]
+ clause_operator = clause[4]
+
+ # Keep track of the variables that were refined (start with clause_variables) and variables that need refining
+ satisfaction = True
+ all_variables_refined = numba.typed.List(clause_variables)
+ variables_just_refined = numba.typed.List(clause_variables)
+ new_variables_refined = numba.typed.List.empty_list(numba.types.string)
+ while len(variables_just_refined) > 0:
+ for j in range(i):
+ c = clauses[j]
+ c_type = c[0]
+ c_label = c[1]
+ c_variables = c[2]
+ c_bnd = c[3]
+ c_operator = c[4]
+
+ # If it is an edge clause or edge comparison clause, check if any of clause_variables are in c_variables
+ # If yes, then update the variable that is with it in the clause
+ if c_type == 'edge' or (c_type == 'comparison' and len(c_variables) > 2):
+ for v in variables_just_refined:
+ for k, cv in enumerate(c_variables):
+ if cv == v:
+ # Find which variable needs to be refined, 1st or 2nd.
+ # 2nd variable needs refining
+ if k == 0:
+ refine_idx = 1
+ refine_v = c_variables[1]
+ # 1st variable needs refining
+ elif k == 1:
+ refine_idx = 0
+ refine_v = c_variables[0]
+ # 2nd variable needs refining
+ elif k == 2:
+ refine_idx = 1
+ refine_v = c_variables[3]
+ # 1st variable needs refining
+ else:
+ refine_idx = 0
+ refine_v = c_variables[2]
+
+ # Refine the variable
+ if refine_v not in all_variables_refined:
+ new_variables_refined.append(refine_v)
+
+ if c_type == 'edge':
+ clause_var_1, clause_var_2 = (refine_v, cv) if refine_idx == 0 else (cv, refine_v)
+ del subsets[refine_v]
+ subset_source, subset_target = get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes)
+
+ # Get qualified edges
+ qe = get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, c_label, c_bnd, reverse_graph)
+ subsets[clause_var_1] = qe[0]
+ subsets[clause_var_2] = qe[1]
+
+ # Check if we still satisfy the clause
+ satisfaction = check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, c_label, thresholds[j], reverse_graph) and satisfaction
+ else:
+ # We do not support refinement for comparison clauses
+ pass
+
+ if not satisfaction:
+ return satisfaction
+
+ variables_just_refined = numba.typed.List(new_variables_refined)
+ all_variables_refined.extend(new_variables_refined)
+ new_variables_refined.clear()
+
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold):
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(grounding)
+
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1)))
+
+ qualified_neigh_len = len(qualified_grounding)
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold):
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(grounding)
+
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1)))
+
+ qualified_neigh_len = len(qualified_grounding)
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def check_node_clause_satisfaction(interpretations_node, subsets, subset, clause_var_1, clause_label, threshold):
+ """NOTE: DEPRECATED"""
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = len(subset)
+
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_components_node_clause(interpretations_node, subset, clause_label, interval.closed(0, 1)))
+
+ # Only take length of clause_var_1 because length of subsets of var_1 and var_2 are supposed to be equal
+ qualified_neigh_len = len(subsets[clause_var_1])
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def check_edge_clause_satisfaction(interpretations_edge, subsets, subset_source, subset_target, clause_var_1, clause_label, threshold, reverse_graph):
+ """NOTE: DEPRECATED"""
+ threshold_quantifier_type = threshold[1][1]
+ if threshold_quantifier_type == 'total':
+ neigh_len = sum([len(l) for l in subset_target])
+
+ # Available is all neighbors that have a particular label with bound inside [0,1]
+ elif threshold_quantifier_type == 'available':
+ neigh_len = len(get_qualified_components_edge_clause(interpretations_edge, subset_source, subset_target, clause_label, interval.closed(0, 1), reverse_graph)[0])
+
+ qualified_neigh_len = len(subsets[clause_var_1])
+ satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold)
+ return satisfaction
+
+
+@numba.njit(cache=True)
+def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l):
+ # The groundings for a node clause can be either a previous grounding or all possible nodes
+ grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1]
+ return grounding
+
+
+@numba.njit(cache=True)
+def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l):
+ # There are 4 cases for predicate(Y,Z):
+ # 1. Both predicate variables Y and Z have not been encountered before
+ # 2. The source variable Y has not been encountered before but the target variable Z has
+ # 3. The target variable Z has not been encountered before but the source variable Y has
+ # 4. Both predicate variables Y and Z have been encountered before
+ edge_groundings = numba.typed.List.empty_list(edge_type)
+
+ # Case 1:
+ # We replace Y by all nodes and Z by the neighbors of each of these nodes
+ if clause_var_1 not in groundings and clause_var_2 not in groundings:
+ edge_groundings = predicate_map[l]
+
+ # Case 2:
+ # We replace Y by the sources of Z
+ elif clause_var_1 not in groundings and clause_var_2 in groundings:
+ for n in groundings[clause_var_2]:
+ es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]])
+ edge_groundings.extend(es)
+
+ # Case 3:
+ # We replace Z by the neighbors of Y
+ elif clause_var_1 in groundings and clause_var_2 not in groundings:
+ for n in groundings[clause_var_1]:
+ es = numba.typed.List([(n, nn) for nn in neighbors[n]])
+ edge_groundings.extend(es)
+
+ # Case 4:
+ # We have seen both variables before
+ else:
+ # We have already seen these two variables in an edge clause
+ if (clause_var_1, clause_var_2) in groundings_edges:
+ edge_groundings = groundings_edges[(clause_var_1, clause_var_2)]
+ # We have seen both these variables but not in an edge clause together
+ else:
+ groundings_clause_var_2_set = set(groundings[clause_var_2])
+ for n in groundings[clause_var_1]:
+ es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set])
+ edge_groundings.extend(es)
+
+ return edge_groundings
+
+
+@numba.njit(cache=True)
+def get_node_rule_node_clause_subset(clause_var_1, target_node, subsets, nodes):
+ """NOTE: DEPRECATED"""
+ # The groundings for node clauses are either the target node, neighbors of the target node, or an existing subset of nodes
+ if clause_var_1 == '__target':
+ subset = numba.typed.List([target_node])
+ else:
+ nodes_without_target = numba.typed.List([n for n in nodes if n != target_node])
+ subset = nodes_without_target if clause_var_1 not in subsets else subsets[clause_var_1]
+ return subset
+
+
+@numba.njit(cache=True)
+def get_node_rule_edge_clause_subset(clause_var_1, clause_var_2, target_node, subsets, neighbors, reverse_neighbors, nodes):
+ """NOTE: DEPRECATED"""
+ # There are 5 cases for predicate(Y,Z):
+ # 1. Either one or both of Y, Z are the target node
+ # 2. Both predicate variables Y and Z have not been encountered before
+ # 3. The source variable Y has not been encountered before but the target variable Z has
+ # 4. The target variable Z has not been encountered before but the source variable Y has
+ # 5. Both predicate variables Y and Z have been encountered before
+
+ # Case 1:
+ # Check if 1st variable or 1st and 2nd variables are the target
+ if clause_var_1 == '__target':
+ subset_source = numba.typed.List([target_node])
+
+ # If both variables are the same
+ if clause_var_2 == '__target':
+ subset_target = numba.typed.List([numba.typed.List([target_node])])
+ elif clause_var_2 in subsets:
+ subset_target = numba.typed.List([subsets[clause_var_2]])
+ else:
+ subset_target = numba.typed.List([neighbors[target_node]])
+
+ # Check if 2nd variable is the target (this means 1st variable isn't the target)
+ elif clause_var_2 == '__target':
+ subset_source = reverse_neighbors[target_node] if clause_var_1 not in subsets else subsets[clause_var_1]
+ subset_target = numba.typed.List([numba.typed.List([target_node]) for _ in subset_source])
+
+ # Case 2:
+ # We replace Y by all nodes (except target_node) and Z by the neighbors of each of these nodes
+ elif clause_var_1 not in subsets and clause_var_2 not in subsets:
+ subset_source = numba.typed.List([n for n in nodes if n != target_node])
+ subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_node]) for n in subset_source])
+
+ # Case 3:
+ # We replace Y by the sources of Z
+ elif clause_var_1 not in subsets and clause_var_2 in subsets:
+ subset_source = numba.typed.List.empty_list(node_type)
+ subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+
+ for n in subsets[clause_var_2]:
+ sources = reverse_neighbors[n]
+ for source in sources:
+ if source != target_node:
+ subset_source.append(source)
+ subset_target.append(numba.typed.List([n]))
+
+ # Case 4:
+ # We replace Z by the neighbors of Y
+ elif clause_var_1 in subsets and clause_var_2 not in subsets:
+ subset_source = subsets[clause_var_1]
+ subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_node]) for n in subset_source])
+
+ # Case 5:
+ else:
+ subset_source = subsets[clause_var_1]
+ subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source])
+
+ # If any of the subsets are empty return them in the correct type
+ if len(subset_source) == 0:
+ subset_source = numba.typed.List.empty_list(node_type)
+ subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ # If any sub lists in subset target are empty, add correct type for empty list
+ for i, t in enumerate(subset_target):
+ if len(t) == 0:
+ subset_target[i] = numba.typed.List.empty_list(node_type)
+
+ return subset_source, subset_target
+
+
+@numba.njit(cache=True)
+def get_edge_rule_node_clause_subset(clause_var_1, target_edge, subsets, nodes):
+ """NOTE: DEPRECATED"""
+ # The groundings for node clauses are either the source, target, neighbors of the source node, or an existing subset of nodes
+ if clause_var_1 == '__source':
+ subset = numba.typed.List([target_edge[0]])
+ elif clause_var_1 == '__target':
+ subset = numba.typed.List([target_edge[1]])
+ else:
+ nodes_without_target_or_source = numba.typed.List([n for n in nodes if n != target_edge[0] and n != target_edge[1]])
+ subset = nodes_without_target_or_source if clause_var_1 not in subsets else subsets[clause_var_1]
+ return subset
+
+
+@numba.njit(cache=True)
+def get_edge_rule_edge_clause_subset(clause_var_1, clause_var_2, target_edge, subsets, neighbors, reverse_neighbors, nodes):
+ """NOTE: DEPRECATED"""
+ # There are 5 cases for predicate(Y,Z):
+ # 1. Either one or both of Y, Z are the source or target node
+ # 2. Both predicate variables Y and Z have not been encountered before
+ # 3. The source variable Y has not been encountered before but the target variable Z has
+ # 4. The target variable Z has not been encountered before but the source variable Y has
+ # 5. Both predicate variables Y and Z have been encountered before
+ # Case 1:
+ # Check if 1st variable is the source
+ if clause_var_1 == '__source':
+ subset_source = numba.typed.List([target_edge[0]])
+
+ # If 2nd variable is source/target/something else
+ if clause_var_2 == '__source':
+ subset_target = numba.typed.List([numba.typed.List([target_edge[0]])])
+ elif clause_var_2 == '__target':
+ subset_target = numba.typed.List([numba.typed.List([target_edge[1]])])
+ elif clause_var_2 in subsets:
+ subset_target = numba.typed.List([subsets[clause_var_2]])
+ else:
+ subset_target = numba.typed.List([neighbors[target_edge[0]]])
+
+ # if 1st variable is the target
+ elif clause_var_1 == '__target':
+ subset_source = numba.typed.List([target_edge[1]])
+
+ # if 2nd variable is source/target/something else
+ if clause_var_2 == '__source':
+ subset_target = numba.typed.List([numba.typed.List([target_edge[0]])])
+ elif clause_var_2 == '__target':
+ subset_target = numba.typed.List([numba.typed.List([target_edge[1]])])
+ elif clause_var_2 in subsets:
+ subset_target = numba.typed.List([subsets[clause_var_2]])
+ else:
+ subset_target = numba.typed.List([neighbors[target_edge[1]]])
+
+ # Handle the cases where the 2nd variable is source/target but the 1st is something else (cannot be source/target)
+ elif clause_var_2 == '__source':
+ subset_source = reverse_neighbors[target_edge[0]] if clause_var_1 not in subsets else subsets[clause_var_1]
+ subset_target = numba.typed.List([numba.typed.List([target_edge[0]]) for _ in subset_source])
+
+ elif clause_var_2 == '__target':
+ subset_source = reverse_neighbors[target_edge[1]] if clause_var_1 not in subsets else subsets[clause_var_1]
+ subset_target = numba.typed.List([numba.typed.List([target_edge[1]]) for _ in subset_source])
+
+ # Case 2:
+ # We replace Y by all nodes (except source/target) and Z by the neighbors of each of these nodes
+ elif clause_var_1 not in subsets and clause_var_2 not in subsets:
+ subset_source = numba.typed.List([n for n in nodes if n != target_edge[0] and n != target_edge[1]])
+ subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_edge[0] and nn != target_edge[1]]) for n in subset_source])
+
+ # Case 3:
+ # We replace Y by the sources of Z
+ elif clause_var_1 not in subsets and clause_var_2 in subsets:
+ subset_source = numba.typed.List.empty_list(node_type)
+ subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+
+ for n in subsets[clause_var_2]:
+ sources = reverse_neighbors[n]
+ for source in sources:
+ if source != target_edge[0] and source != target_edge[1]:
+ subset_source.append(source)
+ subset_target.append(numba.typed.List([n]))
+
+ # Case 4:
+ # We replace Z by the neighbors of Y
+ elif clause_var_1 in subsets and clause_var_2 not in subsets:
+ subset_source = subsets[clause_var_1]
+ subset_target = numba.typed.List([numba.typed.List([nn for nn in neighbors[n] if nn != target_edge[0] and nn != target_edge[1]]) for n in subset_source])
+
+ # Case 5:
+ else:
+ subset_source = subsets[clause_var_1]
+ subset_target = numba.typed.List([subsets[clause_var_2] for _ in subset_source])
+
+ # If any of the subsets are empty return them in the correct type
+ if len(subset_source) == 0:
+ subset_source = numba.typed.List.empty_list(node_type)
+ subset_target = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ # If any sub lists in subset target are empty, add correct type for empty list
+ for i, t in enumerate(subset_target):
+ if len(t) == 0:
+ subset_target[i] = numba.typed.List.empty_list(node_type)
+
+ return subset_source, subset_target
+
+
+@numba.njit(cache=True)
+def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd):
+ # Filter the grounding by the predicate and bound of the clause
+ qualified_groundings = numba.typed.List.empty_list(node_type)
+ for n in grounding:
+ if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)):
+ qualified_groundings.append(n)
+
+ return qualified_groundings
+
+
+@numba.njit(cache=True)
+def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd):
+ # Filter the grounding by the predicate and bound of the clause
+ qualified_groundings = numba.typed.List.empty_list(edge_type)
+ for e in grounding:
+ if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)):
+ qualified_groundings.append(e)
+
+ return qualified_groundings
+
+
+@numba.njit(cache=True)
+def get_qualified_components_node_clause(interpretations_node, candidates, l, bnd):
+ """NOTE: DEPRECATED"""
+ # Get all the qualified neighbors for a particular clause
+ qualified_nodes = numba.typed.List.empty_list(node_type)
+ for n in candidates:
+ if is_satisfied_node(interpretations_node, n, (l, bnd)) and n not in qualified_nodes:
+ qualified_nodes.append(n)
+
+ return qualified_nodes
+
+
+@numba.njit(cache=True)
+def get_qualified_components_node_comparison_clause(interpretations_node, candidates, l, bnd):
+ """NOTE: DEPRECATED"""
+ # Get all the qualified neighbors for a particular comparison clause and return them along with the number associated
+ qualified_nodes = numba.typed.List.empty_list(node_type)
+ qualified_nodes_numbers = numba.typed.List.empty_list(numba.types.float64)
+ for n in candidates:
+ result, number = is_satisfied_node_comparison(interpretations_node, n, (l, bnd))
+ if result:
+ qualified_nodes.append(n)
+ qualified_nodes_numbers.append(number)
+
+ return qualified_nodes, qualified_nodes_numbers
+
+
+@numba.njit(cache=True)
+def get_qualified_components_edge_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph):
+ """NOTE: DEPRECATED"""
+ # Get all the qualified sources and targets for a particular clause
+ qualified_nodes_source = numba.typed.List.empty_list(node_type)
+ qualified_nodes_target = numba.typed.List.empty_list(node_type)
+ for i, source in enumerate(candidates_source):
+ for target in candidates_target[i]:
+ edge = (source, target) if not reverse_graph else (target, source)
+ if is_satisfied_edge(interpretations_edge, edge, (l, bnd)):
+ qualified_nodes_source.append(source)
+ qualified_nodes_target.append(target)
+
+ return qualified_nodes_source, qualified_nodes_target
+
+
+@numba.njit(cache=True)
+def get_qualified_components_edge_comparison_clause(interpretations_edge, candidates_source, candidates_target, l, bnd, reverse_graph):
+ """NOTE: DEPRECATED"""
+ # Get all the qualified sources and targets for a particular clause
+ qualified_nodes_source = numba.typed.List.empty_list(node_type)
+ qualified_nodes_target = numba.typed.List.empty_list(node_type)
+ qualified_edges_numbers = numba.typed.List.empty_list(numba.types.float64)
+ for i, source in enumerate(candidates_source):
+ for target in candidates_target[i]:
+ edge = (source, target) if not reverse_graph else (target, source)
+ result, number = is_satisfied_edge_comparison(interpretations_edge, edge, (l, bnd))
+ if result:
+ qualified_nodes_source.append(source)
+ qualified_nodes_target.append(target)
+ qualified_edges_numbers.append(number)
+
+ return qualified_nodes_source, qualified_nodes_target, qualified_edges_numbers
+
+
+@numba.njit(cache=True)
+def compare_numbers_node_predicate(numbers_1, numbers_2, op, qualified_nodes_1, qualified_nodes_2):
+ """NOTE: DEPRECATED"""
+ result = False
+ final_qualified_nodes_1 = numba.typed.List.empty_list(node_type)
+ final_qualified_nodes_2 = numba.typed.List.empty_list(node_type)
+ for i in range(len(numbers_1)):
+ for j in range(len(numbers_2)):
+ if op == '<':
+ if numbers_1[i] < numbers_2[j]:
+ result = True
+ elif op == '<=':
+ if numbers_1[i] <= numbers_2[j]:
+ result = True
+ elif op == '>':
+ if numbers_1[i] > numbers_2[j]:
+ result = True
+ elif op == '>=':
+ if numbers_1[i] >= numbers_2[j]:
+ result = True
+ elif op == '==':
+ if numbers_1[i] == numbers_2[j]:
+ result = True
+ elif op == '!=':
+ if numbers_1[i] != numbers_2[j]:
+ result = True
+
+ if result:
+ final_qualified_nodes_1.append(qualified_nodes_1[i])
+ final_qualified_nodes_2.append(qualified_nodes_2[j])
+ return result, final_qualified_nodes_1, final_qualified_nodes_2
+
+
+@numba.njit(cache=True)
+def compare_numbers_edge_predicate(numbers_1, numbers_2, op, qualified_nodes_1a, qualified_nodes_1b, qualified_nodes_2a, qualified_nodes_2b):
+ """NOTE: DEPRECATED"""
+ result = False
+ final_qualified_nodes_1a = numba.typed.List.empty_list(node_type)
+ final_qualified_nodes_1b = numba.typed.List.empty_list(node_type)
+ final_qualified_nodes_2a = numba.typed.List.empty_list(node_type)
+ final_qualified_nodes_2b = numba.typed.List.empty_list(node_type)
+ for i in range(len(numbers_1)):
+ for j in range(len(numbers_2)):
+ if op == '<':
+ if numbers_1[i] < numbers_2[j]:
+ result = True
+ elif op == '<=':
+ if numbers_1[i] <= numbers_2[j]:
+ result = True
+ elif op == '>':
+ if numbers_1[i] > numbers_2[j]:
+ result = True
+ elif op == '>=':
+ if numbers_1[i] >= numbers_2[j]:
+ result = True
+ elif op == '==':
+ if numbers_1[i] == numbers_2[j]:
+ result = True
+ elif op == '!=':
+ if numbers_1[i] != numbers_2[j]:
+ result = True
+
+ if result:
+ final_qualified_nodes_1a.append(qualified_nodes_1a[i])
+ final_qualified_nodes_1b.append(qualified_nodes_1b[i])
+ final_qualified_nodes_2a.append(qualified_nodes_2a[j])
+ final_qualified_nodes_2b.append(qualified_nodes_2b[j])
+ return result, final_qualified_nodes_1a, final_qualified_nodes_1b, final_qualified_nodes_2a, final_qualified_nodes_2b
+
+
+@numba.njit(cache=True)
+def _satisfies_threshold(num_neigh, num_qualified_component, threshold):
+ # Checks if qualified neighbors satisfy threshold. This is for one clause
+ if threshold[1][0]=='number':
+ if threshold[0]=='greater_equal':
+ result = True if num_qualified_component >= threshold[2] else False
+ elif threshold[0]=='greater':
+ result = True if num_qualified_component > threshold[2] else False
+ elif threshold[0]=='less_equal':
+ result = True if num_qualified_component <= threshold[2] else False
+ elif threshold[0]=='less':
+ result = True if num_qualified_component < threshold[2] else False
+ elif threshold[0]=='equal':
+ result = True if num_qualified_component == threshold[2] else False
+
+ elif threshold[1][0]=='percent':
+ if num_neigh==0:
+ result = False
+ elif threshold[0]=='greater_equal':
+ result = True if num_qualified_component/num_neigh >= threshold[2]*0.01 else False
+ elif threshold[0]=='greater':
+ result = True if num_qualified_component/num_neigh > threshold[2]*0.01 else False
+ elif threshold[0]=='less_equal':
+ result = True if num_qualified_component/num_neigh <= threshold[2]*0.01 else False
+ elif threshold[0]=='less':
+ result = True if num_qualified_component/num_neigh < threshold[2]*0.01 else False
+ elif threshold[0]=='equal':
+ result = True if num_qualified_component/num_neigh == threshold[2]*0.01 else False
+
+ return result
+
+
+@numba.njit(cache=True)
+def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False):
+ updated = False
+ # This is to prevent a key error in case the label is a specific label
+ try:
+ world = interpretations[comp]
+ l, bnd = na
+ updated_bnds = numba.typed.List.empty_list(interval.interval_type)
+
+ # Add label to world if it is not there
+ if l not in world.world:
+ world.world[l] = interval.closed(0, 1)
+ if l in predicate_map:
+ predicate_map[l].append(comp)
+ else:
+ predicate_map[l] = numba.typed.List([comp])
+
+ # Check if update is necessary with previous bnd
+ prev_bnd = world.world[l].copy()
+
+ # override will not check for inconsistencies
+ if override:
+ world.world[l].set_lower_upper(bnd.lower, bnd.upper)
+ else:
+ world.update(l, bnd)
+ world.world[l].set_static(static)
+ if world.world[l]!=prev_bnd:
+ updated = True
+ updated_bnds.append(world.world[l])
+
+ # Add to rule trace if update happened and add to atom trace if necessary
+ if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy()))
+ if atom_trace:
+ # Mode can be fact or rule, updation of trace will happen accordingly
+ if mode=='fact' or mode=='graph-attribute-fact':
+ qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ name = facts_to_be_applied_trace[idx]
+ _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
+ elif mode=='rule':
+ qn, qe, name = rules_to_be_applied_trace[idx]
+ _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
+
+ # Update complement of predicate (if exists) based on new knowledge of predicate
+ if updated:
+ ip_update_cnt = 0
+ for p1, p2 in ipl:
+ if p1 == l:
+ if p2 not in world.world:
+ world.world[p2] = interval.closed(0, 1)
+ if p2 in predicate_map:
+ predicate_map[p2].append(comp)
+ else:
+ predicate_map[p2] = numba.typed.List([comp])
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
+ lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
+ upper = min(world.world[p2].upper, 1 - world.world[p1].lower)
+ world.world[p2].set_lower_upper(lower, upper)
+ world.world[p2].set_static(static)
+ ip_update_cnt += 1
+ updated_bnds.append(world.world[p2])
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper)))
+ if p2 == l:
+ if p1 not in world.world:
+ world.world[p1] = interval.closed(0, 1)
+ if p1 in predicate_map:
+ predicate_map[p1].append(comp)
+ else:
+ predicate_map[p1] = numba.typed.List([comp])
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
+ lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
+ upper = min(world.world[p1].upper, 1 - world.world[p2].lower)
+ world.world[p1].set_lower_upper(lower, upper)
+ world.world[p1].set_static(static)
+ ip_update_cnt += 1
+ updated_bnds.append(world.world[p1])
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper)))
+
+ # Gather convergence data
+ change = 0
+ if updated:
+ # Find out if it has changed from previous interp
+ current_bnd = world.world[l]
+ prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper)
+ if current_bnd != prev_t_bnd:
+ if convergence_mode=='delta_bound':
+ for i in updated_bnds:
+ lower_delta = abs(i.lower-prev_t_bnd.lower)
+ upper_delta = abs(i.upper-prev_t_bnd.upper)
+ max_delta = max(lower_delta, upper_delta)
+ change = max(change, max_delta)
+ else:
+ change = 1 + ip_update_cnt
+
+ return (updated, change)
+
+ except:
+ return (False, 0)
+
+
+@numba.njit(cache=True)
+def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False):
+ updated = False
+ # This is to prevent a key error in case the label is a specific label
+ try:
+ world = interpretations[comp]
+ l, bnd = na
+ updated_bnds = numba.typed.List.empty_list(interval.interval_type)
+
+ # Add label to world if it is not there
+ if l not in world.world:
+ world.world[l] = interval.closed(0, 1)
+ if l in predicate_map:
+ predicate_map[l].append(comp)
+ else:
+ predicate_map[l] = numba.typed.List([comp])
+
+ # Check if update is necessary with previous bnd
+ prev_bnd = world.world[l].copy()
+
+ # override will not check for inconsistencies
+ if override:
+ world.world[l].set_lower_upper(bnd.lower, bnd.upper)
+ else:
+ world.update(l, bnd)
+ world.world[l].set_static(static)
+ if world.world[l]!=prev_bnd:
+ updated = True
+ updated_bnds.append(world.world[l])
+
+ # Add to rule trace if update happened and add to atom trace if necessary
+ if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy()))
+ if atom_trace:
+ # Mode can be fact or rule, updation of trace will happen accordingly
+ if mode=='fact' or mode=='graph-attribute-fact':
+ qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
+ qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
+ name = facts_to_be_applied_trace[idx]
+ _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
+ elif mode=='rule':
+ qn, qe, name = rules_to_be_applied_trace[idx]
+ _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name)
+
+ # Update complement of predicate (if exists) based on new knowledge of predicate
+ if updated:
+ ip_update_cnt = 0
+ for p1, p2 in ipl:
+ if p1 == l:
+ if p2 not in world.world:
+ world.world[p2] = interval.closed(0, 1)
+ if p2 in predicate_map:
+ predicate_map[p2].append(comp)
+ else:
+ predicate_map[p2] = numba.typed.List([comp])
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}')
+ lower = max(world.world[p2].lower, 1 - world.world[p1].upper)
+ upper = min(world.world[p2].upper, 1 - world.world[p1].lower)
+ world.world[p2].set_lower_upper(lower, upper)
+ world.world[p2].set_static(static)
+ ip_update_cnt += 1
+ updated_bnds.append(world.world[p2])
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper)))
+ if p2 == l:
+ if p1 not in world.world:
+ world.world[p1] = interval.closed(0, 1)
+ if p1 in predicate_map:
+ predicate_map[p1].append(comp)
+ else:
+ predicate_map[p1] = numba.typed.List([comp])
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}')
+ lower = max(world.world[p1].lower, 1 - world.world[p2].upper)
+ upper = min(world.world[p1].upper, 1 - world.world[p2].lower)
+ world.world[p1].set_lower_upper(lower, upper)
+ world.world[p1].set_static(static)
+ ip_update_cnt += 1
+ updated_bnds.append(world.world[p2])
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper)))
+
+ # Gather convergence data
+ change = 0
+ if updated:
+ # Find out if it has changed from previous interp
+ current_bnd = world.world[l]
+ prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper)
+ if current_bnd != prev_t_bnd:
+ if convergence_mode=='delta_bound':
+ for i in updated_bnds:
+ lower_delta = abs(i.lower-prev_t_bnd.lower)
+ upper_delta = abs(i.upper-prev_t_bnd.upper)
+ max_delta = max(lower_delta, upper_delta)
+ change = max(change, max_delta)
+ else:
+ change = 1 + ip_update_cnt
+
+ return (updated, change)
+ except:
+ return (False, 0)
+
+
+@numba.njit(cache=True)
+def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name):
+ rule_trace.append((qn, qe, prev_bnd.copy(), name))
+
+
+@numba.njit(cache=True)
+def are_satisfied_node(interpretations, comp, nas):
+ result = True
+ for (l, bnd) in nas:
+ result = result and is_satisfied_node(interpretations, comp, (l, bnd))
+ return result
+
+
+@numba.njit(cache=True)
+def is_satisfied_node(interpretations, comp, na):
+ result = False
+ if not (na[0] is None or na[1] is None):
+ # This is to prevent a key error in case the label is a specific label
+ try:
+ world = interpretations[comp]
+ result = world.is_satisfied(na[0], na[1])
+ except:
+ result = False
+ else:
+ result = True
+ return result
+
+
+@numba.njit(cache=True)
+def is_satisfied_node_comparison(interpretations, comp, na):
+ result = False
+ number = 0
+ l, bnd = na
+ l_str = l.value
+
+ if not (l is None or bnd is None):
+ # This is to prevent a key error in case the label is a specific label
+ try:
+ world = interpretations[comp]
+ for world_l in world.world.keys():
+ world_l_str = world_l.value
+ if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit():
+ # The label is contained in the world
+ result = world.is_satisfied(world_l, na[1])
+ # Find the suffix number
+ number = str_to_float(world_l_str[len(l_str)+1:])
+ break
+
+ except:
+ result = False
+ else:
+ result = True
+ return result, number
+
+
+@numba.njit(cache=True)
+def are_satisfied_edge(interpretations, comp, nas):
+ result = True
+ for (l, bnd) in nas:
+ result = result and is_satisfied_edge(interpretations, comp, (l, bnd))
+ return result
+
+
+@numba.njit(cache=True)
+def is_satisfied_edge(interpretations, comp, na):
+ result = False
+ if not (na[0] is None or na[1] is None):
+ # This is to prevent a key error in case the label is a specific label
+ try:
+ world = interpretations[comp]
+ result = world.is_satisfied(na[0], na[1])
+ except:
+ result = False
+ else:
+ result = True
+ return result
+
+
+@numba.njit(cache=True)
+def is_satisfied_edge_comparison(interpretations, comp, na):
+ result = False
+ number = 0
+ l, bnd = na
+ l_str = l.value
+
+ if not (l is None or bnd is None):
+ # This is to prevent a key error in case the label is a specific label
+ try:
+ world = interpretations[comp]
+ for world_l in world.world.keys():
+ world_l_str = world_l.value
+ if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit():
+ # The label is contained in the world
+ result = world.is_satisfied(world_l, na[1])
+ # Find the suffix number
+ number = str_to_float(world_l_str[len(l_str)+1:])
+ break
+
+ except:
+ result = False
+ else:
+ result = True
+ return result, number
+
+
+@numba.njit(cache=True)
+def annotate(annotation_functions, rule, annotations, weights):
+ func_name = rule.get_annotation_function()
+ if func_name == '':
+ return rule.get_bnd().lower, rule.get_bnd().upper
+ else:
+ with numba.objmode(annotation='Tuple((float64, float64))'):
+ for func in annotation_functions:
+ if func.__name__ == func_name:
+ annotation = func(annotations, weights)
+ return annotation
+
+
+@numba.njit(cache=True)
+def check_consistent_node(interpretations, comp, na):
+ world = interpretations[comp]
+ if na[0] in world.world:
+ bnd = world.world[na[0]]
+ else:
+ bnd = interval.closed(0, 1)
+ if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper):
+ return False
+ else:
+ return True
+
+
+@numba.njit(cache=True)
+def check_consistent_edge(interpretations, comp, na):
+ world = interpretations[comp]
+ if na[0] in world.world:
+ bnd = world.world[na[0]]
+ else:
+ bnd = interval.closed(0, 1)
+ if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper):
+ return False
+ else:
+ return True
+
+
+@numba.njit(cache=True)
+def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
+ world = interpretations[comp]
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1)))
+ if mode == 'fact' or mode == 'graph-attribute-fact':
+ name = facts_to_be_applied_trace[idx]
+ elif mode == 'rule':
+ name = rules_to_be_applied_trace[idx][2]
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}')
+ # Resolve inconsistency and set static
+ world.world[na[0]].set_lower_upper(0, 1)
+ world.world[na[0]].set_static(True)
+ for p1, p2 in ipl:
+ if p1==na[0]:
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}')
+ world.world[p2].set_lower_upper(0, 1)
+ world.world[p2].set_static(True)
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1)))
+
+ if p2==na[0]:
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}')
+ world.world[p1].set_lower_upper(0, 1)
+ world.world[p1].set_static(True)
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1)))
+ # Add inconsistent predicates to a list
+
+
+@numba.njit(cache=True)
+def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode):
+ w = interpretations[comp]
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1)))
+ if mode == 'fact' or mode == 'graph-attribute-fact':
+ name = facts_to_be_applied_trace[idx]
+ elif mode == 'rule':
+ name = rules_to_be_applied_trace[idx][2]
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}')
+ # Resolve inconsistency and set static
+ w.world[na[0]].set_lower_upper(0, 1)
+ w.world[na[0]].set_static(True)
+ for p1, p2 in ipl:
+ if p1==na[0]:
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}')
+ w.world[p2].set_lower_upper(0, 1)
+ w.world[p2].set_static(True)
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1)))
+
+ if p2==na[0]:
+ if atom_trace:
+ _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}')
+ w.world[p1].set_lower_upper(0, 1)
+ w.world[p1].set_static(True)
+ if store_interpretation_changes:
+ rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1)))
+
+
+@numba.njit(cache=True)
+def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node):
+ nodes.append(node)
+ neighbors[node] = numba.typed.List.empty_list(node_type)
+ reverse_neighbors[node] = numba.typed.List.empty_list(node_type)
+ interpretations_node[node] = world.World(numba.typed.List.empty_list(label.label_type))
+
+
+@numba.njit(cache=True)
+def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map):
+ # If not a node, add to list of nodes and initialize neighbors
+ if source not in nodes:
+ _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node)
+
+ if target not in nodes:
+ _add_node(target, neighbors, reverse_neighbors, nodes, interpretations_node)
+
+ # Make sure edge doesn't already exist
+ # Make sure, if l=='', not to add the label
+ # Make sure, if edge exists, that we don't override the l label if it exists
+ edge = (source, target)
+ new_edge = False
+ if edge not in edges:
+ new_edge = True
+ edges.append(edge)
+ neighbors[source].append(target)
+ reverse_neighbors[target].append(source)
+ if l.value!='':
+ interpretations_edge[edge] = world.World(numba.typed.List([l]))
+ if l in predicate_map:
+ predicate_map[l].append(edge)
+ else:
+ predicate_map[l] = numba.typed.List([edge])
+ else:
+ interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type))
+ else:
+ if l not in interpretations_edge[edge].world and l.value!='':
+ new_edge = True
+ interpretations_edge[edge].world[l] = interval.closed(0, 1)
+
+ return edge, new_edge
+
+
+@numba.njit(cache=True)
+def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map):
+ changes = 0
+ edges_added = numba.typed.List.empty_list(edge_type)
+ for source in sources:
+ for target in targets:
+ edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map)
+ edges_added.append(edge)
+ changes = changes+1 if new_edge else changes
+ return edges_added, changes
+
+
+@numba.njit(cache=True)
+def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map):
+ source, target = edge
+ edges.remove(edge)
+ del interpretations_edge[edge]
+ for l in predicate_map:
+ if edge in predicate_map[l]:
+ predicate_map[l].remove(edge)
+ neighbors[source].remove(target)
+ reverse_neighbors[target].remove(source)
+
+
+@numba.njit(cache=True)
+def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map):
+ nodes.remove(node)
+ del interpretations_node[node]
+ del neighbors[node]
+ del reverse_neighbors[node]
+ for l in predicate_map:
+ if node in predicate_map[l]:
+ predicate_map[l].remove(node)
+
+ # Remove all occurrences of node in neighbors
+ for n in neighbors.keys():
+ if node in neighbors[n]:
+ neighbors[n].remove(node)
+ for n in reverse_neighbors.keys():
+ if node in reverse_neighbors[n]:
+ reverse_neighbors[n].remove(node)
+
+
+@numba.njit(cache=True)
+def float_to_str(value):
+ number = int(value)
+ decimal = int(value % 1 * 1000)
+ float_str = f'{number}.{decimal}'
+ return float_str
+
+
+@numba.njit(cache=True)
+def str_to_float(value):
+ decimal_pos = value.find('.')
+ if decimal_pos != -1:
+ after_decimal_len = len(value[decimal_pos+1:])
+ else:
+ after_decimal_len = 0
+ value = value.replace('.', '')
+ value = str_to_int(value)
+ value = value / 10**after_decimal_len
+ return value
+
+
+@numba.njit(cache=True)
+def str_to_int(value):
+ if value[0] == '-':
+ negative = True
+ value = value.replace('-','')
+ else:
+ negative = False
+ final_index, result = len(value) - 1, 0
+ for i, v in enumerate(value):
+ result += (ord(v) - 48) * (10 ** (final_index - i))
+ result = -result if negative else result
+ return result
diff --git a/pyreason/scripts/interval/interval.py b/pyreason/scripts/interval/interval.py
index a6ab76e3..56252274 100755
--- a/pyreason/scripts/interval/interval.py
+++ b/pyreason/scripts/interval/interval.py
@@ -2,6 +2,7 @@
from numba import njit
import numpy as np
+
class Interval(structref.StructRefProxy):
def __new__(cls, l, u, s=False):
return structref.StructRefProxy.__new__(cls, l, u, s, l, u)
@@ -79,6 +80,9 @@ def __eq__(self, interval):
def __repr__(self):
return f'[{self.lower},{self.upper}]'
+ def __hash__(self):
+ return hash((self.lower, self.upper))
+
def __contains__(self, item):
if self.lower <= item.lower and self.upper >= item.upper:
return True
diff --git a/pyreason/scripts/numba_wrapper/numba_types/rule_type.py b/pyreason/scripts/numba_wrapper/numba_types/rule_type.py
index 970d7106..7da3bb9a 100755
--- a/pyreason/scripts/numba_wrapper/numba_types/rule_type.py
+++ b/pyreason/scripts/numba_wrapper/numba_types/rule_type.py
@@ -32,8 +32,8 @@ def typeof_rule(val, c):
# Construct object from Numba functions (Doesn't work. We don't need this currently)
@type_callable(Rule)
def type_rule(context):
- def typer(rule_name, type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule):
- if isinstance(rule_name, types.UnicodeType) and isinstance(type, types.UnicodeType) and isinstance(target, label.LabelType) and isinstance(delta, types.Integer) and isinstance(clauses, (types.NoneType, types.ListType)) and isinstance(bnd, interval.IntervalType) and isinstance(thresholds, types.ListType) and isinstance(ann_fn, types.UnicodeType) and isinstance(weights, types.Array) and isinstance(edges, types.Tuple) and isinstance(static, types.Boolean) and isinstance(immediate_rule, types.Boolean):
+ def typer(rule_name, type, target, head_variables, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static):
+ if isinstance(rule_name, types.UnicodeType) and isinstance(type, types.UnicodeType) and isinstance(target, label.LabelType) and isinstance(head_variables, types.ListType) and isinstance(delta, types.Integer) and isinstance(clauses, (types.NoneType, types.ListType)) and isinstance(bnd, interval.IntervalType) and isinstance(thresholds, types.ListType) and isinstance(ann_fn, types.UnicodeType) and isinstance(weights, types.Array) and isinstance(edges, types.Tuple) and isinstance(static, types.Boolean):
return rule_type
return typer
@@ -46,6 +46,7 @@ def __init__(self, dmm, fe_type):
('rule_name', types.string),
('type', types.string),
('target', label.label_type),
+ ('head_variables', types.ListType(types.string)),
('delta', types.uint16),
('clauses', types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string)))),
('bnd', interval.interval_type),
@@ -54,7 +55,6 @@ def __init__(self, dmm, fe_type):
('weights', types.float64[::1]),
('edges', types.Tuple((types.string, types.string, label.label_type))),
('static', types.boolean),
- ('immediate_rule', types.boolean)
]
models.StructModel.__init__(self, dmm, fe_type, members)
@@ -63,6 +63,7 @@ def __init__(self, dmm, fe_type):
make_attribute_wrapper(RuleType, 'rule_name', 'rule_name')
make_attribute_wrapper(RuleType, 'type', 'type')
make_attribute_wrapper(RuleType, 'target', 'target')
+make_attribute_wrapper(RuleType, 'head_variables', 'head_variables')
make_attribute_wrapper(RuleType, 'delta', 'delta')
make_attribute_wrapper(RuleType, 'clauses', 'clauses')
make_attribute_wrapper(RuleType, 'bnd', 'bnd')
@@ -71,20 +72,21 @@ def __init__(self, dmm, fe_type):
make_attribute_wrapper(RuleType, 'weights', 'weights')
make_attribute_wrapper(RuleType, 'edges', 'edges')
make_attribute_wrapper(RuleType, 'static', 'static')
-make_attribute_wrapper(RuleType, 'immediate_rule', 'immediate_rule')
# Implement constructor
-@lower_builtin(Rule, types.string, types.string, label.label_type, types.uint16, types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), interval.interval_type, types.ListType(types.ListType(types.Tuple((types.string, types.string, types.float64)))), types.string, types.float64[::1], types.Tuple((types.string, types.string, label.label_type)), types.boolean, types.boolean)
+@lower_builtin(Rule, types.string, types.string, label.label_type, types.ListType(types.string), types.uint16, types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), interval.interval_type, types.ListType(types.ListType(types.Tuple((types.string, types.string, types.float64)))), types.string, types.float64[::1], types.Tuple((types.string, types.string, label.label_type)), types.boolean, types.boolean)
def impl_rule(context, builder, sig, args):
typ = sig.return_type
- rule_name, type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule = args
+ rule_name, type, target, head_variables, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static = args
+ context.nrt.incref(builder, types.ListType(types.string), head_variables)
context.nrt.incref(builder, types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), clauses)
context.nrt.incref(builder, types.ListType(types.Tuple((types.string, types.UniTuple(types.string, 2), types.float64))), thresholds)
rule = cgutils.create_struct_proxy(typ)(context, builder)
rule.rule_name = rule_name
rule.type = type
rule.target = target
+ rule.head_variables = head_variables
rule.delta = delta
rule.clauses = clauses
rule.bnd = bnd
@@ -93,7 +95,6 @@ def impl_rule(context, builder, sig, args):
rule.weights = weights
rule.edges = edges
rule.static = static
- rule.immediate_rule = immediate_rule
return rule._getvalue()
@@ -119,6 +120,13 @@ def getter(rule):
return getter
+@overload_method(RuleType, "get_head_variables")
+def get_head_variables(rule):
+ def getter(rule):
+ return rule.head_variables
+ return getter
+
+
@overload_method(RuleType, "get_delta")
def get_delta(rule):
def getter(rule):
@@ -133,6 +141,13 @@ def getter(rule):
return getter
+@overload_method(RuleType, "set_clauses")
+def set_clauses(rule):
+ def setter(rule, clauses):
+ rule.clauses = clauses
+ return setter
+
+
@overload_method(RuleType, "get_bnd")
def get_bnd(rule):
def impl(rule):
@@ -175,19 +190,13 @@ def impl(rule):
return impl
-@overload_method(RuleType, "is_immediate_rule")
-def is_immediate_rule(rule):
- def impl(rule):
- return rule.immediate_rule
- return impl
-
-
# Tell numba how to make native
@unbox(RuleType)
def unbox_rule(typ, obj, c):
name_obj = c.pyapi.object_getattr_string(obj, "_rule_name")
type_obj = c.pyapi.object_getattr_string(obj, "_type")
target_obj = c.pyapi.object_getattr_string(obj, "_target")
+ head_variables_obj = c.pyapi.object_getattr_string(obj, "_head_variables")
delta_obj = c.pyapi.object_getattr_string(obj, "_delta")
clauses_obj = c.pyapi.object_getattr_string(obj, "_clauses")
bnd_obj = c.pyapi.object_getattr_string(obj, "_bnd")
@@ -196,11 +205,11 @@ def unbox_rule(typ, obj, c):
weights_obj = c.pyapi.object_getattr_string(obj, "_weights")
edges_obj = c.pyapi.object_getattr_string(obj, "_edges")
static_obj = c.pyapi.object_getattr_string(obj, "_static")
- immediate_rule_obj = c.pyapi.object_getattr_string(obj, "_immediate_rule")
rule = cgutils.create_struct_proxy(typ)(c.context, c.builder)
rule.rule_name = c.unbox(types.string, name_obj).value
rule.type = c.unbox(types.string, type_obj).value
rule.target = c.unbox(label.label_type, target_obj).value
+ rule.head_variables = c.unbox(types.ListType(types.string), head_variables_obj).value
rule.delta = c.unbox(types.uint16, delta_obj).value
rule.clauses = c.unbox(types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), clauses_obj).value
rule.bnd = c.unbox(interval.interval_type, bnd_obj).value
@@ -209,10 +218,10 @@ def unbox_rule(typ, obj, c):
rule.weights = c.unbox(types.float64[::1], weights_obj).value
rule.edges = c.unbox(types.Tuple((types.string, types.string, label.label_type)), edges_obj).value
rule.static = c.unbox(types.boolean, static_obj).value
- rule.immediate_rule = c.unbox(types.boolean, immediate_rule_obj).value
c.pyapi.decref(name_obj)
c.pyapi.decref(type_obj)
c.pyapi.decref(target_obj)
+ c.pyapi.decref(head_variables_obj)
c.pyapi.decref(delta_obj)
c.pyapi.decref(clauses_obj)
c.pyapi.decref(bnd_obj)
@@ -221,7 +230,6 @@ def unbox_rule(typ, obj, c):
c.pyapi.decref(weights_obj)
c.pyapi.decref(edges_obj)
c.pyapi.decref(static_obj)
- c.pyapi.decref(immediate_rule_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(rule._getvalue(), is_error=is_error)
@@ -233,6 +241,7 @@ def box_rule(typ, val, c):
name_obj = c.box(types.string, rule.rule_name)
type_obj = c.box(types.string, rule.type)
target_obj = c.box(label.label_type, rule.target)
+ head_variables_obj = c.box(types.ListType(types.string), rule.head_variables)
delta_obj = c.box(types.uint16, rule.delta)
clauses_obj = c.box(types.ListType(types.Tuple((types.string, label.label_type, types.ListType(types.string), interval.interval_type, types.string))), rule.clauses)
bnd_obj = c.box(interval.interval_type, rule.bnd)
@@ -241,11 +250,11 @@ def box_rule(typ, val, c):
weights_obj = c.box(types.float64[::1], rule.weights)
edges_obj = c.box(types.Tuple((types.string, types.string, label.label_type)), rule.edges)
static_obj = c.box(types.boolean, rule.static)
- immediate_rule_obj = c.box(types.boolean, rule.immediate_rule)
- res = c.pyapi.call_function_objargs(class_obj, (name_obj, type_obj, target_obj, delta_obj, clauses_obj, bnd_obj, thresholds_obj, ann_fn_obj, weights_obj, edges_obj, static_obj, immediate_rule_obj))
+ res = c.pyapi.call_function_objargs(class_obj, (name_obj, type_obj, target_obj, head_variables_obj, delta_obj, clauses_obj, bnd_obj, thresholds_obj, ann_fn_obj, weights_obj, edges_obj, static_obj))
c.pyapi.decref(name_obj)
c.pyapi.decref(type_obj)
c.pyapi.decref(target_obj)
+ c.pyapi.decref(head_variables_obj)
c.pyapi.decref(delta_obj)
c.pyapi.decref(clauses_obj)
c.pyapi.decref(ann_fn_obj)
@@ -254,6 +263,5 @@ def box_rule(typ, val, c):
c.pyapi.decref(weights_obj)
c.pyapi.decref(edges_obj)
c.pyapi.decref(static_obj)
- c.pyapi.decref(immediate_rule_obj)
c.pyapi.decref(class_obj)
return res
diff --git a/pyreason/scripts/program/program.py b/pyreason/scripts/program/program.py
index 4c992ae9..f8ca8413 100755
--- a/pyreason/scripts/program/program.py
+++ b/pyreason/scripts/program/program.py
@@ -3,12 +3,10 @@
class Program:
- available_labels_node = []
- available_labels_edge = []
specific_node_labels = []
specific_edge_labels = []
- def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, parallel_computing, update_mode):
+ def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, parallel_computing, update_mode, allow_ground_rules):
self._graph = graph
self._facts_node = facts_node
self._facts_edge = facts_edge
@@ -23,21 +21,20 @@ def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functio
self._store_interpretation_changes = store_interpretation_changes
self._parallel_computing = parallel_computing
self._update_mode = update_mode
+ self._allow_ground_rules = allow_ground_rules
self.interp = None
def reason(self, tmax, convergence_threshold, convergence_bound_threshold, verbose=True):
self._tmax = tmax
# Set up available labels
- Interpretation.available_labels_node = self.available_labels_node
- Interpretation.available_labels_edge = self.available_labels_edge
Interpretation.specific_node_labels = self.specific_node_labels
Interpretation.specific_edge_labels = self.specific_edge_labels
# Instantiate correct interpretation class based on whether we parallelize the code or not. (We cannot parallelize with cache on)
if self._parallel_computing:
- self.interp = InterpretationParallel(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode)
+ self.interp = InterpretationParallel(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_rules)
else:
- self.interp = Interpretation(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode)
+ self.interp = Interpretation(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_rules)
self.interp.start_fp(self._tmax, self._facts_node, self._facts_edge, self._rules, verbose, convergence_threshold, convergence_bound_threshold)
return self.interp
diff --git a/pyreason/scripts/query/__init__.py b/pyreason/scripts/query/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pyreason/scripts/query/query.py b/pyreason/scripts/query/query.py
new file mode 100644
index 00000000..b5071353
--- /dev/null
+++ b/pyreason/scripts/query/query.py
@@ -0,0 +1,36 @@
+from pyreason.scripts.utils.query_parser import parse_query
+
+
+class Query:
+ def __init__(self, query_text: str):
+ """
+ PyReason query object which is parsed from a string of form:
+ `pred(node)` or `pred(edge)` or `pred(node) : [l, u]`
+ If bounds are not specified, they are set to [1, 1] by default. A tilde `~` before the predicate means that the bounds
+ are inverted, i.e. [0, 0] for [1, 1] and vice versa.
+
+ Queries can be used to analyze the graph and extract information about the graph after the reasoning process.
+ Queries can also be used as input to the reasoner to filter the ruleset based which rules are applicable to the query.
+
+ :param query_text: The query string of form described above
+ """
+ self.__pred, self.__component, self.__comp_type, self.__bnd = parse_query(query_text)
+ self.query_text = query_text
+
+ def get_predicate(self):
+ return self.__pred
+
+ def get_component(self):
+ return self.__component
+
+ def get_component_type(self):
+ return self.__comp_type
+
+ def get_bounds(self):
+ return self.__bnd
+
+ def __str__(self):
+ return self.query_text
+
+ def __repr__(self):
+ return self.query_text
diff --git a/pyreason/scripts/rules/rule.py b/pyreason/scripts/rules/rule.py
index 73824c4d..7055bb01 100755
--- a/pyreason/scripts/rules/rule.py
+++ b/pyreason/scripts/rules/rule.py
@@ -7,16 +7,16 @@ class Rule:
`'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'`
1. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
- TODO: Add weights as a parameter
"""
- def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None):
+ def __init__(self, rule_text: str, name: str = None, infer_edges: bool = False, set_static: bool = False, custom_thresholds=None, weights=None):
"""
:param rule_text: The rule in text format
:param name: The name of the rule. This will appear in the rule trace
:param infer_edges: Whether to infer new edges after edge rule fires
:param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change
- :param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied
+ :param custom_thresholds: A list of custom thresholds for the rule. If not specified, the default thresholds for ANY are used. It can be a list of
+ size #of clauses or a map of clause index to threshold
+ :param weights: A list of weights for the rule clauses. This is passed to an annotation function. If not specified,
+ the weights array is a list of 1s with the length as number of clauses.
"""
- if custom_thresholds is None:
- custom_thresholds = []
- self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, immediate_rule)
+ self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, weights)
diff --git a/pyreason/scripts/rules/rule_internal.py b/pyreason/scripts/rules/rule_internal.py
index 69de3ca3..e4962008 100755
--- a/pyreason/scripts/rules/rule_internal.py
+++ b/pyreason/scripts/rules/rule_internal.py
@@ -1,9 +1,10 @@
class Rule:
- def __init__(self, rule_name, rule_type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule):
+ def __init__(self, rule_name, rule_type, target, head_variables, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static):
self._rule_name = rule_name
self._type = rule_type
self._target = target
+ self._head_variables = head_variables
self._delta = delta
self._clauses = clauses
self._bnd = bnd
@@ -12,28 +13,39 @@ def __init__(self, rule_name, rule_type, target, delta, clauses, bnd, thresholds
self._weights = weights
self._edges = edges
self._static = static
- self._immediate_rule = immediate_rule
def get_rule_name(self):
return self._rule_name
+ def set_rule_name(self, rule_name):
+ self._rule_name = rule_name
+
def get_rule_type(self):
return self._type
def get_target(self):
return self._target
+ def get_head_variables(self):
+ return self._head_variables
+
def get_delta(self):
return self._delta
- def get_neigh_criteria(self):
+ def get_clauses(self):
return self._clauses
+
+ def set_clauses(self, clauses):
+ self._clauses = clauses
def get_bnd(self):
return self._bnd
def get_thresholds(self):
- return self._thresholds
+ return self._thresholds
+
+ def set_thresholds(self, thresholds):
+ self._thresholds = thresholds
def get_annotation_function(self):
return self._ann_fn
@@ -41,8 +53,30 @@ def get_annotation_function(self):
def get_edges(self):
return self._edges
+ def get_weights(self):
+ return self._weights
+
def is_static(self):
return self._static
- def is_immediate_rule(self):
- return self._immediate_rule
+ def __eq__(self, other):
+ if not isinstance(other, Rule):
+ return False
+ clause_eq = []
+ other_clause_eq = []
+ for c in self._clauses:
+ clause_eq.append((c[0], c[1], tuple(c[2]), c[3], c[4]))
+ for c in other.get_clauses():
+ other_clause_eq.append((c[0], c[1], tuple(c[2]), c[3], c[4]))
+ if self._rule_name == other.get_rule_name() and self._type == other.get_rule_type() and self._target == other.get_target() and self._head_variables == other.get_head_variables() and self._delta == other.get_delta() and tuple(clause_eq) == tuple(other_clause_eq) and self._bnd == other.get_bnd():
+ return True
+ else:
+ return False
+
+ def __hash__(self):
+ clause_hashes = []
+ for c in self._clauses:
+ clause_hash = (c[0], c[1], tuple(c[2]), c[3], c[4])
+ clause_hashes.append(clause_hash)
+
+ return hash((self._rule_name, self._type, self._target.get_value(), *self._head_variables, self._delta, *clause_hashes, self._bnd))
diff --git a/pyreason/scripts/threshold/threshold.py b/pyreason/scripts/threshold/threshold.py
index 39722631..1a4ee646 100644
--- a/pyreason/scripts/threshold/threshold.py
+++ b/pyreason/scripts/threshold/threshold.py
@@ -38,4 +38,4 @@ def to_tuple(self):
Returns:
tuple: A tuple representation of the Threshold instance.
"""
- return (self.quantifier, self.quantifier_type, self.thresh)
\ No newline at end of file
+ return self.quantifier, self.quantifier_type, self.thresh
diff --git a/pyreason/scripts/utils/fact_parser.py b/pyreason/scripts/utils/fact_parser.py
new file mode 100644
index 00000000..6b3c922c
--- /dev/null
+++ b/pyreason/scripts/utils/fact_parser.py
@@ -0,0 +1,40 @@
+import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+
+
+def parse_fact(fact_text):
+ f = fact_text.replace(' ', '')
+
+ # Separate into predicate-component and bound. If there is no bound it means it's true
+ if ':' in f:
+ pred_comp, bound = f.split(':')
+ else:
+ pred_comp = f
+ if pred_comp[0] == '~':
+ bound = 'False'
+ pred_comp = pred_comp[1:]
+ else:
+ bound = 'True'
+
+ # Check if bound is a boolean or a list of floats
+ bound = bound.lower()
+ if bound == 'true':
+ bound = interval.closed(1, 1)
+ elif bound == 'false':
+ bound = interval.closed(0, 0)
+ else:
+ bound = [float(b) for b in bound[1:-1].split(',')]
+ bound = interval.closed(*bound)
+
+ # Split the predicate and component
+ idx = pred_comp.find('(')
+ pred = pred_comp[:idx]
+ component = pred_comp[idx + 1:-1]
+
+ # Check if it is a node or edge fact
+ if ',' in component:
+ fact_type = 'edge'
+ component = tuple(component.split(','))
+ else:
+ fact_type = 'node'
+
+ return pred, component, bound, fact_type
diff --git a/pyreason/scripts/utils/filter_ruleset.py b/pyreason/scripts/utils/filter_ruleset.py
new file mode 100644
index 00000000..848e4c75
--- /dev/null
+++ b/pyreason/scripts/utils/filter_ruleset.py
@@ -0,0 +1,34 @@
+def filter_ruleset(queries, rules):
+ """
+ Filter the ruleset based on the queries provided.
+
+ :param queries: List of Query objects
+ :param rules: List of Rule objects
+ :return: List of Rule objects that are applicable to the queries
+ """
+
+ # Helper function to collect all rules that can support making a given rule true
+ def applicable_rules_from_query(query):
+ # Start with rules that match the query directly
+ applicable = []
+
+ for rule in rules:
+ # If the rule's target matches the query
+ if query == rule.get_target():
+ # Add the rule to the applicable set
+ applicable.append(rule)
+ # Recursively check rules that can lead up to this rule
+ for clause in rule.get_clauses():
+ # Find supporting rules with the clause as the target
+ supporting_rules = applicable_rules_from_query(clause[1])
+ applicable.extend(supporting_rules)
+
+ return applicable
+
+ # Collect applicable rules for each query and eliminate duplicates
+ filtered_rules = []
+ for q in queries:
+ filtered_rules.extend(applicable_rules_from_query(q.get_predicate()))
+
+ # Use set to avoid duplicates if a rule supports multiple queries
+ return list(set(filtered_rules))
diff --git a/pyreason/scripts/utils/output.py b/pyreason/scripts/utils/output.py
index e0d12f96..e680083d 100755
--- a/pyreason/scripts/utils/output.py
+++ b/pyreason/scripts/utils/output.py
@@ -4,8 +4,9 @@
class Output:
- def __init__(self, timestamp):
+ def __init__(self, timestamp, clause_map=None):
self.timestamp = timestamp
+ self.clause_map = clause_map
self.rule_trace_node = None
self.rule_trace_edge = None
@@ -80,6 +81,14 @@ def _parse_internal_rule_trace(self, interpretation):
# Store the trace in a DataFrame
self.rule_trace_edge = pd.DataFrame(data, columns=header_edge)
+ # Now do the reordering
+ if self.clause_map is not None:
+ offset = 7
+ columns_to_reorder_node = header_node[offset:]
+ columns_to_reorder_edge = header_edge[offset:]
+ self.rule_trace_node = self.rule_trace_node.apply(self._reorder_row, axis=1, map_dict=self.clause_map, columns_to_reorder=columns_to_reorder_node)
+ self.rule_trace_edge = self.rule_trace_edge.apply(self._reorder_row, axis=1, map_dict=self.clause_map, columns_to_reorder=columns_to_reorder_edge)
+
def save_rule_trace(self, interpretation, folder='./'):
if self.rule_trace_node is None and self.rule_trace_edge is None:
self._parse_internal_rule_trace(interpretation)
@@ -94,3 +103,14 @@ def get_rule_trace(self, interpretation):
self._parse_internal_rule_trace(interpretation)
return self.rule_trace_node, self.rule_trace_edge
+
+ @staticmethod
+ def _reorder_row(row, map_dict, columns_to_reorder):
+ if row['Occurred Due To'] in map_dict:
+ original_values = row[columns_to_reorder].values
+ new_values = [None] * len(columns_to_reorder)
+ for orig_pos, target_pos in map_dict[row['Occurred Due To']].items():
+ new_values[target_pos] = original_values[orig_pos]
+ for i, col in enumerate(columns_to_reorder):
+ row[col] = new_values[i]
+ return row
diff --git a/pyreason/scripts/utils/query_parser.py b/pyreason/scripts/utils/query_parser.py
new file mode 100644
index 00000000..7c5bfdb8
--- /dev/null
+++ b/pyreason/scripts/utils/query_parser.py
@@ -0,0 +1,34 @@
+import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+import pyreason.scripts.numba_wrapper.numba_types.label_type as label
+
+
+def parse_query(query: str):
+ query = query.replace(' ', '')
+
+ if ':' in query:
+ pred_comp, bounds = query.split(':')
+ bounds = bounds.replace('[', '').replace(']', '')
+ l, u = bounds.split(',')
+ l, u = float(l), float(u)
+ else:
+ if query[0] == '~':
+ pred_comp = query[1:]
+ l, u = 0, 0
+ else:
+ pred_comp = query
+ l, u = 1, 1
+
+ bnd = interval.closed(l, u)
+
+ # Split predicate and component
+ idx = pred_comp.find('(')
+ pred = label.Label(pred_comp[:idx])
+ component = pred_comp[idx + 1:-1]
+
+ if ',' in component:
+ component = tuple(component.split(','))
+ comp_type = 'edge'
+ else:
+ comp_type = 'node'
+
+ return pred, component, comp_type, bnd
diff --git a/pyreason/scripts/utils/reorder_clauses.py b/pyreason/scripts/utils/reorder_clauses.py
new file mode 100644
index 00000000..11408ffb
--- /dev/null
+++ b/pyreason/scripts/utils/reorder_clauses.py
@@ -0,0 +1,30 @@
+import numba
+import pyreason.scripts.numba_wrapper.numba_types.label_type as label
+import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+
+
+def reorder_clauses(rule):
+ # Go through all clauses in the rule and re-order them if necessary
+ # It is faster for grounding to have node clauses first and then edge clauses
+ # Move all the node clauses to the front of the list
+ reordered_clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string)))
+ reordered_thresholds = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, numba.types.UniTuple(numba.types.string, 2), numba.types.float64)))
+ node_clauses = []
+ edge_clauses = []
+ reordered_clauses_map = {}
+
+ for index, clause in enumerate(rule.get_clauses()):
+ if clause[0] == 'node':
+ node_clauses.append((index, clause))
+ else:
+ edge_clauses.append((index, clause))
+
+ thresholds = rule.get_thresholds()
+ for new_index, (original_index, clause) in enumerate(node_clauses + edge_clauses):
+ reordered_clauses.append(clause)
+ reordered_thresholds.append(thresholds[original_index])
+ reordered_clauses_map[new_index] = original_index
+
+ rule.set_clauses(reordered_clauses)
+ rule.set_thresholds(reordered_thresholds)
+ return rule, reordered_clauses_map
diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py
index 36bdede0..dc1e7728 100644
--- a/pyreason/scripts/utils/rule_parser.py
+++ b/pyreason/scripts/utils/rule_parser.py
@@ -1,12 +1,15 @@
import numba
import numpy as np
+from typing import Union
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
+# import pyreason.scripts.rules.rule_internal as rule
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
+from pyreason.scripts.threshold.threshold import Threshold
-def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
+def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, dict], infer_edges: bool = False, set_static: bool = False, weights: Union[None, np.ndarray] = None) -> rule.Rule:
# First remove all spaces from line
r = rule_text.replace(' ', '')
@@ -33,7 +36,7 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
# 2. replace ) by )) and ] by ]] so that we can split without damaging the string
# 3. Split with ), and then for each element of list, split with ], and add to new list
# 4. Then replace ]] with ] and )) with ) in for loop
- # 5. Add :[1,1] to the end of each element if a bound is not specified
+ # 5. Add :[1,1] or :[0,0] to the end of each element if a bound is not specified
# 6. Then split each element with :
# 7. Transform bound strings into pr.intervals
@@ -54,7 +57,9 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
# 5
for i in range(len(split_body)):
- if split_body[i][-1] != ']':
+ if split_body[i][0] == '~':
+ split_body[i] = split_body[i][1:] + ':[0,0]'
+ elif split_body[i][-1] != ']':
split_body[i] += ':[1,1]'
# 6
@@ -65,6 +70,14 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
body_clauses.append(clause)
body_bounds.append(bound)
+ # Check if there are custom thresholds for the rule such as forall in string form
+ for i, b in enumerate(body_clauses.copy()):
+ if 'forall(' in b:
+ if not custom_thresholds:
+ custom_thresholds = {}
+ custom_thresholds[i] = Threshold("greater_equal", ("percent", "total"), 100)
+ body_clauses[i] = b[:-1].replace('forall(', '')
+
# 7
for i in range(len(body_bounds)):
bound = body_bounds[i]
@@ -79,7 +92,10 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
# This means there is no bound or annotation function specified
if head[-1] == ')':
- head += ':[1,1]'
+ if head[0] == '~':
+ head = head[1:] + ':[0,0]'
+ else:
+ head += ':[1,1]'
head, head_bound = head.split(':')
# Check if we have a bound or annotation function
@@ -123,25 +139,6 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
if rule_type == 'node':
infer_edges = False
- # Replace the variables in the body with source/target if they match the variables in the head
- # If infer_edges is true, then we consider all rules to be node rules, we infer the 2nd variable of the target predicate from the rule body
- # Else we consider the rule to be an edge rule and replace variables with source/target
- # Node rules with possibility of adding edges
- if infer_edges or len(head_variables) == 1:
- head_source_variable = head_variables[0]
- for i in range(len(body_variables)):
- for j in range(len(body_variables[i])):
- if body_variables[i][j] == head_source_variable:
- body_variables[i][j] = '__target'
- # Edge rule, no edges to be added
- elif len(head_variables) == 2:
- for i in range(len(body_variables)):
- for j in range(len(body_variables[i])):
- if body_variables[i][j] == head_variables[0]:
- body_variables[i][j] = '__source'
- elif body_variables[i][j] == head_variables[1]:
- body_variables[i][j] = '__target'
-
# Start setting up clauses
# clauses = [c1, c2, c3, c4]
# thresholds = [t1, t2, t3, t4]
@@ -155,18 +152,25 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
# gather count of clauses for threshold validation
num_clauses = len(body_clauses)
- if custom_thresholds and (len(custom_thresholds) != num_clauses):
- raise Exception('The length of custom thresholds {} is not equal to number of clauses {}'
- .format(len(custom_thresholds), num_clauses))
-
+ if isinstance(custom_thresholds, list):
+ if len(custom_thresholds) != num_clauses:
+ raise Exception(f'The length of custom thresholds {len(custom_thresholds)} is not equal to number of clauses {num_clauses}')
+ for threshold in custom_thresholds:
+ thresholds.append(threshold.to_tuple())
+ elif isinstance(custom_thresholds, dict):
+ if max(custom_thresholds.keys()) >= num_clauses:
+ raise Exception(f'The max clause index in the custom thresholds map {max(custom_thresholds.keys())} is greater than number of clauses {num_clauses}')
+ for i in range(num_clauses):
+ if i in custom_thresholds:
+ thresholds.append(custom_thresholds[i].to_tuple())
+ else:
+ thresholds.append(('greater_equal', ('number', 'total'), 1.0))
+
# If no custom thresholds provided, use defaults
# otherwise loop through user-defined thresholds and convert to numba compatible format
- if not custom_thresholds:
+ elif not custom_thresholds:
for _ in range(num_clauses):
thresholds.append(('greater_equal', ('number', 'total'), 1.0))
- else:
- for threshold in custom_thresholds:
- thresholds.append(threshold.to_tuple())
# # Loop though clauses
for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds):
@@ -184,15 +188,20 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges:
# Assert that there are two variables in the head of the rule if we infer edges
# Add edges between head variables if necessary
if infer_edges:
- var = '__target' if head_variables[0] == head_variables[1] else head_variables[1]
- edges = ('__target', var, target)
+ # var = '__target' if head_variables[0] == head_variables[1] else head_variables[1]
+ # edges = ('__target', var, target)
+ edges = (head_variables[0], head_variables[1], target)
else:
edges = ('', '', label.Label(''))
- weights = np.ones(len(body_predicates), dtype=np.float64)
- weights = np.append(weights, 0)
+ if weights is None:
+ weights = np.ones(len(body_predicates), dtype=np.float64)
+ elif len(weights) != len(body_predicates):
+ raise Exception(f'Number of weights {len(weights)} is not equal to number of clauses {len(body_predicates)}')
+
+ head_variables = numba.typed.List(head_variables)
- r = rule.Rule(name, rule_type, target, numba.types.uint16(t), clauses, target_bound, thresholds, ann_fn, weights, edges, set_static, immediate_rule)
+ r = rule.Rule(name, rule_type, target, head_variables, numba.types.uint16(t), clauses, target_bound, thresholds, ann_fn, weights, edges, set_static)
return r
diff --git a/requirements.txt b/requirements.txt
index b54fb593..25f9bc11 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,7 @@ numba==0.59.1
numpy==1.26.4
memory_profiler
pytest
+setuptools_scm
sphinx_rtd_theme
sphinx
diff --git a/rule_trace_edges_20241119-012005.csv b/rule_trace_edges_20241119-012005.csv
new file mode 100644
index 00000000..1b0030fe
--- /dev/null
+++ b/rule_trace_edges_20241119-012005.csv
@@ -0,0 +1 @@
+Time,Fixed-Point-Operation,Edge,Label,Old Bound,New Bound,Occurred Due To
diff --git a/rule_trace_edges_20241125-114246.csv b/rule_trace_edges_20241125-114246.csv
new file mode 100644
index 00000000..1b0030fe
--- /dev/null
+++ b/rule_trace_edges_20241125-114246.csv
@@ -0,0 +1 @@
+Time,Fixed-Point-Operation,Edge,Label,Old Bound,New Bound,Occurred Due To
diff --git a/setup.py b/setup.py
index 5cc9b95f..8705509e 100644
--- a/setup.py
+++ b/setup.py
@@ -4,11 +4,11 @@
from pathlib import Path
this_directory = Path(__file__).parent
-long_description = (this_directory / "README.md").read_text()
+long_description = (this_directory / "README.md").read_text(encoding='UTF-8')
setup(
name='pyreason',
- version='2.3.0',
+ version='3.0.0',
author='Dyuman Aditya',
author_email='dyuman.aditya@gmail.com',
description='An explainable inference software supporting annotated, real valued, graph based and temporal logic',
@@ -35,6 +35,8 @@
'memory_profiler',
'pytest'
],
+ use_scm_version=True,
+ setup_requires=['setuptools_scm'],
packages=find_packages(),
include_package_data=True
)
diff --git a/tests/group_chat_graph.graphml b/tests/group_chat_graph.graphml
index 7c76e29b..852d05cf 100644
--- a/tests/group_chat_graph.graphml
+++ b/tests/group_chat_graph.graphml
@@ -1,26 +1,23 @@
-
-
-
+
+
-
- 1
+ 1
- 1
-
-
-
- 1
+ 1
- 1
+ 1
+
+
+ 1
diff --git a/tests/knowledge_graph_test_subset.graphml b/tests/knowledge_graph_test_subset.graphml
new file mode 100644
index 00000000..72e5c23b
--- /dev/null
+++ b/tests/knowledge_graph_test_subset.graphml
@@ -0,0 +1,71 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+ 1
+
+
+
+
\ No newline at end of file
diff --git a/tests/test_annotation_function.py b/tests/test_annotation_function.py
new file mode 100644
index 00000000..12522c73
--- /dev/null
+++ b/tests/test_annotation_function.py
@@ -0,0 +1,36 @@
+# Test if annotation functions work
+import pyreason as pr
+import numba
+import numpy as np
+
+
+@numba.njit
+def probability_func(annotations, weights):
+ prob_A = annotations[0][0].lower
+ prob_B = annotations[1][0].lower
+ union_prob = prob_A + prob_B
+ union_prob = np.round(union_prob, 3)
+ return union_prob, 1
+
+
+def test_annotation_function():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
+ pr.settings.allow_ground_rules = True
+
+ pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
+ pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))
+ pr.add_annotation_function(probability_func)
+ pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
+
+ interpretation = pr.reason(timesteps=1)
+
+ dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ assert interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')), 'Union probability should be 0.21'
diff --git a/tests/test_anyBurl_infer_edges_rules.py b/tests/test_anyBurl_infer_edges_rules.py
new file mode 100644
index 00000000..0ae5df74
--- /dev/null
+++ b/tests/test_anyBurl_infer_edges_rules.py
@@ -0,0 +1,139 @@
+import pyreason as pr
+
+
+def test_anyBurl_rule_1():
+ graph_path = './tests/knowledge_graph_test_subset.graphml'
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.memory_profile = False
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+ pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+ # pr.save_rule_trace(interpretation)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+ assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations'
+ assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
+ assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
+
+
+def test_anyBurl_rule_2():
+ graph_path = './tests/knowledge_graph_test_subset.graphml'
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.memory_profile = False
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+ pr.settings.parallel_computing = False
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+
+ pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+ # pr.save_rule_trace(interpretation)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+ assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations'
+ assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
+ assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
+
+
+def test_anyBurl_rule_3():
+ graph_path = './tests/knowledge_graph_test_subset.graphml'
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.memory_profile = False
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+ pr.settings.parallel_computing = False
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+
+ pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+ # pr.save_rule_trace(interpretation)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+ assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations'
+ assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
+ assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps'
+
+
+def test_anyBurl_rule_4():
+ graph_path = './tests/knowledge_graph_test_subset.graphml'
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.memory_profile = False
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+ pr.settings.parallel_computing = False
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+
+ pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+ # pr.save_rule_trace(interpretation)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+ assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations'
+ assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom'
+ assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps'
diff --git a/tests/test_custom_thresholds.py b/tests/test_custom_thresholds.py
index b982bf36..e1ae437a 100644
--- a/tests/test_custom_thresholds.py
+++ b/tests/test_custom_thresholds.py
@@ -11,7 +11,8 @@ def test_custom_thresholds():
# Modify the paths based on where you've stored the files we made above
graph_path = "./tests/group_chat_graph.graphml"
- # Modify pyreason settings to make verbose and to save the rule trace to a file
+ # Modify pyreason settings to make verbose
+ pr.reset_settings()
pr.settings.verbose = True # Print info to screen
# Load all the files into pyreason
@@ -25,16 +26,16 @@ def test_custom_thresholds():
pr.add_rule(
pr.Rule(
- "ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)",
+ "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)",
"viewed_by_all_rule",
custom_thresholds=user_defined_thresholds,
)
)
- pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3))
- pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3))
- pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3))
- pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3))
+ pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
+ pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
+ pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
+ pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
# Run the program for three timesteps to see the diffusion take place
interpretation = pr.reason(timesteps=3)
diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py
index c932daff..51674d5d 100644
--- a/tests/test_hello_world.py
+++ b/tests/test_hello_world.py
@@ -1,24 +1,28 @@
# Test if the simple hello world program works
import pyreason as pr
+import faulthandler
def test_hello_world():
# Reset PyReason
pr.reset()
pr.reset_rules()
+ pr.reset_settings()
# Modify the paths based on where you've stored the files we made above
graph_path = './tests/friends_graph.graphml'
- # Modify pyreason settings to make verbose and to save the rule trace to a file
+ # Modify pyreason settings to make verbose
pr.settings.verbose = True # Print info to screen
+ # pr.settings.optimize_rules = False # Disable rule optimization for debugging
# Load all the files into pyreason
pr.load_graphml(graph_path)
pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
- pr.add_fact(pr.Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
+ pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
# Run the program for two timesteps to see the diffusion take place
+ faulthandler.enable()
interpretation = pr.reason(timesteps=2)
# Display the changes in the interpretation for each timestep
@@ -29,8 +33,8 @@ def test_hello_world():
print()
assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person'
- assert len(dataframes[1]) == 2, 'At t=0 there should be two popular people'
- assert len(dataframes[2]) == 3, 'At t=0 there should be three popular people'
+ assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people'
+ assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people'
# Mary should be popular in all three timesteps
assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps'
@@ -43,3 +47,5 @@ def test_hello_world():
# John should be popular in timestep 3
assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps'
+
+test_hello_world()
\ No newline at end of file
diff --git a/tests/test_hello_world_parallel.py b/tests/test_hello_world_parallel.py
index 1b7ee03c..fe47a337 100644
--- a/tests/test_hello_world_parallel.py
+++ b/tests/test_hello_world_parallel.py
@@ -10,14 +10,14 @@ def test_hello_world_parallel():
# Modify the paths based on where you've stored the files we made above
graph_path = './tests/friends_graph.graphml'
- # Modify pyreason settings to make verbose and to save the rule trace to a file
+ # Modify pyreason settings to make verbose
+ pr.reset_settings()
pr.settings.verbose = True # Print info to screen
- pr.settings.parallel_computing = True
# Load all the files into pyreason
pr.load_graphml(graph_path)
pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
- pr.add_fact(pr.Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
+ pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
# Run the program for two timesteps to see the diffusion take place
interpretation = pr.reason(timesteps=2)
@@ -30,8 +30,8 @@ def test_hello_world_parallel():
print()
assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person'
- assert len(dataframes[1]) == 2, 'At t=0 there should be two popular people'
- assert len(dataframes[2]) == 3, 'At t=0 there should be three popular people'
+ assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people'
+ assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people'
# Mary should be popular in all three timesteps
assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps'
diff --git a/tests/test_num_ga.py b/tests/test_num_ga.py
new file mode 100644
index 00000000..220e1948
--- /dev/null
+++ b/tests/test_num_ga.py
@@ -0,0 +1,37 @@
+# Test if the simple hello world program works
+import pyreason as pr
+
+
+def test_num_ga():
+ graph_path = './tests/knowledge_graph_test_subset.graphml'
+ pr.reset()
+ pr.reset_rules()
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True
+ pr.settings.atom_trace = True
+ pr.settings.canonical = True
+ pr.settings.inconsistency_check = False
+ pr.settings.static_graph_facts = False
+ pr.settings.output_to_file = False
+ pr.settings.store_interpretation_changes = True
+ pr.settings.save_graph_attributes_to_trace = True
+
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+ pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True))
+ # pr.add_fact(pr.Fact('dummy(Riga_International_Airport): [0, 1]', 'dummy_fact', 0, 1))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=1)
+ # pr.save_rule_trace(interpretation)
+
+ # Find number of ground atoms from dictionary
+ ga_cnt = []
+ d = interpretation.get_dict()
+ for time, atoms in d.items():
+ ga_cnt.append(0)
+ for comp, label_bnds in atoms.items():
+ ga_cnt[time] += len(label_bnds)
+
+ # Make sure the computed number of ground atoms is correct
+ assert ga_cnt == list(interpretation.get_num_ground_atoms()), 'Number of ground atoms should be the same as the computed number of ground atoms'
diff --git a/tests/test_reorder_clauses.py b/tests/test_reorder_clauses.py
new file mode 100644
index 00000000..6407f9b5
--- /dev/null
+++ b/tests/test_reorder_clauses.py
@@ -0,0 +1,52 @@
+# Test if the simple hello world program works
+import pyreason as pr
+
+
+def test_reorder_clauses():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+ pr.reset_settings()
+
+ # Modify the paths based on where you've stored the files we made above
+ graph_path = './tests/friends_graph.graphml'
+
+ # Modify pyreason settings to make verbose
+ pr.settings.verbose = True # Print info to screen
+ pr.settings.atom_trace = True # Print atom trace
+
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+ pr.add_rule(pr.Rule('popular(x) <-1 Friends(x,y), popular(y), owns(y,z), owns(x,z)', 'popular_rule'))
+ pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2))
+
+ # Run the program for two timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=2)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])
+ for t, df in enumerate(dataframes):
+ print(f'TIMESTEP - {t}')
+ print(df)
+ print()
+
+ assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person'
+ assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people'
+ assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people'
+
+ # Mary should be popular in all three timesteps
+ assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps'
+ assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps'
+ assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps'
+
+ # Justin should be popular in timesteps 1, 2
+ assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps'
+ assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps'
+
+ # John should be popular in timestep 3
+ assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps'
+
+ # Now look at the trace and make sure the order has gone back to the original rule
+ # The second row, clause 1 should be the edge grounding ('Justin', 'Mary')
+ rule_trace_node, _ = pr.get_rule_trace(interpretation)
+ assert rule_trace_node.iloc[2]['Clause-1'][0] == ('Justin', 'Mary')
diff --git a/tests/test_rule_filtering.py b/tests/test_rule_filtering.py
new file mode 100644
index 00000000..ab348089
--- /dev/null
+++ b/tests/test_rule_filtering.py
@@ -0,0 +1,37 @@
+import pyreason as pr
+
+
+def test_rule_filtering():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+ pr.reset_settings()
+
+ # Modify the paths based on where you've stored the files we made above
+ graph_path = './tests/friends_graph.graphml'
+
+ # Modify pyreason settings to make verbose
+ pr.settings.verbose = True # Print info to screen
+ pr.settings.atom_trace = True # Print atom trace
+
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+ pr.add_rule(pr.Rule('head1(x) <-1 pred1(x,y), pred2(y,z), pred3(z, w)', 'rule1')) # Should fire
+ pr.add_rule(pr.Rule('head1(x) <-1 pred1(x,y), pred4(y,z), pred3(z, w)', 'rule2')) # Should fire
+ pr.add_rule(pr.Rule('head2(x) <-1 pred1(x,y), pred2(y,z), pred3(z, w)', 'rule3')) # Should not fire
+
+ # Dependency rules
+ pr.add_rule(pr.Rule('pred1(x,y) <-1 pred2(x,y)', 'rule4')) # Should fire
+ pr.add_rule(pr.Rule('pred2(x,y) <-1 pred3(x,y)', 'rule5')) # Should fire
+
+ # Define the query
+ query = pr.Query('head1(x)')
+
+ # Filter the rules
+ filtered_rules = pr.ruleset_filter.filter_ruleset([query], pr.get_rules())
+ filtered_rule_names = [r.get_rule_name() for r in filtered_rules]
+ assert 'rule1' in filtered_rule_names, 'Rule 1 should be in the filtered rules'
+ assert 'rule2' in filtered_rule_names, 'Rule 2 should be in the filtered rules'
+ assert 'rule4' in filtered_rule_names, 'Rule 4 should be in the filtered rules'
+ assert 'rule5' in filtered_rule_names, 'Rule 5 should be in the filtered rules'
+ assert 'rule3' not in filtered_rule_names, 'Rule 3 should not be in the filtered rules'