From 25abdf31aa035301aa988262a4f0ae43622ec70b Mon Sep 17 00:00:00 2001 From: Sotaro Katayama Date: Tue, 1 Nov 2022 19:40:24 +0900 Subject: [PATCH] add methods to get cost list --- .../python/robotoc/cost/cost_function.cpp | 4 ++- include/robotoc/cost/cost_function.hpp | 18 +++++++++++ src/cost/cost_function.cpp | 30 +++++++++++++++++++ test/cost/cost_function_test.cpp | 4 +++ 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/bindings/python/robotoc/cost/cost_function.cpp b/bindings/python/robotoc/cost/cost_function.cpp index 18be77df7..cc236d488 100644 --- a/bindings/python/robotoc/cost/cost_function.cpp +++ b/bindings/python/robotoc/cost/cost_function.cpp @@ -57,7 +57,9 @@ PYBIND11_MODULE(cost_function, m) { py::arg("robot"), py::arg("impact_status"), py::arg("data"), py::arg("grid_info"), py::arg("s"), py::arg("kkt_residual"), py::arg("kkt_matrix")) - DEFINE_ROBOTOC_PYBIND11_CLASS_CLONE(CostFunction); + .def("get_cost_component_list", &CostFunction::getCostComponentList) + DEFINE_ROBOTOC_PYBIND11_CLASS_CLONE(CostFunction) + DEFINE_ROBOTOC_PYBIND11_CLASS_PRINT(CostFunction); } } // namespace python diff --git a/include/robotoc/cost/cost_function.hpp b/include/robotoc/cost/cost_function.hpp index ee9200760..b1ed9c02c 100644 --- a/include/robotoc/cost/cost_function.hpp +++ b/include/robotoc/cost/cost_function.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "Eigen/Core" @@ -278,6 +279,23 @@ class CostFunction { SplitKKTResidual& kkt_residual, SplitKKTMatrix& kkt_matrix) const; + /// + /// @brief Gets a list of the cost components. + /// @return Name list of cost components. + /// + std::vector getCostComponentList() const; + + /// + /// @brief Displays the cost function onto a ostream. + /// + void disp(std::ostream& os) const; + + friend std::ostream& operator<<(std::ostream& os, + const CostFunction& cost_function); + + friend std::ostream& operator<<(std::ostream& os, + const std::shared_ptr& cost_function); + private: std::vector costs_; std::unordered_map cost_names_; diff --git a/src/cost/cost_function.cpp b/src/cost/cost_function.cpp index cd7831aaf..12cdf2a6e 100644 --- a/src/cost/cost_function.cpp +++ b/src/cost/cost_function.cpp @@ -326,4 +326,34 @@ double CostFunction::quadratizeImpactCost(Robot& robot, return l; } + +std::vector CostFunction::getCostComponentList() const { + std::vector cost_component_list; + for (std::pair e : cost_names_) { + cost_component_list.push_back(e.first); + } + return cost_component_list; +} + + +void CostFunction::disp(std::ostream& os) const { + os << "CostFunction:" << "\n"; + for (const auto& e : getCostComponentList()) { + os << " - " << e << "\n"; + } +} + + +std::ostream& operator<<(std::ostream& os, const CostFunction& cost_function) { + cost_function.disp(os); + return os; +} + + +std::ostream& operator<<(std::ostream& os, + const std::shared_ptr& cost_function) { + cost_function->disp(os); + return os; +} + } // namespace robotoc \ No newline at end of file diff --git a/test/cost/cost_function_test.cpp b/test/cost/cost_function_test.cpp index b27591c68..01638fe82 100644 --- a/test/cost/cost_function_test.cpp +++ b/test/cost/cost_function_test.cpp @@ -99,6 +99,10 @@ void CostFunctionTest::testStageCost(Robot& robot) { cost_component->as_shared_ptr(), std::runtime_error ); + EXPECT_NO_THROW( + std::cout << non_discounted_cost << std::endl; + std::cout << *non_discounted_cost.get() << std::endl; + ); }