Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no-op #6862

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

no-op #6862

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 0 additions & 282 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,288 +89,6 @@
metadata_store_pb2.Execution.State.CANCELED:
run_state_pb2.RunState.STOPPED,
}


@dataclasses.dataclass
class StateRecord(json_utils.Jsonable):
state: str
backfill_token: str
status_code: Optional[int]
update_time: float
# TODO(b/242083811) Some status_msg have already been written into MLMD.
# Keeping this field is for backward compatibility to avoid json failing to
# parse existing status_msg. We can remove it once we are sure no status_msg
# in MLMD is in use.
status_msg: str = ''


# TODO(b/228198652): Stop using json_util.Jsonable. Before we do,
# this class MUST NOT be moved out of this module.
@attr.s(auto_attribs=True, kw_only=True)
class NodeState(json_utils.Jsonable):
"""Records node state.

Attributes:
state: Current state of the node.
status: Status of the node in state STOPPING or STOPPED.
"""

STARTED = 'started' # Node is ready for execution.
STOPPING = 'stopping' # Pending work before state can change to STOPPED.
STOPPED = 'stopped' # Node execution is stopped.
RUNNING = 'running' # Node is under active execution (i.e. triggered).
COMPLETE = 'complete' # Node execution completed successfully.
# Node execution skipped due to condition not satisfied when pipeline has
# conditionals.
SKIPPED = 'skipped'
# Node execution skipped due to partial run.
SKIPPED_PARTIAL_RUN = 'skipped_partial_run'
FAILED = 'failed' # Node execution failed due to errors.

state: str = attr.ib(
default=STARTED,
validator=attr.validators.in_([
STARTED,
STOPPING,
STOPPED,
RUNNING,
COMPLETE,
SKIPPED,
SKIPPED_PARTIAL_RUN,
FAILED,
]),
on_setattr=attr.setters.validate,
)
backfill_token: str = ''
status_code: Optional[int] = None
status_msg: str = ''
last_updated_time: float = attr.ib(factory=lambda: time.time()) # pylint:disable=unnecessary-lambda

state_history: List[StateRecord] = attr.ib(default=attr.Factory(list))

@property
def status(self) -> Optional[status_lib.Status]:
if self.status_code is not None:
return status_lib.Status(code=self.status_code, message=self.status_msg)
return None

def update(
self,
state: str,
status: Optional[status_lib.Status] = None,
backfill_token: str = '',
) -> None:
if self.state != state:
self.state_history.append(
StateRecord(
state=self.state,
backfill_token=self.backfill_token,
status_code=self.status_code,
update_time=self.last_updated_time,
)
)
if len(self.state_history) > _MAX_STATE_HISTORY_LEN:
self.state_history = self.state_history[-_MAX_STATE_HISTORY_LEN:]
self.last_updated_time = time.time()

self.state = state
self.backfill_token = backfill_token
self.status_code = status.code if status is not None else None
self.status_msg = (status.message or '') if status is not None else ''

def is_startable(self) -> bool:
"""Returns True if the node can be started."""
return self.state in set([self.STOPPING, self.STOPPED, self.FAILED])

def is_stoppable(self) -> bool:
"""Returns True if the node can be stopped."""
return self.state in set([self.STARTED, self.RUNNING])

def is_backfillable(self) -> bool:
"""Returns True if the node can be backfilled."""
return self.state in set([self.STOPPED, self.FAILED])

def is_programmatically_skippable(self) -> bool:
"""Returns True if the node can be skipped via programmatic operation."""
return self.state in set([self.STARTED, self.STOPPED])

def is_success(self) -> bool:
return is_node_state_success(self.state)

def is_failure(self) -> bool:
return is_node_state_failure(self.state)

def to_run_state(self) -> run_state_pb2.RunState:
"""Returns this NodeState converted to a RunState."""
status_code_value = None
if self.status_code is not None:
status_code_value = run_state_pb2.RunState.StatusCodeValue(
value=self.status_code)
return run_state_pb2.RunState(
state=_NODE_STATE_TO_RUN_STATE_MAP.get(
self.state, run_state_pb2.RunState.UNKNOWN
),
status_code=status_code_value,
status_msg=self.status_msg,
update_time=int(self.last_updated_time * 1000),
)

def to_run_state_history(self) -> List[run_state_pb2.RunState]:
run_state_history = []
for state in self.state_history:
# STARTING, PAUSING and PAUSED has been deprecated but may still be
# present in state_history.
if (
state.state == 'starting'
or state.state == 'pausing'
or state.state == 'paused'
):
continue
run_state_history.append(
NodeState(
state=state.state,
status_code=state.status_code,
last_updated_time=state.update_time).to_run_state())
return run_state_history

# By default, json_utils.Jsonable serializes and deserializes objects using
# obj.__dict__, which prevents attr.ib from populating default fields.
# Overriding this function to ensure default fields are populated.
@classmethod
def from_json_dict(cls, dict_data: Dict[str, Any]) -> Any:
"""Convert from dictionary data to an object."""
return cls(**dict_data)

def latest_predicate_time_s(self, predicate: Callable[[StateRecord], bool],
include_current_state: bool) -> Optional[int]:
"""Returns the latest time the StateRecord satisfies the given predicate.

Args:
predicate: Predicate that takes the state string.
include_current_state: Whether to include the current node state when
checking the node state history (the node state history doesn't include
the current node state).

Returns:
The latest time (in the state history) the StateRecord satisfies the given
predicate, or None if the predicate is never satisfied.
"""
if include_current_state:
current_record = StateRecord(
state=self.state,
backfill_token=self.backfill_token,
status_code=self.status_code,
update_time=self.last_updated_time,
)
if predicate(current_record):
return int(current_record.update_time)

for s in reversed(self.state_history):
if predicate(s):
return int(s.update_time)
return None

def latest_running_time_s(self) -> Optional[int]:
"""Returns the latest time the node entered a RUNNING state.

Returns:
The latest time (in the state history) the node entered a RUNNING
state, or None if the node never entered a RUNNING state.
"""
return self.latest_predicate_time_s(
lambda s: is_node_state_running(s.state), include_current_state=True)


class _NodeStatesProxy:
"""Proxy for reading and updating deserialized NodeState dicts from Execution.

This proxy contains an internal write-back cache. Changes are not saved back
to the `Execution` until `save()` is called; cache would not be updated if
changes were made outside of the proxy, either. This is primarily used to
reduce JSON serialization/deserialization overhead for getting node state
execution property from pipeline execution.
"""

def __init__(self, execution: metadata_store_pb2.Execution):
self._custom_properties = execution.custom_properties
self._deserialized_cache: Dict[str, Dict[str, NodeState]] = {}
self._changed_state_types: Set[str] = set()

def get(self, state_type: str = _NODE_STATES) -> Dict[str, NodeState]:
"""Gets node states dict from pipeline execution with the specified type."""
if state_type not in [_NODE_STATES, _PREVIOUS_NODE_STATES]:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message=(
f'Expected state_type is {_NODE_STATES} or'
f' {_PREVIOUS_NODE_STATES}, got {state_type}.'
),
)
if state_type not in self._deserialized_cache:
node_states_json = _get_metadata_value(
self._custom_properties.get(state_type)
)
self._deserialized_cache[state_type] = (
json_utils.loads(node_states_json) if node_states_json else {}
)
return self._deserialized_cache[state_type]

def set(
self, node_states: Dict[str, NodeState], state_type: str = _NODE_STATES
) -> None:
"""Sets node states dict with the specified type."""
self._deserialized_cache[state_type] = node_states
self._changed_state_types.add(state_type)

def save(self) -> None:
"""Saves all changed node states dicts to pipeline execution."""
max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length()

for state_type in self._changed_state_types:
node_states = self._deserialized_cache[state_type]
node_states_json = json_utils.dumps(node_states)

# Removes state history from node states if it's too large to avoid
# hitting MLMD limit.
if (
max_mlmd_str_value_len
and len(node_states_json) > max_mlmd_str_value_len
):
logging.info(
'Node states length %d is too large (> %d); Removing state history'
' from it.',
len(node_states_json),
max_mlmd_str_value_len,
)
node_states_no_history = {}
for node, old_state in node_states.items():
new_state = copy.deepcopy(old_state)
new_state.state_history.clear()
node_states_no_history[node] = new_state
node_states_json = json_utils.dumps(node_states_no_history)
logging.info(
'Node states length after removing state history: %d',
len(node_states_json),
)

data_types_utils.set_metadata_value(
self._custom_properties[state_type], node_states_json
)


def is_node_state_success(state: str) -> bool:
return state in (NodeState.COMPLETE, NodeState.SKIPPED,
NodeState.SKIPPED_PARTIAL_RUN)


def is_node_state_failure(state: str) -> bool:
return state == NodeState.FAILED


def is_node_state_running(state: str) -> bool:
return state == NodeState.RUNNING


_NODE_STATE_TO_RUN_STATE_MAP = {
NodeState.STARTED: run_state_pb2.RunState.READY,
NodeState.STOPPING: run_state_pb2.RunState.UNKNOWN,
Expand Down