diff --git a/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py b/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py index 2a6a6147d6d1dc..cb7f15345e363c 100644 --- a/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py +++ b/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py @@ -1,7 +1,5 @@ import argparse - from mlflow_dh_client import MLflowDatahubClient - import datahub.metadata.schema_classes as models from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType @@ -14,11 +12,10 @@ client = MLflowDatahubClient(token=args.token) # Create model group - # Using property classes directly model_group_urn = client.create_model_group( - group_id="airline_forecast_models_group_4", + group_id="airline_forecast_models_group", properties=models.MLModelGroupPropertiesClass( - name="Airline Forecast Models Group 4", + name="Airline Forecast Models Group", description="Group of models for airline passenger forecasting", created=models.TimeStampClass( time=1628580000000, actor="urn:li:corpuser:datahub" @@ -28,14 +25,30 @@ # Creating a model with property classes model_urn = client.create_model( - model_id="arima_model_5", + model_id="arima_model", properties=models.MLModelPropertiesClass( - name="ARIMA Model 6", + name="ARIMA Model", description="ARIMA model for airline passenger forecasting", customProperties={"team": "forecasting"}, + trainingMetrics=[ + models.MLMetricClass(name="accuracy", value="0.9"), + models.MLMetricClass(name="precision", value="0.8"), + ], + hyperParams=[ + models.MLHyperParamClass(name="learning_rate", value="0.01"), + models.MLHyperParamClass(name="batch_size", value="32"), + ], + externalUrl="https:localhost:5000", + created=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + tags=["forecasting", "arima"], ), - version="6.0", - alias="arima_model_6_alias", + version="1.0", + alias="champion", ) # Creating an experiment with property class @@ -45,6 +58,12 @@ name="Airline Forecast Experiment", description="Experiment to forecast airline passenger numbers", customProperties={"team": "forecasting"}, + created=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), ), ) @@ -55,11 +74,14 @@ created=models.AuditStampClass( time=1628580000000, actor="urn:li:corpuser:datahub" ), + customProperties={"team": "forecasting"}, ), training_run_properties=models.MLTrainingRunPropertiesClass( id="simple_training_run_4", outputUrls=["s3://my-bucket/output"], trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")], + hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")], + externalUrl="https:localhost:5000", ), run_result=RunResultType.FAILURE, start_timestamp=1628580000000,