Skip to content

Commit

Permalink
Expose host store spill (#9189)
Browse files Browse the repository at this point in the history
* Expose host store spill from catalog

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

---------

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Sep 7, 2023
1 parent d6a8338 commit 22cfc2e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ class RapidsBufferCatalog(
* Free memory in `store` by spilling buffers to the spill store synchronously.
* @param store store to spill from
* @param targetTotalSize maximum total size of this store after spilling completes
* @param stream CUDA stream to use or null for default stream
* @return optionally number of bytes that were spilled, or None if this called
* @param stream CUDA stream to use or omit for default stream
* @return optionally number of bytes that were spilled, or None if this call
* made no attempt to spill due to a detected spill race
*/
def synchronousSpill(
Expand Down Expand Up @@ -806,14 +806,24 @@ object RapidsBufferCatalog extends Logging {
deviceStorage = rdms
}

/**
* Set a `RapidsDiskStore` instance to use when instantiating our
* catalog.
*
* @note This should only be called from tests!
*/
def setDiskStorage(rdms: RapidsDiskStore): Unit = {
diskStorage = rdms
}

/**
* Set a `RapidsHostMemoryStore` instance to use when instantiating our
* catalog.
*
* @note This should only be called from tests!
*/
def setHostStorage(rdhs: RapidsHostMemoryStore): Unit = {
hostStorage = rdhs
def setHostStorage(rhms: RapidsHostMemoryStore): Unit = {
hostStorage = rhms
}

/**
Expand Down Expand Up @@ -918,6 +928,8 @@ object RapidsBufferCatalog extends Logging {

def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage

def getHostStorage: RapidsHostMemoryStore = hostStorage

def shouldUnspill: Boolean = _shouldUnspill

/**
Expand Down Expand Up @@ -978,6 +990,21 @@ object RapidsBufferCatalog extends Logging {

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager

/**
* Free memory in `store` by spilling buffers to its spill store synchronously.
* @param store store to spill from
* @param targetTotalSize maximum total size of this store after spilling completes
* @param stream CUDA stream to use or omit for default stream
* @return optionally number of bytes that were spilled, or None if this call
* made no attempt to spill due to a detected spill race
*/
def synchronousSpill(
store: RapidsBufferStore,
targetTotalSize: Long,
stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = {
singleton.synchronousSpill(store, targetTotalSize, stream)
}

/**
* Given a `MemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated
* with it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,66 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar {
}
}

test("get memory buffer after host spill") {
RapidsBufferCatalog.close()
val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType,
DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5))
val spillPriority = -10
val hostStoreMaxSize = 1L * 1024 * 1024
try {
val bm = new RapidsDiskBlockManager(new SparkConf())
val (catalog, devStore, hostStore, diskStore) =
closeOnExcept(new RapidsDiskStore(bm)) { diskStore =>
closeOnExcept(new RapidsDeviceMemoryStore()) { devStore =>
closeOnExcept(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore =>
devStore.setSpillStore(hostStore)
hostStore.setSpillStore(diskStore)
val catalog = closeOnExcept(
new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog }
(catalog, devStore, hostStore, diskStore)
}
}
}

RapidsBufferCatalog.setDeviceStorage(devStore)
RapidsBufferCatalog.setHostStorage(hostStore)
RapidsBufferCatalog.setDiskStorage(diskStore)
RapidsBufferCatalog.setCatalog(catalog)

var expectedBatch: ColumnarBatch = null
val handle = withResource(buildContiguousTable()) { ct =>
// make a copy of the table so we can compare it later to the
// one reconstituted after the spill
withResource(ct.getTable.contiguousSplit()) { copied =>
expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes)
}
RapidsBufferCatalog.addContiguousTable(
ct,
spillPriority)
}
withResource(expectedBatch) { _ =>
val spilledToHost =
RapidsBufferCatalog.synchronousSpill(
RapidsBufferCatalog.getDeviceStorage, 0)
assert(spilledToHost.isDefined && spilledToHost.get > 0)

val spilledToDisk =
RapidsBufferCatalog.synchronousSpill(
RapidsBufferCatalog.getHostStorage, 0)
assert(spilledToDisk.isDefined && spilledToDisk.get > 0)

withResource(RapidsBufferCatalog.acquireBuffer(handle)) { buffer =>
assertResult(StorageTier.DISK)(buffer.storageTier)
withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch =>
TestUtils.compareBatches(expectedBatch, actualBatch)
}
}
}
} finally {
RapidsBufferCatalog.close()
}
}

test("host buffer originated: get host memory buffer") {
val spillPriority = -10
val hostStoreMaxSize = 1L * 1024 * 1024
Expand Down

0 comments on commit 22cfc2e

Please sign in to comment.