From 0c57e2ba9ebdc89101b2f5bd1a914349bb060c22 Mon Sep 17 00:00:00 2001 From: Bru Date: Fri, 17 Apr 2026 12:37:31 +0200 Subject: [PATCH 01/11] ENH: expose correction and weights parameters in cov MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves #688. Adds `axis`, `correction`, `frequency_weights`, and `weights` to `cov`, giving users control over the degrees-of-freedom correction and the observation-axis / weighted variants that `numpy.cov` and `torch.cov` already support. Naming follows array-api conventions (`axis`, `correction`) rather than numpy's (`rowvar`, `bias`, `ddof`); the docstring includes a one-to-one mapping. The delegation moves observations to the last axis via `xp.moveaxis`, collapsing `rowvar` out of the backend dispatch — only `ddof` vs `correction` differs between branches. Dask's native `cov` forces `.compute()` on a lazy scalar when any weights are given, so weighted dask inputs fall through to the generic implementation, which is fully lazy. --- src/array_api_extra/_delegation.py | 120 +++++++++++++++++++++++++---- src/array_api_extra/_lib/_funcs.py | 66 ++++++++++++---- tests/test_funcs.py | 91 ++++++++++++++++++++++ 3 files changed, 246 insertions(+), 31 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 46639559..31ace576 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -81,7 +81,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array return _funcs.atleast_nd(x, ndim=ndim, xp=xp) -def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: +def cov( + m: Array, + /, + *, + axis: int = -1, + correction: int | float = 1, + frequency_weights: Array | None = None, + weights: Array | None = None, + xp: ModuleType | None = None, +) -> Array: """ Estimate a covariance matrix (or a stack of covariance matrices). @@ -92,16 +101,37 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance of :math:`x_i`. - With the exception of supporting batch input, this provides a subset of - the functionality of ``numpy.cov``. + Extends ``numpy.cov`` with support for batch input and array-api + backends. Naming follows the array-api conventions used elsewhere in + this library (``axis``, ``correction``) rather than the numpy spellings + (``rowvar``, ``bias``, ``ddof``); see Notes for the mapping. Parameters ---------- m : array An array of shape ``(..., N, M)`` whose innermost two dimensions - contain *M* observations of *N* variables. That is, - each row of `m` represents a variable, and each column a single - observation of all those variables. + contain *M* observations of *N* variables by default. The axis of + observations is controlled by `axis`. + axis : int, optional + Axis of `m` containing the observations. Default: ``-1`` (the last + axis), matching the array-api convention. Use ``axis=-2`` (or ``0`` + for 2-D input) to treat each column as a variable, which + corresponds to ``rowvar=False`` in ``numpy.cov``. + correction : int or float, optional + Degrees of freedom correction: normalization divides by + ``N - correction`` (for unweighted input). Default: ``1``, which + gives the unbiased estimate (matches ``numpy.cov`` default of + ``bias=False``). Set to ``0`` for the biased estimate (``N`` + normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to + ``correction`` in ``numpy.var``/``std`` and ``torch.cov``. + frequency_weights : array, optional + 1-D array of integer frequency weights: the number of times each + observation is repeated. Corresponds to ``fweights`` in + ``numpy.cov``/``torch.cov``. + weights : array, optional + 1-D array of observation-vector weights (analytic weights). Larger + values mark more important observations. Corresponds to + ``aweights`` in ``numpy.cov``/``torch.cov``. xp : array_namespace, optional The standard-compatible namespace for `m`. Default: infer. @@ -111,6 +141,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: An array having shape (..., N, N) whose innermost two dimensions represent the covariance matrix of the variables. + Notes + ----- + Mapping from ``numpy.cov`` to this function:: + + numpy.cov(m, rowvar=True) -> cov(m, axis=-1) # default + numpy.cov(m, rowvar=False) -> cov(m, axis=-2) + numpy.cov(m, bias=True) -> cov(m, correction=0) + numpy.cov(m, ddof=k) -> cov(m, correction=k) + numpy.cov(m, fweights=f) -> cov(m, frequency_weights=f) + numpy.cov(m, aweights=a) -> cov(m, weights=a) + + Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective + degrees of freedom is only emitted on the unweighted path. The + weighted path omits the check so that lazy backends (e.g. Dask) can + stay lazy end-to-end; choose ``correction`` and weights such that the + effective normalizer is positive. + Examples -------- >>> import array_api_strict as xp @@ -164,16 +211,57 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: if xp is None: xp = array_namespace(m) - if ( - is_numpy_namespace(xp) - or is_cupy_namespace(xp) - or is_torch_namespace(xp) - or is_dask_namespace(xp) - or is_jax_namespace(xp) - ) and m.ndim <= 2: - return xp.cov(m) - - return _funcs.cov(m, xp=xp) + # Validate axis against m.ndim. + ndim = max(m.ndim, 1) + if not -ndim <= axis < ndim: + msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}" + raise IndexError(msg) + + # Normalize: observations on the last axis. After this, every backend + # sees the same convention and we never need to deal with `rowvar`. + if m.ndim >= 2 and axis not in (-1, m.ndim - 1): + m = xp.moveaxis(m, axis, -1) + + # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` + # requires integer `correction`. For non-integer-valued `correction`, + # fall through to the generic implementation. + integer_correction = isinstance(correction, int) or correction.is_integer() + has_weights = frequency_weights is not None or weights is not None + + if m.ndim <= 2 and integer_correction: + if is_torch_namespace(xp): + device = get_device(m) + fw = ( + None + if frequency_weights is None + else xp.asarray(frequency_weights, device=device) + ) + aw = None if weights is None else xp.asarray(weights, device=device) + return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw) + # `dask.array.cov` forces `.compute()` whenever weights are given: + # its internal `if fact <= 0` check on a lazy 0-D scalar triggers + # materialization. Route to the generic impl, which is fully lazy + # because it only does sum/matmul and skips that scalar check. + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or (is_dask_namespace(xp) and not has_weights) + ): + return xp.cov( + m, + ddof=int(correction), + fweights=frequency_weights, + aweights=weights, + ) + + return _funcs.cov( + m, + correction=correction, + frequency_weights=frequency_weights, + weights=weights, + xp=xp, + ) def create_diagonal( diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..4f9309ec 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -281,9 +281,17 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... return tuple(out) -def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 +def cov( + m: Array, + /, + *, + correction: int | float = 1, + frequency_weights: Array | None = None, + weights: Array | None = None, + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" - m = xp.asarray(m, copy=True) + m = xp.asarray(m) dtype = ( xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64) ) @@ -291,21 +299,49 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = xp.mean(m, axis=-1, keepdims=True) + device = _compat.device(m) + fw = ( + None + if frequency_weights is None + else xp.astype(xp.asarray(frequency_weights, device=device), dtype) + ) + aw = ( + None + if weights is None + else xp.astype(xp.asarray(weights, device=device), dtype) + ) + if fw is None and aw is None: + w = None + elif fw is None: + w = aw + elif aw is None: + w = fw + else: + w = fw * aw m_shape = eager_shape(m) - fact = m_shape[-1] - 1 - - if fact <= 0: - warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) - fact = 0 - - m -= avg - m_transpose = xp.matrix_transpose(m) - if xp.isdtype(m_transpose.dtype, "complex floating"): - m_transpose = xp.conj(m_transpose) - c = xp.matmul(m, m_transpose) - c /= fact + if w is None: + avg = xp.mean(m, axis=-1, keepdims=True) + fact = m_shape[-1] - correction + if fact <= 0: + warnings.warn( + "Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2 + ) + fact = 0 + else: + v1 = xp.sum(w, axis=-1) + avg = xp.sum(m * w, axis=-1, keepdims=True) / v1 + if aw is None: + fact = v1 - correction + else: + fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1 + + m_c = m - avg + m_w = m_c if w is None else m_c * w + m_cT = xp.matrix_transpose(m_c) + if xp.isdtype(m_cT.dtype, "complex floating"): + m_cT = xp.conj(m_cT) + c = xp.matmul(m_w, m_cT) / fact axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) return xp.squeeze(c, axis=axes) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..a3cf59a2 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -608,6 +608,97 @@ def test_batch(self, xp: ModuleType): ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) xp_assert_close(res, xp.asarray(ref)) + def test_correction(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 20)) + for correction in (0, 1, 2): + ref = np.cov(m, ddof=correction) + res = cov(xp.asarray(m), correction=correction) + xp_assert_close(res, xp.asarray(ref)) + + def test_correction_float(self, xp: ModuleType): + # Float correction: reference computed by hand (numpy.cov rejects + # non-integer ddof; our generic path supports it). + rng = np.random.default_rng(20260417) + m = rng.random((3, 20)) + n = m.shape[-1] + centered = m - m.mean(axis=-1, keepdims=True) + ref = centered @ centered.T / (n - 1.5) + res = cov(xp.asarray(m), correction=1.5) + xp_assert_close(res, xp.asarray(ref)) + + def test_axis(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((20, 3)) # observations on axis 0 + ref = np.cov(m, rowvar=False) + res = cov(xp.asarray(m), axis=0) + xp_assert_close(res, xp.asarray(ref)) + res_neg = cov(xp.asarray(m), axis=-2) + xp_assert_close(res_neg, xp.asarray(ref)) + + def test_frequency_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 10)) + fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64) + ref = np.cov(m, fweights=fw) + res = cov(xp.asarray(m), frequency_weights=xp.asarray(fw)) + xp_assert_close(res, xp.asarray(ref)) + + def test_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 10)) + aw = rng.random(10) + ref = np.cov(m, aweights=aw) + res = cov(xp.asarray(m), weights=xp.asarray(aw)) + xp_assert_close(res, xp.asarray(ref)) + + def test_both_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 10)) + fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64) + aw = rng.random(10) + for correction in (0, 1, 2): + ref = np.cov(m, ddof=correction, fweights=fw, aweights=aw) + res = cov( + xp.asarray(m), + correction=correction, + frequency_weights=xp.asarray(fw), + weights=xp.asarray(aw), + ) + xp_assert_close(res, xp.asarray(ref)) + + def test_batch_with_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + batch_shape = (2, 3) + n_var, n_obs = 3, 15 + m = rng.random((*batch_shape, n_var, n_obs)) + aw = rng.random(n_obs) + res = cov(xp.asarray(m), weights=xp.asarray(aw)) + ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))] + ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) + xp_assert_close(res, xp.asarray(ref)) + + def test_axis_with_weights(self, xp: ModuleType): + # axis=-2 (observations on first of 2D) combined with weights: + # verifies that moveaxis and weight alignment cooperate. + rng = np.random.default_rng(20260417) + m = rng.random((15, 3)) # observations on axis 0 + aw = rng.random(15) + fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1], dtype=np.int64) + ref = np.cov(m, rowvar=False, fweights=fw, aweights=aw) + res = cov( + xp.asarray(m), + axis=-2, + frequency_weights=xp.asarray(fw), + weights=xp.asarray(aw), + ) + xp_assert_close(res, xp.asarray(ref)) + + def test_axis_out_of_bounds(self, xp: ModuleType): + m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + with pytest.raises(IndexError): + _ = cov(m, axis=5) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: From d9701e06bc9c6b2180123a4e6677786b623e3543 Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 12:06:47 +0200 Subject: [PATCH 02/11] MNT: drop device= in cov weights --- src/array_api_extra/_delegation.py | 7 ++----- src/array_api_extra/_lib/_funcs.py | 5 ++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 31ace576..eebba14d 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -230,13 +230,10 @@ def cov( if m.ndim <= 2 and integer_correction: if is_torch_namespace(xp): - device = get_device(m) fw = ( - None - if frequency_weights is None - else xp.asarray(frequency_weights, device=device) + None if frequency_weights is None else xp.asarray(frequency_weights) ) - aw = None if weights is None else xp.asarray(weights, device=device) + aw = None if weights is None else xp.asarray(weights) return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw) # `dask.array.cov` forces `.compute()` whenever weights are given: # its internal `if fact <= 0` check on a lazy 0-D scalar triggers diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 4f9309ec..23f9a1e9 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -299,16 +299,15 @@ def cov( m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - device = _compat.device(m) fw = ( None if frequency_weights is None - else xp.astype(xp.asarray(frequency_weights, device=device), dtype) + else xp.astype(xp.asarray(frequency_weights), dtype) ) aw = ( None if weights is None - else xp.astype(xp.asarray(weights, device=device), dtype) + else xp.astype(xp.asarray(weights), dtype) ) if fw is None and aw is None: w = None From 72e2d6135a8d26b7e80beab44acd5d08ce690e55 Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 12:11:49 +0200 Subject: [PATCH 03/11] STY: formatter --- src/array_api_extra/_delegation.py | 4 +--- src/array_api_extra/_lib/_funcs.py | 6 +----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index eebba14d..52424067 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -230,9 +230,7 @@ def cov( if m.ndim <= 2 and integer_correction: if is_torch_namespace(xp): - fw = ( - None if frequency_weights is None else xp.asarray(frequency_weights) - ) + fw = None if frequency_weights is None else xp.asarray(frequency_weights) aw = None if weights is None else xp.asarray(weights) return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw) # `dask.array.cov` forces `.compute()` whenever weights are given: diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 23f9a1e9..6055a7ee 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -304,11 +304,7 @@ def cov( if frequency_weights is None else xp.astype(xp.asarray(frequency_weights), dtype) ) - aw = ( - None - if weights is None - else xp.astype(xp.asarray(weights), dtype) - ) + aw = None if weights is None else xp.astype(xp.asarray(weights), dtype) if fw is None and aw is None: w = None elif fw is None: From c0a20b08a6ad9f89daf91cada37292cf90011efb Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 12:20:17 +0200 Subject: [PATCH 04/11] TST: add bias tests from #691 --- tests/test_funcs.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index a3cf59a2..22e15b9b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -608,6 +608,30 @@ def test_batch(self, xp: ModuleType): ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) xp_assert_close(res, xp.asarray(ref)) + @pytest.mark.parametrize("bias", [True, False, 0, 1]) + def test_bias(self, xp: ModuleType, bias: bool): + # `bias` maps to `correction`: bias=True -> correction=0, bias=False -> 1. + x = np.array([-2.1, -1, 4.3]) + y = np.array([3, 1.1, 0.12]) + X = np.stack((x, y), axis=0) + ref = np.cov(X, bias=bias) + xp_assert_close( + cov(xp.asarray(X, dtype=xp.float64), correction=0 if bias else 1), + xp.asarray(ref, dtype=xp.float64), + rtol=1e-6, + ) + + @pytest.mark.parametrize("bias", [True, False, 0, 1]) + def test_bias_batch(self, xp: ModuleType, bias: bool): + rng = np.random.default_rng(8847643423) + batch_shape = (3, 4) + n_var, n_obs = 3, 20 + m = rng.random((*batch_shape, n_var, n_obs)) + res = cov(xp.asarray(m), correction=0 if bias else 1) + ref_list = [np.cov(m_, bias=bias) for m_ in np.reshape(m, (-1, n_var, n_obs))] + ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) + xp_assert_close(res, xp.asarray(ref)) + def test_correction(self, xp: ModuleType): rng = np.random.default_rng(20260417) m = rng.random((3, 20)) From c621a74bb16aca12ab5190a464693e3bb59e04f6 Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 13:26:20 +0200 Subject: [PATCH 05/11] Update _funcs.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Barthélemy --- src/array_api_extra/_lib/_funcs.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 6055a7ee..744a0dbb 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -299,11 +299,9 @@ def cov( m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - fw = ( - None - if frequency_weights is None - else xp.astype(xp.asarray(frequency_weights), dtype) - ) + fw = None + if frequency_weights is not None: + fw = xp.astype(xp.asarray(frequency_weights), dtype) aw = None if weights is None else xp.astype(xp.asarray(weights), dtype) if fw is None and aw is None: w = None From 98f216a5a4698e5d3c2b560f4a9de7a221807cc9 Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 13:39:04 +0200 Subject: [PATCH 06/11] Update _delegation.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Barthélemy --- src/array_api_extra/_delegation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 52424067..b8765c1b 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -124,11 +124,11 @@ def cov( ``bias=False``). Set to ``0`` for the biased estimate (``N`` normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to ``correction`` in ``numpy.var``/``std`` and ``torch.cov``. - frequency_weights : array, optional + fweights : array, optional 1-D array of integer frequency weights: the number of times each observation is repeated. Corresponds to ``fweights`` in ``numpy.cov``/``torch.cov``. - weights : array, optional + aweights : array, optional 1-D array of observation-vector weights (analytic weights). Larger values mark more important observations. Corresponds to ``aweights`` in ``numpy.cov``/``torch.cov``. From 06b4007435394974bed6bac57c6de474f0a5debd Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 13:57:08 +0200 Subject: [PATCH 07/11] MNT: rename weights params to fweights/aweights --- src/array_api_extra/_delegation.py | 28 ++++++++++++++-------------- src/array_api_extra/_lib/_funcs.py | 12 +++++++----- tests/test_funcs.py | 14 +++++++------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b8765c1b..337f975a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -87,8 +87,8 @@ def cov( *, axis: int = -1, correction: int | float = 1, - frequency_weights: Array | None = None, - weights: Array | None = None, + fweights: Array | None = None, + aweights: Array | None = None, xp: ModuleType | None = None, ) -> Array: """ @@ -126,12 +126,12 @@ def cov( ``correction`` in ``numpy.var``/``std`` and ``torch.cov``. fweights : array, optional 1-D array of integer frequency weights: the number of times each - observation is repeated. Corresponds to ``fweights`` in + observation is repeated. Same as ``fweights`` in ``numpy.cov``/``torch.cov``. aweights : array, optional 1-D array of observation-vector weights (analytic weights). Larger - values mark more important observations. Corresponds to - ``aweights`` in ``numpy.cov``/``torch.cov``. + values mark more important observations. Same as ``aweights`` in + ``numpy.cov``/``torch.cov``. xp : array_namespace, optional The standard-compatible namespace for `m`. Default: infer. @@ -149,8 +149,8 @@ def cov( numpy.cov(m, rowvar=False) -> cov(m, axis=-2) numpy.cov(m, bias=True) -> cov(m, correction=0) numpy.cov(m, ddof=k) -> cov(m, correction=k) - numpy.cov(m, fweights=f) -> cov(m, frequency_weights=f) - numpy.cov(m, aweights=a) -> cov(m, weights=a) + numpy.cov(m, fweights=f) -> cov(m, fweights=f) + numpy.cov(m, aweights=a) -> cov(m, aweights=a) Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective degrees of freedom is only emitted on the unweighted path. The @@ -226,12 +226,12 @@ def cov( # requires integer `correction`. For non-integer-valued `correction`, # fall through to the generic implementation. integer_correction = isinstance(correction, int) or correction.is_integer() - has_weights = frequency_weights is not None or weights is not None + has_weights = fweights is not None or aweights is not None if m.ndim <= 2 and integer_correction: if is_torch_namespace(xp): - fw = None if frequency_weights is None else xp.asarray(frequency_weights) - aw = None if weights is None else xp.asarray(weights) + fw = None if fweights is None else xp.asarray(fweights) + aw = None if aweights is None else xp.asarray(aweights) return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw) # `dask.array.cov` forces `.compute()` whenever weights are given: # its internal `if fact <= 0` check on a lazy 0-D scalar triggers @@ -246,15 +246,15 @@ def cov( return xp.cov( m, ddof=int(correction), - fweights=frequency_weights, - aweights=weights, + fweights=fweights, + aweights=aweights, ) return _funcs.cov( m, correction=correction, - frequency_weights=frequency_weights, - weights=weights, + fweights=fweights, + aweights=aweights, xp=xp, ) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 744a0dbb..4970b5fa 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -286,8 +286,8 @@ def cov( /, *, correction: int | float = 1, - frequency_weights: Array | None = None, - weights: Array | None = None, + fweights: Array | None = None, + aweights: Array | None = None, xp: ModuleType, ) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" @@ -300,9 +300,11 @@ def cov( m = xp.astype(m, dtype) fw = None - if frequency_weights is not None: - fw = xp.astype(xp.asarray(frequency_weights), dtype) - aw = None if weights is None else xp.astype(xp.asarray(weights), dtype) + if fweights is not None: + fw = xp.astype(xp.asarray(fweights), dtype) + aw = None + if aweights is not None: + aw = xp.astype(xp.asarray(aweights), dtype) if fw is None and aw is None: w = None elif fw is None: diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 22e15b9b..8ba9b63b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -665,7 +665,7 @@ def test_frequency_weights(self, xp: ModuleType): m = rng.random((3, 10)) fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64) ref = np.cov(m, fweights=fw) - res = cov(xp.asarray(m), frequency_weights=xp.asarray(fw)) + res = cov(xp.asarray(m), fweights=xp.asarray(fw)) xp_assert_close(res, xp.asarray(ref)) def test_weights(self, xp: ModuleType): @@ -673,7 +673,7 @@ def test_weights(self, xp: ModuleType): m = rng.random((3, 10)) aw = rng.random(10) ref = np.cov(m, aweights=aw) - res = cov(xp.asarray(m), weights=xp.asarray(aw)) + res = cov(xp.asarray(m), aweights=xp.asarray(aw)) xp_assert_close(res, xp.asarray(ref)) def test_both_weights(self, xp: ModuleType): @@ -686,8 +686,8 @@ def test_both_weights(self, xp: ModuleType): res = cov( xp.asarray(m), correction=correction, - frequency_weights=xp.asarray(fw), - weights=xp.asarray(aw), + fweights=xp.asarray(fw), + aweights=xp.asarray(aw), ) xp_assert_close(res, xp.asarray(ref)) @@ -697,7 +697,7 @@ def test_batch_with_weights(self, xp: ModuleType): n_var, n_obs = 3, 15 m = rng.random((*batch_shape, n_var, n_obs)) aw = rng.random(n_obs) - res = cov(xp.asarray(m), weights=xp.asarray(aw)) + res = cov(xp.asarray(m), aweights=xp.asarray(aw)) ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))] ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) xp_assert_close(res, xp.asarray(ref)) @@ -713,8 +713,8 @@ def test_axis_with_weights(self, xp: ModuleType): res = cov( xp.asarray(m), axis=-2, - frequency_weights=xp.asarray(fw), - weights=xp.asarray(aw), + fweights=xp.asarray(fw), + aweights=xp.asarray(aw), ) xp_assert_close(res, xp.asarray(ref)) From 8b5c47167a621b427538b03f3e490408d3013da8 Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 14:04:35 +0200 Subject: [PATCH 08/11] ENH: validate weights shape in cov --- src/array_api_extra/_delegation.py | 16 ++++++++++++++++ tests/test_funcs.py | 13 +++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 337f975a..e3daaa35 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -222,6 +222,22 @@ def cov( if m.ndim >= 2 and axis not in (-1, m.ndim - 1): m = xp.moveaxis(m, axis, -1) + # Validate weight shapes (eager metadata, lazy-safe). Value-based + # checks (non-negative, integer dtype) are intentionally skipped so + # lazy backends don't trigger compute -- same tradeoff as dask.cov. + n_obs = m.shape[-1] + for name, w in (("fweights", fweights), ("aweights", aweights)): + if w is None: + continue + if w.ndim != 1: + msg = f"`{name}` must be 1-D, got ndim={w.ndim}" + raise ValueError(msg) + if w.shape[0] != n_obs: + msg = ( + f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations" + ) + raise ValueError(msg) + # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` # requires integer `correction`. For non-integer-valued `correction`, # fall through to the generic implementation. diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 8ba9b63b..26f7a031 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -723,6 +723,19 @@ def test_axis_out_of_bounds(self, xp: ModuleType): with pytest.raises(IndexError): _ = cov(m, axis=5) + def test_weights_shape_validation(self, xp: ModuleType): + m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + # Wrong length. + with pytest.raises(ValueError, match="`fweights` has length"): + _ = cov(m, fweights=xp.asarray([1, 2])) + with pytest.raises(ValueError, match="`aweights` has length"): + _ = cov(m, aweights=xp.asarray([0.1, 0.2])) + # Wrong ndim. + with pytest.raises(ValueError, match="`fweights` must be 1-D"): + _ = cov(m, fweights=xp.asarray([[1, 2, 3]])) + with pytest.raises(ValueError, match="`aweights` must be 1-D"): + _ = cov(m, aweights=xp.asarray([[0.1, 0.2, 0.3]])) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: From cb717b01436a042d55889cedb74d441ca810e4ce Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 15:35:17 +0200 Subject: [PATCH 09/11] MNT: address lucascolley review --- src/array_api_extra/_delegation.py | 6 ++---- src/array_api_extra/_lib/_funcs.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e3daaa35..b3ff052e 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -209,7 +209,7 @@ def cov( """ if xp is None: - xp = array_namespace(m) + xp = array_namespace(m, fweights, aweights) # Validate axis against m.ndim. ndim = max(m.ndim, 1) @@ -233,9 +233,7 @@ def cov( msg = f"`{name}` must be 1-D, got ndim={w.ndim}" raise ValueError(msg) if w.shape[0] != n_obs: - msg = ( - f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations" - ) + msg = f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations" raise ValueError(msg) # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 4970b5fa..3f061cec 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -336,7 +336,7 @@ def cov( m_cT = xp.matrix_transpose(m_c) if xp.isdtype(m_cT.dtype, "complex floating"): m_cT = xp.conj(m_cT) - c = xp.matmul(m_w, m_cT) / fact + c = (m_w @ m_cT) / fact axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) return xp.squeeze(c, axis=axes) From e34d415ec22eaa3f069fa0d3a35b62ed1d4e181d Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 16:01:19 +0200 Subject: [PATCH 10/11] MNT: move weights validation to generic cov --- src/array_api_extra/_delegation.py | 14 -------------- src/array_api_extra/_lib/_funcs.py | 17 +++++++++++++++++ tests/test_funcs.py | 12 ------------ 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b3ff052e..f1d9ca79 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -222,20 +222,6 @@ def cov( if m.ndim >= 2 and axis not in (-1, m.ndim - 1): m = xp.moveaxis(m, axis, -1) - # Validate weight shapes (eager metadata, lazy-safe). Value-based - # checks (non-negative, integer dtype) are intentionally skipped so - # lazy backends don't trigger compute -- same tradeoff as dask.cov. - n_obs = m.shape[-1] - for name, w in (("fweights", fweights), ("aweights", aweights)): - if w is None: - continue - if w.ndim != 1: - msg = f"`{name}` must be 1-D, got ndim={w.ndim}" - raise ValueError(msg) - if w.shape[0] != n_obs: - msg = f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations" - raise ValueError(msg) - # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` # requires integer `correction`. For non-integer-valued `correction`, # fall through to the generic implementation. diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 3f061cec..d7a753d7 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -299,6 +299,23 @@ def cov( m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) + # Validate weight shapes (eager metadata, lazy-safe). Native backends + # validate themselves; this covers the generic path (array-api-strict, + # sparse, and the dask+weights fallback where the native check is + # bypassed to preserve laziness). + n_obs = m.shape[-1] + for name, w_in in (("fweights", fweights), ("aweights", aweights)): + if w_in is None: + continue + if w_in.ndim != 1: + msg = f"`{name}` must be 1-D, got ndim={w_in.ndim}" + raise ValueError(msg) + if w_in.shape[0] != n_obs: + msg = ( + f"`{name}` has length {w_in.shape[0]} but `m` has {n_obs} observations" + ) + raise ValueError(msg) + fw = None if fweights is not None: fw = xp.astype(xp.asarray(fweights), dtype) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 26f7a031..6620fad1 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -723,18 +723,6 @@ def test_axis_out_of_bounds(self, xp: ModuleType): with pytest.raises(IndexError): _ = cov(m, axis=5) - def test_weights_shape_validation(self, xp: ModuleType): - m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - # Wrong length. - with pytest.raises(ValueError, match="`fweights` has length"): - _ = cov(m, fweights=xp.asarray([1, 2])) - with pytest.raises(ValueError, match="`aweights` has length"): - _ = cov(m, aweights=xp.asarray([0.1, 0.2])) - # Wrong ndim. - with pytest.raises(ValueError, match="`fweights` must be 1-D"): - _ = cov(m, fweights=xp.asarray([[1, 2, 3]])) - with pytest.raises(ValueError, match="`aweights` must be 1-D"): - _ = cov(m, aweights=xp.asarray([[0.1, 0.2, 0.3]])) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) From f56037212e7bae41ef80375e2599ea88cd57ea35 Mon Sep 17 00:00:00 2001 From: Bru Date: Mon, 20 Apr 2026 22:41:40 +0200 Subject: [PATCH 11/11] Update _funcs.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Barthélemy --- src/array_api_extra/_lib/_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index d7a753d7..deb46b15 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -353,7 +353,7 @@ def cov( m_cT = xp.matrix_transpose(m_c) if xp.isdtype(m_cT.dtype, "complex floating"): m_cT = xp.conj(m_cT) - c = (m_w @ m_cT) / fact + c = m_w @ m_cT / fact axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) return xp.squeeze(c, axis=axes)