Skip to content

Commit

Permalink
Merge pull request apache#208 from lythesia/master
Browse files Browse the repository at this point in the history
[SPARKR-188] Add profiling of R execution on worker side
Conflicts:
	pkg/inst/worker/worker.R
  • Loading branch information
shivaram authored and Davies Liu committed Apr 14, 2015
1 parent b317aa7 commit c9497a3
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 79 deletions.
9 changes: 3 additions & 6 deletions R/pkg/R/serialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ writeJobj <- function(con, value) {
}

writeString <- function(con, value) {
writeInt(con, as.integer(nchar(value) + 1))
writeBin(value, con, endian = "big")
utfVal <- enc2utf8(value)
writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1))
writeBin(utfVal, con, endian = "big")
}

writeInt <- function(con, value) {
Expand Down Expand Up @@ -189,7 +190,3 @@ writeArgs <- function(con, args) {
}
}
}

writeStrings <- function(con, stringList) {
writeLines(unlist(stringList), con)
}
57 changes: 53 additions & 4 deletions R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@

# Worker class

# Get current system time
currentTimeSecs <- function() {
as.numeric(Sys.time())
}

# Get elapsed time
elapsedSecs <- function() {
proc.time()[3]
}

# Constants
specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L)

# Timing R process boot
bootTime <- currentTimeSecs()
bootElap <- elapsedSecs()

rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
Expand Down Expand Up @@ -46,6 +63,9 @@ computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
env <- environment(computeFunc)
parent.env(env) <- .GlobalEnv # Attach under global environment.

# Timing init envs for computing
initElap <- elapsedSecs()

# Read and set broadcast variables
numBroadcastVars <- SparkR:::readInt(inputCon)
if (numBroadcastVars > 0) {
Expand All @@ -56,6 +76,9 @@ if (numBroadcastVars > 0) {
}
}

# Timing broadcast
broadcastElap <- elapsedSecs()

# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
Expand All @@ -73,14 +96,23 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()

output <- computeFunc(partition, data)
# Timing computing
computeElap <- elapsedSecs()

if (serializer == "byte") {
SparkR:::writeRawSerialize(outputCon, output)
} else if (serializer == "row") {
SparkR:::writeRowSerialize(outputCon, output)
} else {
SparkR:::writeStrings(outputCon, output)
# write lines one-by-one with flag
lapply(output, function(line) SparkR:::writeString(outputCon, line))
}
# Timing output
outputElap <- elapsedSecs()
} else {
if (deserializer == "byte") {
# Now read as many characters as described in funcLen
Expand All @@ -90,6 +122,8 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
}
# Timing reading input data for execution
inputElap <- elapsedSecs()

res <- new.env()

Expand All @@ -107,6 +141,8 @@ if (isEmpty != 0) {
res[[bucket]] <- acc
}
invisible(lapply(data, hashTupleToEnvir))
# Timing computing
computeElap <- elapsedSecs()

# Step 2: write out all of the environment as key-value pairs.
for (name in ls(res)) {
Expand All @@ -116,13 +152,26 @@ if (isEmpty != 0) {
length(res[[name]]$data) <- res[[name]]$counter
SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
}
# Timing output
outputElap <- elapsedSecs()
}
} else {
inputElap <- broadcastElap
computeElap <- broadcastElap
outputElap <- broadcastElap
}

# Report timing
SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA)
SparkR:::writeDouble(outputCon, bootTime)
SparkR:::writeDouble(outputCon, initElap - bootElap) # init
SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast
SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input
SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute
SparkR:::writeDouble(outputCon, outputElap - computeElap) # output

# End of output
if (serializer %in% c("byte", "row")) {
SparkR:::writeInt(outputCon, 0L)
}
SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM)

close(outputCon)
close(inputCon)
131 changes: 67 additions & 64 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
rLibDir: String,
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
protected var dataStream: DataInputStream = _
private var bootTime: Double = _
override def getPartitions: Array[Partition] = parent.partitions

override def compute(partition: Partition, context: TaskContext): Iterator[U] = {

// Timing start
bootTime = System.currentTimeMillis / 1000.0

// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)

Expand All @@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// the socket used to receive the output of task
val outSocket = serverSocket.accept()
val inputStream = new BufferedInputStream(outSocket.getInputStream)
val dataStream = openDataStream(inputStream)
dataStream = new DataInputStream(inputStream)
serverSocket.close()

try {
Expand Down Expand Up @@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
} else if (deserializer == SerializationFormats.ROW) {
dataOut.write(elem.asInstanceOf[Array[Byte]])
} else if (deserializer == SerializationFormats.STRING) {
// write string(for StringRRDD)
printOut.println(elem)
}
}
Expand All @@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
}.start()
}

protected def openDataStream(input: InputStream): Closeable
protected def readData(length: Int): U

protected def read(): U
protected def read(): U = {
try {
val length = dataStream.readInt()

length match {
case SpecialLengths.TIMING_DATA =>
// Timing data from R worker
val boot = dataStream.readDouble - bootTime
val init = dataStream.readDouble
val broadcast = dataStream.readDouble
val input = dataStream.readDouble
val compute = dataStream.readDouble
val output = dataStream.readDouble
logInfo(
("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
"read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
"total = %.3f s").format(
boot,
init,
broadcast,
input,
compute,
output,
boot + init + broadcast + input + compute + output))
read()
case length if length >= 0 =>
readData(length)
}
} catch {
case eof: EOFException =>
throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
}
}
}

/**
Expand All @@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag](
SerializationFormats.BYTE, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {

private var dataStream: DataInputStream = _

override protected def openDataStream(input: InputStream): Closeable = {
dataStream = new DataInputStream(input)
dataStream
}

override protected def read(): (Int, Array[Byte]) = {
try {
val length = dataStream.readInt()

length match {
case length if length == 2 =>
val hashedKey = dataStream.readInt()
val contentPairsLength = dataStream.readInt()
val contentPairs = new Array[Byte](contentPairsLength)
dataStream.readFully(contentPairs)
(hashedKey, contentPairs)
case _ => null // End of input
}
} catch {
case eof: EOFException => {
throw new SparkException("R worker exited unexpectedly (crashed)", eof)
}
}
override protected def readData(length: Int): (Int, Array[Byte]) = {
length match {
case length if length == 2 =>
val hashedKey = dataStream.readInt()
val contentPairsLength = dataStream.readInt()
val contentPairs = new Array[Byte](contentPairsLength)
dataStream.readFully(contentPairs)
(hashedKey, contentPairs)
case _ => null
}
}

lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
Expand All @@ -247,28 +270,13 @@ private class RRDD[T: ClassTag](
parent, -1, func, deserializer, serializer, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {

private var dataStream: DataInputStream = _

override protected def openDataStream(input: InputStream): Closeable = {
dataStream = new DataInputStream(input)
dataStream
}

override protected def read(): Array[Byte] = {
try {
val length = dataStream.readInt()

length match {
case length if length > 0 =>
val obj = new Array[Byte](length)
dataStream.readFully(obj, 0, length)
obj
case _ => null
}
} catch {
case eof: EOFException => {
throw new SparkException("R worker exited unexpectedly (crashed)", eof)
}
override protected def readData(length: Int): Array[Byte] = {
length match {
case length if length > 0 =>
val obj = new Array[Byte](length)
dataStream.readFully(obj)
obj
case _ => null
}
}

Expand All @@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag](
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {

private var dataStream: BufferedReader = _

override protected def openDataStream(input: InputStream): Closeable = {
dataStream = new BufferedReader(new InputStreamReader(input))
dataStream
}

override protected def read(): String = {
try {
dataStream.readLine()
} catch {
case e: IOException => {
throw new SparkException("R worker exited unexpectedly (crashed)", e)
}
override protected def readData(length: Int): String = {
length match {
case length if length > 0 =>
SerDe.readStringBytes(dataStream, length)
case _ => null
}
}

lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
}

private object SpecialLengths {
val TIMING_DATA = -1
}

private[r] class BufferedStreamThread(
in: InputStream,
name: String,
Expand Down
14 changes: 9 additions & 5 deletions core/src/main/scala/org/apache/spark/api/r/SerDe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ private[spark] object SerDe {
in.readDouble()
}

def readStringBytes(in: DataInputStream, len: Int): String = {
val bytes = new Array[Byte](len)
in.readFully(bytes)
assert(bytes(len - 1) == 0)
val str = new String(bytes.dropRight(1), "UTF-8")
str
}

def readString(in: DataInputStream): String = {
val len = in.readInt()
val asciiBytes = new Array[Byte](len)
in.readFully(asciiBytes)
assert(asciiBytes(len - 1) == 0)
val str = new String(asciiBytes.dropRight(1).map(_.toChar))
str
readStringBytes(in, len)
}

def readBoolean(in: DataInputStream): Boolean = {
Expand Down

0 comments on commit c9497a3

Please sign in to comment.