Skip to content

Commit

Permalink
fix(interpolation): fix bug of interpolation (#8903)
Browse files Browse the repository at this point in the history
* fix(interpolation): fix bug of interpolation

Signed-off-by: Y.Hisaki <[email protected]>

* add const

Signed-off-by: Y.Hisaki <[email protected]>

* auto -> int64_t

Signed-off-by: Y.Hisaki <[email protected]>

* add const

Signed-off-by: Y.Hisaki <[email protected]>

* add const

Signed-off-by: Y.Hisaki <[email protected]>

* add const

Signed-off-by: Y.Hisaki <[email protected]>

---------

Signed-off-by: Y.Hisaki <[email protected]>
  • Loading branch information
yhisaki authored Sep 19, 2024
1 parent c1aac51 commit 7819089
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "autoware/universe_utils/geometry/geometry.hpp"
#include "interpolation/interpolation_utils.hpp"

#include <Eigen/Core>

#include <algorithm>
#include <cmath>
#include <iostream>
Expand All @@ -26,25 +28,6 @@

namespace interpolation
{
// NOTE: X(s) = a_i (s - s_i)^3 + b_i (s - s_i)^2 + c_i (s - s_i) + d_i : (i = 0, 1, ... N-1)
struct MultiSplineCoef
{
MultiSplineCoef() = default;

explicit MultiSplineCoef(const size_t num_spline)
{
a.resize(num_spline);
b.resize(num_spline);
c.resize(num_spline);
d.resize(num_spline);
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};

// static spline interpolation functions
std::vector<double> spline(
const std::vector<double> & base_keys, const std::vector<double> & base_values,
Expand Down Expand Up @@ -98,11 +81,17 @@ class SplineInterpolation
size_t getSize() const { return base_keys_.size(); }

private:
Eigen::VectorXd a_;
Eigen::VectorXd b_;
Eigen::VectorXd c_;
Eigen::VectorXd d_;

std::vector<double> base_keys_;
interpolation::MultiSplineCoef multi_spline_coef_;

void calcSplineCoefficients(
const std::vector<double> & base_keys, const std::vector<double> & base_values);

Eigen::Index get_index(const double & key) const;
};

#endif // INTERPOLATION__SPLINE_INTERPOLATION_HPP_
1 change: 1 addition & 0 deletions common/interpolation/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<buildtool_depend>autoware_cmake</buildtool_depend>

<depend>autoware_universe_utils</depend>
<depend>eigen</depend>

<test_depend>ament_cmake_ros</test_depend>
<test_depend>ament_lint_auto</test_depend>
Expand Down
225 changes: 93 additions & 132 deletions common/interpolation/src/spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,40 @@

#include "interpolation/spline_interpolation.hpp"

#include <cstdint>
#include <vector>

namespace
{
// solve Ax = d
// where A is tridiagonal matrix
// [b_0 c_0 ... ]
// [a_0 b_1 c_1 ... O ]
// A = [ ... ]
// [ O ... a_N-3 b_N-2 c_N-2]
// [ ... a_N-2 b_N-1]
struct TDMACoef
Eigen::VectorXd solve_tridiagonal_matrix_algorithm(
const Eigen::Ref<const Eigen::VectorXd> & a, const Eigen::Ref<const Eigen::VectorXd> & b,
const Eigen::Ref<const Eigen::VectorXd> & c, const Eigen::Ref<const Eigen::VectorXd> & d)
{
explicit TDMACoef(const size_t num_row)
{
a.resize(num_row - 1);
b.resize(num_row);
c.resize(num_row - 1);
d.resize(num_row);
const auto n = d.size();

if (n == 1) {
return d.array() / b.array();
}

std::vector<double> a;
std::vector<double> b;
std::vector<double> c;
std::vector<double> d;
};
Eigen::VectorXd c_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd d_prime = Eigen::VectorXd::Zero(n);
Eigen::VectorXd x = Eigen::VectorXd::Zero(n);

inline std::vector<double> solveTridiagonalMatrixAlgorithm(const TDMACoef & tdma_coef)
{
const auto & a = tdma_coef.a;
const auto & b = tdma_coef.b;
const auto & c = tdma_coef.c;
const auto & d = tdma_coef.d;

const size_t num_row = b.size();

std::vector<double> x(num_row);
if (num_row != 1) {
// calculate p and q
std::vector<double> p;
std::vector<double> q;
p.push_back(-c[0] / b[0]);
q.push_back(d[0] / b[0]);

for (size_t i = 1; i < num_row; ++i) {
const double den = b[i] + a[i - 1] * p[i - 1];
p.push_back(-c[i - 1] / den);
q.push_back((d[i] - a[i - 1] * q[i - 1]) / den);
}
// Forward sweep
c_prime(0) = c(0) / b(0);
d_prime(0) = d(0) / b(0);

// calculate solution
x[num_row - 1] = q[num_row - 1];
for (auto i = 1; i < n; i++) {
const double m = 1.0 / (b(i) - a(i - 1) * c_prime(i - 1));
c_prime(i) = i < n - 1 ? c(i) * m : 0;
d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) * m;
}

for (size_t i = 1; i < num_row; ++i) {
const size_t j = num_row - 1 - i;
x[j] = p[j] * x[j + 1] + q[j];
}
} else {
x[0] = (d[0] / b[0]);
// Back substitution
x(n - 1) = d_prime(n - 1);

for (int64_t i = n - 2; i >= 0; i--) {
x(i) = d_prime(i) - c_prime(i) * x(i + 1);
}

return x;
Expand Down Expand Up @@ -166,53 +141,59 @@ std::vector<double> splineByAkima(
}
} // namespace interpolation

Eigen::Index SplineInterpolation::get_index(const double & key) const
{
const auto it = std::lower_bound(base_keys_.begin(), base_keys_.end(), key);
return std::clamp(
static_cast<int>(std::distance(base_keys_.begin(), it)) - 1, 0,
static_cast<int>(base_keys_.size()) - 2);
}

void SplineInterpolation::calcSplineCoefficients(
const std::vector<double> & base_keys, const std::vector<double> & base_values)
{
// throw exceptions for invalid arguments
interpolation_utils::validateKeysAndValues(base_keys, base_values);

const size_t num_base = base_keys.size(); // N+1

std::vector<double> diff_keys; // N
std::vector<double> diff_values; // N
for (size_t i = 0; i < num_base - 1; ++i) {
diff_keys.push_back(base_keys.at(i + 1) - base_keys.at(i));
diff_values.push_back(base_values.at(i + 1) - base_values.at(i));
}

std::vector<double> v = {0.0};
if (num_base > 2) {
// solve tridiagonal matrix algorithm
TDMACoef tdma_coef(num_base - 2); // N-1

for (size_t i = 0; i < num_base - 2; ++i) {
tdma_coef.b[i] = 2 * (diff_keys[i] + diff_keys[i + 1]);
if (i != num_base - 3) {
tdma_coef.a[i] = diff_keys[i + 1];
tdma_coef.c[i] = diff_keys[i + 1];
}
tdma_coef.d[i] =
6.0 * (diff_values[i + 1] / diff_keys[i + 1] - diff_values[i] / diff_keys[i]);
}

const std::vector<double> tdma_res = solveTridiagonalMatrixAlgorithm(tdma_coef);

// calculate v
v.insert(v.end(), tdma_res.begin(), tdma_res.end());
}
v.push_back(0.0);

// calculate a, b, c, d of spline coefficients
multi_spline_coef_ = interpolation::MultiSplineCoef{num_base - 1}; // N
for (size_t i = 0; i < num_base - 1; ++i) {
multi_spline_coef_.a[i] = (v[i + 1] - v[i]) / 6.0 / diff_keys[i];
multi_spline_coef_.b[i] = v[i] / 2.0;
multi_spline_coef_.c[i] =
diff_values[i] / diff_keys[i] - diff_keys[i] * (2 * v[i] + v[i + 1]) / 6.0;
multi_spline_coef_.d[i] = base_values[i];
const Eigen::VectorXd x = Eigen::Map<const Eigen::VectorXd>(
base_keys.data(), static_cast<Eigen::Index>(base_keys.size()));
const Eigen::VectorXd y = Eigen::Map<const Eigen::VectorXd>(
base_values.data(), static_cast<Eigen::Index>(base_values.size()));

const auto n = x.size();

if (n == 2) {
a_ = Eigen::VectorXd::Zero(1);
b_ = Eigen::VectorXd::Zero(1);
c_ = Eigen::VectorXd::Zero(1);
d_ = Eigen::VectorXd::Zero(1);
c_[0] = (y[1] - y[0]) / (x[1] - x[0]);
d_[0] = y[0];
base_keys_ = base_keys;
return;
}

// Create Tridiagonal matrix
Eigen::VectorXd v(n);
const Eigen::VectorXd h = x.segment(1, n - 1) - x.segment(0, n - 1);
const Eigen::VectorXd a = h.segment(1, n - 3);
const Eigen::VectorXd b = 2 * (h.segment(0, n - 2) + h.segment(1, n - 2));
const Eigen::VectorXd c = h.segment(1, n - 3);
const Eigen::VectorXd y_diff = y.segment(1, n - 1) - y.segment(0, n - 1);
const Eigen::VectorXd d = 6 * (y_diff.segment(1, n - 2).array() / h.tail(n - 2).array() -
y_diff.segment(0, n - 2).array() / h.head(n - 2).array());

// Solve tridiagonal matrix
v.segment(1, n - 2) = solve_tridiagonal_matrix_algorithm(a, b, c, d);
v[0] = 0;
v[n - 1] = 0;

// Calculate spline coefficients
a_ = (v.tail(n - 1) - v.head(n - 1)).array() / 6.0 / (x.tail(n - 1) - x.head(n - 1)).array();
b_ = v.segment(0, n - 1) / 2.0;
c_ = (y.tail(n - 1) - y.head(n - 1)).array() / (x.tail(n - 1) - x.head(n - 1)).array() -
(x.tail(n - 1) - x.head(n - 1)).array() *
(2 * v.segment(0, n - 1).array() + v.segment(1, n - 1).array()) / 6.0;
d_ = y.head(n - 1);
base_keys_ = base_keys;
}

Expand All @@ -221,69 +202,49 @@ std::vector<double> SplineInterpolation::getSplineInterpolatedValues(
{
// throw exceptions for invalid arguments
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;
const auto & d = multi_spline_coef_.d;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(d.at(j) + (c.at(j) + (b.at(j) + a.at(j) * ds) * ds) * ds);
std::vector<double> interpolated_values;
interpolated_values.reserve(query_keys.size());

for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_values.emplace_back(
a_[idx] * dx * dx * dx + b_[idx] * dx * dx + c_[idx] * dx + d_[idx]);
}

return res;
return interpolated_values;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedDiffValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);
std::vector<double> interpolated_diff_values;
interpolated_diff_values.reserve(query_keys.size());

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;
const auto & c = multi_spline_coef_.c;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(c.at(j) + (2.0 * b.at(j) + 3.0 * a.at(j) * ds) * ds);
for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_diff_values.emplace_back(3 * a_[idx] * dx * dx + 2 * b_[idx] * dx + c_[idx]);
}

return res;
return interpolated_diff_values;
}

std::vector<double> SplineInterpolation::getSplineInterpolatedQuadDiffValues(
const std::vector<double> & query_keys) const
{
// throw exceptions for invalid arguments
const auto validated_query_keys = interpolation_utils::validateKeys(base_keys_, query_keys);
std::vector<double> interpolated_quad_diff_values;
interpolated_quad_diff_values.reserve(query_keys.size());

const auto & a = multi_spline_coef_.a;
const auto & b = multi_spline_coef_.b;

std::vector<double> res;
size_t j = 0;
for (const auto & query_key : validated_query_keys) {
while (base_keys_.at(j + 1) < query_key) {
++j;
}

const double ds = query_key - base_keys_.at(j);
res.push_back(2.0 * b.at(j) + 6.0 * a.at(j) * ds);
for (const auto & key : query_keys) {
const auto idx = get_index(key);
const auto dx = key - base_keys_[idx];
interpolated_quad_diff_values.emplace_back(6 * a_[idx] * dx + 2 * b_[idx]);
}

return res;
return interpolated_quad_diff_values;
}
10 changes: 5 additions & 5 deletions common/interpolation/test/src/test_spline_interpolation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ TEST(spline_interpolation, spline)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 18.0};
const std::vector<double> ans{-0.075611, 0.997242, 1.573258};
const std::vector<double> ans{-0.076114, 1.001217, 1.573640};

const auto query_values = interpolation::spline(base_keys, base_values, query_keys);
for (size_t i = 0; i < query_values.size(); ++i) {
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST(spline_interpolation, spline)
const std::vector<double> base_keys = {0.0, 1.0, 1.0001, 2.0, 3.0, 4.0};
const std::vector<double> base_values = {0.0, 0.0, 0.1, 0.1, 0.1, 0.1};
const std::vector<double> query_keys = {0.0, 1.0, 1.5, 2.0, 3.0, 4.0};
const std::vector<double> ans = {0.0, 0.0, 137.591789, 0.1, 0.1, 0.1};
const std::vector<double> ans = {0.0, 0.0, 158.738293, 0.1, 0.1, 0.1};

const auto query_values = interpolation::spline(base_keys, base_values, query_keys);
for (size_t i = 0; i < query_values.size(); ++i) {
Expand Down Expand Up @@ -227,7 +227,7 @@ TEST(spline_interpolation, SplineInterpolation)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 18.0};
const std::vector<double> ans{-0.075611, 0.997242, 1.573258};
const std::vector<double> ans{-0.076114, 1.001217, 1.573640};

SplineInterpolation s(base_keys, base_values);
const std::vector<double> query_values = s.getSplineInterpolatedValues(query_keys);
Expand All @@ -242,7 +242,7 @@ TEST(spline_interpolation, SplineInterpolation)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 12.0, 18.0};
const std::vector<double> ans{0.671301, 0.0509853, 0.209426, -0.253628};
const std::vector<double> ans{0.671343, 0.049289, 0.209471, -0.253746};

SplineInterpolation s(base_keys, base_values);
const std::vector<double> query_values = s.getSplineInterpolatedDiffValues(query_keys);
Expand All @@ -257,7 +257,7 @@ TEST(spline_interpolation, SplineInterpolation)
const std::vector<double> base_keys{-1.5, 1.0, 5.0, 10.0, 15.0, 20.0};
const std::vector<double> base_values{-1.2, 0.5, 1.0, 1.2, 2.0, 1.0};
const std::vector<double> query_keys{0.0, 8.0, 12.0, 18.0};
const std::vector<double> ans{-0.156582, 0.0440771, -0.0116873, -0.0495025};
const std::vector<double> ans{-0.155829, 0.043097, -0.011143, -0.049611};

SplineInterpolation s(base_keys, base_values);
const std::vector<double> query_values = s.getSplineInterpolatedQuadDiffValues(query_keys);
Expand Down
Loading

1 comment on commit 7819089

@SakuragiL
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could you tell me what this bug of interpolation is?

Please sign in to comment.