Skip to content

Commit

Permalink
[SPARK-50274][CORE] Guard against use-after-close in DirectByteBuffer…
Browse files Browse the repository at this point in the history
…OutputStream

### What changes were proposed in this pull request?

`DirectByteBufferOutputStream#close()` calls `StorageUtils.dispose()` to free its direct byte buffer. This puts the object into an unspecified and dangerous state after being closed, and can cause unpredictable JVM crashes if it the object is used after close.

This PR makes this safer by modifying `close()` to place the object into a known-closed state, and modifying all methods to assert not closed.

To minimize the performance impact from the extra checks, this PR also changes `DirectByteBufferOutputStream#buffer` from `private` to `private[this]`, which should produce more efficient direct field accesses.

### Why are the changes needed?

Improves debuggability for users of DirectByteBufferOutputStream such as PythonRunner.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added a test in DirectByteBufferOutputStreamSuite to verify that use after close throws IllegalStateException.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48807 from ankurdave/SPARK-50274-DirectByteBufferOutputStream-checkNotClosed.

Authored-by: Ankur Dave <ankurdave@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ankurdave authored and HyukjinKwon committed Nov 10, 2024
1 parent 2cdbede commit e490dd7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.util
import java.io.OutputStream
import java.nio.ByteBuffer

import org.apache.spark.SparkException
import org.apache.spark.storage.StorageUtils
import org.apache.spark.unsafe.Platform

Expand All @@ -29,16 +30,18 @@ import org.apache.spark.unsafe.Platform
* @param capacity The initial capacity of the direct byte buffer
*/
private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputStream {
private var buffer = Platform.allocateDirectBuffer(capacity)
private[this] var buffer = Platform.allocateDirectBuffer(capacity)

def this() = this(32)

override def write(b: Int): Unit = {
checkNotClosed()
ensureCapacity(buffer.position() + 1)
buffer.put(b.toByte)
}

override def write(b: Array[Byte], off: Int, len: Int): Unit = {
checkNotClosed()
ensureCapacity(buffer.position() + len)
buffer.put(b, off, len)
}
Expand All @@ -63,15 +66,29 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS
buffer = newBuffer
}

def reset(): Unit = buffer.clear()
private def checkNotClosed(): Unit = {
if (buffer == null) {
throw SparkException.internalError(
"Cannot call methods on a closed DirectByteBufferOutputStream")
}
}

def reset(): Unit = {
checkNotClosed()
buffer.clear()
}

def size(): Int = buffer.position()
def size(): Int = {
checkNotClosed()
buffer.position()
}

/**
* Any subsequent call to [[close()]], [[write()]], [[reset()]] will invalidate the buffer
* returned by this method.
*/
def toByteBuffer: ByteBuffer = {
checkNotClosed()
val outputBuffer = buffer.duplicate()
outputBuffer.flip()
outputBuffer
Expand All @@ -80,6 +97,7 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS
override def close(): Unit = {
// Eagerly free the direct byte buffer without waiting for GC to reduce memory pressure.
StorageUtils.dispose(buffer)
buffer = null
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import org.apache.spark.{SparkException, SparkFunSuite}

class DirectByteBufferOutputStreamSuite extends SparkFunSuite {
test("use after close") {
val o = new DirectByteBufferOutputStream()
val size = 1000
o.write(new Array[Byte](size), 0, size)
val b = o.toByteBuffer
o.close()

// Using `o` after close should throw an exception rather than crashing.
assertThrows[SparkException] { o.write(123) }
assertThrows[SparkException] { o.write(new Array[Byte](size), 0, size) }
assertThrows[SparkException] { o.reset() }
assertThrows[SparkException] { o.size() }
assertThrows[SparkException] { o.toByteBuffer }

// Using `b` after `o` is closed may crash.
// val arr = new Array[Byte](size)
// b.get(arr)
}
}

0 comments on commit e490dd7

Please sign in to comment.