Skip to content

Commit

Permalink
Merge branch 'master' into raw-container
Browse files Browse the repository at this point in the history
  • Loading branch information
Ketan Umare committed Jun 24, 2020
2 parents b9cee46 + a33f156 commit 60a1626
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import flytekit.plugins

__version__ = '0.9.3'
__version__ = '0.9.4'
2 changes: 1 addition & 1 deletion flytekit/common/mixins/launchable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def launch(self, project, domain, inputs=None, name=None, notification_overrides
:rtype: T
"""
return self.execute_with_literals(
return self.launch_with_literals(
project,
domain,
self._python_std_input_map_to_literal_map(inputs or {}),
Expand Down
17 changes: 13 additions & 4 deletions flytekit/common/mixins/registerable.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,19 @@ def some_task()
m = _importlib.import_module(self.instantiated_in)

for k in dir(m):
if getattr(m, k) == self:
self._platform_valid_name = _utils.fqdn(m.__name__, k, entity_type=self.resource_type)
_logging.debug("Auto-assigning name to {}".format(self._platform_valid_name))
return
try:
if getattr(m, k) == self:
self._platform_valid_name = _utils.fqdn(m.__name__, k, entity_type=self.resource_type)
_logging.debug("Auto-assigning name to {}".format(self._platform_valid_name))
return
except ValueError as err:
# Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
# ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
# a.any() or a.all()
# Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
# continue looping through m.
_logging.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
pass

_logging.error("Could not auto-assign name")
raise _system_exceptions.FlyteSystemException("Error looking for object while auto-assigning name.")
62 changes: 61 additions & 1 deletion flytekit/common/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

import six as _six

from google.protobuf import json_format as _json_format, struct_pb2 as _struct

import hashlib as _hashlib
import json as _json

from flytekit.common import (
interface as _interfaces, nodes as _nodes, sdk_bases as _sdk_bases, workflow_execution as _workflow_execution
)
Expand All @@ -14,7 +19,7 @@
from flytekit.engines import loader as _engine_loader
from flytekit.models import common as _common_model, task as _task_model
from flytekit.models.core import workflow as _workflow_model, identifier as _identifier_model
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.exceptions import user as _user_exceptions, system as _system_exceptions
from flytekit.common.types import helpers as _type_helpers


Expand Down Expand Up @@ -268,6 +273,61 @@ def _python_std_input_map_to_literal_map(self, inputs):
for k, v in _six.iteritems(self.interface.inputs)
})

def _produce_deterministic_version(self, version=None):
"""
:param Text version:
:return Text:
"""

if self.container is not None and self.container.data_config is None:
# Only in the case of raw container tasks (which are the only valid tasks with container definitions that
# can assign a client-side task version) their data config will be None.
raise ValueError("Client-side task versions are not supported for {} task type".format(self.type))
if version is not None:
return version
custom = _json_format.Parse(_json.dumps(self.custom, sort_keys=True), _struct.Struct()) if self.custom else None

# The task body is the entirety of the task template MINUS the identifier. The identifier is omitted because
# 1) this method is used to compute the version portion of the identifier and
# 2 ) the SDK will actually generate a unique name on every task instantiation which is not great for
# the reproducibility this method attempts.
task_body = (self.type, self.metadata.to_flyte_idl().SerializeToString(deterministic=True),
self.interface.to_flyte_idl().SerializeToString(deterministic=True), custom)
return _hashlib.md5(str(task_body).encode('utf-8')).hexdigest()

@_exception_scopes.system_entry_point
def register_and_launch(self, project, domain, name=None, version=None, inputs=None):
"""
:param Text project: The project in which to register and launch this task.
:param Text domain: The domain in which to register and launch this task.
:param Text name: The name to give this task.
:param Text version: The version in which to register this task
:param dict[Text, Any] inputs: A dictionary of Python standard inputs that will be type-checked, then compiled
to a LiteralMap.
:rtype: flytekit.common.workflow_execution.SdkWorkflowExecution
"""
self.validate()
version = self._produce_deterministic_version(version)

if name is None:
try:
self.auto_assign_name()
generated_name = self._platform_valid_name
except _system_exceptions.FlyteSystemException:
# If we're not able to assign a platform valid name, use the deterministically-produced version instead.
generated_name = version
name = name if name else generated_name
id_to_register = _identifier.Identifier(_identifier_model.ResourceType.TASK, project, domain, name, version)
old_id = self.id
try:
self._id = id_to_register
_engine_loader.get_engine().get_task(self).register(id_to_register)
except:
self._id = old_id
raise
return self.launch(project, domain, inputs=inputs)

@_exception_scopes.system_entry_point
def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None,
label_overrides=None, annotation_overrides=None):
Expand Down
3 changes: 2 additions & 1 deletion flytekit/common/types/impl/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,8 @@ def cast_to(self, other_type):
additional_msg="Cannot cast because a required column '{}' was not found.".format(k),
received_value=self
)
if v != self.type.sdk_columns[k]:
if not isinstance(v, _base_sdk_types.FlyteSdkType) or \
v.to_flyte_literal_type() != self.type.sdk_columns[k].to_flyte_literal_type():
raise _user_exceptions.FlyteTypeException(
self.type.sdk_columns[k],
v,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __str__(self):
return self.verbose_string()

def __hash__(self):
return hash(self.to_flyte_idl().SerializeToString())
return hash(self.to_flyte_idl().SerializeToString(deterministic=True))

def short_string(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"croniter>=0.3.20,<4.0.0",
"deprecated>=1.0,<2.0",
"boto3>=1.4.4,<2.0",
"python-dateutil<2.8.1,>=2.1",
"python-dateutil<=2.8.1,>=2.1",
"grpcio>=1.3.0,<2.0",
"protobuf>=3.6.1,<4",
"pytimeparse>=1.1.8,<2.0.0",
Expand Down
34 changes: 34 additions & 0 deletions tests/flytekit/unit/common_tests/tasks/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from flytekit.models.core import identifier as _identifier
from flytekit.sdk.tasks import python_task, inputs, outputs
from flyteidl.admin import task_pb2 as _admin_task_pb2
from flytekit.common.tasks.presto_task import SdkPrestoTask
from flytekit.sdk.types import Types


@_patch("flytekit.engines.loader.get_engine")
Expand Down Expand Up @@ -68,3 +70,35 @@ def test_task_serialization():
assert isinstance(s, _admin_task_pb2.TaskSpec)
assert s.template.id.name == 'tests.flytekit.unit.common_tests.tasks.test_task.my_task'
assert s.template.container.image == 'myflyteimage:v123'


schema = Types.Schema([("a", Types.String), ("b", Types.Integer)])


def test_task_produce_deterministic_version():
containerless_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 10",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
)
identical_containerless_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 10",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
)
different_containerless_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM flyte.widgets WHERE ds = '{{ .Inputs.ds}}' LIMIT 100000",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
)
assert containerless_task._produce_deterministic_version() ==\
identical_containerless_task._produce_deterministic_version()

assert containerless_task._produce_deterministic_version() !=\
different_containerless_task._produce_deterministic_version()

with _pytest.raises(Exception):
get_sample_task()._produce_deterministic_version()
17 changes: 17 additions & 0 deletions tests/flytekit/unit/common_tests/types/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,20 @@ def test_typed_schema():
assert len(b.type.columns) == len(_ALL_COLUMN_TYPES)
assert list(b.type.sdk_columns.items()) == _ALL_COLUMN_TYPES
assert b.remote_location.startswith(t.name)


# Ensures that subclassing types works inside a schema.
def test_casting():
class MyDateTime(primitives.Datetime):
...

with test_utils.LocalTestFileSystem() as t:
test_columns_1 = [('altered', MyDateTime)]
test_columns_2 = [('altered', primitives.Datetime)]

instantiator_1 = schema.schema_instantiator(test_columns_1)
a = instantiator_1()

instantiator_2 = schema.schema_instantiator(test_columns_2)

a.cast_to(instantiator_2._schema_type)

0 comments on commit 60a1626

Please sign in to comment.