Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add s3 support (with custom endpoints) #1789

Merged
merged 8 commits into from
Aug 22, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ message DataSource {
// gs://path/to/file for GCP GCS storage
// file:///path/to/file for local storage
string file_url = 2;

// override AWS S3 storage endpoint with custom S3 endpoint
string s3_endpoint_override = 3;
}

// Defines options for DataSource that sources features from a BigQuery Query
Expand Down
11 changes: 9 additions & 2 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def evaluate_historical_retrieval():
)

# Read offline parquet data in pyarrow format.
table = pyarrow.parquet.read_table(feature_view.batch_source.path)
filesystem, path = FileSource.prepare_path(
feature_view.batch_source.path,
feature_view.batch_source.file_options.s3_endpoint_override,
)
table = pyarrow.parquet.read_table(path, filesystem=filesystem)

# Rename columns by the field mapping dictionary if it exists
if feature_view.batch_source.field_mapping is not None:
Expand Down Expand Up @@ -238,7 +242,10 @@ def pull_latest_from_table_or_query(

# Create lazy function that is only called from the RetrievalJob object
def evaluate_offline_job():
source_df = pd.read_parquet(data_source.path)
filesystem, path = FileSource.prepare_path(
data_source.path, data_source.file_options.s3_endpoint_override
)
source_df = pd.read_parquet(path, filesystem=filesystem)
# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
source_df[event_timestamp_column] = source_df[event_timestamp_column].apply(
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc)
Expand Down
59 changes: 56 additions & 3 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Dict, Iterable, Optional, Tuple

from pyarrow import fs
from pyarrow.parquet import ParquetFile

from feast import type_map
Expand All @@ -20,6 +21,7 @@ def __init__(
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
s3_endpoint_override: Optional[str] = None,
):
"""Create a FileSource from a file containing feature data. Only Parquet format supported.

Expand All @@ -33,6 +35,7 @@ def __init__(
file_format (optional): Explicitly set the file format. Allows Feast to bypass inferring the file format.
field_mapping: A dictionary mapping of column names in this data source to feature names in a feature table
or view. Only used for feature columns, not entities or timestamp columns.
s3_endpoint_override (optional): Overrides AWS S3 enpoint with custom S3 storage

Examples:
>>> from feast import FileSource
Expand All @@ -51,7 +54,11 @@ def __init__(
else:
file_url = path

self._file_options = FileOptions(file_format=file_format, file_url=file_url)
self._file_options = FileOptions(
file_format=file_format,
file_url=file_url,
s3_endpoint_override=s3_endpoint_override,
)

super().__init__(
event_timestamp_column,
Expand All @@ -70,6 +77,8 @@ def __eq__(self, other):
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
and self.file_options.s3_endpoint_override
== other.file_options.s3_endpoint_override
)

@property
Expand Down Expand Up @@ -102,6 +111,7 @@ def from_proto(data_source: DataSourceProto):
event_timestamp_column=data_source.event_timestamp_column,
created_timestamp_column=data_source.created_timestamp_column,
date_partition_column=data_source.date_partition_column,
s3_endpoint_override=data_source.file_options.s3_endpoint_override,
)

def to_proto(self) -> DataSourceProto:
Expand All @@ -128,20 +138,47 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
schema = ParquetFile(self.path).schema_arrow
filesystem, path = FileSource.prepare_path(
self.path, self._file_options.s3_endpoint_override
)
schema = ParquetFile(
path if filesystem is None else filesystem.open_input_file(path)
).schema_arrow
return zip(schema.names, map(str, schema.types))

@staticmethod
def prepare_path(path: str, s3_endpoint_override: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, type annotations for the output of this method?

if path.startswith("s3://"):
s3 = fs.S3FileSystem(
endpoint_override=s3_endpoint_override if s3_endpoint_override else None
)
return s3, path.replace("s3://", "")
else:
return None, path


class FileOptions:
"""
DataSource File options used to source features from a file
"""

def __init__(
self, file_format: Optional[FileFormat], file_url: Optional[str],
self,
file_format: Optional[FileFormat],
file_url: Optional[str],
s3_endpoint_override: Optional[str],
):
"""
FileOptions initialization method

Args:
file_format (FileFormat, optional): file source format eg. parquet
file_url (str, optional): file source url eg. s3:// or local file
s3_endpoint_override (str, optional): custom s3 endpoint (used only with s3 file_url)
"""
self._file_format = file_format
self._file_url = file_url
self._s3_endpoint_override = s3_endpoint_override

@property
def file_format(self):
Expand Down Expand Up @@ -171,6 +208,20 @@ def file_url(self, file_url):
"""
self._file_url = file_url

@property
def s3_endpoint_override(self):
"""
Returns the s3 endpoint override
"""
return None if self._s3_endpoint_override == "" else self._s3_endpoint_override

@s3_endpoint_override.setter
def s3_endpoint_override(self, s3_endpoint_override):
"""
Sets the s3 endpoint override
"""
self._s3_endpoint_override = s3_endpoint_override

@classmethod
def from_proto(cls, file_options_proto: DataSourceProto.FileOptions):
"""
Expand All @@ -185,6 +236,7 @@ def from_proto(cls, file_options_proto: DataSourceProto.FileOptions):
file_options = cls(
file_format=FileFormat.from_proto(file_options_proto.file_format),
file_url=file_options_proto.file_url,
s3_endpoint_override=file_options_proto.s3_endpoint_override,
)
return file_options

Expand All @@ -201,6 +253,7 @@ def to_proto(self) -> DataSourceProto.FileOptions:
None if self.file_format is None else self.file_format.to_proto()
),
file_url=self.file_url,
s3_endpoint_override=self.s3_endpoint_override,
)

return file_options_proto
2 changes: 2 additions & 0 deletions sdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"isort>=5",
"grpcio-tools==1.34.0",
"grpcio-testing==1.34.0",
"minio==7.1.0",
"mock==2.0.0",
"moto",
"mypy==0.790",
Expand All @@ -99,6 +100,7 @@
"pytest-mock==1.10.4",
"Sphinx!=4.0.0",
"sphinx-rtd-theme",
"testcontainers==3.4.2",
"adlfs==0.5.9",
"firebase-admin==4.5.2",
"pre-commit",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def customer_feature_view(self) -> FeatureView:
customer_table_id = self.data_source_creator.get_prefixed_table_name(
self.name, "customer_profile"
)
ds = self.data_source_creator.create_data_sources(
ds = self.data_source_creator.create_data_source(
customer_table_id,
self.customer_df,
event_timestamp_column="event_timestamp",
Expand All @@ -129,7 +129,7 @@ def driver_stats_feature_view(self) -> FeatureView:
driver_table_id = self.data_source_creator.get_prefixed_table_name(
self.name, "driver_hourly"
)
ds = self.data_source_creator.create_data_sources(
ds = self.data_source_creator.create_data_source(
driver_table_id,
self.driver_df,
event_timestamp_column="event_timestamp",
Expand All @@ -145,7 +145,7 @@ def orders_table(self) -> Optional[str]:
orders_table_id = self.data_source_creator.get_prefixed_table_name(
self.name, "orders"
)
ds = self.data_source_creator.create_data_sources(
ds = self.data_source_creator.create_data_source(
orders_table_id,
self.orders_df,
event_timestamp_column="event_timestamp",
Expand Down Expand Up @@ -221,7 +221,7 @@ def construct_test_environment(
offline_creator: DataSourceCreator = importer.get_class_from_type(
module_name, config_class_name, "DataSourceCreator"
)(project)
ds = offline_creator.create_data_sources(
ds = offline_creator.create_data_source(
project, df, field_mapping={"ts_1": "ts", "id": "driver_id"}
)
offline_store = offline_creator.create_offline_store_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class DataSourceCreator(ABC):
@abstractmethod
def create_data_sources(
def create_data_source(
self,
destination: str,
df: pd.DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def teardown(self):
def create_offline_store_config(self):
return BigQueryOfflineStoreConfig()

def create_data_sources(
def create_data_source(
self,
destination: str,
df: pd.DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Any, Dict

import pandas as pd
from minio import Minio
from testcontainers.core.generic import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from feast import FileSource
from feast.data_format import ParquetFormat
Expand All @@ -19,7 +22,7 @@ class FileDataSourceCreator(DataSourceCreator):
def __init__(self, _: str):
pass

def create_data_sources(
def create_data_source(
self,
destination: str,
df: pd.DataFrame,
Expand All @@ -46,3 +49,79 @@ def create_offline_store_config(self) -> FeastConfigBaseModel:

def teardown(self):
self.f.close()


class S3FileDataSourceCreator(DataSourceCreator):
f: Any
minio: DockerContainer
woop marked this conversation as resolved.
Show resolved Hide resolved
bucket = "feast-test"
access_key = "AKIAIOSFODNN7EXAMPLE"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't really get around hardcoding access_key and secret in two places right now. The test suite doesn't allow us to inject configuration. cc @achals it would be nice if we could propagate a config (context-like) object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed yeah I'll make a task for that

secret = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
minio_image = "minio/minio:RELEASE.2021-08-17T20-53-08Z"

def __init__(self, _: str):
self._setup_minio()

def _setup_minio(self):
self.minio = DockerContainer(self.minio_image)
self.minio.with_exposed_ports(9000).with_exposed_ports(9001).with_env(
"MINIO_ROOT_USER", self.access_key
).with_env("MINIO_ROOT_PASSWORD", self.secret).with_command(
'server /data --console-address ":9001"'
)
self.minio.start()
log_string_to_wait_for = (
"API" # The minio container will print "API: ..." when ready.
)
wait_for_logs(container=self.minio, predicate=log_string_to_wait_for, timeout=5)

def _upload_parquet_file(self, df, file_name, minio_endpoint):
self.f = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False)
df.to_parquet(self.f.name)

client = Minio(
minio_endpoint,
access_key=self.access_key,
secret_key=self.secret,
secure=False,
)
if not client.bucket_exists(self.bucket):
client.make_bucket(self.bucket)
client.fput_object(
self.bucket, file_name, self.f.name,
)

def create_data_source(
self,
destination: str,
df: pd.DataFrame,
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
) -> DataSource:
filename = f"{destination}.parquet"
port = self.minio.get_exposed_port("9000")
host = self.minio.get_container_host_ip()
minio_endpoint = f"{host}:{port}"

self._upload_parquet_file(df, filename, minio_endpoint)

return FileSource(
file_format=ParquetFormat(),
path=f"s3://{self.bucket}/{filename}",
event_timestamp_column=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
date_partition_column="",
field_mapping=field_mapping or {"ts_1": "ts"},
s3_endpoint_override=f"http://{host}:{port}",
)

def get_prefixed_table_name(self, name: str, suffix: str) -> str:
return f"{suffix}"

def create_offline_store_config(self) -> FeastConfigBaseModel:
return FileOfflineStoreConfig()

def teardown(self):
self.minio.stop()
self.f.close()
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, project_name: str):
iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role",
)

def create_data_sources(
def create_data_source(
self,
destination: str,
df: pd.DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ def create_driver_hourly_stats_feature_view(source):
driver_stats_feature_view = FeatureView(
name="driver_stats",
entities=["driver"],
features=[
Feature(name="conv_rate", dtype=ValueType.FLOAT),
Feature(name="acc_rate", dtype=ValueType.FLOAT),
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
],
woop marked this conversation as resolved.
Show resolved Hide resolved
batch_source=source,
ttl=timedelta(hours=2),
)
Expand Down
Loading