Skip to content

Commit

Permalink
In LeastSquaresEquationsSolver:
Browse files Browse the repository at this point in the history
 * solve(func, startingPoint) to rely on solve(func, startingPoint,
   bounds)
 * Rewrote the way to integrate a user defined optimization solver :
   default solver is instanciated using a fake least squares problem
   without bounds. An additionnal check / redefinition + warn is done in
   case bounds are defined and the solver incompatible
  • Loading branch information
sofianehaddad committed Nov 25, 2024
1 parent 78089ec commit 9397132
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 59 deletions.
92 changes: 38 additions & 54 deletions lib/src/Base/Solver/LeastSquaresEquationsSolver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "openturns/OptimizationAlgorithm.hxx"
#include "openturns/Log.hxx"
#include "openturns/PersistentObjectFactory.hxx"

#include "openturns/SymbolicFunction.hxx"

BEGIN_NAMESPACE_OPENTURNS

Expand All @@ -46,7 +46,13 @@ LeastSquaresEquationsSolver::LeastSquaresEquationsSolver(const Scalar absoluteEr
const UnsignedInteger maximumCallsNumber)
: SolverImplementation(absoluteError, relativeError, residualError, maximumCallsNumber)
{
// Nothing to do
// LeastSquares problem
LeastSquaresProblem problem(SymbolicFunction("x", "x"));
solver_ = OptimizationAlgorithm::Build(problem);
solver_.setMaximumCallsNumber(maximumCallsNumber);
solver_.setMaximumAbsoluteError(absoluteError);
solver_.setMaximumRelativeError(relativeError);
solver_.setMaximumResidualError(residualError);
}

/* Virtual constructor */
Expand All @@ -66,13 +72,12 @@ String LeastSquaresEquationsSolver::__repr__() const

void LeastSquaresEquationsSolver::setOptimizationAlgorithm(const OptimizationAlgorithm & algorithm)
{
useDefaultOptimizationAlgorithm_ = false;
algorithm_ = algorithm;
solver_ = algorithm;
}

OptimizationAlgorithm LeastSquaresEquationsSolver::getOptimizationAlgorithm() const
{
return algorithm_;
return solver_;
}

/* Solve attempt to find one root to the system of non-linear equations function(x) = 0
Expand All @@ -81,29 +86,8 @@ OptimizationAlgorithm LeastSquaresEquationsSolver::getOptimizationAlgorithm() co
Point LeastSquaresEquationsSolver::solve(const Function & function,
const Point & startingPoint) const
{
UnsignedInteger callsNumber = 0;
const UnsignedInteger maximumCallsNumber = getMaximumCallsNumber();
const Scalar absoluteError = getAbsoluteError();
const Scalar relativeError = getRelativeError();
const Scalar residualError = getResidualError();
LeastSquaresProblem lsqProblem(function);
OptimizationAlgorithm lsqAlgorithm;
if (useDefaultOptimizationAlgorithm_)
lsqAlgorithm = OptimizationAlgorithm().Build(lsqProblem);
else
lsqAlgorithm = algorithm_;
lsqAlgorithm.setStartingPoint(startingPoint);
lsqAlgorithm.setMaximumCallsNumber(maximumCallsNumber);
lsqAlgorithm.setMaximumAbsoluteError(absoluteError);
lsqAlgorithm.setMaximumRelativeError(relativeError);
lsqAlgorithm.setMaximumResidualError(residualError);
lsqAlgorithm.run();
callsNumber = lsqAlgorithm.getResult().getCallsNumber();
callsNumber_ = callsNumber;
const Point min_value_obtained = lsqAlgorithm.getResult().getOptimalValue();
if ( residualError < min_value_obtained[0]) throw InternalException(HERE) << "Error: solver did not find a solution that satisfies the threshold, here obtained residual=" << min_value_obtained[0];
const Point result = lsqAlgorithm.getResult().getOptimalPoint();
return result;
const Interval bounds;
return solve(function, startingPoint, bounds);
}

/* Solve attempt to find one root to the system of non-linear equations function(x) = 0
Expand All @@ -113,46 +97,46 @@ Point LeastSquaresEquationsSolver::solve(const Function & function,
const Point & startingPoint,
const Interval & bounds) const
{
UnsignedInteger callsNumber = 0;
const UnsignedInteger maximumCallsNumber = getMaximumCallsNumber();
const Scalar absoluteError = getAbsoluteError();
const Scalar relativeError = getRelativeError();
const Scalar residualError = getResidualError();
LeastSquaresProblem lsqProblem(function);
lsqProblem.setBounds(bounds);
OptimizationAlgorithm lsqAlgorithm;
if (useDefaultOptimizationAlgorithm_)
lsqAlgorithm = OptimizationAlgorithm().Build(lsqProblem);
else
lsqAlgorithm = algorithm_;
lsqAlgorithm.setStartingPoint(startingPoint);
lsqAlgorithm.setMaximumCallsNumber(maximumCallsNumber);
lsqAlgorithm.setMaximumAbsoluteError(absoluteError);
lsqAlgorithm.setMaximumRelativeError(relativeError);
lsqAlgorithm.setMaximumResidualError(residualError);
lsqAlgorithm.run();
callsNumber = lsqAlgorithm.getResult().getCallsNumber();
callsNumber_ = callsNumber;
const Point min_value_obtained = lsqAlgorithm.getResult().getOptimalValue();
if ( residualError < min_value_obtained[0]) throw InternalException(HERE) << "Error: solver did not find a solution that satisfies the threshold, here obtained residual=" << min_value_obtained[0];
const Point result = lsqAlgorithm.getResult().getOptimalPoint();
const UnsignedInteger boundsDimension = bounds.getDimension();
if (boundsDimension == function.getInputDimension())
lsqProblem.setBounds(bounds);
if ((boundsDimension > 0) && (boundsDimension != function.getInputDimension()))
throw InvalidArgumentException(HERE) << "Bounds should be of dimension 0 or dimension = " << function.getInputDimension()
<< ". Here bounds's dimension = " << boundsDimension;
OptimizationAlgorithm solver(solver_);
solver.setStartingPoint(startingPoint);
try
{
solver.setProblem(lsqProblem);
}
catch (const InvalidArgumentException &)
{
LOGWARN("Default optimization algorithm could not solve the least squares problem. Trying to set up a new one...");
solver = OptimizationAlgorithm::Build(lsqProblem);
}
solver.setProblem(lsqProblem);
solver.run();
callsNumber_ = solver.getResult().getCallsNumber();
const Point min_value_obtained = solver.getResult().getOptimalValue();
if ( getResidualError() < min_value_obtained[0])
throw InternalException(HERE) << "Error: solver did not find a solution that satisfies the threshold, here obtained residual=" << min_value_obtained[0];
const Point result = solver.getResult().getOptimalPoint();
return result;
}

/* Method save() stores the object through the StorageManager */
void LeastSquaresEquationsSolver::save(Advocate & adv) const
{
SolverImplementation::save(adv);
adv.saveAttribute( "useDefaultOptimizationAlgorithm_", useDefaultOptimizationAlgorithm_ );
adv.saveAttribute( "algorithm_", algorithm_ );
adv.saveAttribute( "solver_", solver_ );
}

/* Method load() reloads the object from the StorageManager */
void LeastSquaresEquationsSolver::load(Advocate & adv)
{
SolverImplementation::load(adv);
adv.loadAttribute( "useDefaultOptimizationAlgorithm_", useDefaultOptimizationAlgorithm_ );
adv.loadAttribute( "algorithm_", algorithm_ );
adv.loadAttribute( "solver_", solver_ );
}

END_NAMESPACE_OPENTURNS
7 changes: 2 additions & 5 deletions lib/src/Base/Solver/openturns/LeastSquaresEquationsSolver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ public:
void load(Advocate & adv) override;

private:
/** Flag to tell to use the default available optimizer */
Bool useDefaultOptimizationAlgorithm_ = true;

/** optimization algorithm */
OptimizationAlgorithm algorithm_;
/** optimization solver */
OptimizationAlgorithm solver_;

}; /* Class LeastSquaresEquationsSolver */

Expand Down

0 comments on commit 9397132

Please sign in to comment.