Skip to content

Commit

Permalink
Support Forecasting (#209)
Browse files Browse the repository at this point in the history
* add forecasting to yaml

* support forecasting

* fixups
  • Loading branch information
jamesbchao authored Feb 1, 2024
1 parent bdbc772 commit 0252bd6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion azuredevops/Build-Update-Dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ parameters:

variables:
- name: SubscriptionName
value: "Project Vienna INT (589c7ae9-223e-45e3-a191-98433e0821a9) - RAI"
value: "Interpretability-Automation"
- name: ConfigFileArtifact
value: WorkspaceConfiguration

Expand Down
2 changes: 1 addition & 1 deletion azuredevops/PR-Gate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ parameters:

variables:
- name: SubscriptionName
value: "Project Vienna INT (589c7ae9-223e-45e3-a191-98433e0821a9) - RAI"
value: "Interpretability-Automation"
- name: ConfigFileArtifact
value: WorkspaceConfiguration

Expand Down
4 changes: 2 additions & 2 deletions src/responsibleai/component_rai_insights.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ inputs:
title:
type: string
task_type:
type: string # [classification, regression]
enum: ['classification', 'regression']
type: string # [classification, regression, forecasting]
enum: ['classification', 'regression', 'forecasting']
model_info_path:
type: path # model_info.json
optional: true
Expand Down
23 changes: 20 additions & 3 deletions src/responsibleai/rai_analyse/rai_component_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import time
import traceback
import uuid
from pathlib import Path
from typing import Any, Dict, Optional

import mlflow
Expand Down Expand Up @@ -126,14 +125,31 @@ def load_mlflow_model(
)
try:
if use_separate_conda_env:
tmp_model_path = "./mlflow_model"
if (not model_path and model_id):
model_path = Model.get_model_path(model_name=model.name, version=model.version)
shutil.copytree(model_path, tmp_model_path)
model_uri = tmp_model_path

_logger.info("MODEL URI: {}".format(
model_uri
))

for root, _, files in os.walk(model_uri):
for f in files:
full_path = os.path.join(root, f)
_logger.info("FILE: {}".format(
full_path
))

conda_install_command = ["mlflow", "models", "prepare-env",
"-m", model_uri,
"--env-manager", "conda"]
else:
# mlflow model input mount as read only. Conda need write access.
local_conda_dep = "./conda_dep.yaml"
shutil.copyfile(conda_file, local_conda_dep)
conda_prefix = str(Path(sys.executable).parents[1])
conda_prefix = str(pathlib.Path(sys.executable).parents[1])
conda_install_command = ["conda", "env", "update",
"--prefix", conda_prefix,
"-f", local_conda_dep]
Expand Down Expand Up @@ -164,7 +180,7 @@ def load_mlflow_model(
return model

# Serve model from separate conda env using mlflow
mlflow_models_serve_logfile_name = "mlflow_models_serve.log"
mlflow_models_serve_logfile_name = "./logs/azureml/mlflow_models_serve.log"
try:
# run mlflow model server in background
with open(mlflow_models_serve_logfile_name, "w") as logfile:
Expand Down Expand Up @@ -231,6 +247,7 @@ def load_mlflow_model(
)
_logger.info("Successfully started mlflow model server.")
model = ServedModelWrapper(port=MLFLOW_MODEL_SERVER_PORT)
_logger.info("Successfully loaded model.")
return model
except Exception as e:
raise UserConfigError(
Expand Down

0 comments on commit 0252bd6

Please sign in to comment.