diff --git a/backend/src/apiserver/visualization/requirements.txt b/backend/src/apiserver/visualization/requirements.txt
index a48eedb1798..f141137396d 100644
--- a/backend/src/apiserver/visualization/requirements.txt
+++ b/backend/src/apiserver/visualization/requirements.txt
@@ -4,12 +4,14 @@ gcsfs==0.2.3
google-api-python-client==1.7.9
itables==0.1.0
ipykernel==5.1.1
+ipython==7.12.0
jupyter_client==5.2.4
nbconvert==5.5.0
nbformat==4.4.0
pandas==0.24.2
pyarrow==0.15.1
scikit_learn==0.21.2
+tensorflow-metadata==0.21.1
tensorflow-model-analysis==0.21.1
tensorflow-data-validation==0.21.1
tornado==6.0.2
\ No newline at end of file
diff --git a/backend/src/apiserver/visualization/types/tfdv.py b/backend/src/apiserver/visualization/types/tfdv.py
index 0190f434e6d..ac3e303b60d 100644
--- a/backend/src/apiserver/visualization/types/tfdv.py
+++ b/backend/src/apiserver/visualization/types/tfdv.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import tensorflow_data_validation as tfdv
+from IPython.display import display
+from IPython.display import HTML
+from tensorflow_metadata.proto.v0 import statistics_pb2
+from typing import Text
# The following variables are provided through dependency injection. These
# variables come from the specified input path and arguments provided by the
@@ -20,6 +25,68 @@
#
# source
-train_stats = tfdv.generate_statistics_from_csv(data_location=source)
+# train_stats = tfdv.generate_statistics_from_csv(data_location=source)
+# tfdv.visualize_statistics(train_stats)
-tfdv.visualize_statistics(train_stats)
+def get_statistics_html(
+ lhs_statistics: statistics_pb2.DatasetFeatureStatisticsList
+) -> Text:
+ """Build the HTML for visualizing the input statistics using Facets.
+ Args:
+ lhs_statistics: A DatasetFeatureStatisticsList protocol buffer.
+ Returns:
+ HTML to be embedded for visualization.
+ Raises:
+ TypeError: If the input argument is not of the expected type.
+ ValueError: If the input statistics protos does not have only one dataset.
+ """
+
+ rhs_statistics = None
+ lhs_name = 'lhs_statistics'
+ rhs_name = 'rhs_statistics'
+
+ if not isinstance(lhs_statistics,
+ statistics_pb2.DatasetFeatureStatisticsList):
+ raise TypeError(
+ 'lhs_statistics is of type %s, should be '
+ 'a DatasetFeatureStatisticsList proto.' % type(lhs_statistics).__name__)
+
+ if len(lhs_statistics.datasets) != 1:
+ raise ValueError('lhs_statistics proto contains multiple datasets. Only '
+ 'one dataset is currently supported.')
+
+ if lhs_statistics.datasets[0].name:
+ lhs_name = lhs_statistics.datasets[0].name
+
+ # Add lhs stats.
+ combined_statistics = statistics_pb2.DatasetFeatureStatisticsList()
+ lhs_stats_copy = combined_statistics.datasets.add()
+ lhs_stats_copy.MergeFrom(lhs_statistics.datasets[0])
+ lhs_stats_copy.name = lhs_name
+
+ protostr = base64.b64encode(
+ combined_statistics.SerializeToString()).decode('utf-8')
+
+ # pylint: disable=line-too-long
+ # Note that in the html template we currently assign a temporary id to the
+ # facets element and then remove it once we have appended the serialized proto
+ # string to the element. We do this to avoid any collision of ids when
+ # displaying multiple facets output in the notebook.
+ html_template = """
+ """
+ # pylint: enable=line-too-long
+ html = html_template.replace('protostr', protostr)
+
+ return html
+
+stats = tfdv.load_statistics(source)
+html = get_statistics_html(stats)
+display(HTML(html))
diff --git a/frontend/src/lib/OutputArtifactLoader.ts b/frontend/src/lib/OutputArtifactLoader.ts
index 61d967ac22a..0c5911ef263 100644
--- a/frontend/src/lib/OutputArtifactLoader.ts
+++ b/frontend/src/lib/OutputArtifactLoader.ts
@@ -14,18 +14,6 @@
* limitations under the License.
*/
-import WorkflowParser, { StoragePath } from './WorkflowParser';
-import { Apis } from '../lib/Apis';
-import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix';
-import { HTMLViewerConfig } from '../components/viewers/HTMLViewer';
-import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer';
-import { PagedTableConfig } from '../components/viewers/PagedTable';
-import { PlotType, ViewerConfig } from '../components/viewers/Viewer';
-import { ROCCurveConfig } from '../components/viewers/ROCCurve';
-import { TensorboardViewerConfig } from '../components/viewers/Tensorboard';
-import { csvParseRows } from 'd3-dsv';
-import { logger, errorToMessage } from './Utils';
-import { ApiVisualization, ApiVisualizationType } from '../apis/visualization';
import {
Api,
Artifact,
@@ -33,17 +21,29 @@ import {
Context,
Event,
Execution,
- GetArtifactTypesRequest,
- GetArtifactTypesResponse,
GetArtifactsByIDRequest,
GetArtifactsByIDResponse,
+ GetArtifactTypesRequest,
+ GetArtifactTypesResponse,
GetContextByTypeAndNameRequest,
GetContextByTypeAndNameResponse,
- GetExecutionsByContextRequest,
- GetExecutionsByContextResponse,
GetEventsByExecutionIDsRequest,
GetEventsByExecutionIDsResponse,
+ GetExecutionsByContextRequest,
+ GetExecutionsByContextResponse,
} from '@kubeflow/frontend';
+import { csvParseRows } from 'd3-dsv';
+import { ApiVisualization, ApiVisualizationType } from '../apis/visualization';
+import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix';
+import { HTMLViewerConfig } from '../components/viewers/HTMLViewer';
+import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer';
+import { PagedTableConfig } from '../components/viewers/PagedTable';
+import { ROCCurveConfig } from '../components/viewers/ROCCurve';
+import { TensorboardViewerConfig } from '../components/viewers/Tensorboard';
+import { PlotType, ViewerConfig } from '../components/viewers/Viewer';
+import { Apis } from '../lib/Apis';
+import { errorToMessage, logger } from './Utils';
+import WorkflowParser, { StoragePath } from './WorkflowParser';
export interface PlotMetadata {
format?: 'csv';
@@ -279,12 +279,7 @@ export class OutputArtifactLoader {
const trainUri = uri + '/train/stats_tfrecord';
viewers = viewers.concat(
[evalUri, trainUri].map(async specificUri => {
- const script = [
- 'import tensorflow_data_validation as tfdv',
- `stats = tfdv.load_statistics('${specificUri}')`,
- 'tfdv.visualize_statistics(stats)',
- ];
- return buildArtifactViewer(script);
+ return buildArtifactViewerTfdvStatistics(specificUri);
}),
);
});
@@ -537,6 +532,21 @@ async function buildArtifactViewer(script: string[]): Promise
};
}
+async function buildArtifactViewerTfdvStatistics(url: string): Promise {
+ const visualizationData: ApiVisualization = {
+ source: url,
+ type: ApiVisualizationType.TFDV,
+ };
+ const visualization = await Apis.buildPythonVisualizationConfig(visualizationData);
+ if (!visualization.htmlContent) {
+ throw new Error('Failed to build artifact viewer, no value in visualization.htmlContent');
+ }
+ return {
+ htmlContent: visualization.htmlContent,
+ type: PlotType.WEB_APP,
+ };
+}
+
// TODO: add tfma back
// function filterTfmaArtifactsPaths(
// artifactTypes: ArtifactType[],