Skip to content

Commit

Permalink
[core][state] Task backend - Port state API [6/n] (#31278)
Browse files Browse the repository at this point in the history
**Previous PRs:**
 - #30829: 
 - #30953: 
 - #30867: 
 - #30979: 
 - #30934
 - #31207
**This PR:** 

With the change, 
- `list_tasks` now will return tasks with attempt number as an additional column. 
- `get_task` might return multiple task attempt entries if there are retries. 


There is also some plumbing in the test and in core (esp  in the test logic) given the changes. Major changes in the PR are: 
- Add limit support to `GcsTaskManager`
- Change the state aggregator to get tasks from GCS.
  • Loading branch information
rickyyx authored and AmeerHajAli committed Jan 12, 2023
1 parent 7b8a421 commit ffb6711
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 183 deletions.
126 changes: 71 additions & 55 deletions dashboard/state_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import ray.dashboard.memory_utils as memory_utils
import ray.dashboard.utils as dashboard_utils
from ray._private.utils import binary_to_hex
from ray.core.generated.common_pb2 import TaskStatus
import ray.core.generated.common_pb2 as common_pb2

from ray.experimental.state.common import (
ActorState,
ListApiOptions,
Expand Down Expand Up @@ -359,70 +359,86 @@ async def list_tasks(self, *, option: ListApiOptions) -> ListApiResponse:
{task_id -> task_data_in_dict}
task_data_in_dict's schema is in TaskState
"""
raylet_ids = self._client.get_all_registered_raylet_ids()
replies = await asyncio.gather(
*[
self._client.get_task_info(node_id, timeout=option.timeout)
for node_id in raylet_ids
],
return_exceptions=True,
)

unresponsive_nodes = 0
running_task_id = set()
successful_replies = []
total_tasks = 0
for reply in replies:
if isinstance(reply, DataSourceUnavailable):
unresponsive_nodes += 1
continue
elif isinstance(reply, Exception):
raise reply

successful_replies.append(reply)
total_tasks += reply.total
for task_id in reply.running_task_ids:
running_task_id.add(binary_to_hex(task_id))

partial_failure_warning = None
if len(raylet_ids) > 0 and unresponsive_nodes > 0:
warning_msg = NODE_QUERY_FAILURE_WARNING.format(
type="raylet",
total=len(raylet_ids),
network_failures=unresponsive_nodes,
log_command="raylet.out",
)
if unresponsive_nodes == len(raylet_ids):
raise DataSourceUnavailable(warning_msg)
partial_failure_warning = (
f"The returned data may contain incomplete result. {warning_msg}"
)
try:
reply = await self._client.get_all_task_info(timeout=option.timeout)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)

result = []
for reply in successful_replies:
assert not isinstance(reply, Exception)
tasks = reply.owned_task_info_entries
for task in tasks:
data = self._message_to_dict(
message=task,
fields_to_decode=["task_id", "job_id", "node_id", "actor_id"],
def _to_task_state(task_attempt: dict) -> dict:
"""
Convert a dict repr of `TaskEvents` to a dic repr of `TaskState`
"""
task_state = {}
task_info = task_attempt.get("task_info", {})
state_updates = task_attempt.get("state_updates", None)

if state_updates is None:
return {}

# Convert those settable fields
mappings = [
(
task_info,
[
"task_id",
"name",
"actor_id",
"type",
"func_or_class_name",
"language",
"required_resources",
"runtime_env_info",
],
),
(task_attempt, ["task_id", "attempt_number", "job_id"]),
(state_updates, ["node_id"]),
]
for src, keys in mappings:
for key in keys:
task_state[key] = src.get(key)

# Get the most updated scheduling_state by state transition ordering.
def _get_most_recent_status(task_state: dict) -> str:
# Reverse the order as defined in protobuf for the most recent state.
for status_name in reversed(common_pb2.TaskStatus.keys()):
key = f"{status_name.lower()}_ts"
if state_updates.get(key):
return status_name
return common_pb2.TaskStatus.Name(common_pb2.NIL)

task_state["scheduling_state"] = _get_most_recent_status(state_updates)

return task_state

result = [
_to_task_state(
self._message_to_dict(
message=message,
fields_to_decode=[
"task_id",
"job_id",
"node_id",
"actor_id",
"parent_task_id",
],
)
)
for message in reply.events_by_task
]
result = [e for e in result if len(e) > 0]

if data["task_id"] in running_task_id:
data["scheduling_state"] = TaskStatus.DESCRIPTOR.values_by_number[
TaskStatus.RUNNING
].name
result.append(data)
num_after_truncation = len(result)
num_total = num_after_truncation + reply.num_status_task_events_dropped

result = self._filter(result, option.filters, TaskState, option.detail)
num_filtered = len(result)
# Sort to make the output deterministic.

result.sort(key=lambda entry: entry["task_id"])
result = list(islice(result, option.limit))
return ListApiResponse(
result=result,
partial_failure_warning=partial_failure_warning,
total=total_tasks,
total=num_total,
num_after_truncation=num_after_truncation,
num_filtered=num_filtered,
)
Expand Down
13 changes: 10 additions & 3 deletions python/ray/experimental/state/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ def get(
# e.g. pinned as local variable, used as parameter
return result

if resource == StateResource.TASKS:
# There might be multiple task attempts given a task id due to
# task retries.
if len(result) == 1:
return result[0]
return result

# For the rest of the resources, there should only be a single entry
# for a particular id.
assert len(result) == 1
Expand Down Expand Up @@ -666,7 +673,7 @@ def get_task(
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
) -> Optional[Dict]:
"""Get a task by id.
"""Get task attempts of a task by id.
Args:
id: Id of the task
Expand All @@ -677,8 +684,8 @@ def get_task(
failed query information.
Returns:
None if actor not found, or dictionarified
:ref:`TaskState <state-api-schema-task>`.
None if task not found, or a list of dictionarified
:ref:`TaskState <state-api-schema-task>` from the task attempts.
Raises:
Exceptions: :ref:`RayStateApiException <state-api-exceptions>` if the CLI
Expand Down
2 changes: 2 additions & 0 deletions python/ray/experimental/state/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,8 @@ class TaskState(StateSchema):

#: The id of the task.
task_id: str = state_column(filterable=True)
#: The attempt (retry) number of the task.
attempt_number: int = state_column(filterable=True)
#: The name of the task if it is given by the name argument.
name: str = state_column(filterable=True)
#: The state of the task.
Expand Down
15 changes: 15 additions & 0 deletions python/ray/experimental/state/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
GetAllPlacementGroupRequest,
GetAllWorkerInfoReply,
GetAllWorkerInfoRequest,
GetTaskEventsReply,
GetTaskEventsRequest,
)
from ray.core.generated.node_manager_pb2 import (
GetObjectsInfoReply,
Expand Down Expand Up @@ -162,6 +164,9 @@ def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
self._gcs_worker_info_stub = gcs_service_pb2_grpc.WorkerInfoGcsServiceStub(
gcs_channel
)
self._gcs_task_info_stub = gcs_service_pb2_grpc.TaskInfoGcsServiceStub(
gcs_channel
)

def register_raylet_client(self, node_id: str, address: str, port: int):
full_addr = f"{address}:{port}"
Expand Down Expand Up @@ -225,6 +230,16 @@ async def get_all_actor_info(
)
return reply

@handle_grpc_network_errors
async def get_all_task_info(
self, timeout: int = None, limit: int = None
) -> Optional[GetTaskEventsReply]:
if not limit:
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE
request = GetTaskEventsRequest(limit=limit)
reply = await self._gcs_task_info_stub.GetTaskEvents(request, timeout=timeout)
return reply

@handle_grpc_network_errors
async def get_all_placement_group_info(
self, timeout: int = None, limit: int = None
Expand Down
Loading

0 comments on commit ffb6711

Please sign in to comment.