Skip to content

Commit

Permalink
test v0.3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
andyjslee committed Sep 10, 2024
1 parent 0526537 commit 73bef0f
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 91 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vstolib"
version = "0.1.9"
version = "0.3.2"
edition = "2021"

[package.metadata.maturin]
Expand Down
2 changes: 2 additions & 0 deletions python/vstolib/cli/cli_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ def run_cli_score_from_parsed_args(args: argparse.Namespace):
variants_list = VariantsList.read_tsv_file(tsv_file=args.tsv_file)

# Step 2. Calculate average alignment score for each breakpoint
logger.info("Started calculating average alignment score for each breakpoint.")
variants_list = score(
variants_list=variants_list,
bam_file=args.bam_file,
window=args.window,
num_threads=args.num_threads
)
logger.info("Finished calculating average alignment score for each breakpoint.")

# Step 3. Write to a TSV file
df_variants = variants_list.to_dataframe()
Expand Down
93 changes: 61 additions & 32 deletions python/vstolib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@

import copy
import multiprocessing as mp
import numpy as np
import pandas as pd
import pysam
from collections import defaultdict
from functools import partial
from typing import List, Literal, Tuple
from vstolib import vstolibrs
from .annotator import Annotator
from .constants import CollapseStrategies, VariantCallingMethods
from .default import *
from .genomic_ranges_list import GenomicRangesList
from .logging import get_logger
from .metrics import calculate_average_alignment_score
from .utilities import is_repeated_sequence
from .variant import Variant
from .variants_list import VariantsList
Expand Down Expand Up @@ -439,28 +441,6 @@ def overlap(
return variants_list_overlapping


def _score_helper(
bam_file: str,
window: int,
variant: Variant
) -> Variant:
for variant_call in variant.variant_calls:
variant_call.position_1_average_alignment_score = calculate_average_alignment_score(
bam_file=bam_file,
chromosome=variant_call.chromosome_1,
position=variant_call.position_1,
window=window
)
variant_call.position_2_average_alignment_score = calculate_average_alignment_score(
bam_file=bam_file,
chromosome=variant_call.chromosome_2,
position=variant_call.position_2,
window=window
)
variant_call.average_alignment_score_window = window
return variant


def score(
variants_list: VariantsList,
bam_file: str,
Expand All @@ -474,19 +454,68 @@ def score(
variants_list : VariantsList object.
bam_file : BAM file.
window : Window (will be applied both upstream and downstream).
num_threads : Number of threads.
Returns:
VariantsList
"""
pool = mp.Pool(processes=num_threads)
func = partial(_score_helper, bam_file, window)
variants = pool.map(func, variants_list.variants)
pool.close()
variants_list_ = VariantsList()
for variant in variants:
variants_list_.add_variant(variant=variant)
return variants_list_
# Step 1. Get the regions
regions = []
bamfile = pysam.AlignmentFile(bam_file, "rb")
for variant in variants_list.variants:
for variant_call in variant.variant_calls:
# Position 1
chromosome_length = bamfile.get_reference_length(variant_call.chromosome_1)
start = variant_call.position_1 - window
end = variant_call.position_1 + window
if start < 0:
start = 0
if end > chromosome_length:
end = chromosome_length
regions.append((variant_call.chromosome_1,start,end))

# Position 2
chromosome_length = bamfile.get_reference_length(variant_call.chromosome_2)
start = variant_call.position_2 - window
end = variant_call.position_2 + window
if start < 0:
start = 0
if end > chromosome_length:
end = chromosome_length
regions.append((variant_call.chromosome_2,start,end))

# Step 2. Calculate the average alignment scores
regions_scores = vstolibrs.calculate_average_alignment_scores(
bam_file=bam_file,
regions=regions,
num_threads=num_threads
)

# Step 3. Store the average alignment scores
for variant in variants_list.variants:
for variant_call in variant.variant_calls:
# Position 1
chromosome_length = bamfile.get_reference_length(variant_call.chromosome_1)
start = variant_call.position_1 - window
end = variant_call.position_1 + window
if start < 0:
start = 0
if end > chromosome_length:
end = chromosome_length
variant_call.position_1_average_alignment_score = regions_scores[(variant_call.chromosome_1,start,end)]

# Position 2
chromosome_length = bamfile.get_reference_length(variant_call.chromosome_2)
start = variant_call.position_2 - window
end = variant_call.position_2 + window
if start < 0:
start = 0
if end > chromosome_length:
end = chromosome_length
variant_call.position_2_average_alignment_score = regions_scores[(variant_call.chromosome_2,start,end)]

variant_call.average_alignment_score_window = window

return variants_list


def vcf2tsv(
Expand Down
56 changes: 0 additions & 56 deletions python/vstolib/metrics.py

This file was deleted.

29 changes: 27 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ extern crate pyo3;
extern crate serde_json;
use chrono::Local;
use env_logger::{Builder, Env};
use log::{info, LevelFilter};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::collections::HashMap;
use std::io::Write;
mod constants;
mod genomic_range;
mod genomic_ranges_list;
mod metrics;
mod utilities;
mod variant;
mod variant_call;
Expand All @@ -35,6 +35,7 @@ mod variant_filter;
mod variants_list;
use genomic_range::GenomicRange;
use genomic_ranges_list::GenomicRangesList;
use metrics::calculate_average_alignment_scores as calculate_average_alignment_scores_;
use variant::Variant;
use variant_call::VariantCall;
use variant_call_annotation::VariantCallAnnotation;
Expand Down Expand Up @@ -125,11 +126,34 @@ fn deserialize_variant_filter(json_str: &str) -> VariantFilter {
}
}

/// Calculates average alignment scores for a given list of regions.
///
/// # Arguments
/// * `bam_file` - BAM file.
/// * `regions` - vector of tuples where each tuple is (chromosome,start,end).
/// * `num_threads` - number of threads.
///
/// # Returns
/// * HashMap where key = (chromosome,start,end) and value = average alignment score.
#[pyfunction]
fn calculate_average_alignment_scores(
py: Python,
bam_file: String,
regions: Vec<(String,u32,u32)>,
num_threads: usize) -> PyResult<HashMap<(String, u32, u32), f64>> {
let scores: HashMap<(String,u32,u32),f64> = calculate_average_alignment_scores_(
bam_file.as_str(),
&regions,
num_threads
);
Ok(scores)
}

/// This function filters a serialized VariantsList object and returns a filtered VariantsList.
///
/// # Arguments
/// * `py_str` - serialized VariantsList object.
/// * `py_list` - a list of serialized VariantFilter objects.
/// * `py_list` - list of serialized VariantFilter objects.
/// * `num_threads` - number of threads.
///
/// # Returns
Expand Down Expand Up @@ -284,6 +308,7 @@ fn vstolibrs(_py: Python, m: &PyModule) -> PyResult<()> {
)
}).init();

m.add_function(wrap_pyfunction!(calculate_average_alignment_scores, m)?);
m.add_function(wrap_pyfunction!(filter_variants_list, m)?);
m.add_function(wrap_pyfunction!(find_overlapping_variant_calls, m)?);
m.add_function(wrap_pyfunction!(intersect_variants_lists, m)?);
Expand Down
85 changes: 85 additions & 0 deletions src/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


use rayon::prelude::*;
use rayon::ThreadPoolBuilder;
use bam::IndexedReader;
use std::collections::HashMap;


pub fn calculate_average_alignment_scores(
bam_file: &str,
regions: &Vec<(String, u32, u32)>,
num_threads: usize,
) -> HashMap<(String, u32, u32), f64> {
// Step 1. Read the BAM file
let reader = bam::IndexedReader::from_path(bam_file).unwrap();
let header = reader.header();

// Step 2. Get the regions with chromosome IDs
let mut regions_reformatted: Vec<(u32, u32, u32)> = Vec::new();
let mut chromosomes_map: HashMap<u32, String> = HashMap::new();
for (chromosome, start, end) in regions.iter() {
if let Some(chromosome_id) = header.reference_id(chromosome) {
regions_reformatted.push((chromosome_id, *start, *end));
chromosomes_map.insert(chromosome_id, chromosome.to_string());
} else {
panic!("The chromosome ID cannot be found for {}", chromosome);
}
}

// Step 3. Split the regions into roughly equal-sized chunks
let chunk_size = (regions_reformatted.len() + num_threads - 1) / num_threads;
let regions_reformatted_chunks: Vec<_> = regions_reformatted.chunks(chunk_size).collect();

// Step 4. Calculate the average alignment score for each position
let thread_pool = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.unwrap();
let scores: Vec<(u32, u32, u32, f64)> = thread_pool.install(|| {
regions_reformatted_chunks
.par_iter()
.flat_map(|chunk| {
let mut local_reader = bam::IndexedReader::from_path(bam_file).unwrap();
chunk
.iter()
.map(|(chromosome, start, end)| {
let mut total: u32 = 0;
let mut count: u32 = 0;
for record in local_reader.fetch(&bam::Region::new(*chromosome, *start, *end)).unwrap() {
total += record.unwrap().mapq() as u32;
count += 1;
}
let score = if count > 0 {
total as f64 / count as f64
} else {
-1.0
};
(*chromosome, *start, *end, score)
})
.collect::<Vec<(u32, u32, u32, f64)>>()
})
.collect()
});

// Step 5. Store the average alignment scores as a HashMap
let mut scores_map: HashMap<(String, u32, u32), f64> = HashMap::new();
for (chromosome_id, start, end, score) in scores.iter() {
if let Some(chromosome_name) = chromosomes_map.get(chromosome_id) {
scores_map.insert((chromosome_name.clone(), *start, *end), *score);
}
}

scores_map
}

0 comments on commit 73bef0f

Please sign in to comment.