-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: One step in finding all the pairs of vertices which share faces is a simple calculation but annoying to parallelize. It was implemented in pure Python. We move it to C++. We still pull the data to the CPU and put the answer back on the device. Reviewed By: nikhilaravi, gkioxari Differential Revision: D26073475 fbshipit-source-id: ffbf4e2c347a511ab5084bceff600465812b6a52
- Loading branch information
1 parent
5ac2f42
commit 4bfe715
Showing
4 changed files
with
84 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
#pragma once | ||
#include <torch/extension.h> | ||
#include "utils/pytorch3d_cutils.h" | ||
|
||
// For mesh_normal_consistency, find pairs of vertices opposite the same edge. | ||
// | ||
// Args: | ||
// edge_num: int64 Tensor of shape (E,) giving the number of vertices | ||
// corresponding to each edge. | ||
// | ||
// Returns: | ||
// pairs: int64 Tensor of shape (N,2) | ||
|
||
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num); | ||
|
||
// Exposed implementation. | ||
at::Tensor MeshNormalConsistencyFindVertices(const at::Tensor& edge_num) { | ||
if (edge_num.is_cuda()) { | ||
AT_ERROR("This function needs a CPU tensor."); | ||
} | ||
return MeshNormalConsistencyFindVerticesCpu(edge_num); | ||
} |
47 changes: 47 additions & 0 deletions
47
pytorch3d/csrc/mesh_normal_consistency/mesh_normal_consistency_cpu.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
#include <ATen/ATen.h> | ||
#include <utility> | ||
#include <vector> | ||
|
||
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num) { | ||
// We take a LongTensor of shape (E,) giving the number of things intersecting | ||
// each edge. The things are taken to be numbered in order. | ||
// (In fact, the "things" are opposite vertices to edges, renumbered). | ||
// We return a tensor of shape (?, 2) where for every pair of things which | ||
// intersect the same edge there is a row of their numbers in the output. | ||
|
||
// Example possible inputs and outputs (order of output is not specified): | ||
// [1,0,1,1,0] => [[]] | ||
// [3] => [[0,1], [0,2], [1,2]] | ||
// [0,3] => [[0,1], [0,2], [1,2]] | ||
// [1,3] => [[1,2], [1,3], [2,3]] | ||
//[1,0,2,1,0,2] => [[1,2], [4,5]] | ||
|
||
const auto num_edges = edge_num.size(0); | ||
auto edges_a = edge_num.accessor<int64_t, 1>(); | ||
|
||
int64_t vert_idx = 0; | ||
std::vector<std::pair<int64_t, int64_t>> pairs; | ||
for (int64_t i_edge = 0; i_edge < num_edges; ++i_edge) { | ||
int64_t e = edges_a[i_edge]; | ||
for (int64_t j = 0; j < e; ++j) { | ||
for (int64_t i = 0; i < j; ++i) { | ||
pairs.emplace_back(vert_idx + i, vert_idx + j); | ||
} | ||
} | ||
vert_idx += e; | ||
} | ||
|
||
// Convert from std::vector by copying over the items to a new empty torch | ||
// tensor. | ||
auto pairs_tensor = at::empty({(int64_t)pairs.size(), 2}, edge_num.options()); | ||
auto pairs_a = pairs_tensor.accessor<int64_t, 2>(); | ||
for (int64_t i_pair = 0; i_pair < pairs.size(); ++i_pair) { | ||
auto accessor = pairs_a[i_pair]; | ||
accessor[0] = pairs[i_pair].first; | ||
accessor[1] = pairs[i_pair].second; | ||
} | ||
|
||
return pairs_tensor; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters