Skip to content

Commit

Permalink
[MNT] add and modify docstring in bridge folder
Browse files Browse the repository at this point in the history
  • Loading branch information
WaTerminator committed Dec 13, 2023
1 parent 9e305e8 commit c8f537a
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 7 deletions.
34 changes: 28 additions & 6 deletions abl/bridge/base_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@


class BaseBridge(metaclass=ABCMeta):
"""
A base class for bridging machine learning and reasoning parts.
This class provides necessary methods that need to be overridden in subclasses
to construct a typical pipeline of Abductive learning (corresponding to ``train``),
which involves the following four methods:
- predict: Predict class indices on the given data samples.
- idx_to_pseudo_label: Map indices into pseudo labels.
- abduce_pseudo_label: Revise pseudo labels based on abdutive reasoning.
- pseudo_label_to_idx: Map revised pseudo labels back into indices.
Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
The reasoning part wrapped in ``Reasoner``, which is used for pseudo label revision.
"""

def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(
Expand All @@ -22,24 +43,25 @@ def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:

@abstractmethod
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predicting labels from input."""
"""Placeholder for predicting class indices from input."""

@abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abducing pseudo labels."""
"""Placeholder for revising pseudo labels based on abdutive reasoning."""

@abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for mapping indexes to pseudo labels."""
"""Placeholder for mapping indices to pseudo labels."""

@abstractmethod
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for mapping pseudo labels to indexes."""
"""Placeholder for mapping pseudo labels to indices."""

def filter_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
'''Default filter function for pseudo label.'''
"""Default filter function for pseudo label."""
non_empty_idx = [
i for i in range(len(data_samples.abduced_pseudo_label))
i
for i in range(len(data_samples.abduced_pseudo_label))
if data_samples.abduced_pseudo_label[i]
]
data_samples.update(data_samples[non_empty_idx])
Expand Down
159 changes: 158 additions & 1 deletion abl/bridge/simple_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@


class SimpleBridge(BaseBridge):
"""
A basic implementation for bridging machine learning and reasoning parts.
This class implements the typical pipeline of Abductive learning, which involves
the following five steps:
- Predict class probabilities and indices for the given data samples.
- Map indices into pseudo labels.
- Revise pseudo labels based on abdutive reasoning.
- Map the revised pseudo labels to indices.
- Train the model.
Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
The reasoning part wrapped in ``Reasoner``, which is used for pseudo label revision.
metric_list : List[BaseMetric]
A list of metrics used for evaluating the model's performance.
"""

def __init__(
self,
model: ABLModel,
Expand All @@ -22,21 +45,74 @@ def __init__(
self.metric_list = metric_list

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
"""
Predict class indices and probabilities (if ``predict_proba`` is implemented in
``self.model.base_model``) on the given data samples.
Parameters
----------
data_samples : ListData
Data samples on which predictions are to be made.
Returns
-------
Tuple[List[ndarray], List[ndarray]]
A tuple containing lists of predicted indices and probabilities.
"""
self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob

def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""
Revise predicted pseudo labels of the given data samples using abduction.
Parameters
----------
data_samples : ListData
Data samples containing predicted pseudo labels.
Returns
-------
List[List[Any]]
A list of abduced pseudo labels for the given data samples.
"""
self.reasoner.batch_abduce(data_samples)
return data_samples.abduced_pseudo_label

def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""
Map indices of data samples into pseudo labels.
Parameters
----------
data_samples : ListData
Data samples containing the indices.
Returns
-------
List[List[Any]]
A list of pseudo labels converted from indices.
"""
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
]
return data_samples.pred_pseudo_label

def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""
Map pseudo labels of data samples into indices.
Parameters
----------
data_samples : ListData
Data samples containing pseudo labels.
Returns
-------
List[List[Any]]
A list of indices converted from pseudo labels.
"""
abduced_idx = [
[self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
Expand All @@ -49,6 +125,21 @@ def data_preprocess(
prefix: str,
data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> ListData:
"""
Transform data in the form of (X, gt_pseudo_label, Y) into ListData.
Parameters
----------
prefix : str
A prefix indicating the type of data processing (e.g., 'train', 'test').
data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Data to be preprocessed. Can be ListData or a tuple of lists.
Returns
-------
ListData
The preprocessed ListData object.
"""
if isinstance(data, ListData):
data_samples = data
if not (
Expand All @@ -69,6 +160,21 @@ def data_preprocess(
def concat_data_samples(
self, unlabel_data_samples: ListData, label_data_samples: Optional[ListData]
) -> ListData:
"""
Concatenate unlabeled and labeled data samples. ``abduced_pseudo_label`` of unlabeled data samples and ``gt_pseudo_label`` of labeled data samples will be used to train the model.
Parameters
----------
unlabel_data_samples : ListData
Unlabeled data samples to concatenate.
label_data_samples : Optional[ListData]
Labeled data samples to concatenate, if available.
Returns
-------
ListData
Concatenated data samples.
"""
if label_data_samples is None:
return unlabel_data_samples

Expand All @@ -89,11 +195,38 @@ def train(
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
] = None,
loops: int = 50,
segment_size: Union[int, float] = -1,
segment_size: Union[int, float] = 1.0,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
"""
A typical training pipeline of Abuductive Learning.
Parameters
----------
train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Training data.
label_data : Optional[Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]]
Data with ``gt_pseudo_label`` that can be used to train the model, by
default None.
val_data : Optional[Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]]
Validation data, by default None.
loops : int
Machine Learning part and Reasoning part will be iteratively optimized
for ``loops`` times, by default 50.
segment_size : Union[int, float]
Data will be split into segments of this size and data in each segment
will be used together to train the model, by default 1.0.
eval_interval : int
The model will be evaluated every ``eval_interval`` loops during training,
by default 1.
save_interval : Optional[int]
The model will be saved every ``eval_interval`` loops during training, by
default None.
save_dir : Optional[str]
Directory to save the model, by default None.
"""
data_samples = self.data_preprocess("train", train_data)

if label_data is not None:
Expand Down Expand Up @@ -147,6 +280,14 @@ def train(
)

def _valid(self, data_samples: ListData) -> None:
"""
Internal method for validating the model with given data samples.
Parameters
----------
data_samples : ListData
Data samples to be used for validation.
"""
self.predict(data_samples)
self.idx_to_pseudo_label(data_samples)

Expand All @@ -165,12 +306,28 @@ def valid(
self,
val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""
Validate the model with the given validation data.
Parameters
----------
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Validation data to be used for model evaluation.
"""
val_data_samples = self.data_preprocess(val_data)
self._valid(val_data_samples)

def test(
self,
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""
Test the model with the given test data.
Parameters
----------
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Test data to be used for model evaluation.
"""
test_data_samples = self.data_preprocess("test", test_data)
self._valid(test_data_samples)

0 comments on commit c8f537a

Please sign in to comment.