Skip to content

Commit

Permalink
Add LeastSquaresEquationsSolver
Browse files Browse the repository at this point in the history
Implementation of the Least Squares Solver for a system of non-linear equations

The work is a continuation of the PR proposed by A.H accounting for some remarks
  • Loading branch information
sofianehaddad authored Dec 3, 2024
1 parent 75ca8cf commit db722e5
Show file tree
Hide file tree
Showing 20 changed files with 533 additions and 12 deletions.
1 change: 1 addition & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
==== Major changes ====

==== New classes ====
* New LeastSquaresEquationsSolver class (openturns.experimental)

==== API changes ====

Expand Down
2 changes: 2 additions & 0 deletions lib/src/Base/Solver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ ot_add_source_file (SolverImplementation.cxx)
ot_add_source_file (Bisection.cxx)
ot_add_source_file (Secant.cxx)
ot_add_source_file (Brent.cxx)
ot_add_source_file (LeastSquaresEquationsSolver.cxx)
ot_add_source_file (ODESolver.cxx)
ot_add_source_file (ODESolverImplementation.cxx)
ot_add_source_file (RungeKutta.cxx)
ot_add_source_file (Fehlberg.cxx)

ot_install_header_file (Brent.hxx)
ot_install_header_file (LeastSquaresEquationsSolver.hxx)
ot_install_header_file (Bisection.hxx)
ot_install_header_file (Secant.hxx)
ot_install_header_file (SolverImplementation.hxx)
Expand Down
142 changes: 142 additions & 0 deletions lib/src/Base/Solver/LeastSquaresEquationsSolver.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// -*- C++ -*-
/**
* @brief Implementation class of an unbounded solver for systems of non-linear equations based on least square optimization
*
* Copyright 2005-2025 Airbus-EDF-IMACS-ONERA-Phimeca
*
* This library is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/

#include "openturns/LeastSquaresEquationsSolver.hxx"
#include "openturns/LeastSquaresProblem.hxx"
#include "openturns/Log.hxx"
#include "openturns/OptimizationAlgorithm.hxx"
#include "openturns/Log.hxx"
#include "openturns/PersistentObjectFactory.hxx"
#include "openturns/SymbolicFunction.hxx"

BEGIN_NAMESPACE_OPENTURNS

/**
* @class LeastSquaresEquationsSolver
*
* This class is an interface for the nonlinear LeastSquaresEquationsSolver
*/

CLASSNAMEINIT(LeastSquaresEquationsSolver)

static const Factory<LeastSquaresEquationsSolver> Factory_LeastSquaresEquationsSolver;

/* Parameter constructor */
LeastSquaresEquationsSolver::LeastSquaresEquationsSolver(const Scalar absoluteError,
const Scalar relativeError,
const Scalar residualError,
const UnsignedInteger maximumCallsNumber)
: SolverImplementation(absoluteError, relativeError, residualError, maximumCallsNumber)
{
// LeastSquares problem
LeastSquaresProblem problem(SymbolicFunction("x", "x"));
solver_ = OptimizationAlgorithm::Build(problem);
solver_.setMaximumCallsNumber(maximumCallsNumber);
solver_.setMaximumAbsoluteError(absoluteError);
solver_.setMaximumRelativeError(relativeError);
solver_.setMaximumResidualError(residualError);
}

/* Virtual constructor */
LeastSquaresEquationsSolver * LeastSquaresEquationsSolver::clone() const
{
return new LeastSquaresEquationsSolver(*this);
}

/* String converter */
String LeastSquaresEquationsSolver::__repr__() const
{
OSS oss;
oss << "class=" << LeastSquaresEquationsSolver::GetClassName()
<< " derived from " << SolverImplementation::__repr__();
return oss;
}

void LeastSquaresEquationsSolver::setOptimizationAlgorithm(const OptimizationAlgorithm & algorithm)
{
solver_ = algorithm;
}

OptimizationAlgorithm LeastSquaresEquationsSolver::getOptimizationAlgorithm() const
{
return solver_;
}

/* Solve attempt to find one root to the system of non-linear equations function(x) = 0
given a starting point x with a least square optimization method.
*/
Point LeastSquaresEquationsSolver::solve(const Function & function,
const Point & startingPoint) const
{
const Interval bounds;
return solve(function, startingPoint, bounds);
}

/* Solve attempt to find one root to the system of non-linear equations function(x) = 0
given a starting point x with a least square optimization method.
*/
Point LeastSquaresEquationsSolver::solve(const Function & function,
const Point & startingPoint,
const Interval & bounds) const
{
LeastSquaresProblem lsqProblem(function);
const UnsignedInteger boundsDimension = bounds.getDimension();
if (boundsDimension == function.getInputDimension())
lsqProblem.setBounds(bounds);
if ((boundsDimension > 0) && (boundsDimension != function.getInputDimension()))
throw InvalidArgumentException(HERE) << "Bounds should be of dimension 0 or dimension = " << function.getInputDimension()
<< ". Here bounds's dimension = " << boundsDimension;
OptimizationAlgorithm solver(solver_);
solver.setStartingPoint(startingPoint);
try
{
solver.setProblem(lsqProblem);
}
catch (const InvalidArgumentException &)
{
LOGWARN("Default optimization algorithm could not solve the least squares problem. Trying to set up a new one...");
solver = OptimizationAlgorithm::Build(lsqProblem);
}
solver.setProblem(lsqProblem);
solver.run();
callsNumber_ = solver.getResult().getCallsNumber();
const Point min_value_obtained = solver.getResult().getOptimalValue();
if ( getResidualError() < min_value_obtained[0])
throw InternalException(HERE) << "Error: solver did not find a solution that satisfies the threshold, here obtained residual=" << min_value_obtained[0];
const Point result = solver.getResult().getOptimalPoint();
return result;
}

/* Method save() stores the object through the StorageManager */
void LeastSquaresEquationsSolver::save(Advocate & adv) const
{
SolverImplementation::save(adv);
adv.saveAttribute( "solver_", solver_ );
}

/* Method load() reloads the object from the StorageManager */
void LeastSquaresEquationsSolver::load(Advocate & adv)
{
SolverImplementation::load(adv);
adv.loadAttribute( "solver_", solver_ );
}

END_NAMESPACE_OPENTURNS
15 changes: 15 additions & 0 deletions lib/src/Base/Solver/Solver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ Scalar Solver::solve(const Function & function,
return getImplementation()->solve(function, value, infPoint, supPoint, infValue, supValue);
}

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 */
Point Solver::solve(const Function & function,
const Point & startingPoint) const
{
return getImplementation()->solve(function, startingPoint);
}

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 and bounds */
Point Solver::solve(const Function & function,
const Point & startingPoint,
const Interval & bounds) const
{
return getImplementation()->solve(function, startingPoint, bounds);
}

/* Absolute error accessor */
void Solver::setAbsoluteError(const Scalar absoluteError)
{
Expand Down
14 changes: 14 additions & 0 deletions lib/src/Base/Solver/SolverImplementation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,18 @@ Scalar SolverImplementation::solve(const UniVariateFunction &,
throw NotYetImplementedException(HERE) << "In SolverImplementation::solve(const UniVariateFunction & function, const Scalar value, const Scalar infPoint, const Scalar supPoint, const Scalar infValue, const Scalar supValue) const";
}

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 */
Point SolverImplementation::solve(const Function &,
const Point &) const
{
throw NotYetImplementedException(HERE) << "In SolverImplementation::solve(const Function &, const Point &)";
}

Point SolverImplementation::solve(const Function &,
const Point &,
const Interval &) const
{
throw NotYetImplementedException(HERE) << "In SolverImplementation::solve(const Function &, const Point &, const Interval&)";
}

END_NAMESPACE_OPENTURNS
79 changes: 79 additions & 0 deletions lib/src/Base/Solver/openturns/LeastSquaresEquationsSolver.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// -*- C++ -*-
/**
* @brief Implementation class of an unbounded solver for systems of non-linear equations based on least square optimization
*
* Copyright 2005-2025 Airbus-EDF-IMACS-ONERA-Phimeca
*
* This library is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/
#ifndef OPENTURNS_LEASTSQUARESEQUATIONSSOLVER_HXX
#define OPENTURNS_LEASTSQUARESEQUATIONSSOLVER_HXX

#include "openturns/OTprivate.hxx"
#include "openturns/SolverImplementation.hxx"
#include "openturns/OptimizationAlgorithm.hxx"


BEGIN_NAMESPACE_OPENTURNS

/**
* @class LeastSquaresEquationsSolver
*
* This class is an interface for the 1D nonlinear solverImplementations
*/
class OT_API LeastSquaresEquationsSolver :
public SolverImplementation
{
CLASSNAME
public:

/** Parameter constructor */
explicit LeastSquaresEquationsSolver(const Scalar absoluteError = ResourceMap::GetAsScalar("Solver-DefaultAbsoluteError"),
const Scalar relativeError = ResourceMap::GetAsScalar("Solver-DefaultRelativeError"),
const Scalar residualError = ResourceMap::GetAsScalar("Solver-DefaultResidualError"),
const UnsignedInteger maximumFunctionEvaluation = ResourceMap::GetAsUnsignedInteger("Solver-DefaultMaximumFunctionEvaluation"));


/** Virtual constructor */
LeastSquaresEquationsSolver * clone() const override;

/** String converter */
String __repr__() const override;

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 */
using SolverImplementation::solve;
Point solve(const Function & function,
const Point & startingPoint) const override;
Point solve(const Function & function,
const Point & startingPoint,
const Interval & bounds) const override;

/** optimization accessors */
OptimizationAlgorithm getOptimizationAlgorithm() const;
void setOptimizationAlgorithm(const OptimizationAlgorithm & algorithm);

/** save/load */
void save(Advocate & adv) const override;
void load(Advocate & adv) override;

private:
/** optimization solver */
OptimizationAlgorithm solver_;

}; /* Class LeastSquaresEquationsSolver */

END_NAMESPACE_OPENTURNS

#endif /* OPENTURNS_LsqSolver_HXX */
1 change: 1 addition & 0 deletions lib/src/Base/Solver/openturns/OTSolver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
#include "openturns/ODESolverImplementation.hxx"
#include "openturns/RungeKutta.hxx"
#include "openturns/Fehlberg.hxx"
#include "openturns/LeastSquaresEquationsSolver.hxx"

#endif /* OPENTURNS_OTSOLVER_HXX */
9 changes: 9 additions & 0 deletions lib/src/Base/Solver/openturns/Solver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ public:
const Scalar infValue,
const Scalar supValue) const;

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 */
virtual Point solve(const Function & function,
const Point & startingPoint) const;

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 and bounds */
virtual Point solve(const Function & function,
const Point & startingPoint,
const Interval & bounds) const;

/** Absolute error accessor */
void setAbsoluteError(const Scalar absoluteError);
Scalar getAbsoluteError() const;
Expand Down
9 changes: 9 additions & 0 deletions lib/src/Base/Solver/openturns/SolverImplementation.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ public:
const Scalar infValue,
const Scalar supValue) const;

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 */
virtual Point solve(const Function & function,
const Point & startingPoint) const;

/** Solve attempt to find one root to a system of equations function(x) = 0 given a starting point x_0 and bounds */
virtual Point solve(const Function & function,
const Point & startingPoint,
const Interval & bounds) const;

/** Absolute error accessor */
void setAbsoluteError(const Scalar absoluteError);
Scalar getAbsoluteError() const;
Expand Down
1 change: 1 addition & 0 deletions lib/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ ot_check_test (MultiStart_std IGNOREOUT)
# Solver
ot_check_test (Bisection_std)
ot_check_test (Brent_std)
ot_check_test (LeastSquaresEquationsSolver_std IGNOREOUT)
ot_check_test (Secant_std)
ot_check_test (RungeKutta_std)
ot_check_test (Fehlberg_std)
Expand Down
59 changes: 59 additions & 0 deletions lib/test/t_LeastSquaresEquationsSolver_std.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// -*- C++ -*-
/**
* @brief The test file of class LeastSquaresEquationsSolver for standard methods
*
* Copyright 2005-2025 Airbus-EDF-IMACS-ONERA-Phimeca
*
* This library is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/
#include "openturns/OT.hxx"
#include "openturns/OTtestcode.hxx"

using namespace OT;
using namespace OT::Test;

int main(int, char *[])
{
TESTPREAMBLE;
OStream fullprint(std::cout);

try
{
/** Analytical construction */
Description input = {"x","y"};
Description formulas = {"y * x - sin(2 * x)","1 + cos(y) + x"};
SymbolicFunction analytical(input, formulas);

LeastSquaresEquationsSolver algo;
algo.setResidualError(1e-5);
algo.setMaximumCallsNumber(1000),
fullprint << "algo=" << algo << std::endl;
Point startingPoint = {2.0, 1.0};
const Point optimalValue(2);
Point solution(algo.solve(analytical, startingPoint));
fullprint << "Solve " << formulas << "= [0,0] for " << input <<std::endl;
fullprint << "[x,y] = " << solution << std::endl;
fullprint << "algo=" << algo << std::endl;
assert_almost_equal(analytical(solution), optimalValue, 1e-5, 1e-5);
}
catch (TestFailed & ex)
{
std::cerr << ex << std::endl;
return ExitCode::Error;
}


return ExitCode::Success;
}
Loading

0 comments on commit db722e5

Please sign in to comment.