diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 8815ac4eb..2fff0a04c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -467,6 +467,15 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_EXEC_MEMORY_POOL_TYPE: ConfigEntry[String] = conf("spark.comet.exec.memoryPool") + .doc( + "The type of memory pool to be used for Comet native execution. " + + "Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " + + "'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " + + "this config is 'greedy_task_shared'.") + .stringConf + .createWithDefault("greedy_task_shared") + val COMET_SCAN_PREFETCH_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.scan.preFetch.enabled") .doc("Whether to enable pre-fetching feature of CometScan.") diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 7881f0763..ecea70254 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -48,6 +48,7 @@ Comet provides the following configuration settings. | spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true | | spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true | | spark.comet.exec.memoryFraction | The fraction of memory from Comet memory overhead that the native memory manager can use for execution. The purpose of this config is to set aside memory for untracked data structures, as well as imprecise size estimation during memory acquisition. | 0.7 | +| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared | | spark.comet.exec.project.enabled | Whether to enable project by default. | true | | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false | | spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. Compression can be disabled by setting spark.shuffle.compress=false. | zstd | diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 09caf5e27..b1190d905 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -24,6 +24,9 @@ use datafusion::{ physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream}, prelude::{SessionConfig, SessionContext}, }; +use datafusion_execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use futures::poll; use jni::{ errors::Result as JNIResult, @@ -51,20 +54,26 @@ use datafusion_comet_proto::spark_operator::Operator; use datafusion_common::ScalarValue; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use futures::stream::StreamExt; +use jni::sys::JNI_FALSE; use jni::{ objects::GlobalRef, sys::{jboolean, jdouble, jintArray, jobjectArray, jstring}, }; +use std::num::NonZeroUsize; +use std::sync::Mutex; use tokio::runtime::Runtime; use crate::execution::operators::ScanExec; use crate::execution::spark_plan::SparkPlan; use log::info; +use once_cell::sync::{Lazy, OnceCell}; /// Comet native execution context. Kept alive across JNI calls. struct ExecutionContext { /// The id of the execution context. pub id: i64, + /// Task attempt id + pub task_attempt_id: i64, /// The deserialized Spark plan pub spark_plan: Operator, /// The DataFusion root operator converted from the `spark_plan` @@ -89,6 +98,51 @@ struct ExecutionContext { pub explain_native: bool, /// Map of metrics name -> jstring object to cache jni_NewStringUTF calls. pub metrics_jstrings: HashMap>, + /// Memory pool config + pub memory_pool_config: MemoryPoolConfig, +} + +#[derive(PartialEq, Eq)] +enum MemoryPoolType { + Unified, + Greedy, + FairSpill, + GreedyTaskShared, + FairSpillTaskShared, + GreedyGlobal, + FairSpillGlobal, +} + +struct MemoryPoolConfig { + pool_type: MemoryPoolType, + pool_size: usize, +} + +impl MemoryPoolConfig { + fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self { + Self { + pool_type, + pool_size, + } + } +} + +/// The per-task memory pools keyed by task attempt id. +static TASK_SHARED_MEMORY_POOLS: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +struct PerTaskMemoryPool { + memory_pool: Arc, + num_plans: usize, +} + +impl PerTaskMemoryPool { + fn new(memory_pool: Arc) -> Self { + Self { + memory_pool, + num_plans: 0, + } + } } /// Accept serialized query plan and return the address of the native query plan. @@ -105,8 +159,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( comet_task_memory_manager_obj: JObject, batch_size: jint, use_unified_memory_manager: jboolean, + memory_pool_type: jstring, memory_limit: jlong, + memory_limit_per_task: jlong, memory_fraction: jdouble, + task_attempt_id: jlong, debug_native: jboolean, explain_native: jboolean, worker_threads: jint, @@ -145,21 +202,27 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let task_memory_manager = Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?); + let memory_pool_type = env.get_string(&JString::from_raw(memory_pool_type))?.into(); + let memory_pool_config = parse_memory_pool_config( + use_unified_memory_manager != JNI_FALSE, + memory_pool_type, + memory_limit, + memory_limit_per_task, + memory_fraction, + )?; + let memory_pool = + create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id); + // We need to keep the session context alive. Some session state like temporary // dictionaries are stored in session context. If it is dropped, the temporary // dictionaries will be dropped as well. - let session = prepare_datafusion_session_context( - batch_size as usize, - use_unified_memory_manager == 1, - memory_limit as usize, - memory_fraction, - task_memory_manager, - )?; + let session = prepare_datafusion_session_context(batch_size as usize, memory_pool)?; let plan_creation_time = start.elapsed(); let exec_context = Box::new(ExecutionContext { id, + task_attempt_id, spark_plan, root_op: None, scans: vec![], @@ -172,6 +235,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( debug_native: debug_native == 1, explain_native: explain_native == 1, metrics_jstrings: HashMap::new(), + memory_pool_config, }); Ok(Box::into_raw(exec_context) as i64) @@ -181,22 +245,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( /// Configure DataFusion session context. fn prepare_datafusion_session_context( batch_size: usize, - use_unified_memory_manager: bool, - memory_limit: usize, - memory_fraction: f64, - comet_task_memory_manager: Arc, + memory_pool: Arc, ) -> CometResult { let mut rt_config = RuntimeEnvBuilder::new().with_disk_manager(DiskManagerConfig::NewOs); - - // Check if we are using unified memory manager integrated with Spark. - if use_unified_memory_manager { - // Set Comet memory pool for native - let memory_pool = CometMemoryPool::new(comet_task_memory_manager); - rt_config = rt_config.with_memory_pool(Arc::new(memory_pool)); - } else { - // Use the memory pool from DF - rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction) - } + rt_config = rt_config.with_memory_pool(memory_pool); // Get Datafusion configuration from Spark Execution context // can be configured in Comet Spark JVM using Spark --conf parameters @@ -224,6 +276,107 @@ fn prepare_datafusion_session_context( Ok(session_ctx) } +fn parse_memory_pool_config( + use_unified_memory_manager: bool, + memory_pool_type: String, + memory_limit: i64, + memory_limit_per_task: i64, + memory_fraction: f64, +) -> CometResult { + let memory_pool_config = if use_unified_memory_manager { + MemoryPoolConfig::new(MemoryPoolType::Unified, 0) + } else { + // Use the memory pool from DF + let pool_size = (memory_limit as f64 * memory_fraction) as usize; + let pool_size_per_task = (memory_limit_per_task as f64 * memory_fraction) as usize; + match memory_pool_type.as_str() { + "fair_spill_task_shared" => { + MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task) + } + "greedy_task_shared" => { + MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, pool_size_per_task) + } + "fair_spill_global" => { + MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size) + } + "greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size), + "fair_spill" => MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task), + "greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task), + _ => { + return Err(CometError::Config(format!( + "Unsupported memory pool type: {}", + memory_pool_type + ))) + } + } + }; + Ok(memory_pool_config) +} + +fn create_memory_pool( + memory_pool_config: &MemoryPoolConfig, + comet_task_memory_manager: Arc, + task_attempt_id: i64, +) -> Arc { + const NUM_TRACKED_CONSUMERS: usize = 10; + match memory_pool_config.pool_type { + MemoryPoolType::Unified => { + // Set Comet memory pool for native + let memory_pool = CometMemoryPool::new(comet_task_memory_manager); + Arc::new(memory_pool) + } + MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + MemoryPoolType::FairSpill => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + MemoryPoolType::GreedyGlobal => { + static GLOBAL_MEMORY_POOL_GREEDY: OnceCell> = OnceCell::new(); + let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| { + Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )) + }); + Arc::clone(memory_pool) + } + MemoryPoolType::FairSpillGlobal => { + static GLOBAL_MEMORY_POOL_FAIR: OnceCell> = OnceCell::new(); + let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| { + Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )) + }); + Arc::clone(memory_pool) + } + MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared => { + let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap(); + let per_task_memory_pool = + memory_pool_map.entry(task_attempt_id).or_insert_with(|| { + let pool: Arc = + if memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared { + Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )) + } else { + Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )) + }; + PerTaskMemoryPool::new(pool) + }); + per_task_memory_pool.num_plans += 1; + Arc::clone(&per_task_memory_pool.memory_pool) + } + } +} + /// Prepares arrow arrays for output. fn prepare_output( env: &mut JNIEnv, @@ -407,6 +560,22 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( ) { try_unwrap_or_throw(&e, |_| unsafe { let execution_context = get_execution_context(exec_context); + if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared + || execution_context.memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared + { + // Decrement the number of native plans using the per-task shared memory pool, and + // remove the memory pool if the released native plan is the last native plan using it. + let task_attempt_id = execution_context.task_attempt_id; + let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap(); + if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) { + per_task_memory_pool.num_plans -= 1; + if per_task_memory_pool.num_plans == 0 { + // Drop the memory pool from the per-task memory pool map if there are no + // more native plans using it. + memory_pool_map.remove(&task_attempt_id); + } + } + } let _: Box = Box::from_raw(execution_context); Ok(()) }) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 04d930695..0b90a91c7 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.sql.comet.CometMetricNode import org.apache.spark.sql.vectorized._ -import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS} +import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS} import org.apache.comet.vector.NativeUtil /** @@ -72,8 +72,11 @@ class CometExecIterator( new CometTaskMemoryManager(id), batchSize = COMET_BATCH_SIZE.get(), use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false), + memory_pool_type = COMET_EXEC_MEMORY_POOL_TYPE.get(), memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf), + memory_limit_per_task = getMemoryLimitPerTask(conf), memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(), + task_attempt_id = TaskContext.get().taskAttemptId, debug = COMET_DEBUG_ENABLED.get(), explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), workerThreads = COMET_WORKER_THREADS.get(), @@ -84,6 +87,30 @@ class CometExecIterator( private var currentBatch: ColumnarBatch = null private var closed: Boolean = false + private def getMemoryLimitPerTask(conf: SparkConf): Long = { + val numCores = numDriverOrExecutorCores(conf).toFloat + val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf) + val coresPerTask = conf.get("spark.task.cpus", "1").toFloat + // example 16GB maxMemory * 16 cores with 4 cores per task results + // in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB + (maxMemory.toFloat * coresPerTask / numCores).toLong + } + + private def numDriverOrExecutorCores(conf: SparkConf): Int = { + def convertToInt(threads: String): Int = { + if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt + } + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r + val master = conf.get("spark.master") + master match { + case "local" => 1 + case LOCAL_N_REGEX(threads) => convertToInt(threads) + case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) + case _ => conf.get("spark.executor.cores", "1").toInt + } + } + def getNextBatch(): Option[ColumnarBatch] = { assert(partitionIndex >= 0 && partitionIndex < numParts) diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 083c0f2b5..5fd84989b 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -52,8 +52,11 @@ class Native extends NativeBase { taskMemoryManager: CometTaskMemoryManager, batchSize: Int, use_unified_memory_manager: Boolean, + memory_pool_type: String, memory_limit: Long, + memory_limit_per_task: Long, memory_fraction: Double, + task_attempt_id: Long, debug: Boolean, explain: Boolean, workerThreads: Int,