Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalizing dataloader and loading multiple species #88

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mtvector
Copy link

@mtvector mtvector commented Jan 2, 2025

Hi all,

I wanted to start this pull request as a discussion about some extensions to CREsted that I have considered/need, and have built out a preliminary version of in my fork. I haven't written tests and certainly the changes I have here break other things in CREsted I haven't checked.

I'm working on building models that train on data across species, as well as have additional information like gene expression vectors passed to the model.

Therefore what I have altered in my fork includes the following:

  • Altering AnnDataSet and AnnDataLoader so that you can load multiple fields from obs, obsm, var and varp, and deliver those to your trainer as a dict of tensors.n
  • Adding a MetaAnnDataset and MetaSampler, so that you can collate AnnData objects from multiple species, and sample from them randomly in minibatches.

I'm still working out some bugs, but figured I shouldn't go any further until I had reached out to see what you have in the works here, and if this code is useful to you and is the type of thing you would consider merging when it is mature. Otherwise I'll continue developing this as independent extensions for CREsted.

Thanks so much for developing this great package!

Matthew

P.S. The low-level usage for the extended classes I've written looks like so:

adata = crested.import_bigwigs(bigwigs_folder=atac_bigwig_dir,regions_file=bin_path,chromsizes_file=chromsizes_file,target='raw')
bdata = crested.import_bigwigs(bigwigs_folder=atac_bigwig_dir,regions_file=bin_path,chromsizes_file=chromsizes_file,target='raw')

adata.obs['imaginary'] = np.random.randint(0,10,adata.shape[0])
adata.obsm['test'] = np.random.randn(adata.shape[0],3)
bdata.obs['imaginary'] = np.random.randint(-5,-1,adata.shape[0])
bdata.obsm['test'] = np.random.randn(adata.shape[0],3)-2

p = np.full(adata.n_vars, 0.5/(adata.n_vars - 1), dtype=float) 
p[0] = 0.5  # Give the first feature a 0.5 probability
adata.var["sample_prob"] = p
bdata.var["sample_prob"] = p

crested.pp.train_val_test_split(
    adata, strategy="chr", val_chroms=["chr8"], test_chroms=["chr9"]
)
crested.pp.train_val_test_split(
    bdata, strategy="chr", val_chroms=["chr8"], test_chroms=["chr9"]
)

datamodule = MetaAnnDataModule(
    adatas=[adata,bdata],
    genomes=[genome,genome],
    batch_size=32,
    epoch_size=5000,
    max_stochastic_shift=3, 
    always_reverse_complement=True, 
    obs_columns=['imaginary'],
    obsm_keys=['test']
)

datamodule.setup('fit')

for x in datamodule.train_dataloader.data:
    print(x)
    for k in x.keys():
        print(k,x[k].shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant