Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing custom serialization as GTSAM types are pickleable #111

Merged
merged 2 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/scripts/python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ echo "Running .github/scripts/python.sh..."
conda init
conda info --envs

wget -O 2020_01_13_gtsam_python38_wheel.zip --no-check-certificate "https://drive.google.com/uc?export=download&id=1b7zoYopU7jN3D62fuZMqwQgZdhZ4cH6P"
unzip 2020_01_13_gtsam_python38_wheel.zip
wget -O 2020_03_08_gtsam_python38_wheel.zip --no-check-certificate "https://drive.google.com/uc?export=download&id=1lpU4TFqh5puH41h2kcqfIgX7_br_a0nb"

unzip 2020_03_08_gtsam_python38_wheel.zip
pip install gtsam-4.1.1-cp38-cp38-manylinux2014_x86_64.whl

##########################################################
Expand Down
5 changes: 1 addition & 4 deletions gtsfm/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

from dask.delayed import Delayed

import gtsfm.utils.serialization # import needed to register serialization fns
from gtsfm.frontend.detector_descriptor.detector_descriptor_base import (
DetectorDescriptorBase,
)
from gtsfm.frontend.detector_descriptor.detector_descriptor_base import DetectorDescriptorBase


class FeatureExtractor:
Expand Down
17 changes: 4 additions & 13 deletions gtsfm/multi_view_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@

import gtsfm.utils.io as io
import gtsfm.utils.metrics as metrics
import gtsfm.utils.serialization # import needed to register serialization fns
from gtsfm.averaging.rotation.rotation_averaging_base import (
RotationAveragingBase,
)
from gtsfm.averaging.translation.translation_averaging_base import (
TranslationAveragingBase,
)
from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase
from gtsfm.averaging.translation.translation_averaging_base import TranslationAveragingBase
from gtsfm.bundle.bundle_adjustment import BundleAdjustmentOptimizer
from gtsfm.data_association.data_assoc import DataAssociation

Expand Down Expand Up @@ -105,10 +100,8 @@ def create_computation_graph(
return ba_input_graph, ba_result_graph, saved_metrics_graph



def select_largest_connected_component(
rotations: Dict[Tuple[int, int], Optional[Rot3]],
unit_translations: Dict[Tuple[int, int], Optional[Unit3]],
rotations: Dict[Tuple[int, int], Optional[Rot3]], unit_translations: Dict[Tuple[int, int], Optional[Unit3]],
) -> Tuple[Dict[Tuple[int, int], Rot3], Dict[Tuple[int, int], Unit3]]:
"""Process the graph of image indices with Rot3s/Unit3s defining edges, and select the largest connected component.

Expand Down Expand Up @@ -149,9 +142,7 @@ def select_largest_connected_component(


def init_cameras(
wRi_list: List[Optional[Rot3]],
wti_list: List[Optional[Point3]],
intrinsics_list: List[Cal3Bundler],
wRi_list: List[Optional[Rot3]], wti_list: List[Optional[Point3]], intrinsics_list: List[Cal3Bundler],
) -> Dict[int, PinholeCameraCal3Bundler]:
"""Generate camera from valid rotations and unit-translations.

Expand Down
66 changes: 15 additions & 51 deletions gtsfm/scene_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import gtsfm.utils.geometry_comparisons as comp_utils
import gtsfm.utils.io as io_utils
import gtsfm.utils.logger as logger_utils
import gtsfm.utils.serialization # import needed to register serialization fns
import gtsfm.utils.viz as viz_utils
from gtsfm.common.image import Image
from gtsfm.common.keypoints import Keypoints
Expand Down Expand Up @@ -102,10 +101,7 @@ def create_computation_graph(
keypoints_graph_list = []
descriptors_graph_list = []
for delayed_image in image_graph:
(
delayed_dets,
delayed_descs,
) = self.feature_extractor.create_computation_graph(delayed_image)
(delayed_dets, delayed_descs,) = self.feature_extractor.create_computation_graph(delayed_image)
keypoints_graph_list += [delayed_dets]
descriptors_graph_list += [delayed_descs]

Expand Down Expand Up @@ -168,10 +164,7 @@ def create_computation_graph(
auxiliary_graph_list.append(dask.delayed(persist_frontend_metrics_full)(frontend_metrics_dict))

auxiliary_graph_list.append(
dask.delayed(aggregate_frontend_metrics)(
frontend_metrics_dict,
self._pose_angular_error_thresh,
)
dask.delayed(aggregate_frontend_metrics)(frontend_metrics_dict, self._pose_angular_error_thresh,)
)

# as visualization tasks are not to be provided to the user, we create a
Expand All @@ -181,21 +174,19 @@ def create_computation_graph(
keypoints_graph_list = dask.delayed(lambda x, y: (x, y))(keypoints_graph_list, auxiliary_graph_list)[0]
auxiliary_graph_list = []

(ba_input_graph, ba_output_graph, optimizer_metrics_graph, ) = self.multiview_optimizer.create_computation_graph(
(ba_input_graph, ba_output_graph, optimizer_metrics_graph,) = self.multiview_optimizer.create_computation_graph(
num_images,
keypoints_graph_list,
i2Ri1_graph_dict,
i2Ui1_graph_dict,
v_corr_idxs_graph_dict,
camera_intrinsics_graph,
gt_pose_graph
gt_pose_graph,
)

# aggregate metrics for multiview optimizer
if optimizer_metrics_graph is not None:
auxiliary_graph_list.append(
optimizer_metrics_graph
)
auxiliary_graph_list.append(optimizer_metrics_graph)

filtered_sfm_data_graph = dask.delayed(ba_output_graph.filter_landmarks)(
self.multiview_optimizer.data_association_module.reproj_error_thresh
Expand Down Expand Up @@ -258,13 +249,7 @@ def visualize_twoview_correspondences(
corr_idxs_i1i2: correspondence indices.
file_path: file path to save the visualization.
"""
plot_img = viz_utils.plot_twoview_correspondences(
image_i1,
image_i2,
keypoints_i1,
keypoints_i2,
corr_idxs_i1i2,
)
plot_img = viz_utils.plot_twoview_correspondences(image_i1, image_i2, keypoints_i1, keypoints_i2, corr_idxs_i1i2,)

io_utils.save_image(plot_img, file_path)

Expand Down Expand Up @@ -294,10 +279,7 @@ def visualize_sfm_data(sfm_data: SfmData, folder_name: str) -> None:


def visualize_camera_poses(
pre_ba_sfm_data: SfmData,
post_ba_sfm_data: SfmData,
gt_pose_graph: Optional[List[Pose3]],
folder_name: str,
pre_ba_sfm_data: SfmData, post_ba_sfm_data: SfmData, gt_pose_graph: Optional[List[Pose3]], folder_name: str,
) -> None:
"""Visualize the camera pose and save to disk.

Expand Down Expand Up @@ -352,9 +334,7 @@ def write_sfmdata_to_disk(sfm_data: SfmData, save_fpath: str) -> None:
gtsam.writeBAL(save_fpath, sfm_data)


def persist_frontend_metrics_full(
metrics: Dict[Tuple[int, int], FRONTEND_METRICS_FOR_PAIR],
) -> None:
def persist_frontend_metrics_full(metrics: Dict[Tuple[int, int], FRONTEND_METRICS_FOR_PAIR],) -> None:
"""Persist the front-end metrics for every pair on disk.

Args:
Expand Down Expand Up @@ -404,10 +384,7 @@ def aggregate_frontend_metrics(
all_correct = np.count_nonzero(metrics_array[:, 3] == 1.0)

logger.debug(
"[Two view optimizer] [Summary] Rotation success: %d/%d/%d",
success_count_rot3,
num_valid_entries,
num_entries,
"[Two view optimizer] [Summary] Rotation success: %d/%d/%d", success_count_rot3, num_valid_entries, num_entries,
)

logger.debug(
Expand All @@ -418,34 +395,21 @@ def aggregate_frontend_metrics(
)

logger.debug(
"[Two view optimizer] [Summary] Pose success: %d/%d/%d",
success_count_pose,
num_valid_entries,
num_entries,
"[Two view optimizer] [Summary] Pose success: %d/%d/%d", success_count_pose, num_valid_entries, num_entries,
)

logger.debug(
"[Two view optimizer] [Summary] Image pairs with 100%% inlier ratio:: %d/%d",
all_correct,
num_entries,
"[Two view optimizer] [Summary] Image pairs with 100%% inlier ratio:: %d/%d", all_correct, num_entries,
)

front_end_result_info = {
"angular_err_threshold_deg": angular_err_threshold_deg,
"num_valid_entries": int(num_valid_entries),
"num_total_entries": int(num_entries),
"rotation": {
"success_count": int(success_count_rot3),
},
"translation": {
"success_count": int(success_count_unit3),
},
"pose": {
"success_count": int(success_count_pose),
},
"correspondences": {
"all_inliers": int(all_correct),
},
"rotation": {"success_count": int(success_count_rot3),},
"translation": {"success_count": int(success_count_unit3),},
"pose": {"success_count": int(success_count_pose),},
"correspondences": {"all_inliers": int(all_correct),},
}

io_utils.save_json_file(os.path.join(METRICS_PATH, "frontend_summary.json"), front_end_result_info)
5 changes: 1 addition & 4 deletions gtsfm/two_view_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import gtsfm.utils.geometry_comparisons as comp_utils
import gtsfm.utils.logger as logger_utils
import gtsfm.utils.metrics as metric_utils
import gtsfm.utils.serialization # import needed to register serialization fns
from gtsfm.common.keypoints import Keypoints
from gtsfm.frontend.matcher.matcher_base import MatcherBase
from gtsfm.frontend.verifier.verifier_base import VerifierBase
Expand Down Expand Up @@ -153,9 +152,7 @@ def compute_correspondence_metrics(


def compute_relative_pose_metrics(
i2Ri1_computed: Optional[Rot3],
i2Ui1_computed: Optional[Unit3],
i2Ti1_expected: Pose3,
i2Ri1_computed: Optional[Rot3], i2Ui1_computed: Optional[Unit3], i2Ti1_expected: Pose3,
) -> Tuple[Optional[float], Optional[float]]:
"""Compute the metrics on relative camera pose.

Expand Down
34 changes: 17 additions & 17 deletions tests/utils/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import gtsam
import numpy as np
from distributed.protocol.serialize import serialize, deserialize
from gtsam import (
Cal3Bundler,
PinholeCameraCal3Bundler,
Expand All @@ -15,7 +16,6 @@
Unit3,
)

import gtsfm.utils.serialization as serialization_utils
from gtsfm.common.sfm_result import SfmResult

GTSAM_EXAMPLE_FILE = "dubrovnik-3-7-pre"
Expand All @@ -28,36 +28,36 @@ class TestSerialization(unittest.TestCase):
def test_point3_roundtrip(self):
"""Test the round-trip on Point3 object."""
expected = Point3(np.random.randn(3))
header, frames = serialization_utils.serialize_Point3(expected)
recovered = serialization_utils.deserialize_Point3(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)
np.testing.assert_allclose(expected, recovered)

def test_pose3_roundtrip(self):
"""Test the round-trip on Point3 object."""
expected = Pose3(Rot3.RzRyRx(0, 0.1, 0.2), np.random.randn(3))
header, frames = serialization_utils.serialize_Pose3(expected)
recovered = serialization_utils.deserialize_Pose3(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)
self.assertTrue(recovered.equals(expected, 1e-5))

def test_rot3_roundtrip(self):
"""Test the round-trip on Rot3 object."""
expected = Rot3.RzRyRx(0, 0.05, 0.1)
header, frames = serialization_utils.serialize_Rot3(expected)
recovered = serialization_utils.deserialize_Rot3(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)
self.assertTrue(expected.equals(recovered, 1e-5))

def test_unit3_roundtrip(self):
"""Test the round-trip on Unit3 object."""
expected = Unit3(np.random.randn(3))
header, frames = serialization_utils.serialize_Unit3(expected)
recovered = serialization_utils.deserialize_Unit3(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)
self.assertTrue(expected.equals(recovered, 1e-5))

def test_cal3Bundler_roundtrip(self):
"""Test the round-trip on Cal3Bundler object."""
expected = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70)
header, frames = serialization_utils.serialize_Cal3Bundler(expected)
recovered = serialization_utils.deserialize_Cal3Bundler(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)
self.assertTrue(expected.equals(recovered, 1e-5))

def test_pinholeCameraCal3Bundler_roundtrip(self):
Expand All @@ -67,25 +67,25 @@ def test_pinholeCameraCal3Bundler_roundtrip(self):
Pose3(Rot3.RzRyRx(0, 0.1, -0.05), np.random.randn(3, 1)),
Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70),
)
header, frames = serialization_utils.serialize_PinholeCameraCal3Bundler(expected)
recovered = serialization_utils.deserialize_PinholeCameraCal3Bundler(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)

self.assertTrue(expected.equals(recovered, 1e-5))

def test_sfmData_roundtrip(self):
"""Test for equality after serializing and then de-serializing an SfmData instance."""
expected = EXAMPLE_DATA
header, frames = serialization_utils.serialize_SfmData(expected)
recovered = serialization_utils.deserialize_SfmData(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)

# comparing tracks in an order-sensitive fashion.
self.assertTrue(recovered.equals(expected, 1e-9))

def test_sfmResult_roundtrip(self):
"""Test for equality after serializing and then de-serializing an SfmResult instance."""
expected = SfmResult(EXAMPLE_DATA, total_reproj_error=1.5)
header, frames = serialization_utils.serialize_SfmResult(expected)
recovered = serialization_utils.deserialize_SfmResult(header, frames)
header, frames = serialize(expected)
recovered = deserialize(header, frames)

# comparing cameras and total reprojection error
self.assertEqual(recovered.total_reproj_error, expected.total_reproj_error)
Expand Down