Skip to content

Commit

Permalink
fix: minor memory optimisation for minimizer extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Danderson123 committed Aug 14, 2024
1 parent 3ae42d7 commit 3dde066
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions amira_prototype/construct_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,7 +2242,7 @@ def filter_paths_between_bubble_starts(self, unique_paths):
filtered_paths.append((p_list, self.calculate_path_coverage(p_list)))
return filtered_paths

def get_minhash_of_nodes(self, batch, node_minhashes, fastq_data):
def old_get_minhash_of_nodes(self, batch, node_minhashes, fastq_data):
for node_hash in batch:
node = self.get_node_by_hash(node_hash)
minhash = sourmash.MinHash(n=0, ksize=13, scaled=10)
Expand All @@ -2260,13 +2260,13 @@ def get_minhash_of_nodes(self, batch, node_minhashes, fastq_data):
minhash.add_sequence(entire_read_sequence[p[0] : p[1] + 1], force=True)
node_minhashes[node_hash] = minhash

def get_minhash_of_path(self, batch, path_minimizers, node_minhashes):
def old_get_minhash_of_path(self, batch, path_minimizers, node_minhashes):
for path_tuple in batch:
for node_hash in path_tuple:
assert node_minhashes[node_hash] is not None
path_minimizers[path_tuple].update(node_minhashes[node_hash].hashes)

def get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores):
def old_get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores):
path_minimizers = {}
node_minhashes = {}
for path_tuple, path_coverage in sorted_filtered_paths:
Expand All @@ -2290,28 +2290,54 @@ def get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores):
assert not any(v is None for v in path_minimizers.values())
return path_minimizers

def calculate_minimum_path_threshold(self, sorted_filtered_paths):
counts = {}
path_coverages = []
for terminals in sorted_filtered_paths:
if len(sorted_filtered_paths[terminals]) < 2:
continue
for p in sorted_filtered_paths[terminals]:
# round the coverage up to the nearest 10
nearest_tenth = math.ceil(p[1] / 10) * 10
if nearest_tenth not in counts:
counts[nearest_tenth] = 0
counts[nearest_tenth] += 1
path_coverages.append(p[1])
# get the highest count
sorted_counts = sorted([(k, v) for k, v in counts.items()], key=lambda x: x[0])
highest_count = max([x[1] for x in sorted_counts])
# get the threshold
previous_coverage = 10
for coverage, count in sorted_counts:
if count == highest_count:
return previous_coverage
previous_coverage = coverage
def get_minhash_of_nodes(self, batch, node_minhashes, fastq_data):
for node_hash in batch:
node = self.get_node_by_hash(node_hash)
minhash = sourmash.MinHash(n=0, ksize=13, scaled=10)
for read in node.get_reads():
indices = [i for i, n in enumerate(self.get_readNodes()[read]) if n == node_hash]
positions = [self.get_readNodePositions()[read][i] for i in indices]
entire_read_sequence = fastq_data[read.split("_")[0]]["sequence"]
for p in positions:
minhash.add_sequence(entire_read_sequence[p[0]: p[1] + 1], force=True)
node_minhashes[node_hash] = minhash

def get_minhash_of_path(self, batch, path_minimizers, node_minhashes):
for path_tuple in batch:
for node_hash in path_tuple:
path_minimizers[path_tuple].update(node_minhashes[node_hash].hashes)

def get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores):
from collections import defaultdict
path_minimizers = defaultdict(set)
node_minhashes = {}

for path_tuple, path_coverage in sorted_filtered_paths:
path = [p[0] for p in path_tuple]
for node_hash in path:
if node_hash not in node_minhashes:
node_minhashes[node_hash] = None

path_minimizers[tuple(path)] = set()

# Parallel computation of node minhashes
keys = list(node_minhashes.keys())
batches = [keys[i::cores] for i in range(cores)]
Parallel(n_jobs=cores, prefer="threads")(
delayed(self.get_minhash_of_nodes)(batch, node_minhashes, fastq_data)
for batch in batches
)

# Parallel computation of path minhashes
keys = list(path_minimizers.keys())
batches = [keys[i::cores] for i in range(cores)]
Parallel(n_jobs=cores, prefer="threads")(
delayed(self.get_minhash_of_path)(batch, path_minimizers, node_minhashes)
for batch in batches
)
# Ensure that all path minimizers are populated
assert not any(v is None for v in path_minimizers.values())
return path_minimizers

def correct_low_coverage_paths(
self, fastq_data, genesOfInterest, cores, min_path_coverage, use_minimizers=False
Expand Down Expand Up @@ -2346,9 +2372,6 @@ def correct_low_coverage_paths(
path_minimizers = None
# bin the paths based on their terminal nodes
sorted_filtered_paths = self.separate_paths_by_terminal_nodes(sorted_filtered_paths)
# dynamically determine the minimum path coverage
# if min_path_coverage is None:
# min_path_coverage = self.calculate_minimum_path_threshold(sorted_filtered_paths)
# clean the paths
path_coverages += self.correct_bubble_paths(
sorted_filtered_paths,
Expand Down

0 comments on commit 3dde066

Please sign in to comment.