From 95cceb5a61dee6963845e1144778b8ccce9e9379 Mon Sep 17 00:00:00 2001 From: mtvector Date: Thu, 2 Jan 2025 11:28:26 -0800 Subject: [PATCH 1/2] first round of changes --- .../.ipynb_checkpoints/_io-checkpoint.py | 650 ++++++++++++++++++ src/crested/_io.py | 20 +- .../_anndatamodule-checkpoint.py | 451 ++++++++++++ .../_dataloader-checkpoint.py | 143 ++++ .../.ipynb_checkpoints/_dataset-checkpoint.py | 568 +++++++++++++++ src/crested/tl/data/_anndatamodule.py | 316 +++++++-- src/crested/tl/data/_dataloader.py | 80 ++- src/crested/tl/data/_dataset.py | 199 +++++- 8 files changed, 2336 insertions(+), 91 deletions(-) create mode 100644 src/crested/.ipynb_checkpoints/_io-checkpoint.py create mode 100644 src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py create mode 100644 src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py create mode 100644 src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py diff --git a/src/crested/.ipynb_checkpoints/_io-checkpoint.py b/src/crested/.ipynb_checkpoints/_io-checkpoint.py new file mode 100644 index 0000000..431f1ac --- /dev/null +++ b/src/crested/.ipynb_checkpoints/_io-checkpoint.py @@ -0,0 +1,650 @@ +"""I/O functions for importing beds and bigWigs into AnnData objects.""" + +from __future__ import annotations + +import os +import re +import tempfile +from concurrent.futures import ProcessPoolExecutor +from os import PathLike +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import pybigtools +from anndata import AnnData +from loguru import logger +from scipy.sparse import csr_matrix + +from crested import _conf as conf + + +def _sort_files(filename: PathLike): + """Sorts files. + + Prioritizes numeric extraction from filenames of the format 'Class_X.bed' (X=int). + Other filenames are sorted alphabetically, with 'Class_' files coming last if numeric extraction fails. + """ + filename = Path(filename) + parts = filename.stem.split("_") + + if len(parts) > 1: + try: + return (False, int(parts[1])) + except ValueError: + # If the numeric part is not an integer, handle gracefully + return (True, filename.stem) + + # Return True for the first element to sort non-'Class_X' filenames alphabetically after 'Class_X' + return ( + True, + filename.stem, + ) + + +def _custom_region_sort(region: str) -> tuple[int, int, int]: + """Sort regions in the format chr:start-end.""" + chrom, pos = region.split(":") + start, _ = map(int, pos.split("-")) + + # check if the chromosome part contains digits + numeric_match = re.match(r"chr(\d+)|chrom(\d+)", chrom, re.IGNORECASE) + + if numeric_match: + chrom_num = int(numeric_match.group(1) or numeric_match.group(2)) + return (0, chrom_num, start) + else: + return (1, chrom, start) + + +def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]: + """Read chromsizes file into a dictionary.""" + chromsizes = pd.read_csv( + chromsizes_file, sep="\t", header=None, names=["chr", "size"] + ) + chromsizes_dict = chromsizes.set_index("chr")["size"].to_dict() + return chromsizes_dict + + +def _extract_values_from_bigwig( + bw_file: PathLike, bed_file: PathLike, target: str +) -> np.ndarray: + """Extract target values from a bigWig file for regions specified in a BED file.""" + if isinstance(bed_file, Path): + bed_file = str(bed_file) + if isinstance(bw_file, Path): + bw_file = str(bw_file) + + # Get chromosomes available in bigWig file. + with pybigtools.open(bw_file, "r") as bw: + chromosomes_in_bigwig = set(bw.chroms()) + + # Create temporary BED file with only BED entries that are in the bigWig file. + temp_bed_file = tempfile.NamedTemporaryFile() + bed_entries_to_keep_idx = [] + + with open(bed_file) as fh: + for idx, line in enumerate(fh): + chrom = line.split("\t", 1)[0] + if chrom in chromosomes_in_bigwig: + temp_bed_file.file.write(line.encode("utf-8")) + bed_entries_to_keep_idx.append(idx) + # Make sure all content is written to temporary BED file. + temp_bed_file.file.flush() + + total_bed_entries = idx + 1 + bed_entries_to_keep_idx = np.array(bed_entries_to_keep_idx, np.intp) + + if target == "mean": + with pybigtools.open(bw_file, "r") as bw: + values = np.fromiter( + bw.average_over_bed(bed=temp_bed_file.name, names=None, stats="mean0"), + dtype=np.float32, + ) + elif target == "max": + with pybigtools.open(bw_file, "r") as bw: + values = np.fromiter( + bw.average_over_bed(bed=temp_bed_file.name, names=None, stats="max"), + dtype=np.float32, + ) + elif target == "raw": #Could probably use _extract_tracks_from_bigwig instead if you wanted to figure it out + with pybigtools.open(bw_file, "r") as bw: + values = np.vstack( + [ + np.array( + bw.values(chrom, int(start), int(end), missing=np.nan, exact=True) + ) + for chrom, start, end in [ + line.split("\t")[:3] + for line in open(temp_bed_file.name).readlines() + ] + ] + ) + + elif target == "count": + with pybigtools.open(bw_file, "r") as bw: + values = np.fromiter( + bw.average_over_bed(bed=temp_bed_file.name, names=None, stats="sum"), + dtype=np.float32, + ) + elif target == "logcount": + with pybigtools.open(bw_file, "r") as bw: + values = np.log1p( + np.fromiter( + bw.average_over_bed( + bed=temp_bed_file.name, names=None, stats="sum" + ), + dtype=np.float32, + ) + ) + else: + raise ValueError(f"Unsupported target '{target}'") + + # Remove temporary BED file. + temp_bed_file.close() + + if values.shape[0] != total_bed_entries: + # Set all values for BED entries for which the chromosome was not in in the bigWig file to NaN. + all_values = np.full(total_bed_entries, np.nan, dtype=np.float32) + all_values[bed_entries_to_keep_idx] = values + return all_values + else: + return values + + +def _extract_tracks_from_bigwig( + bw_file: PathLike, + coordinates: list[tuple[str, int, int]], + bin_size: int | None = None, + target: str = "mean", + missing: float = 0.0, + oob: float = 0.0, + exact: bool = True, +) -> np.ndarray: + """ + Extract per-base or binned pair values of a list of genomic ranges from a bigWig file. + + Expects all coordinate pairs to be the same length. + + bigwig_file + Path to the bigWig file. + coordinates + A list of tuples looking like (chr, start, end). + bin_size + If set, the returned values are mean-binned at this resolution. + target + How to summarize the values per bin, when binning. Can be 'mean', 'min', or 'max'. + missing + Fill-in value for unreported data in valid regions. Default is 0. + oob + Fill-in value for out-of-bounds regions. + exact + Whether to always return the exact values, or to use the built-in zoom levels to interpolate, when binning. + Setting exact = False leads to a slight speed advantage, but slight loss in accuracy. + + Returns a numpy array of values from the bigWig file of shape [n_coordinates, n_base_pairs] or [n_coordinates, n_base_pairs//bin_size] if bin_size is set. + """ + # Wrapper around pybigtools.BBIRead.values(). + + # Check that all are same size by iterating and checking with predecessor + prev_region_length = coordinates[0][2] - coordinates[0][1] + for region in coordinates: + region_length = region[2] - region[1] + if region_length != prev_region_length: + raise ValueError( + f"All coordinate pairs should be the same length. Coordinate pair {region[0]}:{region[1]}-{region[2]} is not {prev_region_length}bp, but {region_length}bp." + ) + prev_region_length = region_length + + # Check that length is divisible by bin size + if bin_size and (region_length % bin_size != 0): + raise ValueError( + f"All region lengths must be divisible by bin_size. Region length {region_length} is not divisible by bin size {bin_size}." + ) + + # Calculate length (for array creation) and bins (for argument to bw.values) + binned_length = region_length // bin_size if bin_size else region_length + bins = region_length // bin_size if bin_size else None + + # Open the bigWig file + with pybigtools.open(bw_file, "r") as bw: + results = [] + for region in coordinates: + arr = np.empty( + binned_length, dtype="float64" + ) # pybigtools returns values in float64 + chrom, start, end = region + + # Extract values + results.append( + bw.values( + chrom, + start, + end, + bins=bins, + summary=target, + exact=exact, + missing=missing, + oob=oob, + arr=arr, + ) + ) + + return np.vstack(results) + + +def _read_consensus_regions( + regions_file: PathLike, chromsizes_file: PathLike | None = None +) -> pd.DataFrame: + """Read consensus regions BED file and filter out regions not within chromosomes.""" + if chromsizes_file is not None: + chromsizes_file = Path(chromsizes_file) + if not chromsizes_file.is_file(): + raise FileNotFoundError(f"File '{chromsizes_file}' not found") + if chromsizes_file is None and not conf.genome: + logger.warning( + "Chromsizes file not provided. Will not check if regions are within chromosomes", + stacklevel=1, + ) + consensus_peaks = pd.read_csv( + regions_file, + sep="\t", + header=None, + usecols=[0, 1, 2], + dtype={0: str, 1: "Int32", 2: "Int32"}, + ) + consensus_peaks["region"] = ( + consensus_peaks[0].astype(str) + + ":" + + consensus_peaks[1].astype(str) + + "-" + + consensus_peaks[2].astype(str) + ) + if chromsizes_file: + chromsizes_dict = _read_chromsizes(chromsizes_file) + elif conf.genome: + chromsizes_dict = conf.genome.chrom_sizes + else: + return consensus_peaks + + valid_mask = consensus_peaks.apply( + lambda row: row[0] in chromsizes_dict + and row[1] >= 0 + and row[2] <= chromsizes_dict[row[0]], + axis=1, + ) + consensus_peaks_filtered = consensus_peaks[valid_mask] + + if len(consensus_peaks) != len(consensus_peaks_filtered): + logger.warning( + f"Filtered {len(consensus_peaks) - len(consensus_peaks_filtered)} consensus regions (not within chromosomes)", + ) + return consensus_peaks_filtered + + +def _create_temp_bed_file( + consensus_peaks: pd.DataFrame, target_region_width: int +) -> str: + """Adjust consensus regions to a target width and create a temporary BED file.""" + adjusted_peaks = consensus_peaks.copy() + adjusted_peaks[1] = adjusted_peaks.apply( + lambda row: max(0, row[1] - (target_region_width - (row[2] - row[1])) // 2), + axis=1, + ) + adjusted_peaks[2] = adjusted_peaks[1] + target_region_width + adjusted_peaks[1] = adjusted_peaks[1].astype(int) + adjusted_peaks[2] = adjusted_peaks[2].astype(int) + + # Create a temporary BED file + temp_bed_file = "temp_adjusted_regions.bed" + adjusted_peaks.to_csv(temp_bed_file, sep="\t", header=False, index=False) + return temp_bed_file + + +def _check_bed_file_format(bed_file: PathLike) -> None: + """Check if the BED file is in the correct format.""" + with open(bed_file) as f: + first_line = f.readline().strip() + # check if at least three columns are found + if len(first_line.split("\t")) < 3: + raise ValueError( + f"BED file '{bed_file}' is not in the correct format. " + "Expected at least three tab-seperated columns." + ) + pattern = r".*\t\d+\t\d+.*" + if not re.match(pattern, first_line): + raise ValueError( + f"BED file '{bed_file}' is not in the correct format. " + "Expected columns 2 and 3 to contain integers." + ) + + +def import_beds( + beds_folder: PathLike, + regions_file: PathLike | None = None, + chromsizes_file: PathLike | None = None, + classes_subset: list | None = None, + remove_empty_regions: bool = True, + compress: bool = False, +) -> AnnData: + """ + Import beds and optionally consensus regions BED files into AnnData format. + + Expects the folder with BED files where each file is named {class_name}.bed + The result is an AnnData object with classes as rows and the regions as columns, + with the .X values indicating whether a region is open in a class. + + Note + ---- + This is the default function to import topic BED files coming from running pycisTopic + (https://pycistopic.readthedocs.io/en/latest/) on your data. + The result is an AnnData object with topics as rows and consensus region as columns, + with binary values indicating whether a region is present in a topic. + + Parameters + ---------- + beds_folder + Folder path containing the BED files. + regions_file + File path of the consensus regions BED file to use as columns in the AnnData object. + If None, the regions will be extracted from the files. + classes_subset + List of classes to include in the AnnData object. If None, all files + will be included. + Classes should be named after the file name without the extension. + chromsizes_file + File path of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries. + If not provided, will look for a registered genome object. + remove_empty_regions + Remove regions that are not open in any class (only possible if regions_file is provided) + compress + Compress the AnnData.X matrix. If True, the matrix will be stored as + a sparse matrix. If False, the matrix will be stored as a dense matrix. + + WARNING: Compressing the matrix currently makes training very slow and is never recommended. + We're still investigating a way around. + + Returns + ------- + AnnData object with classes as rows and peaks as columns. + + Example + ------- + >>> anndata = crested.import_beds( + ... beds_folder="path/to/beds/folder/", + ... regions_file="path/to/regions.bed", + ... chromsizes_file="path/to/chrom.sizes", + ... classes_subset=["Topic_1", "Topic_2"], + ... ) + """ + beds_folder = Path(beds_folder) + regions_file = Path(regions_file) if regions_file else None + + # Input checks + if not beds_folder.is_dir(): + raise FileNotFoundError(f"Directory '{beds_folder}' not found") + if (regions_file is not None) and (not regions_file.is_file()): + raise FileNotFoundError(f"File '{regions_file}' not found") + if classes_subset is not None: + for classname in classes_subset: + if not any(beds_folder.glob(f"{classname}.bed")): + raise FileNotFoundError(f"'{classname}' not found in '{beds_folder}'") + + if regions_file: + # Read consensus regions BED file and filter out regions not within chromosomes + _check_bed_file_format(regions_file) + consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) + + binary_matrix = pd.DataFrame(0, index=[], columns=consensus_peaks["region"]) + file_paths = [] + + # Which regions are present in the consensus regions + logger.info( + f"Reading bed files from {beds_folder} and using {regions_file} as var_names..." + ) + for file in sorted(beds_folder.glob("*.bed"), key=_sort_files): + class_name = file.stem + if classes_subset is None or class_name in classes_subset: + class_regions = pd.read_csv( + file, sep="\t", header=None, usecols=[0, 1, 2] + ) + class_regions["region"] = ( + class_regions[0].astype(str) + + ":" + + class_regions[1].astype(str) + + "-" + + class_regions[2].astype(str) + ) + + # Create binary row for the current topic + binary_row = binary_matrix.columns.isin(class_regions["region"]).astype( + int + ) + binary_matrix.loc[class_name] = binary_row + file_paths.append(str(file)) + + # else, get regions from the bed files + else: + file_paths = [] + all_regions = set() + + # Collect all regions from the BED files + logger.info( + f"Reading bed files from {beds_folder} without consensus regions..." + ) + for file in sorted(beds_folder.glob("*.bed"), key=_sort_files): + class_name = file.stem + if classes_subset is None or class_name in classes_subset: + _check_bed_file_format(file) + class_regions = pd.read_csv( + file, sep="\t", header=None, usecols=[0, 1, 2] + ) + class_regions["region"] = ( + class_regions[0].astype(str) + + ":" + + class_regions[1].astype(str) + + "-" + + class_regions[2].astype(str) + ) + all_regions.update(class_regions["region"].tolist()) + file_paths.append(str(file)) + + # Convert set to sorted list + all_regions = sorted(all_regions, key=_custom_region_sort) + binary_matrix = pd.DataFrame(0, index=[], columns=all_regions) + + # Populate the binary matrix + for file in file_paths: + class_name = Path(file).stem + class_regions = pd.read_csv(file, sep="\t", header=None, usecols=[0, 1, 2]) + class_regions["region"] = ( + class_regions[0].astype(str) + + ":" + + class_regions[1].astype(str) + + "-" + + class_regions[2].astype(str) + ) + binary_row = binary_matrix.columns.isin(class_regions["region"]).astype(int) + binary_matrix.loc[class_name] = binary_row + + ann_data = AnnData( + binary_matrix, + ) + + ann_data.obs["file_path"] = file_paths + ann_data.obs["n_open_regions"] = ann_data.X.sum(axis=1) + ann_data.var["n_classes"] = ann_data.X.sum(axis=0) + ann_data.var["chr"] = ann_data.var.index.str.split(":").str[0] + ann_data.var["start"] = ( + ann_data.var.index.str.split(":").str[1].str.split("-").str[0] + ).astype(int) + ann_data.var["end"] = ( + ann_data.var.index.str.split(":").str[1].str.split("-").str[1] + ).astype(int) + + if compress: + ann_data.X = csr_matrix(ann_data.X) + + # Output checks + classes_no_open_regions = ann_data.obs[ann_data.obs["n_open_regions"] == 0] + if not classes_no_open_regions.empty: + raise ValueError( + f"{classes_no_open_regions.index} have 0 open regions in the consensus peaks" + ) + regions_no_classes = ann_data.var[ann_data.var["n_classes"] == 0] + if not regions_no_classes.empty: + if remove_empty_regions: + logger.warning( + f"{len(regions_no_classes.index)} consensus regions are not open in any class. Removing them from the AnnData object. Disable this behavior by setting 'remove_empty_regions=False'", + ) + ann_data = ann_data[:, ann_data.var["n_classes"] > 0] + + return ann_data + + +def import_bigwigs( + bigwigs_folder: PathLike, + regions_file: PathLike, + chromsizes_file: PathLike | None = None, + target: str = "mean", + target_region_width: int | None = None, + compress: bool = False, +) -> AnnData: + """ + Import bigWig files and consensus regions BED file into AnnData format. + + This format is required to be able to train a peak prediction model. + The bigWig files target values are calculated for each region and and imported into an AnnData object, + with the bigWig file names as .obs and the consensus regions as .var. + Optionally, the target region width can be specified to extract values from a wider/narrower region around the consensus region, + where the original region will still be used as the index. + This is often useful to extract sequence information around the actual peak region. + + Parameters + ---------- + bigwigs_folder + Folder name containing the bigWig files. + regions_file + File name of the consensus regions BED file. + chromsizes_file + File name of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries. + If not provided, will look for a registered genome object. + target + Target value to extract from bigwigs. Can be 'raw', 'mean', 'max', 'count', or 'logcount' + target_region_width + Width of region that the bigWig target value will be extracted from. If None, the + consensus region width will be used. + compress + Compress the AnnData.X matrix. If True, the matrix will be stored as + a sparse matrix. If False, the matrix will be stored as a dense matrix. + + Returns + ------- + AnnData object with bigWigs as rows and peaks as columns. + + Example + ------- + >>> anndata = crested.import_bigwigs( + ... bigwigs_folder="path/to/bigwigs", + ... regions_file="path/to/peaks.bed", + ... chromsizes_file="path/to/chrom.sizes", + ... target="max", + ... target_region_width=500, + ... ) + """ + bigwigs_folder = Path(bigwigs_folder) + regions_file = Path(regions_file) + + # Input checks + if not bigwigs_folder.is_dir(): + raise FileNotFoundError(f"Directory '{bigwigs_folder}' not found") + if not regions_file.is_file(): + raise FileNotFoundError(f"File '{regions_file}' not found") + + # Read consensus regions BED file and filter out regions not within chromosomes + _check_bed_file_format(regions_file) + consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) + + if target_region_width is not None: + bed_file = _create_temp_bed_file(consensus_peaks, target_region_width) + else: + bed_file = regions_file + + bw_files = [] + for file in os.listdir(bigwigs_folder): + file_path = os.path.join(bigwigs_folder, file) + try: + # Validate using pyBigTools (add specific validation if available) + bw = pybigtools.open(file_path, "r") + bw_files.append(file_path) + bw.close() + except ValueError: + pass + except pybigtools.BBIReadError: + pass + + bw_files = sorted(bw_files) + if not bw_files: + raise FileNotFoundError(f"No valid bigWig files found in '{bigwigs_folder}'") + + # Process bigWig files in parallel and collect the results + logger.info(f"Extracting values from {len(bw_files)} bigWig files...") + all_results = [] + with ProcessPoolExecutor() as executor: + futures = [ + executor.submit( + _extract_values_from_bigwig, + bw_file, + bed_file, + target, + ) + for bw_file in bw_files + ] + for future in futures: + all_results.append(future.result()) + + # for bw_file in bw_files: + # result = _extract_values_from_bigwig(bw_file, bed_file, target=target) + # all_results.append(result) + + if target_region_width is not None: + os.remove(bed_file) + + data_matrix = np.vstack(all_results) + + # Prepare obs and var for AnnData + obs_df = pd.DataFrame( + data={"file_path": bw_files}, + index=[ + os.path.basename(file).rpartition(".")[0].replace(".", "_") + for file in bw_files + ], + ) + var_df = pd.DataFrame( + { + "region": consensus_peaks["region"], + "chr": consensus_peaks["region"].str.split(":").str[0], + "start": ( + consensus_peaks["region"].str.split(":").str[1].str.split("-").str[0] + ).astype(int), + "end": ( + consensus_peaks["region"].str.split(":").str[1].str.split("-").str[1] + ).astype(int), + } + ).set_index("region") + + # Create AnnData object + adata = ad.AnnData(data_matrix, obs=obs_df, var=var_df) + + if compress: + adata.X = csr_matrix(adata.X) + + # Output checks + regions_no_values = adata.var[adata.X.sum(axis=0) == 0] + if not regions_no_values.empty: + logger.warning( + f"{len(regions_no_values.index)} consensus regions have no values in any bigWig file", + ) + + return adata diff --git a/src/crested/_io.py b/src/crested/_io.py index 0ffb249..431f1ac 100644 --- a/src/crested/_io.py +++ b/src/crested/_io.py @@ -108,6 +108,20 @@ def _extract_values_from_bigwig( bw.average_over_bed(bed=temp_bed_file.name, names=None, stats="max"), dtype=np.float32, ) + elif target == "raw": #Could probably use _extract_tracks_from_bigwig instead if you wanted to figure it out + with pybigtools.open(bw_file, "r") as bw: + values = np.vstack( + [ + np.array( + bw.values(chrom, int(start), int(end), missing=np.nan, exact=True) + ) + for chrom, start, end in [ + line.split("\t")[:3] + for line in open(temp_bed_file.name).readlines() + ] + ] + ) + elif target == "count": with pybigtools.open(bw_file, "r") as bw: values = np.fromiter( @@ -517,7 +531,7 @@ def import_bigwigs( File name of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries. If not provided, will look for a registered genome object. target - Target value to extract from bigwigs. Can be 'mean', 'max', 'count', or 'logcount' + Target value to extract from bigwigs. Can be 'raw', 'mean', 'max', 'count', or 'logcount' target_region_width Width of region that the bigWig target value will be extracted from. If None, the consensus region width will be used. @@ -590,6 +604,10 @@ def import_bigwigs( for future in futures: all_results.append(future.result()) + # for bw_file in bw_files: + # result = _extract_values_from_bigwig(bw_file, bed_file, target=target) + # all_results.append(result) + if target_region_width is not None: os.remove(bed_file) diff --git a/src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py b/src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py new file mode 100644 index 0000000..4747ef7 --- /dev/null +++ b/src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py @@ -0,0 +1,451 @@ +"""Anndatamodule which acts as a wrapper around AnnDataset and AnnDataLoader.""" + +from __future__ import annotations + +from os import PathLike +from torch.utils.data import Sampler +import numpy as np + +from crested._genome import Genome, _resolve_genome + +from ._dataloader import AnnDataLoader +from ._dataset import AnnDataset + + +class AnnDataModule: + """ + DataModule class which defines how dataloaders should be loaded in each stage. + + Required input for the `tl.Crested` class. + + Note + ---- + Expects a `split` column in the `.var` DataFrame of the AnnData object. + Run `pp.train_val_test_split` first to add the `split` column to the AnnData object if not yet done. + + Example + ------- + >>> data_module = AnnDataModule( + ... adata, + ... genome=my_genome, + ... always_reverse_complement=True, + ... max_stochastic_shift=50, + ... batch_size=256, + ... ) + + Parameters + ---------- + adata + An instance of AnnData containing the data to be loaded. + genome + Instance of Genome or Path to the fasta file. + If None, will look for a registered genome object. + chromsizes_file + Path to the chromsizes file. Not required if genome is a Genome object. + If genome is a path and chromsizes is not provided, will deduce the chromsizes from the fasta file. + in_memory + If True, the train and val sequences will be loaded into memory. Default is True. + always_reverse_complement + If True, all sequences will be augmented with their reverse complement during training. + Effectively increases the training dataset size by a factor of 2. Default is True. + random_reverse_complement + If True, the sequences will be randomly reverse complemented during training. Default is False. + max_stochastic_shift + Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. Default is 0. + deterministic_shift + If true, each region will be shifted twice with stride 50bp to each side. Default is False. + This is our legacy shifting, we recommend using max_stochastic_shift instead. + shuffle + If True, the data will be shuffled at the end of each epoch during training. Default is True. + batch_size + Number of samples per batch to load. Default is 256. + """ + + def __init__( + self, + adata, + genome: PathLike | Genome | None = None, + chromsizes_file: PathLike | None = None, + in_memory: bool = True, + always_reverse_complement: bool = True, + random_reverse_complement: bool = False, + max_stochastic_shift: int = 0, + deterministic_shift: bool = False, + shuffle: bool = True, + batch_size: int = 256, + obs_columns: list[str] | None = None, + obsm_keys: list[str] | None = None, + varp_keys: list[str] | None = None, + ): + """Initialize the DataModule with the provided dataset and options.""" + self.adata = adata + self.genome = _resolve_genome(genome, chromsizes_file) # Function assumed available + self.always_reverse_complement = always_reverse_complement + self.in_memory = in_memory + self.random_reverse_complement = random_reverse_complement + self.max_stochastic_shift = max_stochastic_shift + self.deterministic_shift = deterministic_shift + self.shuffle = shuffle + self.batch_size = batch_size + self.obs_columns = obs_columns + self.obsm_keys = obsm_keys + self.varp_keys = varp_keys + + self._validate_init_args(random_reverse_complement, always_reverse_complement) + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.predict_dataset = None + + @staticmethod + def _validate_init_args( + random_reverse_complement: bool, always_reverse_complement: bool + ): + if random_reverse_complement and always_reverse_complement: + raise ValueError( + "Only one of `random_reverse_complement` and `always_reverse_complement` can be True." + ) + + def setup(self, stage: str) -> None: + """ + Set up the Anndatasets for a given stage. + + Generates the train, val, test or predict dataset based on the provided stage. + Should always be called before accessing the dataloaders. + Generally, you don't need to call this directly, as this is called inside the `tl.Crested` trainer class. + + Parameters + ---------- + stage + Stage for which to setup the dataloader. Either 'fit', 'test' or 'predict'. + """ + args = { + "anndata": self.adata, + "genome": self.genome, + "in_memory": self.in_memory, + "always_reverse_complement": self.always_reverse_complement, + "random_reverse_complement": self.random_reverse_complement, + "max_stochastic_shift": self.max_stochastic_shift, + "deterministic_shift": self.deterministic_shift, + "obs_columns": self.obs_columns, + "obsm_keys": self.obsm_keys, + "varp_keys": self.varp_keys, + } + if stage == "fit": + # Training dataset + train_args = args.copy() + train_args["split"] = "train" + + val_args = args.copy() + val_args["split"] = "val" + val_args["always_reverse_complement"] = False + val_args["random_reverse_complement"] = False + val_args["max_stochastic_shift"] = 0 + + self.train_dataset = AnnDataset(**train_args) + self.val_dataset = AnnDataset(**val_args) + + elif stage == "test": + test_args = args.copy() + test_args["split"] = "test" + test_args["in_memory"] = False + test_args["always_reverse_complement"] = False + test_args["random_reverse_complement"] = False + test_args["max_stochastic_shift"] = 0 + + self.test_dataset = AnnDataset(**test_args) + + elif stage == "predict": + predict_args = args.copy() + predict_args["split"] = None + predict_args["in_memory"] = False + predict_args["always_reverse_complement"] = False + predict_args["random_reverse_complement"] = False + predict_args["max_stochastic_shift"] = 0 + + self.predict_dataset = AnnDataset(**predict_args) + + else: + raise ValueError(f"Invalid stage: {stage}") + + + @property + def train_dataloader(self): + """:obj:`crested.tl.data.AnnDataLoader`: Training dataloader.""" + if self.train_dataset is None: + raise ValueError("train_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_remainder=False, + ) + + @property + def val_dataloader(self): + """:obj:`crested.tl.data.AnnDataLoader`: Validation dataloader.""" + if self.val_dataset is None: + raise ValueError("val_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + ) + + @property + def test_dataloader(self): + """:obj:`crested.tl.data.AnnDataLoader`: Test dataloader.""" + if self.test_dataset is None: + raise ValueError("test_dataset is not set. Run setup('test') first.") + return AnnDataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + ) + + @property + def predict_dataloader(self): + """:obj:`crested.tl.data.AnnDataLoader`: Prediction dataloader.""" + if self.predict_dataset is None: + raise ValueError("predict_dataset is not set. Run setup('predict') first.") + return AnnDataLoader( + self.predict_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + ) + + def __repr__(self): + """Return a string representation of the AnndataModule.""" + return ( + f"AnndataModule(adata={self.adata}, genome={self.genome}, " + f"in_memory={self.in_memory}, " + f"always_reverse_complement={self.always_reverse_complement}, " + f"random_reverse_complement={self.random_reverse_complement}, " + f"max_stochastic_shift={self.max_stochastic_shift}, shuffle={self.shuffle}, " + f"batch_size={self.batch_size}" + ) + + +class MetaSampler(Sampler): + """ + A Sampler that yields indices in proportion to meta_dataset.global_probs. + """ + + def __init__(self, meta_dataset: MetaAnnDataset, epoch_size: int = 100_000): + """ + Parameters + ---------- + meta_dataset : MetaAnnDataset + The combined dataset with global_indices and global_probs. + epoch_size : int + How many samples we consider in one epoch of training. + """ + super().__init__(data_source=meta_dataset) + self.meta_dataset = meta_dataset + self.epoch_size = epoch_size + + # Check that sum of global_probs ~ 1.0 + s = self.meta_dataset.global_probs.sum() + if not np.isclose(s, 1.0, atol=1e-6): + raise ValueError( + "global_probs do not sum to 1 after final normalization. sum = {}".format(s) + ) + + def __iter__(self): + """ + For each epoch, yield 'epoch_size' random draws from + [0..len(meta_dataset)-1], weighted by global_probs. + """ + n = len(self.meta_dataset) + p = self.meta_dataset.global_probs + for _ in qrange(self.epoch_size): + yield np.random.choice(n, p=p) + + def __len__(self): + """ + The DataLoader uses len(sampler) to figure out how many samples per epoch. + """ + return self.epoch_size + +class MetaAnnDataModule: + """ + A DataModule for multiple AnnData objects (one per species), + merging them into a single MetaAnnDataset for each stage. + Then we rely on the MetaSampler to do globally weighted sampling. + + This replaces the old random-per-dataset approach in MultiAnnDataset. + """ + + def __init__( + self, + adatas: list[AnnData], + genomes: list[Genome], + in_memory: bool = True, + always_reverse_complement: bool = True, + random_reverse_complement: bool = False, + max_stochastic_shift: int = 0, + deterministic_shift: bool = False, + shuffle: bool = True, + batch_size: int = 256, + obs_columns: list[str] | None = None, + obsm_keys: list[str] | None = None, + varp_keys: list[str] | None = None, + epoch_size: int = 100_000, + ): + if len(adatas) != len(genomes): + raise ValueError("Must provide as many `adatas` as `genomes`.") + + self.adatas = adatas + self.genomes = genomes + self.in_memory = in_memory + self.always_reverse_complement = always_reverse_complement + self.random_reverse_complement = random_reverse_complement + self.max_stochastic_shift = max_stochastic_shift + self.deterministic_shift = deterministic_shift + self.shuffle = shuffle + self.batch_size = batch_size + self.obs_columns = obs_columns + self.obsm_keys = obsm_keys + self.varp_keys = varp_keys + self.epoch_size = epoch_size + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.predict_dataset = None + + def setup(self, stage: str) -> None: + """ + Create the AnnDataset objects for each adata+genome, then unify them + into a MetaAnnDataset for the given stage. + """ + def dataset_args(split): + return { + "in_memory": self.in_memory, + "always_reverse_complement": self.always_reverse_complement, + "random_reverse_complement": self.random_reverse_complement, + "max_stochastic_shift": self.max_stochastic_shift, + "deterministic_shift": self.deterministic_shift, + "obs_columns": self.obs_columns, + "obsm_keys": self.obsm_keys, + "varp_keys": self.varp_keys, + "split": split, + } + + if stage == "fit": + train_datasets = [] + val_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + # Training + args = dataset_args("train") + ds_train = AnnDataset(anndata=adata, genome=genome, **args) + train_datasets.append(ds_train) + + # Validation (no shifting, no RC) + val_args = dataset_args("val") + val_args["always_reverse_complement"] = False + val_args["random_reverse_complement"] = False + val_args["max_stochastic_shift"] = 0 + ds_val = AnnDataset(anndata=adata, genome=genome, **val_args) + val_datasets.append(ds_val) + + # Merge them with MetaAnnDataset + self.train_dataset = MetaAnnDataset(train_datasets) + self.val_dataset = MetaAnnDataset(val_datasets) + + elif stage == "test": + test_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + args = dataset_args("test") + args["in_memory"] = False + args["always_reverse_complement"] = False + args["random_reverse_complement"] = False + args["max_stochastic_shift"] = 0 + + ds_test = AnnDataset(anndata=adata, genome=genome, **args) + test_datasets.append(ds_test) + + self.test_dataset = MetaAnnDataset(test_datasets) + + elif stage == "predict": + predict_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + args = dataset_args(None) + args["in_memory"] = False + args["always_reverse_complement"] = False + args["random_reverse_complement"] = False + args["max_stochastic_shift"] = 0 + + ds_pred = AnnDataset(anndata=adata, genome=genome, **args) + predict_datasets.append(ds_pred) + + self.predict_dataset = MetaAnnDataset(predict_datasets) + + else: + raise ValueError(f"Invalid stage: {stage}") + + @property + def train_dataloader(self): + if self.train_dataset is None: + raise ValueError("train_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + @property + def val_dataloader(self): + if self.val_dataset is None: + raise ValueError("val_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + @property + def test_dataloader(self): + if self.test_dataset is None: + raise ValueError("test_dataset is not set. Run setup('test') first.") + return AnnDataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + @property + def predict_dataloader(self): + if self.predict_dataset is None: + raise ValueError("predict_dataset is not set. Run setup('predict') first.") + return AnnDataLoader( + self.predict_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + def __repr__(self): + return ( + f"MetaAnnDataModule(" + f"num_species={len(self.adatas)}, " + f"batch_size={self.batch_size}, shuffle={self.shuffle}, " + f"max_stochastic_shift={self.max_stochastic_shift}, " + f"random_reverse_complement={self.random_reverse_complement}, " + f"always_reverse_complement={self.always_reverse_complement}, " + f"in_memory={self.in_memory}, " + f"deterministic_shift={self.deterministic_shift}, " + f"epoch_size={self.epoch_size}" + f")" + ) diff --git a/src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py b/src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py new file mode 100644 index 0000000..e64fbf1 --- /dev/null +++ b/src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py @@ -0,0 +1,143 @@ +"""Dataloader for batching, shuffling, and one-hot encoding of AnnDataset.""" + +from __future__ import annotations + +import os +from collections import defaultdict + +if os.environ["KERAS_BACKEND"] == "torch": + import torch + from torch.utils.data import DataLoader +else: + import tensorflow as tf + +from ._dataset import AnnDataset + + +class AnnDataLoader: + """ + Pytorch-like DataLoader class for AnnDataset with options for batching, shuffling, and one-hot encoding. + + Parameters + ---------- + dataset + The dataset instance provided. + batch_size + Number of samples per batch. + shuffle + Indicates whether shuffling is enabled. + drop_remainder + Indicates whether to drop the last incomplete batch. + + Examples + -------- + >>> dataset = AnnDataset(...) # Your dataset instance + >>> batch_size = 32 + >>> dataloader = AnnDataLoader( + ... dataset, batch_size, shuffle=True, drop_remainder=True + ... ) + >>> for x, y in dataloader.data: + ... # Your training loop here + """ + def __init__( + self, + dataset, # can be AnnDataset or MetaAnnDataset + batch_size: int, + shuffle: bool = False, + drop_remainder: bool = True, + epoch_size: int = 100_000, + ): + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_remainder = drop_remainder + self.epoch_size = epoch_size + + if os.environ.get("KERAS_BACKEND", "") == "torch": + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = None + + self.sampler = None + + # Decide if we should use MetaSampler + if isinstance(dataset, MetaAnnDataset): + # This merges many AnnDataset objects, so let's use MetaSampler + self.sampler = MetaSampler(dataset, epoch_size=self.epoch_size) + else: + # Single AnnDataset => possibly fallback to WeightedRegionSampler or uniform + # We'll do uniform shuffle if asked. WeightedRegionSampler is not shown here, + # but you could do: + # if dataset.augmented_probs is not None: self.sampler = WeightedRegionSampler(...) + if self.shuffle and hasattr(self.dataset, "shuffle"): + self.dataset.shuffle = True + + def _collate_fn(self, batch): + """ + Collate function to gather list of sample-dicts into a single batched dict of tensors. + """ + x = defaultdict(list) + for sample_dict in batch: + for key, val in sample_dict.items(): + x[key].append(torch.tensor(val, dtype=torch.float32)) + + # Stack and move to device + for key in x.keys(): + x[key] = torch.stack(x[key], dim=0) + if self.device is not None: + x[key] = x[key].to(self.device) + return x + + def _create_dataset(self): + if os.environ.get("KERAS_BACKEND", "") == "torch": + if self.sampler is not None: + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.sampler, + drop_last=self.drop_remainder, + num_workers=0, + collate_fn=self._collate_fn, + ) + else: + return DataLoader( + self.dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_last=self.drop_remainder, + num_workers=0, + collate_fn=self._collate_fn, + ) + elif os.environ["KERAS_BACKEND"] == "tensorflow": + ds = tf.data.Dataset.from_generator( + self.dataset, + output_signature=( + tf.TensorSpec(shape=(self.dataset.seq_len, 4), dtype=tf.float32), + tf.TensorSpec(shape=(self.dataset.num_outputs,), dtype=tf.float32), + ), + ) + ds = ( + ds.batch(self.batch_size, drop_remainder=self.drop_remainder) + .repeat() + .prefetch(tf.data.AUTOTUNE) + ) + return ds + + @property + def data(self): + """Return the dataset as a tf.data.Dataset instance.""" + return self._create_dataset() + + def __len__(self): + """Return the number of batches in the DataLoader based on the dataset size and batch size.""" + if self.sampler is not None: + return (self.epoch_size + self.batch_size - 1) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size + + def __repr__(self): + """Return the string representation of the DataLoader.""" + return ( + f"AnnDataLoader(dataset={self.dataset}, batch_size={self.batch_size}, " + f"shuffle={self.shuffle}, drop_remainder={self.drop_remainder})" + ) diff --git a/src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py b/src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py new file mode 100644 index 0000000..a740ed4 --- /dev/null +++ b/src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py @@ -0,0 +1,568 @@ +"""Dataset class for combining genome files and AnnData objects.""" + +from __future__ import annotations + +import os +import re + +import numpy as np +from anndata import AnnData +from loguru import logger +from scipy.sparse import spmatrix +from tqdm import tqdm + +from crested._genome import Genome +from crested.utils import one_hot_encode_sequence + + +def _flip_region_strand(region: str) -> str: + """Reverse the strand of a region.""" + strand_reverser = {"+": "-", "-": "+"} + return region[:-1] + strand_reverser[region[-1]] + + +def _check_strandedness(region: str) -> bool: + """Check the strandedness of a region, raising an error if the formatting isn't recognised.""" + if re.fullmatch(r".+:\d+-\d+:[-+]", region): + return True + elif re.fullmatch(r".+:\d+-\d+", region): + return False + else: + raise ValueError( + f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)." + "If provided, strand must be + or -." + ) + + +def _deterministic_shift_region( + region: str, stride: int = 50, n_shifts: int = 2 +) -> list[str]: + """ + Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. + + This is a legacy function, it's recommended to use stochastic shifting instead. + """ + new_regions = [] + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + for i in range(-n_shifts, n_shifts + 1): + new_start = start + i * stride + new_end = end + i * stride + new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") + return new_regions + + +class SequenceLoader: + """ + Load sequences from a genome file. + + Options for reverse complementing and stochastic shifting are available. + + Parameters + ---------- + genome + Genome instance. + in_memory + If True, the sequences of supplied regions will be loaded into memory. + always_reverse_complement + If True, all sequences will be augmented with their reverse complement. + Doubles the dataset size. + max_stochastic_shift + Maximum stochastic shift (n base pairs) to apply randomly to each sequence. + regions + List of regions to load into memory. Required if in_memory is True. + """ + + def __init__( + self, + genome: Genome, + in_memory: bool = False, + always_reverse_complement: bool = False, + deterministic_shift: bool = False, + max_stochastic_shift: int = 0, + regions: list[str] | None = None, + ): + """Initialize the SequenceLoader with the provided genome file and options.""" + self.genome = genome.fasta + self.chromsizes = genome.chrom_sizes + self.in_memory = in_memory + self.always_reverse_complement = always_reverse_complement + self.deterministic_shift = deterministic_shift + self.max_stochastic_shift = max_stochastic_shift + self.sequences = {} + self.complement = str.maketrans("ACGT", "TGCA") + self.regions = regions + if self.in_memory: + self._load_sequences_into_memory(self.regions) + + def _load_sequences_into_memory(self, regions: list[str]): + """Load all sequences into memory (dict).""" + logger.info("Loading sequences into memory...") + # Check region formatting + stranded = _check_strandedness(regions[0]) + + for region in tqdm(regions): + # Make region stranded if not + if not stranded: + strand = "+" + region = f"{region}:{strand}" + if region[-4] == ":": + raise ValueError( + f"You are double-adding strand ids to your region {region}. Check if all regions are stranded or unstranded." + ) + + # Add deterministic shift regions + if self.deterministic_shift: + regions = _deterministic_shift_region(region) + else: + regions = [region] + + for region in regions: + # Parse region + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + + # Add region to self.sequences + extended_sequence = self._get_extended_sequence( + chrom, start, end, strand + ) + self.sequences[region] = extended_sequence + + # Add reverse-complemented region to self.sequences if always_reverse_complement + if self.always_reverse_complement: + self.sequences[ + _flip_region_strand(region) + ] = self._reverse_complement(extended_sequence) + + def _get_extended_sequence( + self, chrom: str, start: int, end: int, strand: str + ) -> str: + """Get sequence from genome file, extended for stochastic shifting.""" + extended_start = max(0, start - self.max_stochastic_shift) + extended_end = extended_start + (end - start) + (self.max_stochastic_shift * 2) + + if self.chromsizes and chrom in self.chromsizes: + chrom_size = self.chromsizes[chrom] + if extended_end > chrom_size: + extended_start = chrom_size - ( + end - start + self.max_stochastic_shift * 2 + ) + extended_end = chrom_size + + seq = self.genome.fetch(chrom, extended_start, extended_end).upper() + if strand == "-": + seq = self._reverse_complement(seq) + return seq + + def _reverse_complement(self, sequence: str) -> str: + """Reverse complement a sequence.""" + return sequence.translate(self.complement)[::-1] + + def get_sequence( + self, region: str, stranded: bool | None = None, shift: int = 0 + ) -> str: + """ + Get sequence for a region, strand, and shift from memory or fasta. + + If no strand is given in region or strand, assumes positive strand. + + Parameters + ---------- + region + Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand). + stranded + Whether the input data is stranded. Default (None) infers from sequence (at a computational cost). + If not stranded, positive strand is assumed. + shift: + Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism. + + Returns + ------- + The DNA sequence, as a string. + """ + if stranded is None: + stranded = _check_strandedness(region) + if not stranded: + region = f"{region}:+" + # Parse region + chrom, start_end, strand = region.split(":") + start, end = map(int, start_end.split("-")) + + # Get extended sequence + if self.in_memory: + sequence = self.sequences[region] + else: + sequence = self._get_extended_sequence(chrom, start, end, strand) + + # Extract from extended sequence + start_idx = self.max_stochastic_shift + shift + end_idx = start_idx + (end - start) + sub_sequence = sequence[start_idx:end_idx] + + # Pad with Ns if sequence is shorter than expected + if len(sub_sequence) < (end - start): + if strand == "+": + sub_sequence = sub_sequence.ljust(end - start, "N") + else: + sub_sequence = sub_sequence.rjust(end - start, "N") + + return sub_sequence + + +class IndexManager: + """ + Manage indices for the dataset. + + Augments indices with strand information if always reverse complement. + + Parameters + ---------- + indices + List of indices in format "chr:start-end" or "chr:start-end:strand". + always_reverse_complement + If True, all sequences will be augmented with their reverse complement. + deterministic_shift + If True, each region will be shifted twice with stride 50bp to each side. + """ + + def __init__( + self, + indices: list[str], + always_reverse_complement: bool, + deterministic_shift: bool = False, + ): + """Initialize the IndexManager with the provided indices.""" + self.indices = indices + self.always_reverse_complement = always_reverse_complement + self.deterministic_shift = deterministic_shift + self.augmented_indices, self.augmented_indices_map = self._augment_indices( + indices + ) + + def shuffle_indices(self): + """Shuffle indices. Managed by wrapping class AnnDataLoader.""" + np.random.shuffle(self.augmented_indices) + + def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str]]: + """Augment indices with strand information. Necessary if always reverse complement to map sequences back to targets.""" + augmented_indices = [] + augmented_indices_map = {} + for region in indices: + if not _check_strandedness( + region + ): # If slow, can use AnnDataset stranded argument - but this validates every region's formatting as well + stranded_region = f"{region}:+" + else: + stranded_region = region + + if self.deterministic_shift: + shifted_regions = _deterministic_shift_region(stranded_region) + for shifted_region in shifted_regions: + augmented_indices.append(shifted_region) + augmented_indices_map[shifted_region] = region + if self.always_reverse_complement: + augmented_indices.append(_flip_region_strand(shifted_region)) + augmented_indices_map[ + _flip_region_strand(shifted_region) + ] = region + else: + augmented_indices.append(stranded_region) + augmented_indices_map[stranded_region] = region + if self.always_reverse_complement: + augmented_indices.append(_flip_region_strand(stranded_region)) + augmented_indices_map[_flip_region_strand(stranded_region)] = region + return augmented_indices, augmented_indices_map + + + +class AnnDataset(BaseClass): + """ + Dataset class for combining genome files and AnnData objects. + + Called by the by the AnnDataModule class. + + Parameters + ---------- + anndata + AnnData object containing the data. + genome + Genome instance + split + 'train', 'val', or 'test' split column in anndata.var. + in_memory + If True, the train and val sequences will be loaded into memory. + random_reverse_complement + If True, the sequences will be randomly reverse complemented during training. + always_reverse_complement + If True, all sequences will be augmented with their reverse complement during training. + max_stochastic_shift + Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. + deterministic_shift + If true, each region will be shifted twice with stride 50bp to each side. + This is our legacy shifting, we recommend using max_stochastic_shift instead. + obs_columns + Columns in obs that will be added to the dataset. + obsm_columns + Keys in obsm that will be added to the dataset. + varp_columns + Keys in varp that will be added to the dataset. + """ + + def __init__( + self, + anndata: AnnData, + genome: Genome, + split: str = None, + in_memory: bool = True, + random_reverse_complement: bool = False, + always_reverse_complement: bool = False, + max_stochastic_shift: int = 0, + deterministic_shift: bool = False, + obs_columns: list[str] | None = None, # multiple obs columns + obsm_keys: list[str] | None = None, # multiple obsm keys + varp_keys: list[str] | None = None, # multiple varp keys + ): + + """Initialize the dataset with the provided AnnData object and options.""" + self.anndata = self._split_anndata(anndata, split) + self.split = split + self.indices = list(self.anndata.var_names) + self.in_memory = in_memory + self.compressed = isinstance(self.anndata.X, spmatrix) + self.index_map = {index: i for i, index in enumerate(self.indices)} + self.num_outputs = self.anndata.X.shape[0] + self.random_reverse_complement = random_reverse_complement + self.max_stochastic_shift = max_stochastic_shift + self.shuffle = False # managed by wrapping class AnnDataLoader + self.obs_columns = obs_columns if obs_columns is not None else [] + self.obsm_keys = obsm_keys if obsm_keys is not None else [] + self.varp_keys = varp_keys if varp_keys is not None else [] + + # Validate and store obs data + self.obs_data = {} + for col in self.obs_columns: + if col not in anndata.obs: + raise ValueError(f"obs column '{col}' not found.") + # Convert categorical to integer codes if needed + if pd.api.types.is_categorical_dtype(anndata.obs[col]): + self.obs_data[col] = anndata.obs[col].cat.codes.values + else: + self.obs_data[col] = anndata.obs[col].values + + # Validate and store obsm data + self.obsm_data = {} + for key in self.obsm_keys: + if key not in anndata.obsm: + raise ValueError(f"obsm key '{key}' not found.") + mat = anndata.obsm[key] + if mat.shape[0] != anndata.n_obs: + raise ValueError(f"Dimension mismatch for obsm key '{key}'.") + self.obsm_data[key] = mat + + # Validate and store varp data + self.varp_data = {} + for key in self.varp_keys: + if key not in anndata.varp: + raise ValueError(f"varp key '{key}' not found.") + mat = anndata.varp[key] + if mat.shape[0] != anndata.n_var: + raise ValueError(f"Dimension mismatch for varp key '{key}'.") + self.varp_data[key] = mat + + + # Check region formatting + stranded = _check_strandedness(self.indices[0]) + if stranded and (always_reverse_complement or random_reverse_complement): + logger.info( + "Setting always_reverse_complement=True or random_reverse_complement=True with stranded data.", + "This means both strands are used when training and the strand information is effectively disregarded.", + ) + + self.sequence_loader = SequenceLoader( + genome, + in_memory=in_memory, + always_reverse_complement=always_reverse_complement, + deterministic_shift=deterministic_shift, + max_stochastic_shift=max_stochastic_shift, + regions=self.indices, + ) + self.index_manager = IndexManager( + self.indices, + always_reverse_complement=always_reverse_complement, + deterministic_shift=deterministic_shift, + ) + self.seq_len = len( + self.sequence_loader.get_sequence(self.indices[0], stranded=stranded) + ) + + self.augmented_probs = None + if "sample_prob" in anndata.var.columns: + # 1) Extract raw sample_prob from adata.var + probs = anndata.var["sample_prob"].values.astype(float) + # 2) Ensure no negative values + probs = np.clip(probs, 0, None) + + # 3) For each augmented index, set unnormalized probability + self.augmented_probs = np.empty(len(self.index_manager.augmented_indices), dtype=float) + for i, aug_region in enumerate(self.index_manager.augmented_indices): + original_region = self.index_manager.augmented_indices_map[aug_region] + var_idx = self.index_map[original_region] + self.augmented_probs[i] = probs[var_idx] + else: + # If no sample_prob, we might default to 1.0 for each region + # or simply None to indicate uniform sampling + self.augmented_probs = None + + + @staticmethod + def _split_anndata(anndata: AnnData, split: str) -> AnnData: + """Return subset of anndata based on a given split column.""" + if split: + if "split" not in anndata.var.columns: + raise KeyError( + "No split column found in anndata.var. Run `pp.train_val_test_split` first." + ) + subset = ( + anndata[:, anndata.var["split"] == split].copy() + if split + else anndata.copy() + ) + return subset + + def __len__(self) -> int: + """Get number of (augmented) samples in the dataset.""" + return len(self.index_manager.augmented_indices) + + def _get_target(self, index: str) -> np.ndarray: + """Get target for a given index.""" + y_index = self.index_map[index] + return ( + self.anndata.X[:, y_index].toarray().flatten() + if self.compressed + else self.anndata.X[:, y_index] + ) + + def __getitem__(self, idx: int) -> dict: + """Return sequence and target for a given index.""" + augmented_index = self.index_manager.augmented_indices[idx] + original_index = self.index_manager.augmented_indices_map[augmented_index] + + # Get sequence and target as before + shift = 0 + if self.max_stochastic_shift > 0: + shift = np.random.randint(-self.max_stochastic_shift, self.max_stochastic_shift + 1) + + x = self.sequence_loader.get_sequence(augmented_index, stranded=True, shift=shift) + x = one_hot_encode_sequence(x, expand_dim=False) + y = self._get_target(original_index) + + # Random reverse complement if needed + if self.random_reverse_complement and np.random.rand() < 0.5: + x = self.sequence_loader._reverse_complement(x) + x = one_hot_encode_sequence(x, expand_dim=False) + + item = { + "sequence": x, + "y": y, + } + + # Add obsmp columns directly to the dictionary + for col in self.obs_columns: + item[col] = self.obs_data[col] + + for key in self.obsm_keys: + item[key] = self.obsm_data[key] + + for key in self.varp_keys: + item[key] = self.varp_data[key][idx]#.todense() + + return item + + def __call__(self): + """Call generator for the dataset.""" + for i in range(len(self)): + if i == 0: + if self.shuffle: + self.index_manager.shuffle_indices() + yield self.__getitem__(i) + + def __repr__(self) -> str: + """Get string representation of the dataset.""" + return f"AnnDataset(anndata_shape={self.anndata.shape}, n_samples={len(self)}, num_outputs={self.num_outputs}, split={self.split}, in_memory={self.in_memory})" + +class MetaAnnDataset: + """ + Combines multiple AnnDataset objects into a single dataset, + merging all their (augmented_index, probability) pairs into one global list. + + We do a final normalization across all sub-datasets so that + sample_prob from each dataset is treated as an unnormalized weight. + """ + + def __init__(self, datasets: list[AnnDataset]): + """ + Parameters + ---------- + datasets : list of AnnDataset + Each AnnDataset is for a different species or annotation set. + """ + if not datasets: + raise ValueError("No AnnDataset provided to MetaAnnDataset.") + + self.datasets = datasets + + # global_indices will store tuples of (dataset_idx, local_idx) + # global_probs will store the merged, unnormalized probabilities + self.global_indices = [] + self.global_probs = [] + + for ds_idx, ds in enumerate(datasets): + ds_len = len(ds.index_manager.augmented_indices) + if ds_len == 0: + continue + + # If the dataset has augmented_probs, we use them as unnormalized weights + # If not, fallback to 1.0 for each region + if ds.augmented_probs is not None: + for local_i in range(ds_len): + self.global_indices.append((ds_idx, local_i)) + self.global_probs.append(ds.augmented_probs[local_i]) + else: + for local_i in range(ds_len): + self.global_indices.append((ds_idx, local_i)) + self.global_probs.append(1.0) + + # Convert to numpy arrays + self.global_indices = np.array(self.global_indices, dtype=object) + self.global_probs = np.array(self.global_probs, dtype=float) + + # Normalize across the entire set + total = self.global_probs.sum() + if total > 0: + self.global_probs /= total + else: + # fallback: uniform if everything is zero + n = len(self.global_probs) + if n > 0: + self.global_probs.fill(1.0 / n) + + def __len__(self): + """ + The total number of augmented indices across all sub-datasets. + """ + return len(self.global_indices) + + def __getitem__(self, global_idx: int): + """ + A DataLoader or sampler will pass a global_idx in [0..len(self)-1]. + We map that to (dataset_idx, local_i) and call the sub-dataset's __getitem__. + """ + ds_idx, local_i = self.global_indices[global_idx] + ds_idx = int(ds_idx) + local_i = int(local_i) + return self.datasets[ds_idx][local_i] + + def __repr__(self): + return (f"MetaAnnDataset(num_datasets={len(self.datasets)}, " + f"total_augmented_indices={len(self.global_indices)})") + diff --git a/src/crested/tl/data/_anndatamodule.py b/src/crested/tl/data/_anndatamodule.py index 924af06..4747ef7 100644 --- a/src/crested/tl/data/_anndatamodule.py +++ b/src/crested/tl/data/_anndatamodule.py @@ -3,6 +3,8 @@ from __future__ import annotations from os import PathLike +from torch.utils.data import Sampler +import numpy as np from crested._genome import Genome, _resolve_genome @@ -65,16 +67,19 @@ def __init__( genome: PathLike | Genome | None = None, chromsizes_file: PathLike | None = None, in_memory: bool = True, - always_reverse_complement=True, + always_reverse_complement: bool = True, random_reverse_complement: bool = False, max_stochastic_shift: int = 0, deterministic_shift: bool = False, shuffle: bool = True, batch_size: int = 256, + obs_columns: list[str] | None = None, + obsm_keys: list[str] | None = None, + varp_keys: list[str] | None = None, ): """Initialize the DataModule with the provided dataset and options.""" self.adata = adata - self.genome = _resolve_genome(genome, chromsizes_file) # backward compatibility + self.genome = _resolve_genome(genome, chromsizes_file) # Function assumed available self.always_reverse_complement = always_reverse_complement self.in_memory = in_memory self.random_reverse_complement = random_reverse_complement @@ -82,6 +87,9 @@ def __init__( self.deterministic_shift = deterministic_shift self.shuffle = shuffle self.batch_size = batch_size + self.obs_columns = obs_columns + self.obsm_keys = obsm_keys + self.varp_keys = varp_keys self._validate_init_args(random_reverse_complement, always_reverse_complement) @@ -105,56 +113,63 @@ def setup(self, stage: str) -> None: Generates the train, val, test or predict dataset based on the provided stage. Should always be called before accessing the dataloaders. - Generally you don't need to call this directly, as this is called inside the `tl.Crested` trainer class. + Generally, you don't need to call this directly, as this is called inside the `tl.Crested` trainer class. Parameters ---------- stage Stage for which to setup the dataloader. Either 'fit', 'test' or 'predict'. """ + args = { + "anndata": self.adata, + "genome": self.genome, + "in_memory": self.in_memory, + "always_reverse_complement": self.always_reverse_complement, + "random_reverse_complement": self.random_reverse_complement, + "max_stochastic_shift": self.max_stochastic_shift, + "deterministic_shift": self.deterministic_shift, + "obs_columns": self.obs_columns, + "obsm_keys": self.obsm_keys, + "varp_keys": self.varp_keys, + } if stage == "fit": - self.train_dataset = AnnDataset( - self.adata, - self.genome, - split="train", - in_memory=self.in_memory, - always_reverse_complement=self.always_reverse_complement, - random_reverse_complement=self.random_reverse_complement, - max_stochastic_shift=self.max_stochastic_shift, - deterministic_shift=self.deterministic_shift, - ) - self.val_dataset = AnnDataset( - self.adata, - self.genome, - split="val", - in_memory=self.in_memory, - always_reverse_complement=False, - random_reverse_complement=False, - max_stochastic_shift=0, - ) + # Training dataset + train_args = args.copy() + train_args["split"] = "train" + + val_args = args.copy() + val_args["split"] = "val" + val_args["always_reverse_complement"] = False + val_args["random_reverse_complement"] = False + val_args["max_stochastic_shift"] = 0 + + self.train_dataset = AnnDataset(**train_args) + self.val_dataset = AnnDataset(**val_args) + elif stage == "test": - self.test_dataset = AnnDataset( - self.adata, - self.genome, - split="test", - in_memory=False, - always_reverse_complement=False, - random_reverse_complement=False, - max_stochastic_shift=0, - ) + test_args = args.copy() + test_args["split"] = "test" + test_args["in_memory"] = False + test_args["always_reverse_complement"] = False + test_args["random_reverse_complement"] = False + test_args["max_stochastic_shift"] = 0 + + self.test_dataset = AnnDataset(**test_args) + elif stage == "predict": - self.predict_dataset = AnnDataset( - self.adata, - self.genome, - split=None, - in_memory=False, - always_reverse_complement=False, - random_reverse_complement=False, - max_stochastic_shift=0, - ) + predict_args = args.copy() + predict_args["split"] = None + predict_args["in_memory"] = False + predict_args["always_reverse_complement"] = False + predict_args["random_reverse_complement"] = False + predict_args["max_stochastic_shift"] = 0 + + self.predict_dataset = AnnDataset(**predict_args) + else: raise ValueError(f"Invalid stage: {stage}") + @property def train_dataloader(self): """:obj:`crested.tl.data.AnnDataLoader`: Training dataloader.""" @@ -213,3 +228,224 @@ def __repr__(self): f"max_stochastic_shift={self.max_stochastic_shift}, shuffle={self.shuffle}, " f"batch_size={self.batch_size}" ) + + +class MetaSampler(Sampler): + """ + A Sampler that yields indices in proportion to meta_dataset.global_probs. + """ + + def __init__(self, meta_dataset: MetaAnnDataset, epoch_size: int = 100_000): + """ + Parameters + ---------- + meta_dataset : MetaAnnDataset + The combined dataset with global_indices and global_probs. + epoch_size : int + How many samples we consider in one epoch of training. + """ + super().__init__(data_source=meta_dataset) + self.meta_dataset = meta_dataset + self.epoch_size = epoch_size + + # Check that sum of global_probs ~ 1.0 + s = self.meta_dataset.global_probs.sum() + if not np.isclose(s, 1.0, atol=1e-6): + raise ValueError( + "global_probs do not sum to 1 after final normalization. sum = {}".format(s) + ) + + def __iter__(self): + """ + For each epoch, yield 'epoch_size' random draws from + [0..len(meta_dataset)-1], weighted by global_probs. + """ + n = len(self.meta_dataset) + p = self.meta_dataset.global_probs + for _ in qrange(self.epoch_size): + yield np.random.choice(n, p=p) + + def __len__(self): + """ + The DataLoader uses len(sampler) to figure out how many samples per epoch. + """ + return self.epoch_size + +class MetaAnnDataModule: + """ + A DataModule for multiple AnnData objects (one per species), + merging them into a single MetaAnnDataset for each stage. + Then we rely on the MetaSampler to do globally weighted sampling. + + This replaces the old random-per-dataset approach in MultiAnnDataset. + """ + + def __init__( + self, + adatas: list[AnnData], + genomes: list[Genome], + in_memory: bool = True, + always_reverse_complement: bool = True, + random_reverse_complement: bool = False, + max_stochastic_shift: int = 0, + deterministic_shift: bool = False, + shuffle: bool = True, + batch_size: int = 256, + obs_columns: list[str] | None = None, + obsm_keys: list[str] | None = None, + varp_keys: list[str] | None = None, + epoch_size: int = 100_000, + ): + if len(adatas) != len(genomes): + raise ValueError("Must provide as many `adatas` as `genomes`.") + + self.adatas = adatas + self.genomes = genomes + self.in_memory = in_memory + self.always_reverse_complement = always_reverse_complement + self.random_reverse_complement = random_reverse_complement + self.max_stochastic_shift = max_stochastic_shift + self.deterministic_shift = deterministic_shift + self.shuffle = shuffle + self.batch_size = batch_size + self.obs_columns = obs_columns + self.obsm_keys = obsm_keys + self.varp_keys = varp_keys + self.epoch_size = epoch_size + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.predict_dataset = None + + def setup(self, stage: str) -> None: + """ + Create the AnnDataset objects for each adata+genome, then unify them + into a MetaAnnDataset for the given stage. + """ + def dataset_args(split): + return { + "in_memory": self.in_memory, + "always_reverse_complement": self.always_reverse_complement, + "random_reverse_complement": self.random_reverse_complement, + "max_stochastic_shift": self.max_stochastic_shift, + "deterministic_shift": self.deterministic_shift, + "obs_columns": self.obs_columns, + "obsm_keys": self.obsm_keys, + "varp_keys": self.varp_keys, + "split": split, + } + + if stage == "fit": + train_datasets = [] + val_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + # Training + args = dataset_args("train") + ds_train = AnnDataset(anndata=adata, genome=genome, **args) + train_datasets.append(ds_train) + + # Validation (no shifting, no RC) + val_args = dataset_args("val") + val_args["always_reverse_complement"] = False + val_args["random_reverse_complement"] = False + val_args["max_stochastic_shift"] = 0 + ds_val = AnnDataset(anndata=adata, genome=genome, **val_args) + val_datasets.append(ds_val) + + # Merge them with MetaAnnDataset + self.train_dataset = MetaAnnDataset(train_datasets) + self.val_dataset = MetaAnnDataset(val_datasets) + + elif stage == "test": + test_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + args = dataset_args("test") + args["in_memory"] = False + args["always_reverse_complement"] = False + args["random_reverse_complement"] = False + args["max_stochastic_shift"] = 0 + + ds_test = AnnDataset(anndata=adata, genome=genome, **args) + test_datasets.append(ds_test) + + self.test_dataset = MetaAnnDataset(test_datasets) + + elif stage == "predict": + predict_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + args = dataset_args(None) + args["in_memory"] = False + args["always_reverse_complement"] = False + args["random_reverse_complement"] = False + args["max_stochastic_shift"] = 0 + + ds_pred = AnnDataset(anndata=adata, genome=genome, **args) + predict_datasets.append(ds_pred) + + self.predict_dataset = MetaAnnDataset(predict_datasets) + + else: + raise ValueError(f"Invalid stage: {stage}") + + @property + def train_dataloader(self): + if self.train_dataset is None: + raise ValueError("train_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + @property + def val_dataloader(self): + if self.val_dataset is None: + raise ValueError("val_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + @property + def test_dataloader(self): + if self.test_dataset is None: + raise ValueError("test_dataset is not set. Run setup('test') first.") + return AnnDataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + @property + def predict_dataloader(self): + if self.predict_dataset is None: + raise ValueError("predict_dataset is not set. Run setup('predict') first.") + return AnnDataLoader( + self.predict_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + ) + + def __repr__(self): + return ( + f"MetaAnnDataModule(" + f"num_species={len(self.adatas)}, " + f"batch_size={self.batch_size}, shuffle={self.shuffle}, " + f"max_stochastic_shift={self.max_stochastic_shift}, " + f"random_reverse_complement={self.random_reverse_complement}, " + f"always_reverse_complement={self.always_reverse_complement}, " + f"in_memory={self.in_memory}, " + f"deterministic_shift={self.deterministic_shift}, " + f"epoch_size={self.epoch_size}" + f")" + ) diff --git a/src/crested/tl/data/_dataloader.py b/src/crested/tl/data/_dataloader.py index 8a756a8..e64fbf1 100644 --- a/src/crested/tl/data/_dataloader.py +++ b/src/crested/tl/data/_dataloader.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from collections import defaultdict if os.environ["KERAS_BACKEND"] == "torch": import torch @@ -38,43 +39,75 @@ class AnnDataLoader: >>> for x, y in dataloader.data: ... # Your training loop here """ - def __init__( self, - dataset: AnnDataset, + dataset, # can be AnnDataset or MetaAnnDataset batch_size: int, shuffle: bool = False, drop_remainder: bool = True, + epoch_size: int = 100_000, ): - """Initialize the DataLoader with the provided dataset and options.""" self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_remainder = drop_remainder - if os.environ["KERAS_BACKEND"] == "torch": + self.epoch_size = epoch_size + + if os.environ.get("KERAS_BACKEND", "") == "torch": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = None - if self.shuffle: - self.dataset.shuffle = True + self.sampler = None + + # Decide if we should use MetaSampler + if isinstance(dataset, MetaAnnDataset): + # This merges many AnnDataset objects, so let's use MetaSampler + self.sampler = MetaSampler(dataset, epoch_size=self.epoch_size) + else: + # Single AnnDataset => possibly fallback to WeightedRegionSampler or uniform + # We'll do uniform shuffle if asked. WeightedRegionSampler is not shown here, + # but you could do: + # if dataset.augmented_probs is not None: self.sampler = WeightedRegionSampler(...) + if self.shuffle and hasattr(self.dataset, "shuffle"): + self.dataset.shuffle = True def _collate_fn(self, batch): - """Collate function to move tensors to the specified device if backend is torch.""" - inputs, targets = zip(*batch) - inputs = torch.stack([torch.tensor(input) for input in inputs]).to(self.device) - targets = torch.stack([torch.tensor(target) for target in targets]).to( - self.device - ) - return inputs, targets + """ + Collate function to gather list of sample-dicts into a single batched dict of tensors. + """ + x = defaultdict(list) + for sample_dict in batch: + for key, val in sample_dict.items(): + x[key].append(torch.tensor(val, dtype=torch.float32)) + + # Stack and move to device + for key in x.keys(): + x[key] = torch.stack(x[key], dim=0) + if self.device is not None: + x[key] = x[key].to(self.device) + return x def _create_dataset(self): - if os.environ["KERAS_BACKEND"] == "torch": - return DataLoader( - self.dataset, - batch_size=self.batch_size, - drop_last=self.drop_remainder, - num_workers=0, - collate_fn=self._collate_fn, - ) + if os.environ.get("KERAS_BACKEND", "") == "torch": + if self.sampler is not None: + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.sampler, + drop_last=self.drop_remainder, + num_workers=0, + collate_fn=self._collate_fn, + ) + else: + return DataLoader( + self.dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_last=self.drop_remainder, + num_workers=0, + collate_fn=self._collate_fn, + ) elif os.environ["KERAS_BACKEND"] == "tensorflow": ds = tf.data.Dataset.from_generator( self.dataset, @@ -97,7 +130,10 @@ def data(self): def __len__(self): """Return the number of batches in the DataLoader based on the dataset size and batch size.""" - return (len(self.dataset) + self.batch_size - 1) // self.batch_size + if self.sampler is not None: + return (self.epoch_size + self.batch_size - 1) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size def __repr__(self): """Return the string representation of the DataLoader.""" diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 123af48..a740ed4 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -274,13 +274,6 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str return augmented_indices, augmented_indices_map -if os.environ["KERAS_BACKEND"] == "pytorch": - import torch - - BaseClass = torch.utils.data.Dataset -else: - BaseClass = object - class AnnDataset(BaseClass): """ @@ -307,6 +300,12 @@ class AnnDataset(BaseClass): deterministic_shift If true, each region will be shifted twice with stride 50bp to each side. This is our legacy shifting, we recommend using max_stochastic_shift instead. + obs_columns + Columns in obs that will be added to the dataset. + obsm_columns + Keys in obsm that will be added to the dataset. + varp_columns + Keys in varp that will be added to the dataset. """ def __init__( @@ -319,7 +318,11 @@ def __init__( always_reverse_complement: bool = False, max_stochastic_shift: int = 0, deterministic_shift: bool = False, + obs_columns: list[str] | None = None, # multiple obs columns + obsm_keys: list[str] | None = None, # multiple obsm keys + varp_keys: list[str] | None = None, # multiple varp keys ): + """Initialize the dataset with the provided AnnData object and options.""" self.anndata = self._split_anndata(anndata, split) self.split = split @@ -331,7 +334,42 @@ def __init__( self.random_reverse_complement = random_reverse_complement self.max_stochastic_shift = max_stochastic_shift self.shuffle = False # managed by wrapping class AnnDataLoader - + self.obs_columns = obs_columns if obs_columns is not None else [] + self.obsm_keys = obsm_keys if obsm_keys is not None else [] + self.varp_keys = varp_keys if varp_keys is not None else [] + + # Validate and store obs data + self.obs_data = {} + for col in self.obs_columns: + if col not in anndata.obs: + raise ValueError(f"obs column '{col}' not found.") + # Convert categorical to integer codes if needed + if pd.api.types.is_categorical_dtype(anndata.obs[col]): + self.obs_data[col] = anndata.obs[col].cat.codes.values + else: + self.obs_data[col] = anndata.obs[col].values + + # Validate and store obsm data + self.obsm_data = {} + for key in self.obsm_keys: + if key not in anndata.obsm: + raise ValueError(f"obsm key '{key}' not found.") + mat = anndata.obsm[key] + if mat.shape[0] != anndata.n_obs: + raise ValueError(f"Dimension mismatch for obsm key '{key}'.") + self.obsm_data[key] = mat + + # Validate and store varp data + self.varp_data = {} + for key in self.varp_keys: + if key not in anndata.varp: + raise ValueError(f"varp key '{key}' not found.") + mat = anndata.varp[key] + if mat.shape[0] != anndata.n_var: + raise ValueError(f"Dimension mismatch for varp key '{key}'.") + self.varp_data[key] = mat + + # Check region formatting stranded = _check_strandedness(self.indices[0]) if stranded and (always_reverse_complement or random_reverse_complement): @@ -356,7 +394,26 @@ def __init__( self.seq_len = len( self.sequence_loader.get_sequence(self.indices[0], stranded=stranded) ) + + self.augmented_probs = None + if "sample_prob" in anndata.var.columns: + # 1) Extract raw sample_prob from adata.var + probs = anndata.var["sample_prob"].values.astype(float) + # 2) Ensure no negative values + probs = np.clip(probs, 0, None) + + # 3) For each augmented index, set unnormalized probability + self.augmented_probs = np.empty(len(self.index_manager.augmented_indices), dtype=float) + for i, aug_region in enumerate(self.index_manager.augmented_indices): + original_region = self.index_manager.augmented_indices_map[aug_region] + var_idx = self.index_map[original_region] + self.augmented_probs[i] = probs[var_idx] + else: + # If no sample_prob, we might default to 1.0 for each region + # or simply None to indicate uniform sampling + self.augmented_probs = None + @staticmethod def _split_anndata(anndata: AnnData, split: str) -> AnnData: """Return subset of anndata based on a given split column.""" @@ -385,32 +442,41 @@ def _get_target(self, index: str) -> np.ndarray: else self.anndata.X[:, y_index] ) - def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: + def __getitem__(self, idx: int) -> dict: """Return sequence and target for a given index.""" augmented_index = self.index_manager.augmented_indices[idx] original_index = self.index_manager.augmented_indices_map[augmented_index] - # stochastic shift + + # Get sequence and target as before + shift = 0 if self.max_stochastic_shift > 0: - shift = np.random.randint( - -self.max_stochastic_shift, self.max_stochastic_shift + 1 - ) - else: - shift = 0 - - # Get sequence - x = self.sequence_loader.get_sequence( - augmented_index, stranded=True, shift=shift - ) - - # random reverse complement (always_reverse_complement is done in the sequence loader) - if self.random_reverse_complement and np.random.rand() < 0.5: - x = self.sequence_loader._reverse_complement(x) - - # one hot encode sequence and convert to numpy array + shift = np.random.randint(-self.max_stochastic_shift, self.max_stochastic_shift + 1) + + x = self.sequence_loader.get_sequence(augmented_index, stranded=True, shift=shift) x = one_hot_encode_sequence(x, expand_dim=False) y = self._get_target(original_index) - - return x, y + + # Random reverse complement if needed + if self.random_reverse_complement and np.random.rand() < 0.5: + x = self.sequence_loader._reverse_complement(x) + x = one_hot_encode_sequence(x, expand_dim=False) + + item = { + "sequence": x, + "y": y, + } + + # Add obsmp columns directly to the dictionary + for col in self.obs_columns: + item[col] = self.obs_data[col] + + for key in self.obsm_keys: + item[key] = self.obsm_data[key] + + for key in self.varp_keys: + item[key] = self.varp_data[key][idx]#.todense() + + return item def __call__(self): """Call generator for the dataset.""" @@ -423,3 +489,80 @@ def __call__(self): def __repr__(self) -> str: """Get string representation of the dataset.""" return f"AnnDataset(anndata_shape={self.anndata.shape}, n_samples={len(self)}, num_outputs={self.num_outputs}, split={self.split}, in_memory={self.in_memory})" + +class MetaAnnDataset: + """ + Combines multiple AnnDataset objects into a single dataset, + merging all their (augmented_index, probability) pairs into one global list. + + We do a final normalization across all sub-datasets so that + sample_prob from each dataset is treated as an unnormalized weight. + """ + + def __init__(self, datasets: list[AnnDataset]): + """ + Parameters + ---------- + datasets : list of AnnDataset + Each AnnDataset is for a different species or annotation set. + """ + if not datasets: + raise ValueError("No AnnDataset provided to MetaAnnDataset.") + + self.datasets = datasets + + # global_indices will store tuples of (dataset_idx, local_idx) + # global_probs will store the merged, unnormalized probabilities + self.global_indices = [] + self.global_probs = [] + + for ds_idx, ds in enumerate(datasets): + ds_len = len(ds.index_manager.augmented_indices) + if ds_len == 0: + continue + + # If the dataset has augmented_probs, we use them as unnormalized weights + # If not, fallback to 1.0 for each region + if ds.augmented_probs is not None: + for local_i in range(ds_len): + self.global_indices.append((ds_idx, local_i)) + self.global_probs.append(ds.augmented_probs[local_i]) + else: + for local_i in range(ds_len): + self.global_indices.append((ds_idx, local_i)) + self.global_probs.append(1.0) + + # Convert to numpy arrays + self.global_indices = np.array(self.global_indices, dtype=object) + self.global_probs = np.array(self.global_probs, dtype=float) + + # Normalize across the entire set + total = self.global_probs.sum() + if total > 0: + self.global_probs /= total + else: + # fallback: uniform if everything is zero + n = len(self.global_probs) + if n > 0: + self.global_probs.fill(1.0 / n) + + def __len__(self): + """ + The total number of augmented indices across all sub-datasets. + """ + return len(self.global_indices) + + def __getitem__(self, global_idx: int): + """ + A DataLoader or sampler will pass a global_idx in [0..len(self)-1]. + We map that to (dataset_idx, local_i) and call the sub-dataset's __getitem__. + """ + ds_idx, local_i = self.global_indices[global_idx] + ds_idx = int(ds_idx) + local_i = int(local_i) + return self.datasets[ds_idx][local_i] + + def __repr__(self): + return (f"MetaAnnDataset(num_datasets={len(self.datasets)}, " + f"total_augmented_indices={len(self.global_indices)})") + From f6b5a2cc34b0c6329a28af16b4bbe2d93f0ce86a Mon Sep 17 00:00:00 2001 From: mtvector Date: Thu, 2 Jan 2025 11:29:23 -0800 Subject: [PATCH 2/2] annoying checkpoints --- .../_anndatamodule-checkpoint.py | 451 -------------- .../_dataloader-checkpoint.py | 143 ----- .../.ipynb_checkpoints/_dataset-checkpoint.py | 568 ------------------ 3 files changed, 1162 deletions(-) delete mode 100644 src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py delete mode 100644 src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py delete mode 100644 src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py diff --git a/src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py b/src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py deleted file mode 100644 index 4747ef7..0000000 --- a/src/crested/tl/data/.ipynb_checkpoints/_anndatamodule-checkpoint.py +++ /dev/null @@ -1,451 +0,0 @@ -"""Anndatamodule which acts as a wrapper around AnnDataset and AnnDataLoader.""" - -from __future__ import annotations - -from os import PathLike -from torch.utils.data import Sampler -import numpy as np - -from crested._genome import Genome, _resolve_genome - -from ._dataloader import AnnDataLoader -from ._dataset import AnnDataset - - -class AnnDataModule: - """ - DataModule class which defines how dataloaders should be loaded in each stage. - - Required input for the `tl.Crested` class. - - Note - ---- - Expects a `split` column in the `.var` DataFrame of the AnnData object. - Run `pp.train_val_test_split` first to add the `split` column to the AnnData object if not yet done. - - Example - ------- - >>> data_module = AnnDataModule( - ... adata, - ... genome=my_genome, - ... always_reverse_complement=True, - ... max_stochastic_shift=50, - ... batch_size=256, - ... ) - - Parameters - ---------- - adata - An instance of AnnData containing the data to be loaded. - genome - Instance of Genome or Path to the fasta file. - If None, will look for a registered genome object. - chromsizes_file - Path to the chromsizes file. Not required if genome is a Genome object. - If genome is a path and chromsizes is not provided, will deduce the chromsizes from the fasta file. - in_memory - If True, the train and val sequences will be loaded into memory. Default is True. - always_reverse_complement - If True, all sequences will be augmented with their reverse complement during training. - Effectively increases the training dataset size by a factor of 2. Default is True. - random_reverse_complement - If True, the sequences will be randomly reverse complemented during training. Default is False. - max_stochastic_shift - Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. Default is 0. - deterministic_shift - If true, each region will be shifted twice with stride 50bp to each side. Default is False. - This is our legacy shifting, we recommend using max_stochastic_shift instead. - shuffle - If True, the data will be shuffled at the end of each epoch during training. Default is True. - batch_size - Number of samples per batch to load. Default is 256. - """ - - def __init__( - self, - adata, - genome: PathLike | Genome | None = None, - chromsizes_file: PathLike | None = None, - in_memory: bool = True, - always_reverse_complement: bool = True, - random_reverse_complement: bool = False, - max_stochastic_shift: int = 0, - deterministic_shift: bool = False, - shuffle: bool = True, - batch_size: int = 256, - obs_columns: list[str] | None = None, - obsm_keys: list[str] | None = None, - varp_keys: list[str] | None = None, - ): - """Initialize the DataModule with the provided dataset and options.""" - self.adata = adata - self.genome = _resolve_genome(genome, chromsizes_file) # Function assumed available - self.always_reverse_complement = always_reverse_complement - self.in_memory = in_memory - self.random_reverse_complement = random_reverse_complement - self.max_stochastic_shift = max_stochastic_shift - self.deterministic_shift = deterministic_shift - self.shuffle = shuffle - self.batch_size = batch_size - self.obs_columns = obs_columns - self.obsm_keys = obsm_keys - self.varp_keys = varp_keys - - self._validate_init_args(random_reverse_complement, always_reverse_complement) - - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - self.predict_dataset = None - - @staticmethod - def _validate_init_args( - random_reverse_complement: bool, always_reverse_complement: bool - ): - if random_reverse_complement and always_reverse_complement: - raise ValueError( - "Only one of `random_reverse_complement` and `always_reverse_complement` can be True." - ) - - def setup(self, stage: str) -> None: - """ - Set up the Anndatasets for a given stage. - - Generates the train, val, test or predict dataset based on the provided stage. - Should always be called before accessing the dataloaders. - Generally, you don't need to call this directly, as this is called inside the `tl.Crested` trainer class. - - Parameters - ---------- - stage - Stage for which to setup the dataloader. Either 'fit', 'test' or 'predict'. - """ - args = { - "anndata": self.adata, - "genome": self.genome, - "in_memory": self.in_memory, - "always_reverse_complement": self.always_reverse_complement, - "random_reverse_complement": self.random_reverse_complement, - "max_stochastic_shift": self.max_stochastic_shift, - "deterministic_shift": self.deterministic_shift, - "obs_columns": self.obs_columns, - "obsm_keys": self.obsm_keys, - "varp_keys": self.varp_keys, - } - if stage == "fit": - # Training dataset - train_args = args.copy() - train_args["split"] = "train" - - val_args = args.copy() - val_args["split"] = "val" - val_args["always_reverse_complement"] = False - val_args["random_reverse_complement"] = False - val_args["max_stochastic_shift"] = 0 - - self.train_dataset = AnnDataset(**train_args) - self.val_dataset = AnnDataset(**val_args) - - elif stage == "test": - test_args = args.copy() - test_args["split"] = "test" - test_args["in_memory"] = False - test_args["always_reverse_complement"] = False - test_args["random_reverse_complement"] = False - test_args["max_stochastic_shift"] = 0 - - self.test_dataset = AnnDataset(**test_args) - - elif stage == "predict": - predict_args = args.copy() - predict_args["split"] = None - predict_args["in_memory"] = False - predict_args["always_reverse_complement"] = False - predict_args["random_reverse_complement"] = False - predict_args["max_stochastic_shift"] = 0 - - self.predict_dataset = AnnDataset(**predict_args) - - else: - raise ValueError(f"Invalid stage: {stage}") - - - @property - def train_dataloader(self): - """:obj:`crested.tl.data.AnnDataLoader`: Training dataloader.""" - if self.train_dataset is None: - raise ValueError("train_dataset is not set. Run setup('fit') first.") - return AnnDataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - drop_remainder=False, - ) - - @property - def val_dataloader(self): - """:obj:`crested.tl.data.AnnDataLoader`: Validation dataloader.""" - if self.val_dataset is None: - raise ValueError("val_dataset is not set. Run setup('fit') first.") - return AnnDataLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - drop_remainder=False, - ) - - @property - def test_dataloader(self): - """:obj:`crested.tl.data.AnnDataLoader`: Test dataloader.""" - if self.test_dataset is None: - raise ValueError("test_dataset is not set. Run setup('test') first.") - return AnnDataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=False, - drop_remainder=False, - ) - - @property - def predict_dataloader(self): - """:obj:`crested.tl.data.AnnDataLoader`: Prediction dataloader.""" - if self.predict_dataset is None: - raise ValueError("predict_dataset is not set. Run setup('predict') first.") - return AnnDataLoader( - self.predict_dataset, - batch_size=self.batch_size, - shuffle=False, - drop_remainder=False, - ) - - def __repr__(self): - """Return a string representation of the AnndataModule.""" - return ( - f"AnndataModule(adata={self.adata}, genome={self.genome}, " - f"in_memory={self.in_memory}, " - f"always_reverse_complement={self.always_reverse_complement}, " - f"random_reverse_complement={self.random_reverse_complement}, " - f"max_stochastic_shift={self.max_stochastic_shift}, shuffle={self.shuffle}, " - f"batch_size={self.batch_size}" - ) - - -class MetaSampler(Sampler): - """ - A Sampler that yields indices in proportion to meta_dataset.global_probs. - """ - - def __init__(self, meta_dataset: MetaAnnDataset, epoch_size: int = 100_000): - """ - Parameters - ---------- - meta_dataset : MetaAnnDataset - The combined dataset with global_indices and global_probs. - epoch_size : int - How many samples we consider in one epoch of training. - """ - super().__init__(data_source=meta_dataset) - self.meta_dataset = meta_dataset - self.epoch_size = epoch_size - - # Check that sum of global_probs ~ 1.0 - s = self.meta_dataset.global_probs.sum() - if not np.isclose(s, 1.0, atol=1e-6): - raise ValueError( - "global_probs do not sum to 1 after final normalization. sum = {}".format(s) - ) - - def __iter__(self): - """ - For each epoch, yield 'epoch_size' random draws from - [0..len(meta_dataset)-1], weighted by global_probs. - """ - n = len(self.meta_dataset) - p = self.meta_dataset.global_probs - for _ in qrange(self.epoch_size): - yield np.random.choice(n, p=p) - - def __len__(self): - """ - The DataLoader uses len(sampler) to figure out how many samples per epoch. - """ - return self.epoch_size - -class MetaAnnDataModule: - """ - A DataModule for multiple AnnData objects (one per species), - merging them into a single MetaAnnDataset for each stage. - Then we rely on the MetaSampler to do globally weighted sampling. - - This replaces the old random-per-dataset approach in MultiAnnDataset. - """ - - def __init__( - self, - adatas: list[AnnData], - genomes: list[Genome], - in_memory: bool = True, - always_reverse_complement: bool = True, - random_reverse_complement: bool = False, - max_stochastic_shift: int = 0, - deterministic_shift: bool = False, - shuffle: bool = True, - batch_size: int = 256, - obs_columns: list[str] | None = None, - obsm_keys: list[str] | None = None, - varp_keys: list[str] | None = None, - epoch_size: int = 100_000, - ): - if len(adatas) != len(genomes): - raise ValueError("Must provide as many `adatas` as `genomes`.") - - self.adatas = adatas - self.genomes = genomes - self.in_memory = in_memory - self.always_reverse_complement = always_reverse_complement - self.random_reverse_complement = random_reverse_complement - self.max_stochastic_shift = max_stochastic_shift - self.deterministic_shift = deterministic_shift - self.shuffle = shuffle - self.batch_size = batch_size - self.obs_columns = obs_columns - self.obsm_keys = obsm_keys - self.varp_keys = varp_keys - self.epoch_size = epoch_size - - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - self.predict_dataset = None - - def setup(self, stage: str) -> None: - """ - Create the AnnDataset objects for each adata+genome, then unify them - into a MetaAnnDataset for the given stage. - """ - def dataset_args(split): - return { - "in_memory": self.in_memory, - "always_reverse_complement": self.always_reverse_complement, - "random_reverse_complement": self.random_reverse_complement, - "max_stochastic_shift": self.max_stochastic_shift, - "deterministic_shift": self.deterministic_shift, - "obs_columns": self.obs_columns, - "obsm_keys": self.obsm_keys, - "varp_keys": self.varp_keys, - "split": split, - } - - if stage == "fit": - train_datasets = [] - val_datasets = [] - for adata, genome in zip(self.adatas, self.genomes): - # Training - args = dataset_args("train") - ds_train = AnnDataset(anndata=adata, genome=genome, **args) - train_datasets.append(ds_train) - - # Validation (no shifting, no RC) - val_args = dataset_args("val") - val_args["always_reverse_complement"] = False - val_args["random_reverse_complement"] = False - val_args["max_stochastic_shift"] = 0 - ds_val = AnnDataset(anndata=adata, genome=genome, **val_args) - val_datasets.append(ds_val) - - # Merge them with MetaAnnDataset - self.train_dataset = MetaAnnDataset(train_datasets) - self.val_dataset = MetaAnnDataset(val_datasets) - - elif stage == "test": - test_datasets = [] - for adata, genome in zip(self.adatas, self.genomes): - args = dataset_args("test") - args["in_memory"] = False - args["always_reverse_complement"] = False - args["random_reverse_complement"] = False - args["max_stochastic_shift"] = 0 - - ds_test = AnnDataset(anndata=adata, genome=genome, **args) - test_datasets.append(ds_test) - - self.test_dataset = MetaAnnDataset(test_datasets) - - elif stage == "predict": - predict_datasets = [] - for adata, genome in zip(self.adatas, self.genomes): - args = dataset_args(None) - args["in_memory"] = False - args["always_reverse_complement"] = False - args["random_reverse_complement"] = False - args["max_stochastic_shift"] = 0 - - ds_pred = AnnDataset(anndata=adata, genome=genome, **args) - predict_datasets.append(ds_pred) - - self.predict_dataset = MetaAnnDataset(predict_datasets) - - else: - raise ValueError(f"Invalid stage: {stage}") - - @property - def train_dataloader(self): - if self.train_dataset is None: - raise ValueError("train_dataset is not set. Run setup('fit') first.") - return AnnDataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - drop_remainder=False, - epoch_size=self.epoch_size, - ) - - @property - def val_dataloader(self): - if self.val_dataset is None: - raise ValueError("val_dataset is not set. Run setup('fit') first.") - return AnnDataLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - drop_remainder=False, - epoch_size=self.epoch_size, - ) - - @property - def test_dataloader(self): - if self.test_dataset is None: - raise ValueError("test_dataset is not set. Run setup('test') first.") - return AnnDataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=False, - drop_remainder=False, - epoch_size=self.epoch_size, - ) - - @property - def predict_dataloader(self): - if self.predict_dataset is None: - raise ValueError("predict_dataset is not set. Run setup('predict') first.") - return AnnDataLoader( - self.predict_dataset, - batch_size=self.batch_size, - shuffle=False, - drop_remainder=False, - epoch_size=self.epoch_size, - ) - - def __repr__(self): - return ( - f"MetaAnnDataModule(" - f"num_species={len(self.adatas)}, " - f"batch_size={self.batch_size}, shuffle={self.shuffle}, " - f"max_stochastic_shift={self.max_stochastic_shift}, " - f"random_reverse_complement={self.random_reverse_complement}, " - f"always_reverse_complement={self.always_reverse_complement}, " - f"in_memory={self.in_memory}, " - f"deterministic_shift={self.deterministic_shift}, " - f"epoch_size={self.epoch_size}" - f")" - ) diff --git a/src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py b/src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py deleted file mode 100644 index e64fbf1..0000000 --- a/src/crested/tl/data/.ipynb_checkpoints/_dataloader-checkpoint.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Dataloader for batching, shuffling, and one-hot encoding of AnnDataset.""" - -from __future__ import annotations - -import os -from collections import defaultdict - -if os.environ["KERAS_BACKEND"] == "torch": - import torch - from torch.utils.data import DataLoader -else: - import tensorflow as tf - -from ._dataset import AnnDataset - - -class AnnDataLoader: - """ - Pytorch-like DataLoader class for AnnDataset with options for batching, shuffling, and one-hot encoding. - - Parameters - ---------- - dataset - The dataset instance provided. - batch_size - Number of samples per batch. - shuffle - Indicates whether shuffling is enabled. - drop_remainder - Indicates whether to drop the last incomplete batch. - - Examples - -------- - >>> dataset = AnnDataset(...) # Your dataset instance - >>> batch_size = 32 - >>> dataloader = AnnDataLoader( - ... dataset, batch_size, shuffle=True, drop_remainder=True - ... ) - >>> for x, y in dataloader.data: - ... # Your training loop here - """ - def __init__( - self, - dataset, # can be AnnDataset or MetaAnnDataset - batch_size: int, - shuffle: bool = False, - drop_remainder: bool = True, - epoch_size: int = 100_000, - ): - self.dataset = dataset - self.batch_size = batch_size - self.shuffle = shuffle - self.drop_remainder = drop_remainder - self.epoch_size = epoch_size - - if os.environ.get("KERAS_BACKEND", "") == "torch": - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = None - - self.sampler = None - - # Decide if we should use MetaSampler - if isinstance(dataset, MetaAnnDataset): - # This merges many AnnDataset objects, so let's use MetaSampler - self.sampler = MetaSampler(dataset, epoch_size=self.epoch_size) - else: - # Single AnnDataset => possibly fallback to WeightedRegionSampler or uniform - # We'll do uniform shuffle if asked. WeightedRegionSampler is not shown here, - # but you could do: - # if dataset.augmented_probs is not None: self.sampler = WeightedRegionSampler(...) - if self.shuffle and hasattr(self.dataset, "shuffle"): - self.dataset.shuffle = True - - def _collate_fn(self, batch): - """ - Collate function to gather list of sample-dicts into a single batched dict of tensors. - """ - x = defaultdict(list) - for sample_dict in batch: - for key, val in sample_dict.items(): - x[key].append(torch.tensor(val, dtype=torch.float32)) - - # Stack and move to device - for key in x.keys(): - x[key] = torch.stack(x[key], dim=0) - if self.device is not None: - x[key] = x[key].to(self.device) - return x - - def _create_dataset(self): - if os.environ.get("KERAS_BACKEND", "") == "torch": - if self.sampler is not None: - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.sampler, - drop_last=self.drop_remainder, - num_workers=0, - collate_fn=self._collate_fn, - ) - else: - return DataLoader( - self.dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - drop_last=self.drop_remainder, - num_workers=0, - collate_fn=self._collate_fn, - ) - elif os.environ["KERAS_BACKEND"] == "tensorflow": - ds = tf.data.Dataset.from_generator( - self.dataset, - output_signature=( - tf.TensorSpec(shape=(self.dataset.seq_len, 4), dtype=tf.float32), - tf.TensorSpec(shape=(self.dataset.num_outputs,), dtype=tf.float32), - ), - ) - ds = ( - ds.batch(self.batch_size, drop_remainder=self.drop_remainder) - .repeat() - .prefetch(tf.data.AUTOTUNE) - ) - return ds - - @property - def data(self): - """Return the dataset as a tf.data.Dataset instance.""" - return self._create_dataset() - - def __len__(self): - """Return the number of batches in the DataLoader based on the dataset size and batch size.""" - if self.sampler is not None: - return (self.epoch_size + self.batch_size - 1) // self.batch_size - else: - return (len(self.dataset) + self.batch_size - 1) // self.batch_size - - def __repr__(self): - """Return the string representation of the DataLoader.""" - return ( - f"AnnDataLoader(dataset={self.dataset}, batch_size={self.batch_size}, " - f"shuffle={self.shuffle}, drop_remainder={self.drop_remainder})" - ) diff --git a/src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py b/src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py deleted file mode 100644 index a740ed4..0000000 --- a/src/crested/tl/data/.ipynb_checkpoints/_dataset-checkpoint.py +++ /dev/null @@ -1,568 +0,0 @@ -"""Dataset class for combining genome files and AnnData objects.""" - -from __future__ import annotations - -import os -import re - -import numpy as np -from anndata import AnnData -from loguru import logger -from scipy.sparse import spmatrix -from tqdm import tqdm - -from crested._genome import Genome -from crested.utils import one_hot_encode_sequence - - -def _flip_region_strand(region: str) -> str: - """Reverse the strand of a region.""" - strand_reverser = {"+": "-", "-": "+"} - return region[:-1] + strand_reverser[region[-1]] - - -def _check_strandedness(region: str) -> bool: - """Check the strandedness of a region, raising an error if the formatting isn't recognised.""" - if re.fullmatch(r".+:\d+-\d+:[-+]", region): - return True - elif re.fullmatch(r".+:\d+-\d+", region): - return False - else: - raise ValueError( - f"Region {region} was not recognised as a valid coordinate set (chr:start-end or chr:start-end:strand)." - "If provided, strand must be + or -." - ) - - -def _deterministic_shift_region( - region: str, stride: int = 50, n_shifts: int = 2 -) -> list[str]: - """ - Shift each region by a deterministic stride to each side. Will increase the number of regions by n_shifts times two. - - This is a legacy function, it's recommended to use stochastic shifting instead. - """ - new_regions = [] - chrom, start_end, strand = region.split(":") - start, end = map(int, start_end.split("-")) - for i in range(-n_shifts, n_shifts + 1): - new_start = start + i * stride - new_end = end + i * stride - new_regions.append(f"{chrom}:{new_start}-{new_end}:{strand}") - return new_regions - - -class SequenceLoader: - """ - Load sequences from a genome file. - - Options for reverse complementing and stochastic shifting are available. - - Parameters - ---------- - genome - Genome instance. - in_memory - If True, the sequences of supplied regions will be loaded into memory. - always_reverse_complement - If True, all sequences will be augmented with their reverse complement. - Doubles the dataset size. - max_stochastic_shift - Maximum stochastic shift (n base pairs) to apply randomly to each sequence. - regions - List of regions to load into memory. Required if in_memory is True. - """ - - def __init__( - self, - genome: Genome, - in_memory: bool = False, - always_reverse_complement: bool = False, - deterministic_shift: bool = False, - max_stochastic_shift: int = 0, - regions: list[str] | None = None, - ): - """Initialize the SequenceLoader with the provided genome file and options.""" - self.genome = genome.fasta - self.chromsizes = genome.chrom_sizes - self.in_memory = in_memory - self.always_reverse_complement = always_reverse_complement - self.deterministic_shift = deterministic_shift - self.max_stochastic_shift = max_stochastic_shift - self.sequences = {} - self.complement = str.maketrans("ACGT", "TGCA") - self.regions = regions - if self.in_memory: - self._load_sequences_into_memory(self.regions) - - def _load_sequences_into_memory(self, regions: list[str]): - """Load all sequences into memory (dict).""" - logger.info("Loading sequences into memory...") - # Check region formatting - stranded = _check_strandedness(regions[0]) - - for region in tqdm(regions): - # Make region stranded if not - if not stranded: - strand = "+" - region = f"{region}:{strand}" - if region[-4] == ":": - raise ValueError( - f"You are double-adding strand ids to your region {region}. Check if all regions are stranded or unstranded." - ) - - # Add deterministic shift regions - if self.deterministic_shift: - regions = _deterministic_shift_region(region) - else: - regions = [region] - - for region in regions: - # Parse region - chrom, start_end, strand = region.split(":") - start, end = map(int, start_end.split("-")) - - # Add region to self.sequences - extended_sequence = self._get_extended_sequence( - chrom, start, end, strand - ) - self.sequences[region] = extended_sequence - - # Add reverse-complemented region to self.sequences if always_reverse_complement - if self.always_reverse_complement: - self.sequences[ - _flip_region_strand(region) - ] = self._reverse_complement(extended_sequence) - - def _get_extended_sequence( - self, chrom: str, start: int, end: int, strand: str - ) -> str: - """Get sequence from genome file, extended for stochastic shifting.""" - extended_start = max(0, start - self.max_stochastic_shift) - extended_end = extended_start + (end - start) + (self.max_stochastic_shift * 2) - - if self.chromsizes and chrom in self.chromsizes: - chrom_size = self.chromsizes[chrom] - if extended_end > chrom_size: - extended_start = chrom_size - ( - end - start + self.max_stochastic_shift * 2 - ) - extended_end = chrom_size - - seq = self.genome.fetch(chrom, extended_start, extended_end).upper() - if strand == "-": - seq = self._reverse_complement(seq) - return seq - - def _reverse_complement(self, sequence: str) -> str: - """Reverse complement a sequence.""" - return sequence.translate(self.complement)[::-1] - - def get_sequence( - self, region: str, stranded: bool | None = None, shift: int = 0 - ) -> str: - """ - Get sequence for a region, strand, and shift from memory or fasta. - - If no strand is given in region or strand, assumes positive strand. - - Parameters - ---------- - region - Region to get the sequence for. Either (chr:start-end) or (chr:start-end:strand). - stranded - Whether the input data is stranded. Default (None) infers from sequence (at a computational cost). - If not stranded, positive strand is assumed. - shift: - Shift of the sequence within the extended sequence, for use with the stochastic shift mechanism. - - Returns - ------- - The DNA sequence, as a string. - """ - if stranded is None: - stranded = _check_strandedness(region) - if not stranded: - region = f"{region}:+" - # Parse region - chrom, start_end, strand = region.split(":") - start, end = map(int, start_end.split("-")) - - # Get extended sequence - if self.in_memory: - sequence = self.sequences[region] - else: - sequence = self._get_extended_sequence(chrom, start, end, strand) - - # Extract from extended sequence - start_idx = self.max_stochastic_shift + shift - end_idx = start_idx + (end - start) - sub_sequence = sequence[start_idx:end_idx] - - # Pad with Ns if sequence is shorter than expected - if len(sub_sequence) < (end - start): - if strand == "+": - sub_sequence = sub_sequence.ljust(end - start, "N") - else: - sub_sequence = sub_sequence.rjust(end - start, "N") - - return sub_sequence - - -class IndexManager: - """ - Manage indices for the dataset. - - Augments indices with strand information if always reverse complement. - - Parameters - ---------- - indices - List of indices in format "chr:start-end" or "chr:start-end:strand". - always_reverse_complement - If True, all sequences will be augmented with their reverse complement. - deterministic_shift - If True, each region will be shifted twice with stride 50bp to each side. - """ - - def __init__( - self, - indices: list[str], - always_reverse_complement: bool, - deterministic_shift: bool = False, - ): - """Initialize the IndexManager with the provided indices.""" - self.indices = indices - self.always_reverse_complement = always_reverse_complement - self.deterministic_shift = deterministic_shift - self.augmented_indices, self.augmented_indices_map = self._augment_indices( - indices - ) - - def shuffle_indices(self): - """Shuffle indices. Managed by wrapping class AnnDataLoader.""" - np.random.shuffle(self.augmented_indices) - - def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str]]: - """Augment indices with strand information. Necessary if always reverse complement to map sequences back to targets.""" - augmented_indices = [] - augmented_indices_map = {} - for region in indices: - if not _check_strandedness( - region - ): # If slow, can use AnnDataset stranded argument - but this validates every region's formatting as well - stranded_region = f"{region}:+" - else: - stranded_region = region - - if self.deterministic_shift: - shifted_regions = _deterministic_shift_region(stranded_region) - for shifted_region in shifted_regions: - augmented_indices.append(shifted_region) - augmented_indices_map[shifted_region] = region - if self.always_reverse_complement: - augmented_indices.append(_flip_region_strand(shifted_region)) - augmented_indices_map[ - _flip_region_strand(shifted_region) - ] = region - else: - augmented_indices.append(stranded_region) - augmented_indices_map[stranded_region] = region - if self.always_reverse_complement: - augmented_indices.append(_flip_region_strand(stranded_region)) - augmented_indices_map[_flip_region_strand(stranded_region)] = region - return augmented_indices, augmented_indices_map - - - -class AnnDataset(BaseClass): - """ - Dataset class for combining genome files and AnnData objects. - - Called by the by the AnnDataModule class. - - Parameters - ---------- - anndata - AnnData object containing the data. - genome - Genome instance - split - 'train', 'val', or 'test' split column in anndata.var. - in_memory - If True, the train and val sequences will be loaded into memory. - random_reverse_complement - If True, the sequences will be randomly reverse complemented during training. - always_reverse_complement - If True, all sequences will be augmented with their reverse complement during training. - max_stochastic_shift - Maximum stochastic shift (n base pairs) to apply randomly to each sequence during training. - deterministic_shift - If true, each region will be shifted twice with stride 50bp to each side. - This is our legacy shifting, we recommend using max_stochastic_shift instead. - obs_columns - Columns in obs that will be added to the dataset. - obsm_columns - Keys in obsm that will be added to the dataset. - varp_columns - Keys in varp that will be added to the dataset. - """ - - def __init__( - self, - anndata: AnnData, - genome: Genome, - split: str = None, - in_memory: bool = True, - random_reverse_complement: bool = False, - always_reverse_complement: bool = False, - max_stochastic_shift: int = 0, - deterministic_shift: bool = False, - obs_columns: list[str] | None = None, # multiple obs columns - obsm_keys: list[str] | None = None, # multiple obsm keys - varp_keys: list[str] | None = None, # multiple varp keys - ): - - """Initialize the dataset with the provided AnnData object and options.""" - self.anndata = self._split_anndata(anndata, split) - self.split = split - self.indices = list(self.anndata.var_names) - self.in_memory = in_memory - self.compressed = isinstance(self.anndata.X, spmatrix) - self.index_map = {index: i for i, index in enumerate(self.indices)} - self.num_outputs = self.anndata.X.shape[0] - self.random_reverse_complement = random_reverse_complement - self.max_stochastic_shift = max_stochastic_shift - self.shuffle = False # managed by wrapping class AnnDataLoader - self.obs_columns = obs_columns if obs_columns is not None else [] - self.obsm_keys = obsm_keys if obsm_keys is not None else [] - self.varp_keys = varp_keys if varp_keys is not None else [] - - # Validate and store obs data - self.obs_data = {} - for col in self.obs_columns: - if col not in anndata.obs: - raise ValueError(f"obs column '{col}' not found.") - # Convert categorical to integer codes if needed - if pd.api.types.is_categorical_dtype(anndata.obs[col]): - self.obs_data[col] = anndata.obs[col].cat.codes.values - else: - self.obs_data[col] = anndata.obs[col].values - - # Validate and store obsm data - self.obsm_data = {} - for key in self.obsm_keys: - if key not in anndata.obsm: - raise ValueError(f"obsm key '{key}' not found.") - mat = anndata.obsm[key] - if mat.shape[0] != anndata.n_obs: - raise ValueError(f"Dimension mismatch for obsm key '{key}'.") - self.obsm_data[key] = mat - - # Validate and store varp data - self.varp_data = {} - for key in self.varp_keys: - if key not in anndata.varp: - raise ValueError(f"varp key '{key}' not found.") - mat = anndata.varp[key] - if mat.shape[0] != anndata.n_var: - raise ValueError(f"Dimension mismatch for varp key '{key}'.") - self.varp_data[key] = mat - - - # Check region formatting - stranded = _check_strandedness(self.indices[0]) - if stranded and (always_reverse_complement or random_reverse_complement): - logger.info( - "Setting always_reverse_complement=True or random_reverse_complement=True with stranded data.", - "This means both strands are used when training and the strand information is effectively disregarded.", - ) - - self.sequence_loader = SequenceLoader( - genome, - in_memory=in_memory, - always_reverse_complement=always_reverse_complement, - deterministic_shift=deterministic_shift, - max_stochastic_shift=max_stochastic_shift, - regions=self.indices, - ) - self.index_manager = IndexManager( - self.indices, - always_reverse_complement=always_reverse_complement, - deterministic_shift=deterministic_shift, - ) - self.seq_len = len( - self.sequence_loader.get_sequence(self.indices[0], stranded=stranded) - ) - - self.augmented_probs = None - if "sample_prob" in anndata.var.columns: - # 1) Extract raw sample_prob from adata.var - probs = anndata.var["sample_prob"].values.astype(float) - # 2) Ensure no negative values - probs = np.clip(probs, 0, None) - - # 3) For each augmented index, set unnormalized probability - self.augmented_probs = np.empty(len(self.index_manager.augmented_indices), dtype=float) - for i, aug_region in enumerate(self.index_manager.augmented_indices): - original_region = self.index_manager.augmented_indices_map[aug_region] - var_idx = self.index_map[original_region] - self.augmented_probs[i] = probs[var_idx] - else: - # If no sample_prob, we might default to 1.0 for each region - # or simply None to indicate uniform sampling - self.augmented_probs = None - - - @staticmethod - def _split_anndata(anndata: AnnData, split: str) -> AnnData: - """Return subset of anndata based on a given split column.""" - if split: - if "split" not in anndata.var.columns: - raise KeyError( - "No split column found in anndata.var. Run `pp.train_val_test_split` first." - ) - subset = ( - anndata[:, anndata.var["split"] == split].copy() - if split - else anndata.copy() - ) - return subset - - def __len__(self) -> int: - """Get number of (augmented) samples in the dataset.""" - return len(self.index_manager.augmented_indices) - - def _get_target(self, index: str) -> np.ndarray: - """Get target for a given index.""" - y_index = self.index_map[index] - return ( - self.anndata.X[:, y_index].toarray().flatten() - if self.compressed - else self.anndata.X[:, y_index] - ) - - def __getitem__(self, idx: int) -> dict: - """Return sequence and target for a given index.""" - augmented_index = self.index_manager.augmented_indices[idx] - original_index = self.index_manager.augmented_indices_map[augmented_index] - - # Get sequence and target as before - shift = 0 - if self.max_stochastic_shift > 0: - shift = np.random.randint(-self.max_stochastic_shift, self.max_stochastic_shift + 1) - - x = self.sequence_loader.get_sequence(augmented_index, stranded=True, shift=shift) - x = one_hot_encode_sequence(x, expand_dim=False) - y = self._get_target(original_index) - - # Random reverse complement if needed - if self.random_reverse_complement and np.random.rand() < 0.5: - x = self.sequence_loader._reverse_complement(x) - x = one_hot_encode_sequence(x, expand_dim=False) - - item = { - "sequence": x, - "y": y, - } - - # Add obsmp columns directly to the dictionary - for col in self.obs_columns: - item[col] = self.obs_data[col] - - for key in self.obsm_keys: - item[key] = self.obsm_data[key] - - for key in self.varp_keys: - item[key] = self.varp_data[key][idx]#.todense() - - return item - - def __call__(self): - """Call generator for the dataset.""" - for i in range(len(self)): - if i == 0: - if self.shuffle: - self.index_manager.shuffle_indices() - yield self.__getitem__(i) - - def __repr__(self) -> str: - """Get string representation of the dataset.""" - return f"AnnDataset(anndata_shape={self.anndata.shape}, n_samples={len(self)}, num_outputs={self.num_outputs}, split={self.split}, in_memory={self.in_memory})" - -class MetaAnnDataset: - """ - Combines multiple AnnDataset objects into a single dataset, - merging all their (augmented_index, probability) pairs into one global list. - - We do a final normalization across all sub-datasets so that - sample_prob from each dataset is treated as an unnormalized weight. - """ - - def __init__(self, datasets: list[AnnDataset]): - """ - Parameters - ---------- - datasets : list of AnnDataset - Each AnnDataset is for a different species or annotation set. - """ - if not datasets: - raise ValueError("No AnnDataset provided to MetaAnnDataset.") - - self.datasets = datasets - - # global_indices will store tuples of (dataset_idx, local_idx) - # global_probs will store the merged, unnormalized probabilities - self.global_indices = [] - self.global_probs = [] - - for ds_idx, ds in enumerate(datasets): - ds_len = len(ds.index_manager.augmented_indices) - if ds_len == 0: - continue - - # If the dataset has augmented_probs, we use them as unnormalized weights - # If not, fallback to 1.0 for each region - if ds.augmented_probs is not None: - for local_i in range(ds_len): - self.global_indices.append((ds_idx, local_i)) - self.global_probs.append(ds.augmented_probs[local_i]) - else: - for local_i in range(ds_len): - self.global_indices.append((ds_idx, local_i)) - self.global_probs.append(1.0) - - # Convert to numpy arrays - self.global_indices = np.array(self.global_indices, dtype=object) - self.global_probs = np.array(self.global_probs, dtype=float) - - # Normalize across the entire set - total = self.global_probs.sum() - if total > 0: - self.global_probs /= total - else: - # fallback: uniform if everything is zero - n = len(self.global_probs) - if n > 0: - self.global_probs.fill(1.0 / n) - - def __len__(self): - """ - The total number of augmented indices across all sub-datasets. - """ - return len(self.global_indices) - - def __getitem__(self, global_idx: int): - """ - A DataLoader or sampler will pass a global_idx in [0..len(self)-1]. - We map that to (dataset_idx, local_i) and call the sub-dataset's __getitem__. - """ - ds_idx, local_i = self.global_indices[global_idx] - ds_idx = int(ds_idx) - local_i = int(local_i) - return self.datasets[ds_idx][local_i] - - def __repr__(self): - return (f"MetaAnnDataset(num_datasets={len(self.datasets)}, " - f"total_augmented_indices={len(self.global_indices)})") -