diff --git a/csaps/_sspndg.py b/csaps/_sspndg.py index e485f12..3789998 100644 --- a/csaps/_sspndg.py +++ b/csaps/_sspndg.py @@ -107,12 +107,19 @@ def __call__(self, interpolation axis in the original array with the shape of x. """ + x = ndgrid_prepare_data_vectors(x, 'x', min_size=1) if len(x) != self.ndim: raise ValueError( f"'x' sequence must have length {self.ndim} according to 'breaks'") + if nu is None: + nu = (0,) * len(x) + + if extrapolate is None: + extrapolate = True + shape = tuple(x.size for x in x) coeffs = ndg_coeffs_to_flatten(self.coeffs) @@ -128,8 +135,9 @@ def __call__(self, coeffs = coeffs.reshape(c_shape) coeffs_cnl = umv_coeffs_to_canonical(coeffs, self.pieces[i]) - coeffs = PPoly.construct_fast(coeffs_cnl, self.breaks[i], - extrapolate=extrapolate, axis=1)(x[i]) + + spline = PPoly.construct_fast(coeffs_cnl, self.breaks[i], axis=1) + coeffs = spline(x[i], nu=nu[i], extrapolate=extrapolate) shape_r = (*coeffs_shape[:ndim_m1], shape[i]) coeffs = coeffs.reshape(shape_r).transpose(permuted_axes) diff --git a/tests/test_ndg.py b/tests/test_ndg.py index f276afc..c1975cf 100644 --- a/tests/test_ndg.py +++ b/tests/test_ndg.py @@ -3,6 +3,7 @@ import pytest import numpy as np +from scipy.interpolate import NdPPoly import csaps @@ -196,3 +197,29 @@ def test_auto_smooth_2d(ndgrid_2d_data): assert s.smooth == pytest.approx(smooth_expected) assert zi == pytest.approx(zi_expected) + + +@pytest.mark.parametrize('nu', [ + None, + (0, 0), + (1, 1), + (2, 2), +]) +@pytest.mark.parametrize('extrapolate', [ + None, + True, + False, +]) +def test_evaluate_nu_extrapolate(nu: tuple, extrapolate: bool): + x = ([1, 2, 3, 4], [1, 2, 3, 4]) + xi = ([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]) + y = np.arange(4 * 4).reshape((4, 4)) + + ss = csaps.NdGridCubicSmoothingSpline(x, y, smooth=1.0) + y_ss = ss(xi, nu=nu, extrapolate=extrapolate) + + pp = NdPPoly(ss.spline.c, x) + xx = tuple(np.meshgrid(*xi, indexing='ij')) + y_pp = pp(xx, nu=nu, extrapolate=extrapolate) + + np.testing.assert_allclose(y_ss, y_pp, rtol=1e-05, atol=1e-08, equal_nan=True) diff --git a/tests/test_umv.py b/tests/test_umv.py index aa6a579..3040a80 100644 --- a/tests/test_umv.py +++ b/tests/test_umv.py @@ -241,3 +241,19 @@ def test_cubic_bc_natural(): assert cs.c == pytest.approx(ss.spline.c) assert y_cs == pytest.approx(y_ss) + + +@pytest.mark.parametrize('nu', [0, 1, 2]) +@pytest.mark.parametrize('extrapolate', [None, True, False, 'periodic']) +def test_evaluate_nu_extrapolate(nu, extrapolate): + x = [1, 2, 3, 4] + xi = [0, 1, 2, 3, 4, 5] + y = [1, 2, 3, 4] + + cs = CubicSpline(x, y) + y_cs = cs(xi, nu=nu, extrapolate=extrapolate) + + ss = csaps.CubicSmoothingSpline(x, y, smooth=1.0) + y_ss = ss(xi, nu=nu, extrapolate=extrapolate) + + np.testing.assert_allclose(y_ss, y_cs, rtol=1e-05, atol=1e-08, equal_nan=True)