Skip to content

Commit

Permalink
Merge pull request #465 from theislab/feature/perturbation_space_impr…
Browse files Browse the repository at this point in the history
…ovements

Docs improvements for perturbation space
  • Loading branch information
Lilly-May authored Dec 14, 2023
2 parents a720782 + 5e1344c commit 4f22a90
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 20 deletions.
48 changes: 29 additions & 19 deletions pertpy/tools/_perturbation_space/_discriminator_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
class DiscriminatorClassifierSpace(PerturbationSpace):
"""Leveraging discriminator classifier. Fit a regressor model to the data and take the feature space.
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19)
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
We use either the coefficients of the model for each perturbation as a feature or train a classifier example
(simple MLP or logistic regression and take the penultimate layer as feature space and apply pseudobulking approach).
(simple MLP or logistic regression) and take the penultimate layer as feature space and apply pseudobulking approach.
"""

def load( # type: ignore
Expand All @@ -39,22 +39,25 @@ def load( # type: ignore
test_split_size: float = 0.2,
validation_split_size: float = 0.25,
):
"""Creates a model with the specified parameters (hidden_dim, dropout, batch_norm).
"""Creates a neural network model using the specified parameters (hidden_dim, dropout, batch_norm). Further
parameters such as the number of classes to predict (number of perturbations) are obtained from the provided
AnnData object directly.
It further creates dataloaders and fixes class imbalance due to control.
Sets the device to a GPU if available.
Args:
adata: AnnData object of size cells x genes
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
layer_key: Layer to use. Defaults to None.
layer_key: Layer in adata to use. Defaults to None.
hidden_dim: list of hidden layers of the neural network. For instance: [512, 256].
dropout: amount of dropout applied, constant for all layers. Defaults to 0.
batch_norm: Whether to apply batch normalization. Defaults to True.
batch_size: The batch size. Defaults to 256.
test_split_size: Default to 0.2.
validation_split_size: Size of the validation split taking into account that is taking with respect to the resultant train split.
Defaults to 0.25.
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
test_split_size: Fraction of data to put in the test set. Default to 0.2.
validation_split_size: Fraction of data to put in the validation set of the resultant train set.
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
will be used for validation. Defaults to 0.25.
Examples:
>>> import pertpy as pt
Expand Down Expand Up @@ -121,12 +124,13 @@ def load( # type: ignore
return self

def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = 2):
"""Trains and test the defined model in the load step.
"""Trains and tests the neural network model defined in the load step.
Args:
max_epochs: max epochs for training. Default to 40
val_epochs_check: check in validation dataset each val_epochs_check epochs
patience: patience before the early stopping flag is activated
max_epochs: max epochs for training. Default to 40.
val_epochs_check: test performance on validation dataset after every val_epochs_check training epochs.
patience: number of validation performance checks without improvement, after which the early stopping flag
is activated and training is therefore stopped.
Examples:
>>> import pertpy as pt
Expand All @@ -152,7 +156,7 @@ def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int =
self.trainer.test(model=self.model, dataloaders=self.test_dataloader)

def get_embeddings(self) -> AnnData:
"""Access to the embeddings of the last layer.
"""Obtain the embeddings of the data, i.e., the values in the last layer of the MLP.
Returns:
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
Expand Down Expand Up @@ -194,10 +198,10 @@ def __init__(
) -> None:
"""
Args:
sizes: size of layers
sizes: size of layers.
dropout: Dropout probability. Defaults to 0.0.
batch_norm: batch norm. Defaults to True.
layer_norm: layern norm, common in Transformers. Defaults to False.
batch_norm: specifies if batch norm should be applied. Defaults to True.
layer_norm: specifies if layer norm should be applied, as commonly used in Transformers. Defaults to False.
last_layer_act: activation function of last layer. Defaults to "linear".
"""
super().__init__()
Expand Down Expand Up @@ -301,8 +305,14 @@ def __init__(
seed=42,
):
"""
Inputs:
layers - list: layers of the MLP
Args:
layers: list of layers of the MLP
dropout: dropout probability
batch_norm: whether to apply batch norm
layer_norm: whether to apply layer norm
last_layer_act: activation function of last layer
lr: learning rate
seed: random seed
"""
super().__init__()
self.save_hyperparameters()
Expand Down Expand Up @@ -368,7 +378,7 @@ def test_step(self, batch, batch_idx):
def embedding(self, x):
"""
Inputs:
x - Input features of shape [Batch, SeqLen, 1]
x: Input features of shape [Batch, SeqLen, 1]
"""
x = self.net.embedding(x)
return x
Expand Down
2 changes: 1 addition & 1 deletion pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class PerturbationSpace:
"""Implements various ways of interacting with PerturbationSpaces.
We differentiate between a cell space and a perturbation space.
Visually speaking, in cell spaces single dota points in an embeddings summarize a cell,
Visually speaking, in cell spaces single data points in an embeddings summarize a cell,
whereas in a perturbation space, data points summarize whole perturbations.
"""

Expand Down
17 changes: 17 additions & 0 deletions pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def compute(
keep_obs: Whether .obs columns in the input AnnData should be kept in the output pseudobulk AnnData. Only .obs columns with the same value for
each cell of one perturbation are kept. Defaults to True.
Returns:
AnnData object with one observation per perturbation, storing the embedding data of the
centroid of the respective perturbation.
Examples:
Compute the centroids of a UMAP embedding of the papalexi_2021 dataset:
Expand Down Expand Up @@ -123,6 +127,9 @@ def compute(
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
**kwargs: Are passed to decoupler's get_pseuobulk.
Returns:
AnnData object with one observation per perturbation.
Examples:
>>> import pertpy as pp
>>> mdata = pt.dt.papalexi_2021()
Expand Down Expand Up @@ -179,6 +186,11 @@ def compute( # type: ignore
return_object: if True returns the clustering object
**kwargs: Are passed to sklearn's KMeans.
Returns:
If return_object is True, the adata and the clustering object is returned.
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
that stores the cluster labels.
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
Expand Down Expand Up @@ -239,6 +251,11 @@ def compute( # type: ignore
return_object: if True returns the clustering object
**kwargs: Are passed to sklearn's DBSCAN.
Returns:
If return_object is True, the adata and the clustering object is returned.
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
that stores the cluster labels.
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
Expand Down

0 comments on commit 4f22a90

Please sign in to comment.