Skip to content

Commit

Permalink
Add flag to parse query times from logs for serverless DPB services.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719033039
  • Loading branch information
dorellang authored and copybara-github committed Jan 23, 2025
1 parent a26965d commit 939aef5
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 77 deletions.
36 changes: 28 additions & 8 deletions perfkitbenchmarker/dpb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
"""

import abc
from collections.abc import MutableMapping
from collections.abc import Callable, MutableMapping
import dataclasses
import datetime
import logging
import os
import shutil
import tempfile
from typing import Dict, List, Type
from typing import Dict, List, Type, TypeAlias

from absl import flags
import jinja2
Expand Down Expand Up @@ -164,18 +164,37 @@ class JobSubmissionError(errors.Benchmarks.RunError):
pass


FetchOutputFn: TypeAlias = Callable[[], tuple[str | None, str | None]]


@dataclasses.dataclass
class JobResult:
"""Data class for the timing of a successful DPB job."""
"""Data class for the timing of a successful DPB job.
Attributes:
run_time: Service reported execution time.
pending_time: Service reported pending time (0 if service does not report).
stdout: Job's stdout. Call FetchOutput before to ensure it's populated.
stderr: Job's stderr. Call FetchOutput before to ensure it's populated.
fetch_output_fn: Callback expected to return a 2-tuple of str or None whose
values correspond to stdout and stderr respectively. This is called by
FetchOutput which updates stdout and stderr if their respective value in
this callback's return tuple is not None. Defaults to a no-op.
"""

# Service reported execution time
run_time: float
# Service reported pending time (0 if service does not report).
pending_time: float = 0
# Stdout of the job.
stdout: str = ''
# Stderr of the job.
stderr: str = ''
fetch_output_fn: FetchOutputFn = lambda: (None, None)

def FetchOutput(self):
"""Populates stdout and stderr according to fetch_output_fn callback."""
stdout, stderr = self.fetch_output_fn()
if stdout is not None:
self.stdout = stdout
if stderr is not None:
self.stderr = stderr

@property
def wall_time(self) -> float:
Expand Down Expand Up @@ -795,7 +814,8 @@ def CheckPrerequisites(self):
if self.cloud == 'AWS' and not aws_flags.AWS_EC2_INSTANCE_PROFILE.value:
raise ValueError(
'EC2 based Spark and Hadoop services require '
'--aws_ec2_instance_profile.')
'--aws_ec2_instance_profile.'
)

def GetClusterCreateTime(self) -> float | None:
"""Returns the cluster creation time.
Expand Down
164 changes: 124 additions & 40 deletions perfkitbenchmarker/linux_benchmarks/dpb_sparksql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@
import json
import logging
import os
import re
import time
from typing import List
from typing import Any, List

from absl import flags
from perfkitbenchmarker import configs
Expand All @@ -60,6 +61,7 @@
from perfkitbenchmarker import object_storage_service
from perfkitbenchmarker import sample
from perfkitbenchmarker import temp_dir
from perfkitbenchmarker import vm_util

BENCHMARK_NAME = 'dpb_sparksql_benchmark'

Expand Down Expand Up @@ -112,9 +114,34 @@
'The record format to use when connecting to BigQuery storage. See: '
'https://github.com/GoogleCloudDataproc/spark-bigquery-connector#properties',
)
_FETCH_RESULTS_FROM_LOGS = flags.DEFINE_bool(
'dpb_sparksql_fetch_results_from_logs',
False,
'Make the query runner script to log query timings to stdout/stderr '
' instead of writing them to some object storage location. Reduces runner '
' latency (and hence its total wall time), but it is not supported by all '
' DPB services.',
)

FLAGS = flags.FLAGS

LOG_RESULTS_PATTERN = (
r'----@spark_sql_runner:results_start@----'
r'(.*)'
r'----@spark_sql_runner:results_end@----'
)
POLL_LOGS_INTERVAL = 60
POLL_LOGS_TIMEOUT = 6 * 60
RESULTS_FROM_LOGS_SUPPORTED_DPB_SERVICES = (
dpb_constants.DATAPROC_SERVERLESS,
dpb_constants.EMR_SERVERLESS,
dpb_constants.GLUE,
)


class QueryResultsNotReadyError(Exception):
"""Used to signal a job is still running."""


def GetConfig(user_config):
return configs.LoadConfig(BENCHMARK_CONFIG, user_config, BENCHMARK_NAME)
Expand All @@ -129,7 +156,6 @@ def CheckPrerequisites(benchmark_config):
Raises:
Config.InvalidValue: On encountering invalid configuration.
"""
del benchmark_config # unused
if not FLAGS.dpb_sparksql_data and FLAGS.dpb_sparksql_create_hive_tables:
raise errors.Config.InvalidValue(
'You must pass dpb_sparksql_data with dpb_sparksql_create_hive_tables'
Expand Down Expand Up @@ -160,7 +186,7 @@ def CheckPrerequisites(benchmark_config):
bool(_BIGQUERY_TABLES.value),
bool(FLAGS.dpb_sparksql_database),
])
== 1
> 1
):
logging.warning(
'You should only pass one of them: --dpb_sparksql_data,'
Expand All @@ -176,6 +202,16 @@ def CheckPrerequisites(benchmark_config):
'--dpb_sparksql_simultaneous is not compatible with '
'--dpb_sparksql_streams.'
)
if (
_FETCH_RESULTS_FROM_LOGS.value
and benchmark_config.dpb_service.service_type
not in RESULTS_FROM_LOGS_SUPPORTED_DPB_SERVICES
):
raise errors.Config.InvalidValue(
f'Current dpb service {benchmark_config.dpb_service.service_type!r} is'
' not supported for --dpb_sparksql_fetch_results_from_logs. Supported'
f' dpb services are: {RESULTS_FROM_LOGS_SUPPORTED_DPB_SERVICES!r}'
)


def Prepare(benchmark_spec):
Expand Down Expand Up @@ -275,7 +311,7 @@ def Run(benchmark_spec):
# Run PySpark Spark SQL Runner
report_dir, job_result = _RunQueries(benchmark_spec)

results = _GetQuerySamples(storage_service, report_dir, metadata)
results = _GetQuerySamples(storage_service, report_dir, job_result, metadata)
results += _GetGlobalSamples(results, cluster, job_result, metadata)
results += _GetPrepareSamples(benchmark_spec, metadata)
return results
Expand Down Expand Up @@ -319,7 +355,10 @@ def _RunQueries(benchmark_spec) -> tuple[str, dpb_service.JobResult]:
else:
for stream in benchmark_spec.query_streams:
args += ['--sql-queries', ','.join(stream)]
args += ['--report-dir', report_dir]
if _FETCH_RESULTS_FROM_LOGS.value:
args += ['--log-results', 'True']
else:
args += ['--report-dir', report_dir]
if FLAGS.dpb_sparksql_database:
args += ['--database', FLAGS.dpb_sparksql_database]
if FLAGS.dpb_sparksql_create_hive_tables:
Expand Down Expand Up @@ -365,46 +404,33 @@ def _RunQueries(benchmark_spec) -> tuple[str, dpb_service.JobResult]:
def _GetQuerySamples(
storage_service: object_storage_service.ObjectStorageService,
report_dir: str,
job_result: dpb_service.JobResult,
base_metadata: MutableMapping[str, str],
) -> list[sample.Sample]:
"""Get Sample objects from metrics storage path."""
# Spark can only write data to directories not files. So do a recursive copy
# of that directory and then search it for the collection of JSON files with
# the results.
temp_run_dir = temp_dir.GetRunDirPath()
storage_service.Copy(report_dir, temp_run_dir, recursive=True)
report_files = []
for dir_name, _, files in os.walk(
os.path.join(temp_run_dir, os.path.basename(report_dir))
):
for filename in files:
if filename.endswith('.json'):
report_file = os.path.join(dir_name, filename)
report_files.append(report_file)
logging.info("Found report file '%s'.", report_file)
if not report_files:
raise errors.Benchmarks.RunError('Job report not found.')
"""Get Sample objects from job's logs."""

if _FETCH_RESULTS_FROM_LOGS.value:
query_results = _FetchResultsFromLogs(job_result)
else:
query_results = _FetchResultsFromStorage(storage_service, report_dir)

samples = []
for report_file in report_files:
with open(report_file) as file:
for line in file:
result = json.loads(line)
logging.info('Timing: %s', result)
query_id = result['query_id']
assert query_id
metadata_copy = base_metadata.copy()
metadata_copy['query'] = query_id
if FLAGS.dpb_sparksql_streams:
metadata_copy['stream'] = result['stream']
samples.append(
sample.Sample(
'sparksql_run_time',
result['duration'],
'seconds',
metadata_copy,
)
for result in query_results:
logging.info('Timing: %s', result)
query_id = result['query_id']
assert query_id
metadata_copy = dict(base_metadata)
metadata_copy['query'] = query_id
if FLAGS.dpb_sparksql_streams:
metadata_copy['stream'] = result['stream']
samples.append(
sample.Sample(
'sparksql_run_time',
result['duration'],
'seconds',
metadata_copy,
)
)
return samples


Expand Down Expand Up @@ -524,6 +550,64 @@ def _GetPrepareSamples(
return samples


def _FetchResultsFromStorage(
storage_service: object_storage_service.ObjectStorageService,
report_dir: str,
) -> list[dict[str, Any]]:
"""Get Sample objects from metrics storage path."""
# Spark can only write data to directories not files. So do a recursive copy
# of that directory and then search it for the collection of JSON files with
# the results.
temp_run_dir = temp_dir.GetRunDirPath()
storage_service.Copy(report_dir, temp_run_dir, recursive=True)
report_files = []
for dir_name, _, files in os.walk(
os.path.join(temp_run_dir, os.path.basename(report_dir))
):
for filename in files:
if filename.endswith('.json'):
report_file = os.path.join(dir_name, filename)
report_files.append(report_file)
logging.info("Found report file '%s'.", report_file)
if not report_files:
raise errors.Benchmarks.RunError('Job report not found.')
results = []
for report_file in report_files:
with open(report_file) as file:
for line in file:
results.append(json.loads(line))
return results


@vm_util.Retry(
timeout=POLL_LOGS_TIMEOUT,
poll_interval=POLL_LOGS_INTERVAL,
fuzz=0,
retryable_exceptions=(QueryResultsNotReadyError,),
)
def _FetchResultsFromLogs(job_result: dpb_service.JobResult):
"""Get samples from job results logs."""
job_result.FetchOutput()
logs = '\n'.join([job_result.stdout or '', job_result.stderr])
query_results = _ParseResultsFromLogs(logs)
if query_results is None:
raise QueryResultsNotReadyError
return query_results


def _ParseResultsFromLogs(logs: str) -> list[dict[str, Any]] | None:
json_str_match = re.search(LOG_RESULTS_PATTERN, logs, re.DOTALL)
if not json_str_match:
return None
try:
results = json.loads(json_str_match.group(1))
except ValueError as e:
raise errors.Benchmarks.RunError(
'Corrupted results from logs cannot be deserialized.'
) from e
return results


def _GetDistCpMetadata(base_dir: str, subdirs: List[str], extra_metadata=None):
"""Compute list of table metadata for spark_sql_distcp metadata flags."""
metadata = []
Expand Down
Loading

0 comments on commit 939aef5

Please sign in to comment.