Skip to content

Commit

Permalink
In LeastSquaresEquationsSolver, allow for user defined optimization
Browse files Browse the repository at this point in the history
Introduce new {get, set}OptimizationAlgorithm. The purpose is to allow
user sets his own optimization solver. The algorithm might be
incompatible with the problem and then LSES throws an exception when
running solve.
  • Loading branch information
sofianehaddad committed Dec 3, 2024
1 parent e67f3dc commit 9d9847d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
39 changes: 37 additions & 2 deletions lib/src/Base/Solver/LeastSquaresEquationsSolver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ String LeastSquaresEquationsSolver::__repr__() const
return oss;
}

void LeastSquaresEquationsSolver::setOptimizationAlgorithm(const OptimizationAlgorithm & algorithm)
{
useDefaultOptimizationAlgorithm_ = false;
algorithm_ = algorithm;
}

OptimizationAlgorithm LeastSquaresEquationsSolver::getOptimizationAlgorithm() const
{
return algorithm_;
}

/* Solve attempt to find one root to the system of non-linear equations function(x) = 0
given a starting point x with a least square optimization method.
*/
Expand All @@ -76,7 +87,11 @@ Point LeastSquaresEquationsSolver::solve(const Function & function,
const Scalar relativeError = getRelativeError();
const Scalar residualError = getResidualError();
LeastSquaresProblem lsqProblem(function);
OptimizationAlgorithm lsqAlgorithm = OptimizationAlgorithm().Build(lsqProblem);
OptimizationAlgorithm lsqAlgorithm;
if (useDefaultOptimizationAlgorithm_)
lsqAlgorithm = OptimizationAlgorithm().Build(lsqProblem);
else
lsqAlgorithm = algorithm_;
lsqAlgorithm.setStartingPoint(startingPoint);
lsqAlgorithm.setMaximumCallsNumber(maximumCallsNumber);
lsqAlgorithm.setMaximumAbsoluteError(absoluteError);
Expand Down Expand Up @@ -105,7 +120,11 @@ Point LeastSquaresEquationsSolver::solve(const Function & function,
const Scalar residualError = getResidualError();
LeastSquaresProblem lsqProblem(function);
lsqProblem.setBounds(bounds);
OptimizationAlgorithm lsqAlgorithm = OptimizationAlgorithm().Build(lsqProblem);
OptimizationAlgorithm lsqAlgorithm;
if (useDefaultOptimizationAlgorithm_)
lsqAlgorithm = OptimizationAlgorithm().Build(lsqProblem);
else
lsqAlgorithm = algorithm_;
lsqAlgorithm.setStartingPoint(startingPoint);
lsqAlgorithm.setMaximumCallsNumber(maximumCallsNumber);
lsqAlgorithm.setMaximumAbsoluteError(absoluteError);
Expand All @@ -120,4 +139,20 @@ Point LeastSquaresEquationsSolver::solve(const Function & function,
return result;
}

/* Method save() stores the object through the StorageManager */
void LeastSquaresEquationsSolver::save(Advocate & adv) const
{
SolverImplementation::save(adv);
adv.saveAttribute( "useDefaultOptimizationAlgorithm_", useDefaultOptimizationAlgorithm_ );
adv.saveAttribute( "algorithm_", algorithm_ );
}

/* Method load() reloads the object from the StorageManager */
void LeastSquaresEquationsSolver::load(Advocate & adv)
{
SolverImplementation::load(adv);
adv.loadAttribute( "useDefaultOptimizationAlgorithm_", useDefaultOptimizationAlgorithm_ );
adv.loadAttribute( "algorithm_", algorithm_ );
}

END_NAMESPACE_OPENTURNS
13 changes: 13 additions & 0 deletions lib/src/Base/Solver/openturns/LeastSquaresEquationsSolver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "openturns/OTprivate.hxx"
#include "openturns/SolverImplementation.hxx"
#include "openturns/OptimizationAlgorithm.hxx"


BEGIN_NAMESPACE_OPENTURNS
Expand Down Expand Up @@ -59,8 +60,20 @@ public:
const Point & startingPoint,
const Interval & bounds) const override;

/** optimization accessors */
OptimizationAlgorithm getOptimizationAlgorithm() const;
void setOptimizationAlgorithm(const OptimizationAlgorithm & algorithm);

/** save/load */
void save(Advocate & adv) const override;
void load(Advocate & adv) override;

private:
/** Flag to tell to use the default available optimizer */
Bool useDefaultOptimizationAlgorithm_ = true;

/** optimization algorithm */
OptimizationAlgorithm algorithm_;

}; /* Class LeastSquaresEquationsSolver */

Expand Down
22 changes: 22 additions & 0 deletions python/src/LeastSquaresEquationsSolver_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,25 @@ Notes
-----
LeastSquaresEquationsSolver might fail and not obtain a result lower than the
specified threshold, in this case an error is thrown."

// ---------------------------------------------------------------------

%feature("docstring") OT::LeastSquaresEquationsSolver::getOptimizationAlgorithm
"Get the used algorithm for the optimization.

Returns
-------
algorithm : :class:`~openturns.OptimizationAlgorithm`
The used optimization algorithm.
"

// ---------------------------------------------------------------------

%feature("docstring") OT::LeastSquaresEquationsSolver::setOptimizationAlgorithm
"Set the used algorithm for the optimization.

Parameters
----------
algorithm : :class:`~openturns.OptimizationAlgorithm`
The optimization algorithm to be used.
"

0 comments on commit 9d9847d

Please sign in to comment.