Skip to content

Commit

Permalink
Handle scipy 1.15 changes related to derivative function used in be…
Browse files Browse the repository at this point in the history
…nchmarking problems (#707)

* Add packaging module dependency to handle version numbers

* Import derivative wrt scipy version

* Fix derivative import

* Fix typos
  • Loading branch information
relf authored Jan 7, 2025
1 parent ce1a818 commit 4eee8f8
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Cython
packaging
numpy
scipy
scikit-learn
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Cython
packaging
numpy
scipy < 1.15
scipy
scikit-learn
pyDOE3
numba # JIT compiler
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
"smt.kernels",
],
install_requires=[
"packaging",
"scikit-learn",
"pyDOE3",
"scipy",
Expand Down
4 changes: 2 additions & 2 deletions smt/problems/torsion_vibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

import numpy as np
from scipy.misc import derivative
from smt.utils.misc import SCIPY_DERIVATIVE

from smt.problems.problem import Problem

Expand Down Expand Up @@ -91,7 +91,7 @@ def wraps(x):
args[var] = x
return func(*args)

return derivative(wraps, point[var], dx=1e-6)
return SCIPY_DERIVATIVE(wraps, point[var], dx=1e-6)

def func(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14):
K1 = np.pi * x2 * x0 / (32 * x1)
Expand Down
4 changes: 2 additions & 2 deletions smt/problems/water_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import numpy as np
from scipy.misc import derivative
from smt.utils.misc import SCIPY_DERIVATIVE

from smt.problems.problem import Problem

Expand Down Expand Up @@ -57,7 +57,7 @@ def wraps(x):
args[var] = x
return func(*args)

return derivative(wraps, point[var], dx=1e-6)
return SCIPY_DERIVATIVE(wraps, point[var], dx=1e-6)

def func(x0, x1, x2, x3, x4, x5, x6, x7):
return (
Expand Down
4 changes: 2 additions & 2 deletions smt/problems/water_flow_lfidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import numpy as np
from scipy.misc import derivative
from smt.utils.misc import SCIPY_DERIVATIVE

from smt.problems.problem import Problem

Expand Down Expand Up @@ -51,7 +51,7 @@ def wraps(x):
args[var] = x
return func(*args)

return derivative(wraps, point[var], dx=1e-6)
return SCIPY_DERIVATIVE(wraps, point[var], dx=1e-6)

def func(x0, x1, x2, x3, x4, x5, x6, x7):
return (
Expand Down
4 changes: 2 additions & 2 deletions smt/problems/welded_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import numpy as np
from scipy.misc import derivative
from smt.utils.misc import SCIPY_DERIVATIVE

from smt.problems.problem import Problem

Expand Down Expand Up @@ -55,7 +55,7 @@ def wraps(x):
args[var] = x
return func(*args)

return derivative(wraps, point[var], dx=1e-6)
return SCIPY_DERIVATIVE(wraps, point[var], dx=1e-6)

def func(x0, x1, x2):
tau1 = 6000 / (np.sqrt(2) * x1 * x2)
Expand Down
4 changes: 2 additions & 2 deletions smt/problems/wing_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import numpy as np
from scipy.misc import derivative
from smt.utils.misc import SCIPY_DERIVATIVE

from smt.problems.problem import Problem

Expand Down Expand Up @@ -61,7 +61,7 @@ def wraps(x):
args[var] = x
return func(*args)

return derivative(wraps, point[var], dx=1e-6)
return SCIPY_DERIVATIVE(wraps, point[var], dx=1e-6)

def func(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9):
return (
Expand Down
11 changes: 11 additions & 0 deletions smt/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@

import numpy as np

import scipy
from packaging.version import Version

# Since scipy 1.15, derivative function has moved from scipy.misc to scipy.differentiate
# As derivative is used by several benchmarking problems, we initialize a constant
# once here as a proxy of the derivative function wrt the installed scipy version
if Version(scipy.__version__) >= Version("1.15"):
SCIPY_DERIVATIVE = scipy.differentiate.derivative
else:
SCIPY_DERIVATIVE = scipy.misc.derivative


def standardization(X, y):
"""
Expand Down

0 comments on commit 4eee8f8

Please sign in to comment.