diff --git a/amira_prototype/construct_graph.py b/amira_prototype/construct_graph.py index 6882eb7e..c0e1456a 100644 --- a/amira_prototype/construct_graph.py +++ b/amira_prototype/construct_graph.py @@ -2242,7 +2242,7 @@ def filter_paths_between_bubble_starts(self, unique_paths): filtered_paths.append((p_list, self.calculate_path_coverage(p_list))) return filtered_paths - def get_minhash_of_nodes(self, batch, node_minhashes, fastq_data): + def old_get_minhash_of_nodes(self, batch, node_minhashes, fastq_data): for node_hash in batch: node = self.get_node_by_hash(node_hash) minhash = sourmash.MinHash(n=0, ksize=13, scaled=10) @@ -2260,13 +2260,13 @@ def get_minhash_of_nodes(self, batch, node_minhashes, fastq_data): minhash.add_sequence(entire_read_sequence[p[0] : p[1] + 1], force=True) node_minhashes[node_hash] = minhash - def get_minhash_of_path(self, batch, path_minimizers, node_minhashes): + def old_get_minhash_of_path(self, batch, path_minimizers, node_minhashes): for path_tuple in batch: for node_hash in path_tuple: assert node_minhashes[node_hash] is not None path_minimizers[path_tuple].update(node_minhashes[node_hash].hashes) - def get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores): + def old_get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores): path_minimizers = {} node_minhashes = {} for path_tuple, path_coverage in sorted_filtered_paths: @@ -2290,28 +2290,54 @@ def get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores): assert not any(v is None for v in path_minimizers.values()) return path_minimizers - def calculate_minimum_path_threshold(self, sorted_filtered_paths): - counts = {} - path_coverages = [] - for terminals in sorted_filtered_paths: - if len(sorted_filtered_paths[terminals]) < 2: - continue - for p in sorted_filtered_paths[terminals]: - # round the coverage up to the nearest 10 - nearest_tenth = math.ceil(p[1] / 10) * 10 - if nearest_tenth not in counts: - counts[nearest_tenth] = 0 - counts[nearest_tenth] += 1 - path_coverages.append(p[1]) - # get the highest count - sorted_counts = sorted([(k, v) for k, v in counts.items()], key=lambda x: x[0]) - highest_count = max([x[1] for x in sorted_counts]) - # get the threshold - previous_coverage = 10 - for coverage, count in sorted_counts: - if count == highest_count: - return previous_coverage - previous_coverage = coverage + def get_minhash_of_nodes(self, batch, node_minhashes, fastq_data): + for node_hash in batch: + node = self.get_node_by_hash(node_hash) + minhash = sourmash.MinHash(n=0, ksize=13, scaled=10) + for read in node.get_reads(): + indices = [i for i, n in enumerate(self.get_readNodes()[read]) if n == node_hash] + positions = [self.get_readNodePositions()[read][i] for i in indices] + entire_read_sequence = fastq_data[read.split("_")[0]]["sequence"] + for p in positions: + minhash.add_sequence(entire_read_sequence[p[0]: p[1] + 1], force=True) + node_minhashes[node_hash] = minhash + + def get_minhash_of_path(self, batch, path_minimizers, node_minhashes): + for path_tuple in batch: + for node_hash in path_tuple: + path_minimizers[path_tuple].update(node_minhashes[node_hash].hashes) + + def get_minhashes_for_paths(self, sorted_filtered_paths, fastq_data, cores): + from collections import defaultdict + path_minimizers = defaultdict(set) + node_minhashes = {} + + for path_tuple, path_coverage in sorted_filtered_paths: + path = [p[0] for p in path_tuple] + for node_hash in path: + if node_hash not in node_minhashes: + node_minhashes[node_hash] = None + + path_minimizers[tuple(path)] = set() + + # Parallel computation of node minhashes + keys = list(node_minhashes.keys()) + batches = [keys[i::cores] for i in range(cores)] + Parallel(n_jobs=cores, prefer="threads")( + delayed(self.get_minhash_of_nodes)(batch, node_minhashes, fastq_data) + for batch in batches + ) + + # Parallel computation of path minhashes + keys = list(path_minimizers.keys()) + batches = [keys[i::cores] for i in range(cores)] + Parallel(n_jobs=cores, prefer="threads")( + delayed(self.get_minhash_of_path)(batch, path_minimizers, node_minhashes) + for batch in batches + ) + # Ensure that all path minimizers are populated + assert not any(v is None for v in path_minimizers.values()) + return path_minimizers def correct_low_coverage_paths( self, fastq_data, genesOfInterest, cores, min_path_coverage, use_minimizers=False @@ -2346,9 +2372,6 @@ def correct_low_coverage_paths( path_minimizers = None # bin the paths based on their terminal nodes sorted_filtered_paths = self.separate_paths_by_terminal_nodes(sorted_filtered_paths) - # dynamically determine the minimum path coverage - # if min_path_coverage is None: - # min_path_coverage = self.calculate_minimum_path_threshold(sorted_filtered_paths) # clean the paths path_coverages += self.correct_bubble_paths( sorted_filtered_paths,