diff --git a/gklearn/ged/model/ged_model.py b/gklearn/ged/model/ged_model.py index eb3f6afc8b..ea6793f04c 100644 --- a/gklearn/ged/model/ged_model.py +++ b/gklearn/ged/model/ged_model.py @@ -200,7 +200,13 @@ def transform( return dis_matrix - def fit_transform(self, X, y=None, save_dm_train=False, **kwargs): + def fit_transform( + self, + X, + y=None, + save_dm_train=False, + save_mm_train: bool = False, + **kwargs): """Fit and transform: compute GED distance matrix on the same data. Parameters @@ -230,7 +236,7 @@ def fit_transform(self, X, y=None, save_dm_train=False, **kwargs): # finally: # np.seterr(**old_settings) - if save_dm_train: + if save_mm_train or save_dm_train: self._dm_train = dis_matrix # If the model is refitted and the `save_dm_train` flag is not set, then # remove the previously computed dm_train to prevent conflicts. @@ -270,6 +276,10 @@ def validate_parameters(self): None. """ + if self.parallel == False: + self.parallel = None + elif self.parallel == True: + self.parallel = 'imap_unordered' if self.parallel is not None and self.parallel != 'imap_unordered': raise ValueError('Parallel mode is not set correctly.') @@ -316,25 +326,27 @@ def compute_distance_matrix(self, Y=None, **kwargs): Parameters ---------- Y : list of graphs, optional - The target graphs. The default is None. If None kernel is computed + The target graphs. The default is None. If None distance is computed between X and itself. Returns ------- - kernel_matrix : numpy array, shape = [n_targets, n_inputs] - The computed kernel matrix. + dis_matrix : numpy array, shape = [n_targets, n_inputs] + The computed distance matrix. """ if Y is None: - # Compute Gram matrix for self._graphs (X). + # Compute metric matrix for self._graphs (X). dis_matrix = self._compute_X_distance_matrix(**kwargs) # self._gram_matrix_unnorm = np.copy(self._gram_matrix) else: - # Compute kernel matrix between Y and self._graphs (X). + # Compute metric matrix between Y and self._graphs (X). Y_copy = ([g.copy() for g in Y] if self.copy_graphs else Y) - graphs_copy = ([g.copy() for g in - self._graphs] if self.copy_graphs else self._graphs) + graphs_copy = ( + [g.copy() for g in self._graphs] + if self.copy_graphs else self._graphs + ) start_time = time.time() @@ -786,8 +798,8 @@ def is_graph(self, graph): if isinstance(graph, nx.MultiDiGraph): return True return False - - + + def __repr__(self): return ( f"{self.__class__.__name__}(" @@ -833,20 +845,31 @@ def graphs(self): def run_time(self): return self._run_time + @property def test_run_time(self): return self._test_run_time + @property def dis_matrix(self): return self._dm_train - @dis_matrix.setter def dis_matrix(self, value): self._dm_train = value + @property + def metric_matrix(self): + return self._dm_train + + + @metric_matrix.setter + def metric_matrix(self, value): + self._dm_train = value + + @property def edit_cost_constants(self): return self._edit_cost_constants @@ -860,6 +883,17 @@ def edit_cost_constants(self): # def gram_matrix_unnorm(self, value): # self._gram_matrix_unnorm = value + @property + def n_pairs(self): + """ + The number of pairs of graphs between which the GEDs are computed. + """ + try: + check_is_fitted(self, '_dm_train') + return len(self._dm_train) * (len(self._dm_train) - 1) / 2 + except NotFittedError: + return None + def _init_worker_ged_mat(gn_toshare): global G_gn diff --git a/gklearn/kernels/graph_kernel.py b/gklearn/kernels/graph_kernel.py index 7992396942..3a9ae9fa1f 100644 --- a/gklearn/kernels/graph_kernel.py +++ b/gklearn/kernels/graph_kernel.py @@ -154,7 +154,12 @@ def transform(self, X=None, load_gm_train=False): return kernel_matrix - def fit_transform(self, X, save_gm_train=False): + def fit_transform( + self, + X, + save_gm_train: bool = False, + save_mm_train: bool = False, + ): """Fit and transform: compute Gram matrix on the same data. Parameters @@ -187,7 +192,7 @@ def fit_transform(self, X, save_gm_train=False): finally: np.seterr(**old_settings) - if save_gm_train: + if save_mm_train or save_gm_train: self._gm_train = gram_matrix return gram_matrix @@ -783,6 +788,16 @@ def gram_matrix(self, value): self._gm_train = value + @property + def metric_matrix(self): + return self._gm_train + + + @metric_matrix.setter + def metric_matrix(self, value): + self._gm_train = value + + @property def gram_matrix_unnorm(self): return self._gram_matrix_unnorm @@ -791,3 +806,15 @@ def gram_matrix_unnorm(self): @gram_matrix_unnorm.setter def gram_matrix_unnorm(self, value): self._gram_matrix_unnorm = value + + + @property + def n_pairs(self): + """ + The number of pairs of graphs between which the kernels are computed. + """ + try: + check_is_fitted(self, '_gm_train') + return len(self._gm_train) * (len(self._gm_train) + 1) / 2 + except NotFittedError: + return None