Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable tfdv, remove hardcode as a sample for following PRs #3089

Merged
merged 1 commit into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/src/apiserver/visualization/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ gcsfs==0.2.3
google-api-python-client==1.7.9
itables==0.1.0
ipykernel==5.1.1
ipython==7.12.0
jupyter_client==5.2.4
nbconvert==5.5.0
nbformat==4.4.0
pandas==0.24.2
pyarrow==0.15.1
scikit_learn==0.21.2
tensorflow-metadata==0.21.1
tensorflow-model-analysis==0.21.1
tensorflow-data-validation==0.21.1
tornado==6.0.2
71 changes: 69 additions & 2 deletions backend/src/apiserver/visualization/types/tfdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,81 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import tensorflow_data_validation as tfdv
from IPython.display import display
from IPython.display import HTML
from tensorflow_metadata.proto.v0 import statistics_pb2
from typing import Text

# The following variables are provided through dependency injection. These
# variables come from the specified input path and arguments provided by the
# API post request.
#
# source

train_stats = tfdv.generate_statistics_from_csv(data_location=source)
# train_stats = tfdv.generate_statistics_from_csv(data_location=source)
# tfdv.visualize_statistics(train_stats)

tfdv.visualize_statistics(train_stats)
def get_statistics_html(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after TFDV fix got release, here it can be reverted and may move to a new type e.x. tfdvstats.py

lhs_statistics: statistics_pb2.DatasetFeatureStatisticsList
) -> Text:
"""Build the HTML for visualizing the input statistics using Facets.
Args:
lhs_statistics: A DatasetFeatureStatisticsList protocol buffer.
Returns:
HTML to be embedded for visualization.
Raises:
TypeError: If the input argument is not of the expected type.
ValueError: If the input statistics protos does not have only one dataset.
"""

rhs_statistics = None
lhs_name = 'lhs_statistics'
rhs_name = 'rhs_statistics'

if not isinstance(lhs_statistics,
statistics_pb2.DatasetFeatureStatisticsList):
raise TypeError(
'lhs_statistics is of type %s, should be '
'a DatasetFeatureStatisticsList proto.' % type(lhs_statistics).__name__)

if len(lhs_statistics.datasets) != 1:
raise ValueError('lhs_statistics proto contains multiple datasets. Only '
'one dataset is currently supported.')

if lhs_statistics.datasets[0].name:
lhs_name = lhs_statistics.datasets[0].name

# Add lhs stats.
combined_statistics = statistics_pb2.DatasetFeatureStatisticsList()
lhs_stats_copy = combined_statistics.datasets.add()
lhs_stats_copy.MergeFrom(lhs_statistics.datasets[0])
lhs_stats_copy.name = lhs_name

protostr = base64.b64encode(
combined_statistics.SerializeToString()).decode('utf-8')

# pylint: disable=line-too-long
# Note that in the html template we currently assign a temporary id to the
# facets element and then remove it once we have appended the serialized proto
# string to the element. We do this to avoid any collision of ids when
# displaying multiple facets output in the notebook.
html_template = """<iframe id='facets-iframe' width="100%" height="500px"></iframe>
<script>
facets_iframe = document.getElementById('facets-iframe');
facets_html = '<script src="https://cdnjs.cloudflare.com/ajax/libs/webcomponentsjs/1.3.3/webcomponents-lite.js"><\/script><link rel="import" href="https://raw.githubusercontent.com/PAIR-code/facets/master/facets-dist/facets-jupyter.html"><facets-overview proto-input="protostr"></facets-overview>';
facets_iframe.srcdoc = facets_html;
facets_iframe.id = "";
setTimeout(() => {
facets_iframe.setAttribute('height', facets_iframe.contentWindow.document.body.offsetHeight + 'px')
}, 1500)
</script>"""
# pylint: enable=line-too-long
html = html_template.replace('protostr', protostr)

return html

stats = tfdv.load_statistics(source)
html = get_statistics_html(stats)
display(HTML(html))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jingzhang36 you can try this way for TFMA. Essentially it just output htmls.
Putting Python codes in frontend side would make many troubles, please move to new typed visualization

54 changes: 32 additions & 22 deletions frontend/src/lib/OutputArtifactLoader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,36 @@
* limitations under the License.
*/

import WorkflowParser, { StoragePath } from './WorkflowParser';
import { Apis } from '../lib/Apis';
import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix';
import { HTMLViewerConfig } from '../components/viewers/HTMLViewer';
import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer';
import { PagedTableConfig } from '../components/viewers/PagedTable';
import { PlotType, ViewerConfig } from '../components/viewers/Viewer';
import { ROCCurveConfig } from '../components/viewers/ROCCurve';
import { TensorboardViewerConfig } from '../components/viewers/Tensorboard';
import { csvParseRows } from 'd3-dsv';
import { logger, errorToMessage } from './Utils';
import { ApiVisualization, ApiVisualizationType } from '../apis/visualization';
import {
Api,
Artifact,
ArtifactType,
Context,
Event,
Execution,
GetArtifactTypesRequest,
GetArtifactTypesResponse,
GetArtifactsByIDRequest,
GetArtifactsByIDResponse,
GetArtifactTypesRequest,
GetArtifactTypesResponse,
GetContextByTypeAndNameRequest,
GetContextByTypeAndNameResponse,
GetExecutionsByContextRequest,
GetExecutionsByContextResponse,
GetEventsByExecutionIDsRequest,
GetEventsByExecutionIDsResponse,
GetExecutionsByContextRequest,
GetExecutionsByContextResponse,
} from '@kubeflow/frontend';
import { csvParseRows } from 'd3-dsv';
import { ApiVisualization, ApiVisualizationType } from '../apis/visualization';
import { ConfusionMatrixConfig } from '../components/viewers/ConfusionMatrix';
import { HTMLViewerConfig } from '../components/viewers/HTMLViewer';
import { MarkdownViewerConfig } from '../components/viewers/MarkdownViewer';
import { PagedTableConfig } from '../components/viewers/PagedTable';
import { ROCCurveConfig } from '../components/viewers/ROCCurve';
import { TensorboardViewerConfig } from '../components/viewers/Tensorboard';
import { PlotType, ViewerConfig } from '../components/viewers/Viewer';
import { Apis } from '../lib/Apis';
import { errorToMessage, logger } from './Utils';
import WorkflowParser, { StoragePath } from './WorkflowParser';

export interface PlotMetadata {
format?: 'csv';
Expand Down Expand Up @@ -279,12 +279,7 @@ export class OutputArtifactLoader {
const trainUri = uri + '/train/stats_tfrecord';
viewers = viewers.concat(
[evalUri, trainUri].map(async specificUri => {
const script = [
'import tensorflow_data_validation as tfdv',
`stats = tfdv.load_statistics('${specificUri}')`,
'tfdv.visualize_statistics(stats)',
];
return buildArtifactViewer(script);
return buildArtifactViewerTfdvStatistics(specificUri);
}),
);
});
Expand Down Expand Up @@ -537,6 +532,21 @@ async function buildArtifactViewer(script: string[]): Promise<HTMLViewerConfig>
};
}

async function buildArtifactViewerTfdvStatistics(url: string): Promise<HTMLViewerConfig> {
const visualizationData: ApiVisualization = {
source: url,
type: ApiVisualizationType.TFDV,
};
const visualization = await Apis.buildPythonVisualizationConfig(visualizationData);
if (!visualization.htmlContent) {
throw new Error('Failed to build artifact viewer, no value in visualization.htmlContent');
}
return {
htmlContent: visualization.htmlContent,
type: PlotType.WEB_APP,
};
}

// TODO: add tfma back
// function filterTfmaArtifactsPaths(
// artifactTypes: ArtifactType[],
Expand Down