Skip to content

Commit

Permalink
[serve] Clean up microbenchmark & don't pass raw `starlette.requests.…
Browse files Browse the repository at this point in the history
…Request` object (ray-project#37040) (ray-project#37057)

- Update test to use new API.
- Clean up test output: disable access log, print results using `print` instead of logger (which wasn't logged to stdout).
- Don't pass the raw starlette request object (it isn't serializable).
  • Loading branch information
edoakes authored Jul 3, 2023
1 parent c887c0b commit 6e7c7ad
Showing 1 changed file with 80 additions and 68 deletions.
148 changes: 80 additions & 68 deletions python/ray/serve/benchmarks/microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,37 @@
import aiohttp
import asyncio
import logging
from pprint import pprint
import time
import requests
from typing import Callable, Dict, Union

import numpy as np
from starlette.requests import Request

import ray
from ray import serve

logger = logging.getLogger(__file__)
from ray.serve.handle import RayServeHandle

NUM_CLIENTS = 8
CALLS_PER_BATCH = 100


async def timeit(name, fn, multiplier=1):
async def timeit(name: str, fn: Callable, multiplier: int = 1):
# warmup
start = time.time()
while time.time() - start < 1:
while time.time() - start < 0.1:
await fn()
# real run
stats = []
for _ in range(4):
for _ in range(1):
start = time.time()
count = 0
while time.time() - start < 2:
while time.time() - start < 0.1:
await fn()
count += 1
end = time.time()
stats.append(multiplier * count / (end - start))
logger.info(
print(
"\t{} {} +- {} requests/s".format(
name, round(np.mean(stats), 2), round(np.std(stats), 2)
)
Expand All @@ -43,7 +44,7 @@ async def timeit(name, fn, multiplier=1):


async def fetch(session, data):
async with session.get("http://localhost:8000/api", data=data) as response:
async with session.get("http://localhost:8000/", data=data) as response:
response = await response.text()
assert response == "ok", response

Expand All @@ -61,68 +62,76 @@ async def do_queries(self, num, data):
await fetch(self.session, data)


async def trial(
result_json,
intermediate_handles,
num_replicas,
max_batch_size,
max_concurrent_queries,
data_size,
def build_app(
intermediate_handles: bool,
num_replicas: int,
max_batch_size: int,
max_concurrent_queries: int,
):
trial_key_base = (
f"replica:{num_replicas}/batch_size:{max_batch_size}/"
f"concurrent_queries:{max_concurrent_queries}/"
f"data_size:{data_size}/intermediate_handle:{intermediate_handles}"
)

logger.info(
f"intermediate_handles={intermediate_handles},"
f"num_replicas={num_replicas},"
f"max_batch_size={max_batch_size},"
f"max_concurrent_queries={max_concurrent_queries},"
f"data_size={data_size}"
)

deployment_name = "api"
if intermediate_handles:
deployment_name = "downstream"

@serve.deployment(name="api", max_concurrent_queries=1000)
class ForwardActor:
def __init__(self):
self.handle = None
@serve.deployment(max_concurrent_queries=1000)
class Upstream:
def __init__(self, handle: RayServeHandle):
self._handle = handle

async def __call__(self, req):
if self.handle is None:
self.handle = serve.get_deployment(deployment_name).get_handle(
sync=False
)
obj_ref = await self.handle.remote(req)
return await obj_ref
# Turn off access log.
logging.getLogger("ray.serve").setLevel(logging.WARNING)

ForwardActor.deploy()
routes = requests.get("http://localhost:8000/-/routes").json()
assert "/api" in routes, routes
async def __call__(self, req: Request):
obj_ref = await self._handle.remote(await req.body())
return await obj_ref

@serve.deployment(
name=deployment_name,
num_replicas=num_replicas,
max_concurrent_queries=max_concurrent_queries,
)
class D:
class Downstream:
def __init__(self):
# Turn off access log.
logging.getLogger("ray.serve").setLevel(logging.WARNING)

@serve.batch(max_batch_size=max_batch_size)
async def batch(self, reqs):
return [b"ok"] * len(reqs)

async def __call__(self, req):
async def __call__(self, req: Union[bytes, Request]):
if max_batch_size > 1:
return await self.batch(req)
else:
return b"ok"

D.deploy()
routes = requests.get("http://localhost:8000/-/routes").json()
assert f"/{deployment_name}" in routes, routes
if intermediate_handles:
return Upstream.bind(Downstream.bind())
else:
return Downstream.bind()


async def trial(
intermediate_handles: bool,
num_replicas: int,
max_batch_size: int,
max_concurrent_queries: int,
data_size: str,
) -> Dict[str, float]:
results = {}

trial_key_base = (
f"replica:{num_replicas}/batch_size:{max_batch_size}/"
f"concurrent_queries:{max_concurrent_queries}/"
f"data_size:{data_size}/intermediate_handle:{intermediate_handles}"
)

print(
f"intermediate_handles={intermediate_handles},"
f"num_replicas={num_replicas},"
f"max_batch_size={max_batch_size},"
f"max_concurrent_queries={max_concurrent_queries},"
f"data_size={data_size}"
)

app = build_app(
intermediate_handles, num_replicas, max_batch_size, max_concurrent_queries
)
serve.run(app)

if data_size == "small":
data = None
Expand All @@ -142,8 +151,8 @@ async def single_client():
single_client,
multiplier=CALLS_PER_BATCH,
)
key = "num_client:1/" + trial_key_base
result_json.update({key: single_client_avg_tps})
key = f"num_client:1/{trial_key_base}"
results[key] = single_client_avg_tps

clients = [Client.remote() for _ in range(NUM_CLIENTS)]
ray.get([client.ready.remote() for client in clients])
Expand All @@ -156,14 +165,13 @@ async def many_clients():
many_clients,
multiplier=CALLS_PER_BATCH * len(clients),
)
key = f"num_client:{len(clients)}/" + trial_key_base
result_json.update({key: multi_client_avg_tps})

logger.info(result_json)
results[f"num_client:{len(clients)}/{trial_key_base}"] = multi_client_avg_tps
return results


async def main():
result_json = {}
results = {}
for intermediate_handles in [False, True]:
for num_replicas in [1, 8]:
for max_batch_size, max_concurrent_queries in [
Expand All @@ -173,15 +181,19 @@ async def main():
]:
# TODO(edoakes): large data causes broken pipe errors.
for data_size in ["small"]:
await trial(
result_json,
intermediate_handles,
num_replicas,
max_batch_size,
max_concurrent_queries,
data_size,
results.update(
await trial(
intermediate_handles,
num_replicas,
max_batch_size,
max_concurrent_queries,
data_size,
)
)
return result_json

print("Results from all conditions:")
pprint(results)
return results


if __name__ == "__main__":
Expand Down

0 comments on commit 6e7c7ad

Please sign in to comment.