Skip to content

Commit

Permalink
Allow for custom optimization function for in silicio evolution
Browse files Browse the repository at this point in the history
  • Loading branch information
SeppeDeWinter committed Sep 5, 2024
1 parent 58bcd63 commit 0e43b43
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
44 changes: 29 additions & 15 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from anndata import AnnData
from loguru import logger
from tqdm import tqdm
from typing import Callable, Any

from crested._logging import log_and_raise
from crested.tl import TaskConfig
from crested.tl._utils import (
_weighted_difference,
EnhancerOptimizer,
generate_motif_insertions,
generate_mutagenesis,
hot_encoding_to_sequence,
Expand Down Expand Up @@ -1136,47 +1138,58 @@ def enhancer_design_motif_implementation(

def enhancer_design_in_silico_evolution(
self,
target_class: str,
n_mutations: int,
n_sequences: int,
target_class: str | None = None,
return_intermediate: bool = False,
class_penalty_weights: np.ndarray | None = None,
no_mutation_flanks: tuple | None = None,
target_len: int | None = None,
enhancer_optimizer: EnhancerOptimizer | None = None,
**kwargs: dict[str, Any]
) -> tuple[list[dict], list] | list:
"""
Create synthetic enhancers for a specified class using in silico evolution (ISE).
Parameters
----------
target_class
Class name for which the enhancers will be designed for.
n_mutations
Number of mutations per sequence
n_sequences
Number of enhancers to design
target_class
Class name for which the enhancers will be designed for. If this value is set to None a custom target can be
defined using kwargs.
return_intermediate
If True, returns a dictionary with predictions and changes made in intermediate steps for selected
sequences
class_penalty_weights
Array with a value per class, determining the penalty weight for that class to be used in scoring
function for sequence selection.
no_mutation_flanks
A tuple of integers which determine the regions in each flank to not do implementations.
target_len
Length of the area in the center of the sequence to make implementations, ignored if no_mutation_flanks
is supplied.
enhancer_optimizer
An instance of EnhancerOptimizer, defining how sequences should be optimized.
If None, a default EnhancerOptimizer will be initialized using `_weighted_difference`
as optimization function.
kwargs
Keyword arguments that will be passed to the `get_best` function of the EnhancerOptimizer
Returns
-------
A list of designed sequences and if return_intermediate is True a list of dictionaries of intermediate
mutations and predictions
"""
self._check_contribution_scores_params([target_class])
if target_class is not None:
self._check_contribution_scores_params([target_class])

all_class_names = list(self.anndatamodule.adata.obs_names)
all_class_names = list(self.anndatamodule.adata.obs_names)

target = all_class_names.index(target_class)
target = all_class_names.index(target_class)

if enhancer_optimizer is None:
enhancer_optimizer = EnhancerOptimizer(
optimize_func = _weighted_difference
)

# get input sequence length of the model
seq_len = (
Expand Down Expand Up @@ -1237,11 +1250,12 @@ def enhancer_design_in_silico_evolution(
mutagenesis_predictions = self.model.predict(mutagenesis)

# determine the best mutation
best_mutation = _weighted_difference(
mutagenesis_predictions,
current_prediction,
target,
class_penalty_weights,

best_mutation = enhancer_optimizer.get_best(
mutated_predictions = mutagenesis_predictions,
original_prediction = current_prediction,
target = target,
**kwargs
)

sequence_onehot = mutagenesis[best_mutation : best_mutation + 1]
Expand Down
27 changes: 26 additions & 1 deletion src/crested/tl/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, Callable

import numpy as np


Expand Down Expand Up @@ -92,9 +94,32 @@ def generate_motif_insertions(x, motif, flanks=(0, 0), masked_locations=None):

return np.concatenate(x_mut, axis=0), insertion_locations

class EnhancerOptimizer:
def __init__(
self,
optimize_func: Callable[..., np.intp]
) -> None:
self.optimize_func = optimize_func

def get_best(
self,
mutated_predictions: np.ndarray,
original_prediction: np.ndarray,
target: int | list[int],
**kwargs: dict[str, Any]
) -> np.intp:
return self.optimize_func(
mutated_predictions,
original_prediction,
target,
**kwargs
)

def _weighted_difference(
mutated_predictions, original_prediction, target, class_penalty_weights=None
mutated_predictions: np.ndarray,
original_prediction: np.ndarray,
target: int,
class_penalty_weights: np.ndarray | None = None
):
n_classes = original_prediction.shape[1]
penalty_factor = 1 / n_classes
Expand Down

0 comments on commit 0e43b43

Please sign in to comment.