Skip to content

Commit

Permalink
enable tfdv, remove hardcode as a sample for following PRs (#3089)
Browse files Browse the repository at this point in the history
Co-authored-by: renmingu <40223865+renmingu@users.noreply.github.com>
  • Loading branch information
rmgogogo and renmingu authored Feb 17, 2020
1 parent 231db91 commit d0ef0ae
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 24 deletions.
2 changes: 2 additions & 0 deletions backend/src/apiserver/visualization/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ gcsfs==0.2.3
google-api-python-client==1.7.9
itables==0.1.0
ipykernel==5.1.1
ipython==7.12.0
jupyter_client==5.2.4
nbconvert==5.5.0
nbformat==4.4.0
pandas==0.24.2
pyarrow==0.15.1
scikit_learn==0.21.2
tensorflow-metadata==0.21.1
tensorflow-model-analysis==0.21.1
tensorflow-data-validation==0.21.1
tornado==6.0.2
71 changes: 69 additions & 2 deletions backend/src/apiserver/visualization/types/tfdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import tensorflow_data_validation as tfdv
from IPython.display import display
from IPython.display import HTML
from tensorflow_metadata.proto.v0 import statistics_pb2
from typing import Text

# The following variables are provided through dependency injection. These
# variables come from the specified input path and arguments provided by the
# API post request.
#
# source

train_stats = tfdv.generate_statistics_from_csv(data_location=source)
# train_stats = tfdv.generate_statistics_from_csv(data_location=source)
# tfdv.visualize_statistics(train_stats)

tfdv.visualize_statistics(train_stats)
def get_statistics_html(
lhs_statistics: statistics_pb2.DatasetFeatureStatisticsList
) -> Text:
"""Build the HTML for visualizing the input statistics using Facets.
Args:
lhs_statistics: A DatasetFeatureStatisticsList protocol buffer.
Returns:
HTML to be embedded for visualization.
Raises:
TypeError: If the input argument is not of the expected type.
ValueError: If the input statistics protos does not have only one dataset.
"""

rhs_statistics = None
lhs_name = 'lhs_statistics'
rhs_name = 'rhs_statistics'

if not isinstance(lhs_statistics,
statistics_pb2.DatasetFeatureStatisticsList):
raise TypeError(
'lhs_statistics is of type %s, should be '
'a DatasetFeatureStatisticsList proto.' % type(lhs_statistics).__name__)

if len(lhs_statistics.datasets) != 1:
raise ValueError('lhs_statistics proto contains multiple datasets. Only '
'one dataset is currently supported.')

if lhs_statistics.datasets[0].name:
lhs_name = lhs_statistics.datasets[0].name

# Add lhs stats.
combined_statistics = statistics_pb2.DatasetFeatureStatisticsList()
lhs_stats_copy = combined_statistics.datasets.add()
lhs_stats_copy.MergeFrom(lhs_statistics.datasets[0])
lhs_stats_copy.name = lhs_name

protostr = base64.b64encode(
combined_statistics.SerializeToString()).decode('utf-8')

# pylint: disable=line-too-long
# Note that in the html template we currently assign a temporary id to the
# facets element and then remove it once we have appended the serialized proto
# string to the element. We do this to avoid any collision of ids when
# displaying multiple facets output in the notebook.
html_template = """<iframe id='facets-iframe' width="100%" height="500px"></iframe>
<script>
facets_iframe = document.getElementById('facets-iframe');
facets_html = '<script src="https://cdnjs.cloudflare.com/ajax/libs/webcomponentsjs/1.3.3/webcomponents-lite.js"><\/script><link rel="import" href="https://raw.githubusercontent.com/PAIR-code/facets/master/facets-dist/facets-jupyter.html"><facets-overview proto-input="protostr"></facets-overview>';
facets_iframe.srcdoc = facets_html;
facets_iframe.id = "";
setTimeout(() => {
facets_iframe.setAttribute('height', facets_iframe.contentWindow.document.body.offsetHeight + 'px')
}, 1500)
</script>"""
# pylint: enable=line-too-long
html = html_template.replace('protostr', protostr)

return html

stats = tfdv.load_statistics(source)
html = get_statistics_html(stats)
display(HTML(html))
54 changes: 32 additions & 22 deletions frontend/src/lib/OutputArtifactLoader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,36 @@
* limitations under the License.
*/

import WorkflowParser, { StoragePath } from './WorkflowParser';
import { Apis } from '../lib/Apis';
import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix';
import { HTMLViewerConfig } from '../components/viewers/HTMLViewer';
import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer';
import { PagedTableConfig } from '../components/viewers/PagedTable';
import { PlotType, ViewerConfig } from '../components/viewers/Viewer';
import { ROCCurveConfig } from '../components/viewers/ROCCurve';
import { TensorboardViewerConfig } from '../components/viewers/Tensorboard';
import { csvParseRows } from 'd3-dsv';
import { logger, errorToMessage } from './Utils';
import { ApiVisualization, ApiVisualizationType } from '../apis/visualization';
import {
Api,
Artifact,
ArtifactType,
Context,
Event,
Execution,
GetArtifactTypesRequest,
GetArtifactTypesResponse,
GetArtifactsByIDRequest,
GetArtifactsByIDResponse,
GetArtifactTypesRequest,
GetArtifactTypesResponse,
GetContextByTypeAndNameRequest,
GetContextByTypeAndNameResponse,
GetExecutionsByContextRequest,
GetExecutionsByContextResponse,
GetEventsByExecutionIDsRequest,
GetEventsByExecutionIDsResponse,
GetExecutionsByContextRequest,
GetExecutionsByContextResponse,
} from '@kubeflow/frontend';
import { csvParseRows } from 'd3-dsv';
import { ApiVisualization, ApiVisualizationType } from '../apis/visualization';
import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix';
import { HTMLViewerConfig } from '../components/viewers/HTMLViewer';
import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer';
import { PagedTableConfig } from '../components/viewers/PagedTable';
import { ROCCurveConfig } from '../components/viewers/ROCCurve';
import { TensorboardViewerConfig } from '../components/viewers/Tensorboard';
import { PlotType, ViewerConfig } from '../components/viewers/Viewer';
import { Apis } from '../lib/Apis';
import { errorToMessage, logger } from './Utils';
import WorkflowParser, { StoragePath } from './WorkflowParser';

export interface PlotMetadata {
format?: 'csv';
Expand Down Expand Up @@ -279,12 +279,7 @@ export class OutputArtifactLoader {
const trainUri = uri + '/train/stats_tfrecord';
viewers = viewers.concat(
[evalUri, trainUri].map(async specificUri => {
const script = [
'import tensorflow_data_validation as tfdv',
`stats = tfdv.load_statistics('${specificUri}')`,
'tfdv.visualize_statistics(stats)',
];
return buildArtifactViewer(script);
return buildArtifactViewerTfdvStatistics(specificUri);
}),
);
});
Expand Down Expand Up @@ -537,6 +532,21 @@ async function buildArtifactViewer(script: string[]): Promise<HTMLViewerConfig>
};
}

async function buildArtifactViewerTfdvStatistics(url: string): Promise<HTMLViewerConfig> {
const visualizationData: ApiVisualization = {
source: url,
type: ApiVisualizationType.TFDV,
};
const visualization = await Apis.buildPythonVisualizationConfig(visualizationData);
if (!visualization.htmlContent) {
throw new Error('Failed to build artifact viewer, no value in visualization.htmlContent');
}
return {
htmlContent: visualization.htmlContent,
type: PlotType.WEB_APP,
};
}

// TODO: add tfma back
// function filterTfmaArtifactsPaths(
// artifactTypes: ArtifactType[],
Expand Down

0 comments on commit d0ef0ae

Please sign in to comment.