from NodeType import NodeType import numpy as np import random class Node(object): """An SPN node""" def add_parent(self, parent): self.parents.append(parent) def add_child(self, child): self.children.append(child) child.add_parent(self) def update_map_weight_counts(self): raise NotImplementedError() def get_value(self, max_mode=False): raise NotImplementedError() def __str__(self): return '{type} node {name}' \ .format(type=self.type.name, name=self.name) class SumNode(Node): """An SPN sum node""" def __init__(self, name, children=[]): self.name = name self.children = children # Set random weights self.links = dict() for child in self.children: child.add_parent(self) link = dict() link['child'] = child link['weight'] = random.random() link['count'] = 0 self.links[child.name] = link self.normalise_weights() self.parents = [] self.type = NodeType.SUM self.value = 0.0 def add_child(self, child, weight=None): self.children.append(child) child.add_parent(self) link = dict() link['child'] = child link['weight'] = weight if weight else random.random() link['count'] = 0 self.links[child.name] = link def normalise_weights(self): total = np.sum(list(map(lambda l: l['weight'], self.links.values()))) for link in self.links.values(): link['weight'] /= total def get_value(self, max_mode=False): if not max_mode: self.value = 0.0 for child in self.children: self.value += self.links[child.name]['weight'] * child.get_value() return self.value else: max_child = {'node': None, 'value': None} for child in self.children: value = self.links[child.name]['weight'] * child.get_value(max_mode=True) if not max_child['value'] or max_child['value'] < value: max_child['node'] = child max_child['value'] = value self.value = max_child['value'] return self.value def update_map_weight_counts(self): maximum = {'value': 0, 'node': None} for child in self.children: val = child.value * self.links[child.name]['weight'] if val >= maximum['value']: maximum['value'] = val maximum['node'] = child self.links[maximum['node'].name]['count'] += 1 maximum['node'].update_map_weight_counts() def normalise_counts_as_weights(self): """Normalise the counts by normalising so they sum to one, and then set that as the weights""" # Calculates the total by summing the counts of all links total = np.sum(list(map(lambda l: l['count'], self.links.values()))) if total == 0: # Set all weights and counts to zero for link in self.links.values(): link['weight'] = 0.0 link['count'] = 0 else: for link in self.links.values(): link['weight'] = 1.0 * link['count'] / total link['count'] = 0 class ProdNode(Node): """An SPN product node""" def __init__(self, name, children=[]): self.name = name self.parents = [] self.links = dict() self.children = children for child in self.children: child.add_parent(self) self.links[child.name] = {'child': child, 'count': 0.0} self.value = 0.0 self.type = NodeType.PRODUCT def get_value(self, max_mode=False): self.value = 1.0 for child in self.children: self.value = self.value * child.get_value(max_mode=max_mode) return self.value def update_map_weight_counts(self): for child in self.children: child.update_map_weight_counts() class LeafNode(Node): """An SPN leaf node""" def __init__(self, name, value=1.0): self.name = name self.parents = [] self.value = value self.children = [] self.type = NodeType.LEAF def get_value(self, max_mode=False): return self.value def update_map_weight_counts(self): pass