diff --git a/backend/src/apiserver/visualization/requirements.txt b/backend/src/apiserver/visualization/requirements.txt index a48eedb1798..f141137396d 100644 --- a/backend/src/apiserver/visualization/requirements.txt +++ b/backend/src/apiserver/visualization/requirements.txt @@ -4,12 +4,14 @@ gcsfs==0.2.3 google-api-python-client==1.7.9 itables==0.1.0 ipykernel==5.1.1 +ipython==7.12.0 jupyter_client==5.2.4 nbconvert==5.5.0 nbformat==4.4.0 pandas==0.24.2 pyarrow==0.15.1 scikit_learn==0.21.2 +tensorflow-metadata==0.21.1 tensorflow-model-analysis==0.21.1 tensorflow-data-validation==0.21.1 tornado==6.0.2 \ No newline at end of file diff --git a/backend/src/apiserver/visualization/types/tfdv.py b/backend/src/apiserver/visualization/types/tfdv.py index 0190f434e6d..ac3e303b60d 100644 --- a/backend/src/apiserver/visualization/types/tfdv.py +++ b/backend/src/apiserver/visualization/types/tfdv.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import tensorflow_data_validation as tfdv +from IPython.display import display +from IPython.display import HTML +from tensorflow_metadata.proto.v0 import statistics_pb2 +from typing import Text # The following variables are provided through dependency injection. These # variables come from the specified input path and arguments provided by the @@ -20,6 +25,68 @@ # # source -train_stats = tfdv.generate_statistics_from_csv(data_location=source) +# train_stats = tfdv.generate_statistics_from_csv(data_location=source) +# tfdv.visualize_statistics(train_stats) -tfdv.visualize_statistics(train_stats) +def get_statistics_html( + lhs_statistics: statistics_pb2.DatasetFeatureStatisticsList +) -> Text: + """Build the HTML for visualizing the input statistics using Facets. + Args: + lhs_statistics: A DatasetFeatureStatisticsList protocol buffer. + Returns: + HTML to be embedded for visualization. + Raises: + TypeError: If the input argument is not of the expected type. + ValueError: If the input statistics protos does not have only one dataset. + """ + + rhs_statistics = None + lhs_name = 'lhs_statistics' + rhs_name = 'rhs_statistics' + + if not isinstance(lhs_statistics, + statistics_pb2.DatasetFeatureStatisticsList): + raise TypeError( + 'lhs_statistics is of type %s, should be ' + 'a DatasetFeatureStatisticsList proto.' % type(lhs_statistics).__name__) + + if len(lhs_statistics.datasets) != 1: + raise ValueError('lhs_statistics proto contains multiple datasets. Only ' + 'one dataset is currently supported.') + + if lhs_statistics.datasets[0].name: + lhs_name = lhs_statistics.datasets[0].name + + # Add lhs stats. + combined_statistics = statistics_pb2.DatasetFeatureStatisticsList() + lhs_stats_copy = combined_statistics.datasets.add() + lhs_stats_copy.MergeFrom(lhs_statistics.datasets[0]) + lhs_stats_copy.name = lhs_name + + protostr = base64.b64encode( + combined_statistics.SerializeToString()).decode('utf-8') + + # pylint: disable=line-too-long + # Note that in the html template we currently assign a temporary id to the + # facets element and then remove it once we have appended the serialized proto + # string to the element. We do this to avoid any collision of ids when + # displaying multiple facets output in the notebook. + html_template = """ + """ + # pylint: enable=line-too-long + html = html_template.replace('protostr', protostr) + + return html + +stats = tfdv.load_statistics(source) +html = get_statistics_html(stats) +display(HTML(html)) diff --git a/frontend/src/lib/OutputArtifactLoader.ts b/frontend/src/lib/OutputArtifactLoader.ts index 61d967ac22a..0c5911ef263 100644 --- a/frontend/src/lib/OutputArtifactLoader.ts +++ b/frontend/src/lib/OutputArtifactLoader.ts @@ -14,18 +14,6 @@ * limitations under the License. */ -import WorkflowParser, { StoragePath } from './WorkflowParser'; -import { Apis } from '../lib/Apis'; -import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix'; -import { HTMLViewerConfig } from '../components/viewers/HTMLViewer'; -import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer'; -import { PagedTableConfig } from '../components/viewers/PagedTable'; -import { PlotType, ViewerConfig } from '../components/viewers/Viewer'; -import { ROCCurveConfig } from '../components/viewers/ROCCurve'; -import { TensorboardViewerConfig } from '../components/viewers/Tensorboard'; -import { csvParseRows } from 'd3-dsv'; -import { logger, errorToMessage } from './Utils'; -import { ApiVisualization, ApiVisualizationType } from '../apis/visualization'; import { Api, Artifact, @@ -33,17 +21,29 @@ import { Context, Event, Execution, - GetArtifactTypesRequest, - GetArtifactTypesResponse, GetArtifactsByIDRequest, GetArtifactsByIDResponse, + GetArtifactTypesRequest, + GetArtifactTypesResponse, GetContextByTypeAndNameRequest, GetContextByTypeAndNameResponse, - GetExecutionsByContextRequest, - GetExecutionsByContextResponse, GetEventsByExecutionIDsRequest, GetEventsByExecutionIDsResponse, + GetExecutionsByContextRequest, + GetExecutionsByContextResponse, } from '@kubeflow/frontend'; +import { csvParseRows } from 'd3-dsv'; +import { ApiVisualization, ApiVisualizationType } from '../apis/visualization'; +import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix'; +import { HTMLViewerConfig } from '../components/viewers/HTMLViewer'; +import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer'; +import { PagedTableConfig } from '../components/viewers/PagedTable'; +import { ROCCurveConfig } from '../components/viewers/ROCCurve'; +import { TensorboardViewerConfig } from '../components/viewers/Tensorboard'; +import { PlotType, ViewerConfig } from '../components/viewers/Viewer'; +import { Apis } from '../lib/Apis'; +import { errorToMessage, logger } from './Utils'; +import WorkflowParser, { StoragePath } from './WorkflowParser'; export interface PlotMetadata { format?: 'csv'; @@ -279,12 +279,7 @@ export class OutputArtifactLoader { const trainUri = uri + '/train/stats_tfrecord'; viewers = viewers.concat( [evalUri, trainUri].map(async specificUri => { - const script = [ - 'import tensorflow_data_validation as tfdv', - `stats = tfdv.load_statistics('${specificUri}')`, - 'tfdv.visualize_statistics(stats)', - ]; - return buildArtifactViewer(script); + return buildArtifactViewerTfdvStatistics(specificUri); }), ); }); @@ -537,6 +532,21 @@ async function buildArtifactViewer(script: string[]): Promise }; } +async function buildArtifactViewerTfdvStatistics(url: string): Promise { + const visualizationData: ApiVisualization = { + source: url, + type: ApiVisualizationType.TFDV, + }; + const visualization = await Apis.buildPythonVisualizationConfig(visualizationData); + if (!visualization.htmlContent) { + throw new Error('Failed to build artifact viewer, no value in visualization.htmlContent'); + } + return { + htmlContent: visualization.htmlContent, + type: PlotType.WEB_APP, + }; +} + // TODO: add tfma back // function filterTfmaArtifactsPaths( // artifactTypes: ArtifactType[],