Skip to content

Commit

Permalink
improve recalls and pointless relations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewcropper committed Dec 13, 2024
1 parent 6d977ab commit 72781eb
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 189 deletions.
90 changes: 29 additions & 61 deletions popper/bkcons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import operator
import pkg_resources
import time
from . util import rule_is_recursive, format_rule, Constraint, order_prog, Literal
from . util import rule_is_recursive, format_rule, Constraint, order_prog, Literal, suppress_stdout_stderr
from clingo import Function, Number, Tuple_
from collections import defaultdict
from itertools import permutations
from itertools import permutations, product
from pysat.card import *
from pysat.formula import CNF
from pysat.solvers import Solver
Expand Down Expand Up @@ -738,98 +738,75 @@ def deduce_bk_cons(settings, tester):


def generate_binary_strings(bit_count):
binary_strings = []
def genbin(n, bs=''):
if len(bs) == n:
binary_strings.append(bs)
else:
genbin(n, bs + '0')
genbin(n, bs + '1')
genbin(bit_count)
return binary_strings

return list(product((0,1), repeat=bit_count))[1:-1]

def deduce_recalls(settings):
# Jan Struyf, Hendrik Blockeel: Query Optimization in Inductive Logic Programming by Reordering Literals. ILP 2003: 329-346

# recall for a subset of arguments, e.g. when A and C are ground in a call to add(A,B,C)
counts = {}
# maximum recall for a predicate symbol
counts_all = {}

with open(settings.bk_file) as f:
bk = f.read()
solver = clingo.Control(['-Wnone'])
solver.add('base', [], bk)
solver.ground([('base', [])])
try:

with open(settings.bk_file) as f:
bk = f.read()
solver = clingo.Control(['-Wnone'])
with suppress_stdout_stderr():
solver.add('base', [], bk)
solver.ground([('base', [])])
except Exception as Err:
print('ERROR deducing recalls', Err)
return None


for pred, arity in settings.body_preds:
# print(pred, arity)
counts_all[pred] = 0
counts[pred] = {}
# we find all facts for a given predicate symbol
d = counts[pred]
binary_strings = generate_binary_strings(arity)

for atom in solver.symbolic_atoms.by_signature(pred, arity=arity):
args = []
for i in range(arity):
arg = atom.symbol.arguments[i]
t = arg.type
if t == clingo.SymbolType.Number:
x = arg.number
elif t == clingo.SymbolType.String:
x = arg.string
else:
x = arg.name
args.append(x)
for var_subset in binary_strings:
d[var_subset] = defaultdict(set)

# print('X', pred, args)
for atom in solver.symbolic_atoms.by_signature(pred, arity=arity):
counts_all[pred] +=1
# x_args = [x[arg] for arg in args]
# we now enumerate all subsets of possible input/ground arguments
# for instance, for a predicate symbol p/2 we consider p(10) and p(01), where 1 denotes input
# note that p(00) is the max recall and p(11) is 1 since it is a boolean check
binary_strings = generate_binary_strings(arity)[1:-1]

args = list(map(str, atom.symbol.arguments))

for var_subset in binary_strings:
# print('var_subset', var_subset)
if var_subset not in counts[pred]:
counts[pred][var_subset] = {}
key = []
value = []
for i in range(arity):
if var_subset[i] == '1':
if var_subset[i]:
key.append(args[i])
else:
value.append(args[i])
key = tuple(key)
value = tuple(value)
# print('\t', key, value)
if key not in counts[pred][var_subset]:
counts[pred][var_subset][key] = set()
counts[pred][var_subset][key].add(value)
d[var_subset][key].add(value)

# we now calculate the maximum recall
all_recalls = {}
for pred, arity in settings.body_preds:
d1 = counts[pred]
all_recalls[(pred, '0'*arity)] = counts_all[pred]
all_recalls[(pred, (0,)*arity)] = counts_all[pred]
for args, d2 in d1.items():
recall = max(len(xs) for xs in d2.values())
# print(pred, args, recall)
all_recalls[(pred, args)] = recall

settings.recall = all_recalls

# for k, v in all_recalls.items():
# for k, v in sorted(all_recalls.items()):
# print(k ,v)

out = []

for (pred, key), recall in all_recalls.items():
if recall > 4:
continue
if '1' not in key:
if 1 not in key:
pass
# continue
arity = len(key)
args = [f'V{i}' for i in range(arity)]
args_str = ','.join(args)
Expand All @@ -840,29 +817,20 @@ def deduce_recalls(settings):
fixer = []

for x, y in zip(key, args):
if x == '0':
if x == 0:
subset.append(y)
fixer.append('_')
else:
fixer.append(y)


subset_str = ','.join(subset)
fixer_str = ','.join(fixer)
if len(fixer) == 1:
fixer_str+= ','
# print(pred, key, fixer, fixer_str)
# print(args_str)

con2 = f':- body_literal(Rule,{pred},_,({fixer_str})), #count{{{subset_str}: body_literal(Rule,{pred},_,({args_str}))}} > {recall}.'
# print(con2)
out.append(con2)

# for x in settings.recall.items():
# print(x)
# print(out)
return out
# settings.deduced_bkcons += '\n' + '\n'.join(out)

def deduce_type_cons(settings):

Expand Down
1 change: 1 addition & 0 deletions popper/gen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, settings, bkcons=[]):

for p,a in settings.pointless:
bias_text = re.sub(rf'body_pred\({p},{a}\).','', bias_text)
bias_text = re.sub(rf'constant\({p},.*?\).*', '', bias_text, flags=re.MULTILINE)

encoding.append(bias_text)
encoding.append(f'max_clauses({settings.max_rules}).')
Expand Down
138 changes: 13 additions & 125 deletions popper/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,41 +1643,36 @@ def popper(settings, tester, bkcons):
def get_bk_cons(settings, tester):
bkcons = []

pointless = settings.pointless = set()
if settings.debug:
try:
pointless = settings.pointless = find_pointless_relations(settings)
settings.datalog = True
except:
settings.datalog = False
else:
with suppress_stdout_stderr():
try:
pointless = settings.pointless = find_pointless_relations(settings)
settings.datalog = True
except:
settings.datalog = False
with settings.stats.duration('find_pointless_relations'):
pointless = settings.pointless = tester.find_pointless_relations()

for p,a in pointless:
if settings.showcons:
print('remove pointless relation', p, a)
settings.body_preds.remove((p,a))

# if settings.datalog:
settings.logger.debug(f'Loading recalls')
if settings.datalog:
with settings.stats.duration('recalls'):
recalls = tuple(deduce_recalls(settings))
with settings.stats.duration('recalls'):
recalls = deduce_recalls(settings)

if recalls == None:
settings.datalog = False
else:
settings.datalog = True
if settings.showcons:
for x in recalls:
print('recall', x)
bkcons.extend(recalls)

if settings.datalog:
type_cons = tuple(deduce_type_cons(settings))
if settings.showcons:
for x in type_cons:
print('type_con', x)
bkcons.extend(type_cons)


if not settings.datalog:
settings.logger.debug(f'Loading recalls FAILURE')
else:
Expand Down Expand Up @@ -1862,111 +1857,4 @@ def non_empty_powerset(iterable):

def non_empty_subset(iterable):
s = tuple(iterable)
return chain.from_iterable(combinations(s, r) for r in range(1, len(s)))

def find_pointless_relations(settings):

import clingo
encoding = []

encoding.append('#show same/2.')

arities = {}

for p, pa in settings.body_preds:
arities[p] = pa
for q, qa in settings.body_preds:
if p == q:
continue
if pa != qa:
continue

if settings.body_types and settings.body_types[p] != settings.body_types[q]:
continue

arg_str = ','.join(f'V{i}' for i in range(pa))

rule1 = f'diff({p},{q}):- {p}({arg_str}), not {q}({arg_str}).'
rule2 = f'diff({p},{q}):- {q}({arg_str}), not {p}({arg_str}).'
rule3 = f'same({p},{q}):- {p}<{q}, not diff({p},{q}).'

encoding.extend([rule1, rule2, rule3])

encoding.append('\n')
with open(settings.bk_file) as f:
bk = f.read()
encoding.append(bk)

encoding = '\n'.join(encoding)

# with open('encoding.pl', 'w') as f:
# f.write(encoding)

solver = clingo.Control(['-Wnone'])
solver.add('base', [], encoding)
solver.ground([('base', [])])

keep = set()
pointless = set()

with solver.solve(yield_=True) as handle:
for m in handle:
for atom in m.symbols(shown = True):
# print(str(atom))
a, b = str(atom)[5:-1].split(',')
if a in keep and b in keep:
assert(False)
if a not in pointless and b not in pointless:
if a in keep:
pointless.add(b)
# print('drop1', b)
elif b in keep:
pointless.add(a)
# print('drop1', a)
else:
keep.add(a)
pointless.add(b)
# print('drop1', b)
elif a in pointless or b in pointless:
if a not in keep:
pointless.add(a)
# print('drop5', a)
if b not in keep:
pointless.add(b)
# print('drop5', b)
elif a not in pointless and b not in pointless:
keep.add(a)
pointless.add(b)
# print('drop2', b)
elif a in pointless:
pointless.add(b)
# print('drop3', b)
elif b in pointless:
pointless.add(b)
# print('keep', a)
# print('drop4', b)


# if a in keep and b in keep:
# assert(False)
# elif a in keep and b not in keep:
# pointless.add(b)
# print('drop', b)
# elif a not in keep and b in keep:
# pointless.add(a)
# print('drop', a)
# elif a not in keep and b not in keep:
# keep.add(a)
# pointless.add(b)
# print('keep', a)
# print('drop', b)

# same(input_plow_row,input_harvest_col)
# print('-----')
# for x in keep:
# print('keep', x)
# for x in pointless:
# print('drop', x)
# exit()
return frozenset((p, arities[p]) for p in pointless)
# settings.drop_preds = pointless
return chain.from_iterable(combinations(s, r) for r in range(1, len(s)))
Loading

0 comments on commit 72781eb

Please sign in to comment.