Skip to content

Commit

Permalink
fix(interpolation): fix bug of interpolation (#8903)
Browse files Browse the repository at this point in the history
* fix(interpolation): fix bug of interpolation

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>

* add const

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>

* auto -> int64_t

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>

* add const

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>

* add const

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>

* add const

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>

---------

Signed-off-by: Y.Hisaki <yhisaki31@gmail.com>
  • Loading branch information
M. Fatih Cırıt committed Sep 20, 2024
1 parent 65dcc7f commit 4996cf7
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "autoware/interpolation/interpolation_utils.hpp"
#include "autoware/universe_utils/geometry/geometry.hpp"

#include <Eigen/Core>

#include <algorithm>
#include <cmath>
#include <iostream>
Expand All @@ -26,25 +28,6 @@

namespace autoware::interpolation
{
// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
struct MultiSplineCoef
{
MultiSplineCoef() = default;

explicit MultiSplineCoef(const size_t num_spline)
{
a.resize(num_spline);
b.resize(num_spline);
c.resize(num_spline);
d.resize(num_spline);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

// static spline interpolation functions
std::vector<double> spline(
const std::vector<double> & base_keys, const std::vector<double> & base_values,
Expand Down Expand Up @@ -97,11 +80,17 @@ class SplineInterpolation
size_t getSize() const { return base_keys_.size(); }

private:
Eigen::VectorXd a_;
Eigen::VectorXd b_;
Eigen::VectorXd c_;
Eigen::VectorXd d_;

std::vector<double> base_keys_;
MultiSplineCoef multi_spline_coef_;

void calcSplineCoefficients(
const std::vector<double> & base_keys, const std::vector<double> & base_values);

Eigen::Index get_index(const double & key) const;
};
} // namespace autoware::interpolation

Expand Down
1 change: 1 addition & 0 deletions common/autoware_interpolation/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<buildtool_depend>autoware_cmake</buildtool_depend>

<depend>autoware_universe_utils</depend>
<depend>eigen</depend>

<test_depend>ament_cmake_ros</test_depend>
<test_depend>ament_lint_auto</test_depend>
Expand Down
237 changes: 99 additions & 138 deletions common/autoware_interpolation/src/spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,40 @@

#include "autoware/interpolation/spline_interpolation.hpp"

#include <cstdint>
#include <vector>

namespace autoware::interpolation
{
// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
Eigen::VectorXd solve_tridiagonal_matrix_algorithm(
const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
const auto n = d.size();

if (n == 1) {
return d.array() / b.array();
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};
Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd x = Eigen::VectorXd::Zero(n);

inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}
// Forward sweep
c_prime(0) = c(0) / b(0);
d_prime(0) = d(0) / b(0);

// calculate solution
x[num_row - 1] = q[num_row - 1];
for (auto i = 1; i < n; i++) {
const double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1));
c_prime(i) = i < n - 1 ? c(i) * m : 0;
d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m;
}

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x[0] = (d[0] / b[0]);
// Back substitution
x(n - 1) = d_prime(n - 1);

for (int64_t i = n - 2; i >= 0; i--) {
x(i) = d_prime(i) - c_prime(i) * x(i + 1);
}

return x;
Expand Down Expand Up @@ -162,125 +137,111 @@ std::vector<double> splineByAkima(
return res;
}

Eigen::Index SplineInterpolation::get_index(const double & key) const
{
const auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key);
return std::clamp(
static_cast<int>(std::distance(base_keys_.begin(), it)) - 1, 0,
static_cast<int>(base_keys_.size()) - 2);
}

void SplineInterpolation::calcSplineCoefficients(
const std::vector<double> & base_keys, const std::vector<double> & base_values)
{
// throw exceptions for invalid arguments
validateKeysAndValues(base_keys, base_values);

const size_t num_base = base_keys.size(); // N+1

std::vector<double> diff_keys; // N
std::vector<double> diff_values; // N
for (size_t i = 0; i < num_base - 1; ++i) {
diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i));
diff_values.push_back(base_values.at(i + 1) - base_values.at(i));
}

std::vector<double> v = {0.0};
if (num_base > 2) {
// solve tridiagonal matrix algorithm
TDMACoef tdma_coef(num_base - 2); // N-1

for (size_t i = 0; i < num_base - 2; ++i) {
tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
if (i != num_base - 3) {
tdma_coef.a[i] = diff_keys[i + 1];
tdma_coef.c[i] = diff_keys[i + 1];
}
tdma_coef.d[i] =
6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
}

const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);

// calculate v
v.insert(v.end(), tdma_res.begin(), tdma_res.end());
}
v.push_back(0.0);

// calculate a, b, c, d of spline coefficients
multi_spline_coef_ = MultiSplineCoef{num_base - 1}; // N
for (size_t i = 0; i < num_base - 1; ++i) {
multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i];
multi_spline_coef_.b[i] = v[i] / 2.0;
multi_spline_coef_.c[i] =
diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0;
multi_spline_coef_.d[i] = base_values[i];
interpolation_utils::validateKeysAndValues(base_keys, base_values);
const Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
base_keys.data(), static_cast<Eigen::Index>(base_keys.size()));
const Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
base_values.data(), static_cast<Eigen::Index>(base_values.size()));

const auto n = x.size();

if (n == 2) {
a_ = Eigen::VectorXd::Zero(1);
b_ = Eigen::VectorXd::Zero(1);
c_ = Eigen::VectorXd::Zero(1);
d_ = Eigen::VectorXd::Zero(1);
c_[0] = (y[1] - y[0]) / (x[1] - x[0]);
d_[0] = y[0];
base_keys_ = base_keys;
return;
}

// Create Tridiagonal matrix
Eigen::VectorXd v(n);
const Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1);
const Eigen::VectorXd a = h.segment(1, n - 3);
const Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2));
const Eigen::VectorXd c = h.segment(1, n - 3);
const Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1);
const Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() -
y_diff.segment(0, n - 2).array() / h.head(n - 2).array());

// Solve tridiagonal matrix
v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d);
v[0] = 0;
v[n - 1] = 0;

// Calculate spline coefficients
a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array();
b_ = v.segment(0, n - 1) / 2.0;
c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() -
(x.tail(n - 1) - x.head(n - 1)).array() *
(2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0;
d_ = y.head(n - 1);
base_keys_ = base_keys;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
const auto validated_query_keys = validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;
const auto & d = multi_spline_coef_.d;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);
std::vector<double> interpolated_values;
interpolated_values.reserve(query_keys.size());

for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_values.emplace_back(
a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
}

return res;
return interpolated_values;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
const auto validated_query_keys = validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);
std::vector<double> interpolated_diff_values;
interpolated_diff_values.reserve(query_keys.size());

for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
}

return res;
return interpolated_diff_values;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedQuadDiffValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
const auto validated_query_keys = validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(2.0 * b.at(j) + 6.0 * a.at(j) * ds);
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);
std::vector<double> interpolated_quad_diff_values;
interpolated_quad_diff_values.reserve(query_keys.size());

for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_quad_diff_values.emplace_back(6 * a_[idx] * dx + 2 * b_[idx]);
}

return res;
return interpolated_quad_diff_values;
}
} // namespace autoware::interpolation
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ TEST(spline_interpolation, spline)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 18.0};
const std::vector<double> ans{-0.075611, 0.997242, 1.573258};
const std::vector<double> ans{-0.076114, 1.001217, 1.573640};

const auto query_values = autoware::interpolation::spline(base_keys, base_values, query_keys);
for (size_t i = 0; i < query_values.size(); ++i) {
Expand Down Expand Up @@ -114,7 +114,7 @@ TEST(spline_interpolation, spline)
const std::vector<double> base_keys = {0.0, 1.0, 1.0001, 2.0, 3.0, 4.0};
const std::vector<double> base_values = {0.0, 0.0, 0.1, 0.1, 0.1, 0.1};
const std::vector<double> query_keys = {0.0, 1.0, 1.5, 2.0, 3.0, 4.0};
const std::vector<double> ans = {0.0, 0.0, 137.591789, 0.1, 0.1, 0.1};
const std::vector<double> ans = {0.0, 0.0, 158.738293, 0.1, 0.1, 0.1};

const auto query_values = autoware::interpolation::spline(base_keys, base_values, query_keys);
for (size_t i = 0; i < query_values.size(); ++i) {
Expand Down Expand Up @@ -237,7 +237,7 @@ TEST(spline_interpolation, SplineInterpolation)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 18.0};
const std::vector<double> ans{-0.075611, 0.997242, 1.573258};
const std::vector<double> ans{-0.076114, 1.001217, 1.573640};

SplineInterpolation s(base_keys, base_values);
const std::vector<double> query_values = s.getSplineInterpolatedValues(query_keys);
Expand All @@ -252,7 +252,7 @@ TEST(spline_interpolation, SplineInterpolation)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 12.0, 18.0};
const std::vector<double> ans{0.671301, 0.0509853, 0.209426, -0.253628};
const std::vector<double> ans{0.671343, 0.049289, 0.209471, -0.253746};

SplineInterpolation s(base_keys, base_values);
const std::vector<double> query_values = s.getSplineInterpolatedDiffValues(query_keys);
Expand All @@ -267,7 +267,7 @@ TEST(spline_interpolation, SplineInterpolation)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 12.0, 18.0};
const std::vector<double> ans{-0.156582, 0.0440771, -0.0116873, -0.0495025};
const std::vector<double> ans{-0.155829, 0.043097, -0.011143, -0.049611};

SplineInterpolation s(base_keys, base_values);
const std::vector<double> query_values = s.getSplineInterpolatedQuadDiffValues(query_keys);
Expand Down
Loading

0 comments on commit 4996cf7

Please sign in to comment.