From a8c5dba1e1afa8562cacf6c01beeac0768a2d2f9 Mon Sep 17 00:00:00 2001 From: Frank Austin Nothaft Date: Mon, 6 Jun 2016 11:58:01 -0700 Subject: [PATCH] [ADAM-952] Expose sorting by reference index. Resolves #952. Adds function `sortByReferenceIndexAndPosition` on RDDs of `AlignmentRecord`. This sorts reads by their position on a contig, where contigs are ordered by contig index. This conforms to the SAM/BAM sort order. --- .../org/bdgenomics/adam/cli/Transform.scala | 19 ++++-- .../org/bdgenomics/adam/cli/ViewSuite.scala | 5 +- .../adam/instrumentation/Timers.scala | 1 + .../read/AlignmentRecordRDDFunctions.scala | 38 +++++++++++ .../AlignmentRecordRDDFunctionsSuite.scala | 67 +++++++++++++++++-- 5 files changed, 119 insertions(+), 11 deletions(-) diff --git a/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala b/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala index 87df7156bb..5150e5d31a 100644 --- a/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala +++ b/adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala @@ -59,6 +59,8 @@ class TransformArgs extends Args4jBase with ADAMSaveAnyArgs with ParquetArgs { var useAlignedReadPredicate: Boolean = false @Args4jOption(required = false, name = "-sort_reads", usage = "Sort the reads by referenceId and read position") var sortReads: Boolean = false + @Args4jOption(required = false, name = "-sort_lexicographically", usage = "Sort the reads lexicographically by contig name, instead of by index.") + var sortLexicographically: Boolean = false @Args4jOption(required = false, name = "-mark_duplicate_reads", usage = "Mark duplicate reads") var markDuplicates: Boolean = false @Args4jOption(required = false, name = "-recalibrate_base_qualities", usage = "Recalibrate the base quality scores (ILLUMINA only)") @@ -123,6 +125,7 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans val stringency = ValidationStringency.valueOf(args.stringency) def apply(rdd: RDD[AlignmentRecord], + sd: SequenceDictionary, rgd: RecordGroupDictionary): RDD[AlignmentRecord] = { var adamRecords = rdd @@ -207,7 +210,11 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans } log.info("Sorting reads") - adamRecords = oldRdd.sortReadsByReferencePosition() + if (args.sortLexicographically) { + adamRecords = oldRdd.sortReadsByReferencePosition() + } else { + adamRecords = oldRdd.sortReadsByReferencePositionAndIndex(sd) + } if (args.cache) { oldRdd.unpersist() @@ -329,15 +336,19 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans }) // run our transformation - val outputRdd = this.apply(mergedRdd, mergedRgd) + val outputRdd = this.apply(mergedRdd, mergedSd, mergedRgd) // if we are sorting, we must strip the indices from the sequence dictionary // and sort the sequence dictionary // // we must do this because we do a lexicographic sort, not an index-based sort val sdFinal = if (args.sortReads) { - mergedSd.stripIndices - .sorted + if (args.sortLexicographically) { + mergedSd.stripIndices + .sorted + } else { + mergedSd + } } else { mergedSd } diff --git a/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala b/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala index c68ea379ef..75d656ef5c 100644 --- a/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala +++ b/adam-cli/src/test/scala/org/bdgenomics/adam/cli/ViewSuite.scala @@ -18,8 +18,9 @@ package org.bdgenomics.adam.cli import org.apache.spark.rdd.RDD -import org.bdgenomics.adam.util.ADAMFunSuite +import org.bdgenomics.adam.models.SequenceDictionary import org.bdgenomics.adam.rdd.ADAMContext._ +import org.bdgenomics.adam.util.ADAMFunSuite import org.bdgenomics.formats.avro.AlignmentRecord import org.bdgenomics.utils.cli.Args4j @@ -46,7 +47,7 @@ class ViewSuite extends ADAMFunSuite { val rdd = aRdd.rdd val rgd = aRdd.recordGroups - reads = transform.apply(rdd, rgd).collect() + reads = transform.apply(rdd, SequenceDictionary.empty, rgd).collect() readsCount = reads.size.toInt } diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala b/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala index be8a41bf33..91874c6f67 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/instrumentation/Timers.scala @@ -69,6 +69,7 @@ object Timers extends Metrics { // Sort Reads val SortReads = timer("Sort Reads") + val SortByIndex = timer("Sort Reads By Index") // File Saving val SAMSave = timer("SAM Save") diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala index c51205d0e1..c38d5c0571 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctions.scala @@ -44,8 +44,20 @@ import org.bdgenomics.utils.misc.Logging import org.seqdoop.hadoop_bam.SAMRecordWritable import scala.annotation.tailrec import scala.language.implicitConversions +import scala.math.{ abs, min } import scala.reflect.ClassTag +private object SequenceIndexOrdering extends Ordering[(Int, Long)] { + def compare(a: (Int, Long), + b: (Int, Long)): Int = { + if (a._1 == b._1) { + a._2.compareTo(b._2) + } else { + a._1.compareTo(b._1) + } + } +} + private[rdd] class AlignmentRecordRDDFunctions(val rdd: RDD[AlignmentRecord]) extends ADAMRDDFunctions[AlignmentRecord] { @@ -580,6 +592,32 @@ private[rdd] class AlignmentRecordRDDFunctions(val rdd: RDD[AlignmentRecord]) }).sortByKey().map(_._2) } + def sortReadsByReferencePositionAndIndex(sd: SequenceDictionary): RDD[AlignmentRecord] = SortByIndex.time { + log.info("Sorting reads by reference index, using %s.".format(sd)) + + implicit val ordering: Ordering[(Int, Long)] = SequenceIndexOrdering + + // NOTE: In order to keep unmapped reads from swamping a single partition + // we sort the unmapped reads by read name. To do this, we hash the sequence name + // and add the max contig index + val maxContigIndex = sd.records.flatMap(_.referenceIndex).max + rdd.keyBy(r => { + val key: (Int, Long) = if (r.getReadMapped) { + val sr = sd(r.getContigName) + require(sr.isDefined, "Read %s has contig name %s not in dictionary %s.".format( + r, r.getContigName, sd)) + require(sr.get.referenceIndex.isDefined, + "Contig %s from sequence dictionary lacks an index.".format(sr)) + + (sr.get.referenceIndex.get, r.getStart) + } else { + (min(abs(r.getReadName.hashCode + maxContigIndex), Int.MaxValue), 0L) + } + + key + }).sortByKey().map(_._2) + } + /** * Marks reads as possible fragment duplicates. * diff --git a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala index 3301504936..f62350aa07 100644 --- a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala +++ b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDDFunctionsSuite.scala @@ -22,7 +22,11 @@ import java.nio.file.Files import htsjdk.samtools.ValidationStringency import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.bdgenomics.adam.models.{ RecordGroupDictionary, SequenceDictionary } +import org.bdgenomics.adam.models.{ + RecordGroupDictionary, + SequenceDictionary, + SequenceRecord +} import org.bdgenomics.adam.rdd.ADAMContext._ import org.bdgenomics.adam.rdd.TestSaveArgs import org.bdgenomics.adam.util.ADAMFunSuite @@ -30,6 +34,17 @@ import org.bdgenomics.formats.avro._ import scala.io.Source import scala.util.Random +private object SequenceIndexWithReadOrdering extends Ordering[((Int, Long), (AlignmentRecord, Int))] { + def compare(a: ((Int, Long), (AlignmentRecord, Int)), + b: ((Int, Long), (AlignmentRecord, Int))): Int = { + if (a._1._1 == b._1._1) { + a._1._2.compareTo(b._1._2) + } else { + a._1._1.compareTo(b._1._1) + } + } +} + class AlignmentRecordRDDFunctionsSuite extends ADAMFunSuite { sparkTest("sorting reads") { @@ -39,11 +54,9 @@ class AlignmentRecordRDDFunctionsSuite extends ADAMFunSuite { val mapped = random.nextBoolean() val builder = AlignmentRecord.newBuilder().setReadMapped(mapped) if (mapped) { - val contig = Contig.newBuilder - .setContigName(random.nextInt(numReadsToCreate / 10).toString) - .build + val contigName = random.nextInt(numReadsToCreate / 10).toString val start = random.nextInt(1000000) - builder.setContigName(contig.getContigName).setStart(start).setEnd(start) + builder.setContigName(contigName).setStart(start).setEnd(start) } builder.setReadName((0 until 20).map(i => (random.nextInt(100) + 64)).mkString) builder.build() @@ -59,6 +72,50 @@ class AlignmentRecordRDDFunctionsSuite extends ADAMFunSuite { assert(expectedSortedReads === mapped) } + sparkTest("sorting reads by reference index") { + val random = new Random("sortingIndices".hashCode) + val numReadsToCreate = 1000 + val reads = for (i <- 0 until numReadsToCreate) yield { + val mapped = random.nextBoolean() + val builder = AlignmentRecord.newBuilder().setReadMapped(mapped) + if (mapped) { + val contigName = random.nextInt(numReadsToCreate / 10).toString + val start = random.nextInt(1000000) + builder.setContigName(contigName).setStart(start).setEnd(start) + } + builder.setReadName((0 until 20).map(i => (random.nextInt(100) + 64)).mkString) + builder.build() + } + val contigNames = reads.filter(_.getReadMapped).map(_.getContigName).toSet + val sd = new SequenceDictionary(contigNames.toSeq + .zipWithIndex + .map(kv => { + val (name, index) = kv + SequenceRecord(name, Int.MaxValue, referenceIndex = Some(index)) + }).toVector) + + val rdd = sc.parallelize(reads) + val sortedReads = rdd.sortReadsByReferencePositionAndIndex(sd).collect().zipWithIndex + val (mapped, unmapped) = sortedReads.partition(_._1.getReadMapped) + + // Make sure that all the unmapped reads are placed at the end + assert(unmapped.forall(p => p._2 > mapped.takeRight(1)(0)._2)) + + def toIndex(r: AlignmentRecord): Int = { + sd(r.getContigName).get.referenceIndex.get + } + + // Make sure that we appropriately sorted the reads + implicit val ordering: Ordering[((Int, Long), (AlignmentRecord, Int))] = SequenceIndexWithReadOrdering + val expectedSortedReads = mapped.map(kv => { + val (r, idx) = kv + val start: Long = r.getStart + ((toIndex(r), start), (r, idx)) + }).sorted + .map(_._2) + assert(expectedSortedReads === mapped) + } + sparkTest("characterizeTags counts integer tag values correctly") { val tagCounts: Map[String, Long] = Map("XT" -> 10L, "XU" -> 9L, "XV" -> 8L) val readItr: Iterable[AlignmentRecord] =