diff --git a/xrft/tests/test_xrft.py b/xrft/tests/test_xrft.py index dc661879..7b6a9301 100644 --- a/xrft/tests/test_xrft.py +++ b/xrft/tests/test_xrft.py @@ -12,6 +12,7 @@ import xarray.testing as xrt import xrft +from ..xrft import _apply_window @pytest.fixture() @@ -524,13 +525,30 @@ def test_cross_spectrum(self, dask): cs = xrft.cross_spectrum( da, da2, dim=dim, shift=True, window="hann", detrend="constant" ) - test = (daft * np.conj(daft2)).values / N ** 4 + test = (daft * np.conj(daft2)) / N ** 4 dk = np.diff(np.fft.fftfreq(N, 1.0))[0] test /= dk ** 2 npt.assert_almost_equal(cs.values, test) npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.0) + cs = xrft.cross_spectrum( + da, + da2, + dim=dim, + shift=True, + window="hann", + detrend="constant", + window_correction=True, + ) + test = (daft * np.conj(daft2)) / N ** 4 + window, _ = _apply_window(da, dim, window_type="hann") + dk = np.diff(np.fft.fftfreq(N, 1.0))[0] + test /= dk ** 2 * (window ** 2).mean() + + npt.assert_almost_equal(cs.values, test) + npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.0) + with pytest.raises(ValueError): xrft.cross_spectrum(da, da2, dim=dim, window=None, window_correction=True) diff --git a/xrft/xrft.py b/xrft/xrft.py index fb7bc622..b5abb8f5 100644 --- a/xrft/xrft.py +++ b/xrft/xrft.py @@ -863,7 +863,7 @@ def cross_spectrum( "window_correction can only be applied when windowing is turned on." ) else: - windows, _ = _apply_window(da, dim, window_type=kwargs.get("window")) + windows, _ = _apply_window(da1, dim, window_type=kwargs.get("window")) cs = cs / (windows ** 2).mean() fs = np.prod([float(cs[d].spacing) for d in updated_dims]) cs *= fs @@ -874,7 +874,7 @@ def cross_spectrum( "window_correction can only be applied when windowing is turned on." ) else: - windows, _ = _apply_window(da, dim, window_type=kwargs.get("window")) + windows, _ = _apply_window(da1, dim, window_type=kwargs.get("window")) cs = cs / windows.mean() ** 2 fs = np.prod([float(cs[d].spacing) for d in updated_dims]) cs *= fs ** 2