Skip to content

Commit

Permalink
[ADAM-599] Eliminate shuffle issues by writing metadata to avro files.
Browse files Browse the repository at this point in the history
Resolves bigdatagenomics#599. Since we have added the RecordGroupMetadata fields in
bdg-formats:0.7.0, we can read/write our metadata as separate Avro files. We
process these files when loading/writing the Parquet files where the alignment
data is stored. This allows us to both eliminate the bulky metadata that we are
currently storing in the AlignmentRecord, while maintaining the Sequence and
RecordGroup dictionaries that we need to keep around.
  • Loading branch information
fnothaft committed Dec 29, 2015
1 parent 82689b4 commit 19562db
Show file tree
Hide file tree
Showing 28 changed files with 402 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.bdgenomics.adam.apis.java

import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaSparkContext
import org.bdgenomics.adam.models.{ RecordGroupDictionary, SequenceDictionary }
import org.bdgenomics.adam.rdd.ADAMContext
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.formats.avro._
Expand Down Expand Up @@ -60,7 +61,10 @@ class JavaADAMContext(val ac: ADAMContext) extends Serializable {
* @param filePath Path to load the file from.
* @return Returns a read RDD.
*/
def adamRecordLoad(filePath: java.lang.String): JavaAlignmentRecordRDD = {
new JavaAlignmentRecordRDD(ac.loadAlignments(filePath).toJavaRDD())
def adamRecordLoad(filePath: java.lang.String): (JavaAlignmentRecordRDD, SequenceDictionary, RecordGroupDictionary) = {
val (rdd, sd, rgd) = ac.loadAlignments(filePath)
(new JavaAlignmentRecordRDD(rdd.toJavaRDD()),
sd,
rgd)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@
*/
package org.bdgenomics.adam.apis.java

import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.spark.api.java.JavaRDD
import org.bdgenomics.adam.models.{ RecordGroupDictionary, SequenceDictionary }
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.rdd.ADAMSaveAnyArgs
import org.bdgenomics.formats.avro._
import org.apache.parquet.hadoop.metadata.CompressionCodecName

private class JavaSaveArgs(var outputPath: String,
var blockSize: Int = 128 * 1024 * 1024,
var pageSize: Int = 1 * 1024 * 1024,
var compressionCodec: CompressionCodecName = CompressionCodecName.GZIP,
var disableDictionaryEncoding: Boolean = false,
var asSingleFile: Boolean = false) extends ADAMSaveAnyArgs {
var sortFastqOutput = false
}

class JavaAlignmentRecordRDD(val jrdd: JavaRDD[AlignmentRecord]) extends Serializable {

Expand All @@ -32,47 +43,63 @@ class JavaAlignmentRecordRDD(val jrdd: JavaRDD[AlignmentRecord]) extends Seriali
* @param pageSize Size per page.
* @param compressCodec Name of the compression codec to use.
* @param disableDictionaryEncoding Whether or not to disable bit-packing.
* @param sd A dictionary describing the contigs this file is aligned against.
* @param rgd A dictionary describing the read groups in this file.
*/
def adamSave(filePath: java.lang.String,
blockSize: java.lang.Integer,
pageSize: java.lang.Integer,
compressCodec: CompressionCodecName,
disableDictionaryEncoding: java.lang.Boolean) {
jrdd.rdd.adamParquetSave(
filePath,
blockSize,
pageSize,
compressCodec,
disableDictionaryEncoding
)
disableDictionaryEncoding: java.lang.Boolean,
sd: SequenceDictionary,
rgd: RecordGroupDictionary) {
jrdd.rdd.saveAsParquet(
new JavaSaveArgs(filePath,
blockSize = blockSize,
pageSize = pageSize,
compressionCodec = compressCodec,
disableDictionaryEncoding = disableDictionaryEncoding),
sd,
rgd)
}

/**
* Saves this RDD to disk as a Parquet file.
*
* @param filePath Path to save the file at.
* @param sd A dictionary describing the contigs this file is aligned against.
* @param rgd A dictionary describing the read groups in this file.
*/
def adamSave(filePath: java.lang.String) {
jrdd.rdd.adamParquetSave(filePath)
def adamSave(filePath: java.lang.String,
sd: SequenceDictionary,
rgd: RecordGroupDictionary) {
jrdd.rdd.saveAsParquet(
new JavaSaveArgs(filePath),
sd,
rgd)
}

/**
* Saves this RDD to disk as a SAM/BAM file.
*
* @param filePath Path to save the file at.
* @param sd A dictionary describing the contigs this file is aligned against.
* @param rgd A dictionary describing the read groups in this file.
* @param asSam If true, saves as SAM. If false, saves as BAM.
* @param asSingleFile If true, saves output as a single file.
* @param isSorted If the output is sorted, this will modify the header.
*/
def adamSAMSave(filePath: java.lang.String,
asSam: java.lang.Boolean) {
jrdd.rdd.adamSAMSave(filePath, asSam)
}

/**
* Saves this RDD to disk as a SAM file.
*
* @param filePath Path to save the file at.
*/
def adamSAMSave(filePath: java.lang.String) {
jrdd.rdd.adamSAMSave(filePath)
sd: SequenceDictionary,
rgd: RecordGroupDictionary,
asSam: java.lang.Boolean,
asSingleFile: java.lang.Boolean,
isSorted: java.lang.Boolean) {
jrdd.rdd.adamSAMSave(filePath,
sd,
rgd,
asSam = asSam,
asSingleFile = asSingleFile,
isSorted = isSorted)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@
import java.nio.file.Path;
import org.apache.spark.api.java.JavaRDD;
import org.bdgenomics.adam.apis.java.JavaADAMContext;
import org.bdgenomics.adam.models.RecordGroupDictionary;
import org.bdgenomics.adam.models.SequenceDictionary;
import org.bdgenomics.formats.avro.AlignmentRecord;
import scala.Tuple3;

/**
* A simple test class for the JavaADAMRDD/Context. Writes an RDD to
* disk and reads it back.
*/
public class JavaADAMConduit {
public static JavaAlignmentRecordRDD conduit(JavaRDD<AlignmentRecord> rdd) throws IOException {
public static Tuple3<JavaAlignmentRecordRDD,
SequenceDictionary,
RecordGroupDictionary> conduit(JavaRDD<AlignmentRecord> rdd,
SequenceDictionary sd,
RecordGroupDictionary rgd) throws IOException {
JavaAlignmentRecordRDD recordRdd = new JavaAlignmentRecordRDD(rdd);

// make temp directory and save file
Path tempDir = Files.createTempDirectory("javaAC");
String fileName = tempDir.toString() + "/testRdd.adam";
recordRdd.adamSave(fileName);
recordRdd.adamSave(fileName, sd, rgd);

// create a new adam context and load the file
JavaADAMContext jac = new JavaADAMContext(rdd.context());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ class JavaADAMContextSuite extends ADAMFunSuite {
sparkTest("can read a small .SAM file") {
val path = resourcePath("small.sam")
val ctx = new JavaADAMContext(sc)
val reads: JavaAlignmentRecordRDD = ctx.adamRecordLoad(path)
val reads = ctx.adamRecordLoad(path)._1
assert(reads.jrdd.count() === 20)
}

sparkTest("can read a small .SAM file inside of java") {
ignore("can read a small .SAM file inside of java") {
val path = resourcePath("small.sam")
val reads: RDD[AlignmentRecord] = sc.loadAlignments(path)
val (reads, sd, rgd) = sc.loadAlignments(path)

val newReads: JavaAlignmentRecordRDD = JavaADAMConduit.conduit(reads)
val newReads = JavaADAMConduit.conduit(reads, sd, rgd)

assert(newReads.jrdd.count() === 20)
assert(newReads._1.jrdd.count() === 20)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Adam2Fastq(val args: Adam2FastqArgs) extends BDGSparkCommand[Adam2FastqArg
else
None

var reads: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = projectionOpt)
var reads: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = projectionOpt)._1

if (args.repartition != -1) {
log.info("Repartitioning reads to to '%d' partitions".format(args.repartition))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class CalculateDepth(protected val args: CalculateDepthArgs) extends BDGSparkCom

val proj = Projection(contig, start, cigar, readMapped)

val adamRDD: RDD[AlignmentRecord] = sc.loadAlignments(args.adamInputPath, projection = Some(proj))
val adamRDD: RDD[AlignmentRecord] = sc.loadAlignments(args.adamInputPath, projection = Some(proj))._1
val mappedRDD = adamRDD.filter(_.getReadMapped)

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CountReadKmers(protected val args: CountReadKmersArgs) extends BDGSparkCom
// read from disk
var adamRecords: RDD[AlignmentRecord] = sc.loadAlignments(
args.inputPath,
projection = Some(Projection(AlignmentRecordField.sequence)))
projection = Some(Projection(AlignmentRecordField.sequence)))._1

if (args.repartition != -1) {
log.info("Repartitioning reads to '%d' partitions".format(args.repartition))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class FlagStat(protected val args: FlagStatArgs) extends BDGSparkCommand[FlagSta
AlignmentRecordField.mapq,
AlignmentRecordField.failedVendorQualityChecks)

val adamFile: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = Some(projection))
val adamFile: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = Some(projection))._1

val (failedVendorQuality, passedVendorQuality) = adamFile.adamFlagStat()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class PluginExecutor(protected val args: PluginExecutorArgs) extends BDGSparkCom
}
}

val firstRdd: RDD[AlignmentRecord] = sc.loadAlignments(args.input, projection = plugin.projection)
val firstRdd: RDD[AlignmentRecord] = sc.loadAlignments(args.input, projection = plugin.projection)._1

val input = filter match {
case None => firstRdd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PrintTags(protected val args: PrintTagsArgs) extends BDGSparkCommand[Print
val toCount = if (args.count != null) args.count.split(",").toSet else Set()

val proj = Projection(attributes, primaryAlignment, readMapped, readPaired, failedVendorQualityChecks)
val rdd: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = Some(proj))
val rdd: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = Some(proj))._1
val filtered = rdd.filter(rec => !rec.getFailedVendorQualityChecks)

if (args.list != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ class Reads2Fragments(protected val args: Reads2FragmentsArgs) extends BDGSparkC
val companion = Reads2Fragments

def run(sc: SparkContext) {
sc.loadAlignments(args.inputPath).toFragments.adamParquetSave(args)
sc.loadAlignments(args.inputPath)._1.toFragments.adamParquetSave(args)
}
}
73 changes: 28 additions & 45 deletions adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class TransformArgs extends Args4jBase with ADAMSaveAnyArgs with ParquetArgs {
var inputPath: String = null
@Argument(required = true, metaVar = "OUTPUT", usage = "Location to write the transformed data in ADAM/Parquet format", index = 1)
var outputPath: String = null
@Args4jOption(required = false, name = "-limit_projection", usage = "Only project necessary fields. Only works for Parquet files.")
var limitProjection: Boolean = false
@Args4jOption(required = false, name = "-aligned_read_predicate", usage = "Only load aligned reads. Only works for Parquet files.")
var useAlignedReadPredicate: Boolean = false
@Args4jOption(required = false, name = "-sort_reads", usage = "Sort the reads by referenceId and read position")
Expand Down Expand Up @@ -229,10 +227,10 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans

def run(sc: SparkContext) {
// throw exception if aligned read predicate or projection flags are used improperly
if ((args.useAlignedReadPredicate || args.limitProjection) &&
if (args.useAlignedReadPredicate &&
(args.forceLoadBam || args.forceLoadFastq || args.forceLoadIFastq)) {
throw new IllegalArgumentException(
"-aligned_read_predicate and -limit_projection only apply to Parquet files, but a non-Parquet force load flag was passed.")
"-aligned_read_predicate only applies to Parquet files, but a non-Parquet force load flag was passed.")
}

val (rdd, sd, rgd) =
Expand All @@ -245,59 +243,34 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans
(sc.loadInterleavedFastq(args.inputPath),
SequenceDictionary.empty, RecordGroupDictionary.empty)
} else if (args.forceLoadParquet ||
args.limitProjection ||
args.useAlignedReadPredicate) {
val pred = if (args.useAlignedReadPredicate) {
Some((BooleanColumn("readMapped") === true))
} else {
None
}
val proj = if (args.limitProjection) {
Some(Projection(AlignmentRecordField.contig,
AlignmentRecordField.start,
AlignmentRecordField.end,
AlignmentRecordField.mapq,
AlignmentRecordField.readName,
AlignmentRecordField.sequence,
AlignmentRecordField.cigar,
AlignmentRecordField.qual,
AlignmentRecordField.recordGroupId,
AlignmentRecordField.recordGroupName,
AlignmentRecordField.readPaired,
AlignmentRecordField.readMapped,
AlignmentRecordField.readNegativeStrand,
AlignmentRecordField.firstOfPair,
AlignmentRecordField.secondOfPair,
AlignmentRecordField.primaryAlignment,
AlignmentRecordField.duplicateRead,
AlignmentRecordField.mismatchingPositions,
AlignmentRecordField.secondaryAlignment,
AlignmentRecordField.supplementaryAlignment))
} else {
None
}
(sc.loadParquetAlignments(args.inputPath,
predicate = pred,
projection = proj),
SequenceDictionary.empty, RecordGroupDictionary.empty)

sc.loadParquetAlignments(args.inputPath,
predicate = pred)
} else {
(sc.loadAlignments(
sc.loadAlignments(
args.inputPath,
filePath2Opt = Option(args.pairedFastqFile),
recordGroupOpt = Option(args.fastqRecordGroup),
stringency = stringency
), SequenceDictionary.empty, RecordGroupDictionary.empty)
stringency = stringency)
}

// Optionally load a second RDD and concatenate it with the first.
// Paired-FASTQ loading is avoided here because that wouldn't make sense
// given that it's already happening above.
val concatRddOpt =
val concatOpt =
Option(args.concatFilename).map(concatFilename =>
if (args.forceLoadBam) {
sc.loadBam(concatFilename)._1
sc.loadBam(concatFilename)
} else if (args.forceLoadIFastq) {
sc.loadInterleavedFastq(concatFilename)
(sc.loadInterleavedFastq(concatFilename),
SequenceDictionary.empty,
RecordGroupDictionary.empty)
} else if (args.forceLoadParquet) {
sc.loadParquetAlignments(concatFilename)
} else {
Expand All @@ -308,17 +281,27 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans
}
)

// if we have a second rdd that we are merging in, process the merger here
val (mergedRdd, mergedSd, mergedRgd) = concatOpt.fold((rdd, sd, rgd))(t => {
val (concatRdd, concatSd, concatRgd) = t
(rdd ++ concatRdd, sd ++ concatSd, rgd ++ concatRgd)
})

// run our transformation
val outputRdd = this.apply(mergedRdd, mergedRgd)

// if we are sorting, we must strip the indices from the sequence dictionary
// and sort the sequence dictionary
//
// we must do this because we do a lexicographic sort, not an index-based sort
val sdFinal = if (args.sortReads) {
sd.stripIndices
mergedSd.stripIndices
.sorted
} else {
sd
mergedSd
}

this.apply(concatRddOpt match {
case Some(concatRdd) => rdd ++ concatRdd
case None => rdd
}, rgd).adamSave(args, sdFinal, rgd, args.sortReads)
outputRdd.adamSave(args, sdFinal, mergedRgd, args.sortReads)
}

private def createKnownSnpsTable(sc: SparkContext): SnpTable = CreateKnownSnpsTable.time {
Expand Down
2 changes: 1 addition & 1 deletion adam-cli/src/main/scala/org/bdgenomics/adam/cli/View.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class View(val args: ViewArgs) extends BDGSparkCommand[ViewArgs] {

def run(sc: SparkContext) = {

val reads: RDD[AlignmentRecord] = applyFilters(sc.loadAlignments(args.inputPath))
val reads: RDD[AlignmentRecord] = applyFilters(sc.loadAlignments(args.inputPath)._1)

if (args.outputPath != null)
reads.adamAlignedRecordSave(args, SequenceDictionary.empty, RecordGroupDictionary.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Adam2FastqSuite extends ADAMFunSuite {
// Picard allows unmapped reads to set the negative strand flag and therefore reverse-complemented on output
val reads: RDD[AlignmentRecord] =
sc
.loadAlignments(readsFilepath)
.loadAlignments(readsFilepath)._1
.filter(r => r.getReadMapped != null && r.getReadMapped)

reads.adamSaveAsFastq(outputFastqR1File, Some(outputFastqR2File), sort = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class FlagStatSuite extends ADAMFunSuite {
AlignmentRecordField.mapq,
AlignmentRecordField.failedVendorQualityChecks)

val adamFile: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = Some(projection))
val adamFile: RDD[AlignmentRecord] = sc.loadAlignments(args.inputPath, projection = Some(projection))._1

val (failedVendorQuality, passedVendorQuality) = apply(adamFile)

Expand Down
Loading

0 comments on commit 19562db

Please sign in to comment.