Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF2.0 support for the mesh plugin #2443

Merged
merged 7 commits into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tensorboard/plugins/mesh/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ py_library(
visibility = [
"//visibility:public",
],
deps = [
":metadata",
":summary_v2",
":protos_all_py_pb2",
"//tensorboard/compat:tensorflow",
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
],
)

py_library(
name = "summary_v2",
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
srcs = ["summary_v2.py"],
srcs_version = "PY2AND3",
visibility = [
"//visibility:public",
],
deps = [
":metadata",
":protos_all_py_pb2",
Expand All @@ -120,6 +135,19 @@ py_test(
],
)

py_test(
name = "summary_v2_test",
size = "small",
srcs = ["summary_v2_test.py"],
srcs_version = "PY2AND3",
deps = [
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
":summary",
":test_utils",
"//tensorboard:expect_tensorflow_installed",
"//tensorboard/util:test_util",
],
)

tb_proto_library(
name = "protos_all",
srcs = ["plugin_data.proto"],
Expand Down Expand Up @@ -167,6 +195,20 @@ py_binary(
],
)

py_binary(
name = "mesh_demo_v2",
srcs = ["mesh_demo_v2.py"],
srcs_version = "PY2AND3",
visibility = [
"//visibility:public",
],
deps = [
":demo_utils",
":summary_v2",
"//tensorboard/compat:tensorflow",
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
],
)

filegroup(
name = "test_data",
srcs = [
Expand Down
90 changes: 90 additions & 0 deletions tensorboard/plugins/mesh/mesh_demo_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple demo which displays constant 3D mesh."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf

from tensorboard.plugins.mesh import summary_v2 as mesh_summary
from tensorboard.plugins.mesh import demo_utils


flags.DEFINE_string('logdir', '/tmp/mesh_demo',
'Directory to write event logs to.')
flags.DEFINE_string('mesh_path', None, 'Path to PLY file to visualize.')

FLAGS = flags.FLAGS

tf.enable_v2_behavior()

# Max number of steps to run training with.
_MAX_STEPS = 10


def train_step(vertices, faces, colors, config_dict, step):
"""Executes summary as a train step."""
# Change colors over time.
t = float(step) / _MAX_STEPS
transformed_colors = t * (255 - colors) + (1 - t) * colors
mesh_summary.mesh(
'mesh_color_tensor', vertices=vertices, faces=faces,
colors=transformed_colors, config_dict=config_dict, step=step)


def run():
"""Runs training steps with a mesh summary."""
# Mesh summaries only work on TensorFlow 2.x.
if int(tf.__version__.split('.')[0]) < 1:
raise ImportError('TensorFlow 2.x is required to run this demo.')
# Flag mesh_path is required.
if FLAGS.mesh_path is None:
raise ValueError(
'Flag --mesh_path is required and must contain path to PLY file.')
# Camera and scene configuration.
config_dict = {
'camera': {'cls': 'PerspectiveCamera', 'fov': 75}
}

# Read sample PLY file.
vertices, colors, faces = demo_utils.read_ascii_ply(FLAGS.mesh_path)

# Add batch dimension.
vertices = np.expand_dims(vertices, 0)
faces = np.expand_dims(faces, 0)
colors = np.expand_dims(colors, 0)

# Create summary writer.
writer = tf.summary.create_file_writer(FLAGS.logdir)

with writer.as_default():
for step in range(_MAX_STEPS):
train_step(vertices, faces, colors, config_dict, step)


def main(unused_argv):
print('Saving output to %s.' % FLAGS.logdir)
run()
print('Done. Output saved to %s.' % FLAGS.logdir)


if __name__ == '__main__':
app.run(main)
5 changes: 5 additions & 0 deletions tensorboard/plugins/mesh/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@

from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh import plugin_data_pb2
from tensorboard.plugins.mesh import summary_v2

PLUGIN_NAME = 'mesh'

# Export V2 versions.
mesh = summary_v2.mesh
mesh_pb = summary_v2.mesh_pb


def _get_tensor_summary(
name, display_name, description, tensor, content_type, components,
Expand Down
208 changes: 208 additions & 0 deletions tensorboard/plugins/mesh/summary_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mesh summaries and TensorFlow operations to create them. V2 versions"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json

from tensorboard.compat import tf2 as tf
from tensorboard.compat.proto import summary_pb2
from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh import plugin_data_pb2
from tensorboard.util import tensor_util

PLUGIN_NAME = 'mesh'
podlipensky marked this conversation as resolved.
Show resolved Hide resolved


def _write_summary(
name, display_name, description, tensor, content_type, components,
json_config, step):
"""Creates a tensor summary with summary metadata.

Args:
name: Uniquely identifiable name of the summary op. Could be replaced by
combination of name and type to make it unique even outside of this
summary.
display_name: Will be used as the display name in TensorBoard.
Defaults to `tag`.
description: A longform readable description of the summary data. Markdown
is supported.
tensor: Tensor to display in summary.
content_type: Type of content inside the Tensor.
components: Bitmask representing present parts (vertices, colors, etc.) that
belong to the summary.
json_config: A string, JSON-serialized dictionary of ThreeJS classes
configuration.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.

Returns:
A boolean indicating if summary was saved successfully or not.
"""
tensor = tf.convert_to_tensor(value=tensor)
shape = tensor.shape.as_list()
shape = [dim if dim is not None else -1 for dim in shape]
tensor_metadata = metadata.create_summary_metadata(
name,
display_name,
content_type,
components,
shape,
description,
json_config=json_config)
return tf.summary.write(
tag=metadata.get_instance_name(name, content_type),
tensor=tensor,
step=step,
metadata=tensor_metadata)


def _get_display_name(name, display_name):
"""Returns display_name from display_name and name."""
if display_name is None:
return name
return display_name


def _get_json_config(config_dict):
"""Parses and returns JSON string from python dictionary."""
json_config = '{}'
if config_dict is not None:
json_config = json.dumps(config_dict, sort_keys=True)
return json_config


def mesh(name, vertices, faces=None, colors=None, display_name=None,
description=None, config_dict=None, step=None):
"""Writes a TensorFlow mesh summary.

Args:
name: A name for this summary operation.
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
coordinates of vertices.
faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
vertices within each triangle.
colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
vertex.
display_name: If set, will be used as the display name in TensorBoard.
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
is supported.
config_dict: Dictionary with ThreeJS classes names and configuration.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.

Returns:
True if all components of the mesh were saved successfully and False
otherwise.
"""
display_name = _get_display_name(name, display_name)
json_config = _get_json_config(config_dict)

# All tensors representing a single mesh will be represented as separate
# summaries internally. Those summaries will be regrouped on the client before
# rendering.
tensors = [
metadata.MeshTensor(
vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32),
metadata.MeshTensor(faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32),
metadata.MeshTensor(
colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8)
]
tensors = [tensor for tensor in tensors if tensor.data is not None]

components = metadata.get_components_bitmask([
tensor.content_type for tensor in tensors])

summary_scope = (
getattr(tf.summary.experimental, 'summary_scope', None) or
tf.summary.summary_scope)
all_success = True
with summary_scope(name, 'mesh_summary', values=tensors):
for tensor in tensors:
all_success = all_success and _write_summary(
name, display_name, description, tensor.data, tensor.content_type,
components, json_config, step)

return all_success


def mesh_pb(name,
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
vertices,
faces=None,
colors=None,
display_name=None,
description=None,
config_dict=None):
"""Create a mesh summary to save in pb format.

Args:
name: A name for this summary operation.
vertices: numpy array of shape `[dim_1, ..., dim_n, 3]` representing the 3D
coordinates of vertices.
faces: numpy array of shape `[dim_1, ..., dim_n, 3]` containing indices of
vertices within each triangle.
colors: numpy array of shape `[dim_1, ..., dim_n, 3]` containing colors for
each vertex.
display_name: If set, will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
config_dict: Dictionary with ThreeJS classes names and configuration.

Returns:
Instance of tf.Summary class.
"""
display_name = _get_display_name(name, display_name)
json_config = _get_json_config(config_dict)

summaries = []
tensors = [
metadata.MeshTensor(
vertices, plugin_data_pb2.MeshPluginData.VERTEX, tf.float32),
metadata.MeshTensor(faces, plugin_data_pb2.MeshPluginData.FACE, tf.int32),
metadata.MeshTensor(
colors, plugin_data_pb2.MeshPluginData.COLOR, tf.uint8)
]
tensors = [tensor for tensor in tensors if tensor.data is not None]
components = metadata.get_components_bitmask([
tensor.content_type for tensor in tensors])
for tensor in tensors:
shape = tensor.data.shape
shape = [dim if dim is not None else -1 for dim in shape]
tensor_proto = tensor_util.make_tensor_proto(
tensor.data, dtype=tensor.data_type)
summary_metadata = metadata.create_summary_metadata(
name,
display_name,
tensor.content_type,
components,
shape,
description,
json_config=json_config)
tag = metadata.get_instance_name(name, tensor.content_type)
summaries.append((tag, summary_metadata, tensor_proto))

summary = summary_pb2.Summary()
for tag, summary_metadata, tensor_proto in summaries:
tf_summary_metadata = summary_pb2.SummaryMetadata.FromString(
podlipensky marked this conversation as resolved.
Show resolved Hide resolved
summary_metadata.SerializeToString())
summary.value.add(
tag=tag, metadata=tf_summary_metadata, tensor=tensor_proto)
return summary
Loading