Skip to content

Commit

Permalink
add methods to get cost list
Browse files Browse the repository at this point in the history
  • Loading branch information
mayataka committed Nov 1, 2022
1 parent c860022 commit 25abdf3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 1 deletion.
4 changes: 3 additions & 1 deletion bindings/python/robotoc/cost/cost_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ PYBIND11_MODULE(cost_function, m) {
py::arg("robot"), py::arg("impact_status"), py::arg("data"),
py::arg("grid_info"), py::arg("s"), py::arg("kkt_residual"),
py::arg("kkt_matrix"))
DEFINE_ROBOTOC_PYBIND11_CLASS_CLONE(CostFunction);
.def("get_cost_component_list", &CostFunction::getCostComponentList)
DEFINE_ROBOTOC_PYBIND11_CLASS_CLONE(CostFunction)
DEFINE_ROBOTOC_PYBIND11_CLASS_PRINT(CostFunction);
}

} // namespace python
Expand Down
18 changes: 18 additions & 0 deletions include/robotoc/cost/cost_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cmath>
#include <vector>
#include <unordered_map>
#include <iostream>

#include "Eigen/Core"

Expand Down Expand Up @@ -278,6 +279,23 @@ class CostFunction {
SplitKKTResidual& kkt_residual,
SplitKKTMatrix& kkt_matrix) const;

///
/// @brief Gets a list of the cost components.
/// @return Name list of cost components.
///
std::vector<std::string> getCostComponentList() const;

///
/// @brief Displays the cost function onto a ostream.
///
void disp(std::ostream& os) const;

friend std::ostream& operator<<(std::ostream& os,
const CostFunction& cost_function);

friend std::ostream& operator<<(std::ostream& os,
const std::shared_ptr<CostFunction>& cost_function);

private:
std::vector<CostFunctionComponentBasePtr> costs_;
std::unordered_map<std::string, size_t> cost_names_;
Expand Down
30 changes: 30 additions & 0 deletions src/cost/cost_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,34 @@ double CostFunction::quadratizeImpactCost(Robot& robot,
return l;
}


std::vector<std::string> CostFunction::getCostComponentList() const {
std::vector<std::string> cost_component_list;
for (std::pair<std::string, size_t> e : cost_names_) {
cost_component_list.push_back(e.first);
}
return cost_component_list;
}


void CostFunction::disp(std::ostream& os) const {
os << "CostFunction:" << "\n";
for (const auto& e : getCostComponentList()) {
os << " - " << e << "\n";
}
}


std::ostream& operator<<(std::ostream& os, const CostFunction& cost_function) {
cost_function.disp(os);
return os;
}


std::ostream& operator<<(std::ostream& os,
const std::shared_ptr<CostFunction>& cost_function) {
cost_function->disp(os);
return os;
}

} // namespace robotoc
4 changes: 4 additions & 0 deletions test/cost/cost_function_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ void CostFunctionTest::testStageCost(Robot& robot) {
cost_component->as_shared_ptr<CoMCost>(),
std::runtime_error
);
EXPECT_NO_THROW(
std::cout << non_discounted_cost << std::endl;
std::cout << *non_discounted_cost.get() << std::endl;
);
}


Expand Down

0 comments on commit 25abdf3

Please sign in to comment.