Skip to content

Commit

Permalink
[SPARK-17876] Write StructuredStreaming WAL to a stream instead of ma…
Browse files Browse the repository at this point in the history
…terializing all at once

## What changes were proposed in this pull request?

The CompactibleFileStreamLog materializes the whole metadata log in memory as a String. This can cause issues when there are lots of files that are being committed, especially during a compaction batch.
You may come across stacktraces that look like:
```
java.lang.OutOfMemoryError: Requested array size exceeds VM limit
at java.lang.StringCoding.encode(StringCoding.java:350)
at java.lang.String.getBytes(String.java:941)
at org.apache.spark.sql.execution.streaming.FileStreamSinkLog.serialize(FileStreamSinkLog.scala:127)

```
The safer way is to write to an output stream so that we don't have to materialize a huge string.

## How was this patch tested?

Existing unit tests

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #15437 from brkyvz/ser-to-stream.
  • Loading branch information
brkyvz authored and zsxwing committed Oct 13, 2016
1 parent 21cb59f commit edeb51a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.execution.streaming

import java.io.IOException
import java.io.{InputStream, IOException, OutputStream}
import java.nio.charset.StandardCharsets.UTF_8

import scala.io.{Source => IOSource}
import scala.reflect.ClassTag

import org.apache.hadoop.fs.{Path, PathFilter}
Expand Down Expand Up @@ -93,20 +94,25 @@ abstract class CompactibleFileStreamLog[T: ClassTag](
}
}

override def serialize(logData: Array[T]): Array[Byte] = {
(metadataLogVersion +: logData.map(serializeData)).mkString("\n").getBytes(UTF_8)
override def serialize(logData: Array[T], out: OutputStream): Unit = {
// called inside a try-finally where the underlying stream is closed in the caller
out.write(metadataLogVersion.getBytes(UTF_8))
logData.foreach { data =>
out.write('\n')
out.write(serializeData(data).getBytes(UTF_8))
}
}

override def deserialize(bytes: Array[Byte]): Array[T] = {
val lines = new String(bytes, UTF_8).split("\n")
if (lines.length == 0) {
override def deserialize(in: InputStream): Array[T] = {
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
if (!lines.hasNext) {
throw new IllegalStateException("Incomplete log file")
}
val version = lines(0)
val version = lines.next()
if (version != metadataLogVersion) {
throw new IllegalStateException(s"Unknown log version: ${version}")
}
lines.slice(1, lines.length).map(deserializeData)
lines.map(deserializeData).toArray
}

override def add(batchId: Long, logs: Array[T]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.execution.streaming

import java.io.{FileNotFoundException, IOException}
import java.nio.ByteBuffer
import java.io.{FileNotFoundException, InputStream, IOException, OutputStream}
import java.util.{ConcurrentModificationException, EnumSet, UUID}

import scala.reflect.ClassTag
Expand All @@ -29,7 +28,6 @@ import org.apache.hadoop.fs._
import org.apache.hadoop.fs.permission.FsPermission

import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.UninterruptibleThread
Expand Down Expand Up @@ -88,12 +86,16 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String)
}
}

protected def serialize(metadata: T): Array[Byte] = {
JavaUtils.bufferToArray(serializer.serialize(metadata))
protected def serialize(metadata: T, out: OutputStream): Unit = {
// called inside a try-finally where the underlying stream is closed in the caller
val outStream = serializer.serializeStream(out)
outStream.writeObject(metadata)
}

protected def deserialize(bytes: Array[Byte]): T = {
serializer.deserialize[T](ByteBuffer.wrap(bytes))
protected def deserialize(in: InputStream): T = {
// called inside a try-finally where the underlying stream is closed in the caller
val inStream = serializer.deserializeStream(in)
inStream.readObject[T]()
}

/**
Expand All @@ -114,7 +116,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String)
// Only write metadata when the batch has not yet been written
Thread.currentThread match {
case ut: UninterruptibleThread =>
ut.runUninterruptibly { writeBatch(batchId, serialize(metadata)) }
ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) }
case _ =>
throw new IllegalStateException(
"HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread")
Expand All @@ -129,17 +131,17 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String)
* There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a
* valid behavior, we still need to prevent it from destroying the files.
*/
private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = {
private def writeBatch(batchId: Long, metadata: T, writer: (T, OutputStream) => Unit): Unit = {
// Use nextId to create a temp file
var nextId = 0
while (true) {
val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp")
try {
val output = fileManager.create(tempPath)
try {
output.write(bytes)
writer(metadata, output)
} finally {
output.close()
IOUtils.closeQuietly(output)
}
try {
// Try to commit the batch
Expand Down Expand Up @@ -193,10 +195,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String)
if (fileManager.exists(batchMetadataFile)) {
val input = fileManager.open(batchMetadataFile)
try {
val bytes = IOUtils.toByteArray(input)
Some(deserialize(bytes))
Some(deserialize(input))
} finally {
input.close()
IOUtils.closeQuietly(input)
}
} else {
logDebug(s"Unable to find batch $batchMetadataFile")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.streaming

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.charset.StandardCharsets.UTF_8

import org.apache.spark.SparkFunSuite
Expand Down Expand Up @@ -133,9 +134,12 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
|{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"}
|{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin
// scalastyle:on
assert(expected === new String(sinkLog.serialize(logs), UTF_8))

assert(VERSION === new String(sinkLog.serialize(Array()), UTF_8))
val baos = new ByteArrayOutputStream()
sinkLog.serialize(logs, baos)
assert(expected === baos.toString(UTF_8.name()))
baos.reset()
sinkLog.serialize(Array(), baos)
assert(VERSION === baos.toString(UTF_8.name()))
}
}

Expand Down Expand Up @@ -174,9 +178,9 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
blockSize = 30000L,
action = FileStreamSinkLog.ADD_ACTION))

assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8)))
assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8))))

assert(Nil === sinkLog.deserialize(VERSION.getBytes(UTF_8)))
assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8))))
}
}

Expand Down

0 comments on commit edeb51a

Please sign in to comment.