Skip to content

Commit

Permalink
[wip] update scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
yoonhyejin committed Jan 27, 2025
1 parent c571d73 commit bc19996
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions metadata-ingestion/examples/ml/mlflow_dh_client_sample.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse

from mlflow_dh_client import MLflowDatahubClient

import datahub.metadata.schema_classes as models
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType

Expand All @@ -14,11 +12,10 @@
client = MLflowDatahubClient(token=args.token)

# Create model group
# Using property classes directly
model_group_urn = client.create_model_group(
group_id="airline_forecast_models_group_4",
group_id="airline_forecast_models_group",
properties=models.MLModelGroupPropertiesClass(
name="Airline Forecast Models Group 4",
name="Airline Forecast Models Group",
description="Group of models for airline passenger forecasting",
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
Expand All @@ -28,14 +25,30 @@

# Creating a model with property classes
model_urn = client.create_model(
model_id="arima_model_5",
model_id="arima_model",
properties=models.MLModelPropertiesClass(
name="ARIMA Model 6",
name="ARIMA Model",
description="ARIMA model for airline passenger forecasting",
customProperties={"team": "forecasting"},
trainingMetrics=[
models.MLMetricClass(name="accuracy", value="0.9"),
models.MLMetricClass(name="precision", value="0.8"),
],
hyperParams=[
models.MLHyperParamClass(name="learning_rate", value="0.01"),
models.MLHyperParamClass(name="batch_size", value="32"),
],
externalUrl="https:localhost:5000",
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
lastModified=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
tags=["forecasting", "arima"],
),
version="6.0",
alias="arima_model_6_alias",
version="1.0",
alias="champion",
)

# Creating an experiment with property class
Expand All @@ -45,6 +58,12 @@
name="Airline Forecast Experiment",
description="Experiment to forecast airline passenger numbers",
customProperties={"team": "forecasting"},
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
lastModified=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
),
)

Expand All @@ -55,11 +74,14 @@
created=models.AuditStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
customProperties={"team": "forecasting"},
),
training_run_properties=models.MLTrainingRunPropertiesClass(
id="simple_training_run_4",
outputUrls=["s3://my-bucket/output"],
trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")],
hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")],
externalUrl="https:localhost:5000",
),
run_result=RunResultType.FAILURE,
start_timestamp=1628580000000,
Expand Down

0 comments on commit bc19996

Please sign in to comment.