Skip to content

Commit

Permalink
[Enhancements] Update GraphKernel and GEDModel, unify the function of…
Browse files Browse the repository at this point in the history
… getting the corresponding metric matrices.
  • Loading branch information
jajupmochi committed Jan 18, 2024
1 parent 30a1757 commit c865eef
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
58 changes: 46 additions & 12 deletions gklearn/ged/model/ged_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,13 @@ def transform(
return dis_matrix


def fit_transform(self, X, y=None, save_dm_train=False, **kwargs):
def fit_transform(
self,
X,
y=None,
save_dm_train=False,
save_mm_train: bool = False,
**kwargs):
"""Fit and transform: compute GED distance matrix on the same data.
Parameters
Expand Down Expand Up @@ -230,7 +236,7 @@ def fit_transform(self, X, y=None, save_dm_train=False, **kwargs):
# finally:
# np.seterr(**old_settings)

if save_dm_train:
if save_mm_train or save_dm_train:
self._dm_train = dis_matrix
# If the model is refitted and the `save_dm_train` flag is not set, then
# remove the previously computed dm_train to prevent conflicts.
Expand Down Expand Up @@ -270,6 +276,10 @@ def validate_parameters(self):
None.
"""
if self.parallel == False:
self.parallel = None
elif self.parallel == True:
self.parallel = 'imap_unordered'
if self.parallel is not None and self.parallel != 'imap_unordered':
raise ValueError('Parallel mode is not set correctly.')

Expand Down Expand Up @@ -316,25 +326,27 @@ def compute_distance_matrix(self, Y=None, **kwargs):
Parameters
----------
Y : list of graphs, optional
The target graphs. The default is None. If None kernel is computed
The target graphs. The default is None. If None distance is computed
between X and itself.
Returns
-------
kernel_matrix : numpy array, shape = [n_targets, n_inputs]
The computed kernel matrix.
dis_matrix : numpy array, shape = [n_targets, n_inputs]
The computed distance matrix.
"""
if Y is None:
# Compute Gram matrix for self._graphs (X).
# Compute metric matrix for self._graphs (X).
dis_matrix = self._compute_X_distance_matrix(**kwargs)
# self._gram_matrix_unnorm = np.copy(self._gram_matrix)

else:
# Compute kernel matrix between Y and self._graphs (X).
# Compute metric matrix between Y and self._graphs (X).
Y_copy = ([g.copy() for g in Y] if self.copy_graphs else Y)
graphs_copy = ([g.copy() for g in
self._graphs] if self.copy_graphs else self._graphs)
graphs_copy = (
[g.copy() for g in self._graphs]
if self.copy_graphs else self._graphs
)

start_time = time.time()

Expand Down Expand Up @@ -786,8 +798,8 @@ def is_graph(self, graph):
if isinstance(graph, nx.MultiDiGraph):
return True
return False


def __repr__(self):
return (
f"{self.__class__.__name__}("
Expand Down Expand Up @@ -833,20 +845,31 @@ def graphs(self):
def run_time(self):
return self._run_time


@property
def test_run_time(self):
return self._test_run_time


@property
def dis_matrix(self):
return self._dm_train


@dis_matrix.setter
def dis_matrix(self, value):
self._dm_train = value


@property
def metric_matrix(self):
return self._dm_train


@metric_matrix.setter
def metric_matrix(self, value):
self._dm_train = value


@property
def edit_cost_constants(self):
return self._edit_cost_constants
Expand All @@ -860,6 +883,17 @@ def edit_cost_constants(self):
# def gram_matrix_unnorm(self, value):
# self._gram_matrix_unnorm = value

@property
def n_pairs(self):
"""
The number of pairs of graphs between which the GEDs are computed.
"""
try:
check_is_fitted(self, '_dm_train')
return len(self._dm_train) * (len(self._dm_train) - 1) / 2
except NotFittedError:
return None


def _init_worker_ged_mat(gn_toshare):
global G_gn
Expand Down
31 changes: 29 additions & 2 deletions gklearn/kernels/graph_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ def transform(self, X=None, load_gm_train=False):
return kernel_matrix


def fit_transform(self, X, save_gm_train=False):
def fit_transform(
self,
X,
save_gm_train: bool = False,
save_mm_train: bool = False,
):
"""Fit and transform: compute Gram matrix on the same data.
Parameters
Expand Down Expand Up @@ -187,7 +192,7 @@ def fit_transform(self, X, save_gm_train=False):
finally:
np.seterr(**old_settings)

if save_gm_train:
if save_mm_train or save_gm_train:
self._gm_train = gram_matrix

return gram_matrix
Expand Down Expand Up @@ -783,6 +788,16 @@ def gram_matrix(self, value):
self._gm_train = value


@property
def metric_matrix(self):
return self._gm_train


@metric_matrix.setter
def metric_matrix(self, value):
self._gm_train = value


@property
def gram_matrix_unnorm(self):
return self._gram_matrix_unnorm
Expand All @@ -791,3 +806,15 @@ def gram_matrix_unnorm(self):
@gram_matrix_unnorm.setter
def gram_matrix_unnorm(self, value):
self._gram_matrix_unnorm = value


@property
def n_pairs(self):
"""
The number of pairs of graphs between which the kernels are computed.
"""
try:
check_is_fitted(self, '_gm_train')
return len(self._gm_train) * (len(self._gm_train) + 1) / 2
except NotFittedError:
return None

0 comments on commit c865eef

Please sign in to comment.