Skip to content
119 changes: 101 additions & 18 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
def cov(
m: Array,
/,
*,
axis: int = -1,
correction: int | float = 1,
fweights: Array | None = None,
aweights: Array | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Estimate a covariance matrix (or a stack of covariance matrices).

Expand All @@ -92,16 +101,37 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
of :math:`x_i`.

With the exception of supporting batch input, this provides a subset of
the functionality of ``numpy.cov``.
Extends ``numpy.cov`` with support for batch input and array-api
backends. Naming follows the array-api conventions used elsewhere in
this library (``axis``, ``correction``) rather than the numpy spellings
(``rowvar``, ``bias``, ``ddof``); see Notes for the mapping.

Parameters
----------
m : array
An array of shape ``(..., N, M)`` whose innermost two dimensions
contain *M* observations of *N* variables. That is,
each row of `m` represents a variable, and each column a single
observation of all those variables.
contain *M* observations of *N* variables by default. The axis of
observations is controlled by `axis`.
axis : int, optional
Axis of `m` containing the observations. Default: ``-1`` (the last
axis), matching the array-api convention. Use ``axis=-2`` (or ``0``
for 2-D input) to treat each column as a variable, which
corresponds to ``rowvar=False`` in ``numpy.cov``.
correction : int or float, optional
Degrees of freedom correction: normalization divides by
``N - correction`` (for unweighted input). Default: ``1``, which
gives the unbiased estimate (matches ``numpy.cov`` default of
``bias=False``). Set to ``0`` for the biased estimate (``N``
normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to
``correction`` in ``numpy.var``/``std`` and ``torch.cov``.
fweights : array, optional
1-D array of integer frequency weights: the number of times each
observation is repeated. Same as ``fweights`` in
``numpy.cov``/``torch.cov``.
aweights : array, optional
1-D array of observation-vector weights (analytic weights). Larger
values mark more important observations. Same as ``aweights`` in
``numpy.cov``/``torch.cov``.
xp : array_namespace, optional
The standard-compatible namespace for `m`. Default: infer.

Expand All @@ -111,6 +141,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
An array having shape (..., N, N) whose innermost two dimensions represent
the covariance matrix of the variables.

Notes
-----
Mapping from ``numpy.cov`` to this function::

numpy.cov(m, rowvar=True) -> cov(m, axis=-1) # default
numpy.cov(m, rowvar=False) -> cov(m, axis=-2)
numpy.cov(m, bias=True) -> cov(m, correction=0)
numpy.cov(m, ddof=k) -> cov(m, correction=k)
numpy.cov(m, fweights=f) -> cov(m, fweights=f)
numpy.cov(m, aweights=a) -> cov(m, aweights=a)

Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective
degrees of freedom is only emitted on the unweighted path. The
weighted path omits the check so that lazy backends (e.g. Dask) can
stay lazy end-to-end; choose ``correction`` and weights such that the
effective normalizer is positive.

Examples
--------
>>> import array_api_strict as xp
Expand Down Expand Up @@ -162,18 +209,54 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
"""

if xp is None:
xp = array_namespace(m)

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_torch_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
) and m.ndim <= 2:
return xp.cov(m)

return _funcs.cov(m, xp=xp)
xp = array_namespace(m, fweights, aweights)

# Validate axis against m.ndim.
ndim = max(m.ndim, 1)
if not -ndim <= axis < ndim:
msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}"
raise IndexError(msg)
Comment on lines +214 to +218
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a thought, maybe some common logic can be extracted out from this and

if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
err_msg = (
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
)

that is for a tuple of axes though, so maybe not

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of something to optimize this. The only thing that i could think here is the normalize_axis_index from numpy, but then we would need another PR to introduce here in the library ;p

https://numpy.org/doc/2.1/reference/generated/numpy.lib.array_utils.normalize_axis_index.html#numpy.lib.array_utils.normalize_axis_index


# Normalize: observations on the last axis. After this, every backend
# sees the same convention and we never need to deal with `rowvar`.
if m.ndim >= 2 and axis not in (-1, m.ndim - 1):
m = xp.moveaxis(m, axis, -1)

# `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov`
# requires integer `correction`. For non-integer-valued `correction`,
# fall through to the generic implementation.
integer_correction = isinstance(correction, int) or correction.is_integer()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we allow non integer corrections in the first place? Is this to allow people to pass correction=1. instead of raising an error? Or do people really use corrections that are correction=1.234 (I'm not familiar with advanced uses)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I follow the xp.var approach to allow int and float, https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html.

Think about the case use, I don't know either. We can reduce the scope of the var and allow only integers. I think it is a good idea!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @lucascolley has an idea for use-cases?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If observations have weights, the unbiased correction is often not n-1. Instead, it depends on the sum of weights and their dispersion. Another instance where the correction is not an integer is for autocorrelated data.

One of the reasons for not using ddof was to get away from the implicit integer-encoded mental model of correction factors.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many thanks for the input @kgryte 🙏🏽

has_weights = fweights is not None or aweights is not None

if m.ndim <= 2 and integer_correction:
if is_torch_namespace(xp):
fw = None if fweights is None else xp.asarray(fweights)
aw = None if aweights is None else xp.asarray(aweights)
return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw)
# `dask.array.cov` forces `.compute()` whenever weights are given:
# its internal `if fact <= 0` check on a lazy 0-D scalar triggers
# materialization. Route to the generic impl, which is fully lazy
# because it only does sum/matmul and skips that scalar check.
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or (is_dask_namespace(xp) and not has_weights)
):
return xp.cov(
m,
ddof=int(correction),
fweights=fweights,
aweights=aweights,
)

return _funcs.cov(
m,
correction=correction,
fweights=fweights,
aweights=aweights,
xp=xp,
)


def create_diagonal(
Expand Down
78 changes: 63 additions & 15 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,79 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
return tuple(out)


def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
def cov(
m: Array,
/,
*,
correction: int | float = 1,
fweights: Array | None = None,
aweights: Array | None = None,
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
m = xp.asarray(m, copy=True)
m = xp.asarray(m)
Comment on lines -286 to +294
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we drop this?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mostly one small optimization that I made, as the new code for the covariance doesn't mutate in-place anymore.

Like, if I understand correctly, we need to copy because of this line:

 m -= avg 

But now, as we are doing:

m_c = m - avg                                                                                                                               
m_w = m_c if w is None else m_c * w                                                                                                                                                        
m_cT = xp.matrix_transpose(m_c)
c = (m_w @ m_cT) / fact

I noticed this by accident, then I was testing the speed test, and I noticed some small regression. I think it's worth disabling copying.

dtype = (
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
)

m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = xp.mean(m, axis=-1, keepdims=True)
# Validate weight shapes (eager metadata, lazy-safe). Native backends
# validate themselves; this covers the generic path (array-api-strict,
# sparse, and the dask+weights fallback where the native check is
# bypassed to preserve laziness).
n_obs = m.shape[-1]
for name, w_in in (("fweights", fweights), ("aweights", aweights)):
if w_in is None:
continue
if w_in.ndim != 1:
msg = f"`{name}` must be 1-D, got ndim={w_in.ndim}"
raise ValueError(msg)
if w_in.shape[0] != n_obs:
msg = (
f"`{name}` has length {w_in.shape[0]} but `m` has {n_obs} observations"
)
raise ValueError(msg)

fw = None
if fweights is not None:
fw = xp.astype(xp.asarray(fweights), dtype)
aw = None
if aweights is not None:
aw = xp.astype(xp.asarray(aweights), dtype)
if fw is None and aw is None:
w = None
elif fw is None:
w = aw
elif aw is None:
w = fw
else:
w = fw * aw

m_shape = eager_shape(m)
fact = m_shape[-1] - 1

if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
fact = 0

m -= avg
m_transpose = xp.matrix_transpose(m)
if xp.isdtype(m_transpose.dtype, "complex floating"):
m_transpose = xp.conj(m_transpose)
c = xp.matmul(m, m_transpose)
c /= fact
if w is None:
avg = xp.mean(m, axis=-1, keepdims=True)
fact = m_shape[-1] - correction
if fact <= 0:
warnings.warn(
"Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2
)
fact = 0
else:
v1 = xp.sum(w, axis=-1)
avg = xp.sum(m * w, axis=-1, keepdims=True) / v1
if aw is None:
fact = v1 - correction
else:
fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1

m_c = m - avg
m_w = m_c if w is None else m_c * w
m_cT = xp.matrix_transpose(m_c)
Comment thread
lucascolley marked this conversation as resolved.
if xp.isdtype(m_cT.dtype, "complex floating"):
m_cT = xp.conj(m_cT)
c = (m_w @ m_cT) / fact
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
return xp.squeeze(c, axis=axes)

Expand Down
116 changes: 116 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,122 @@ def test_batch(self, xp: ModuleType):
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))

@pytest.mark.parametrize("bias", [True, False, 0, 1])
def test_bias(self, xp: ModuleType, bias: bool):
# `bias` maps to `correction`: bias=True -> correction=0, bias=False -> 1.
x = np.array([-2.1, -1, 4.3])
y = np.array([3, 1.1, 0.12])
X = np.stack((x, y), axis=0)
ref = np.cov(X, bias=bias)
xp_assert_close(
cov(xp.asarray(X, dtype=xp.float64), correction=0 if bias else 1),
xp.asarray(ref, dtype=xp.float64),
rtol=1e-6,
)

@pytest.mark.parametrize("bias", [True, False, 0, 1])
def test_bias_batch(self, xp: ModuleType, bias: bool):
rng = np.random.default_rng(8847643423)
batch_shape = (3, 4)
n_var, n_obs = 3, 20
m = rng.random((*batch_shape, n_var, n_obs))
res = cov(xp.asarray(m), correction=0 if bias else 1)
ref_list = [np.cov(m_, bias=bias) for m_ in np.reshape(m, (-1, n_var, n_obs))]
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))

def test_correction(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 20))
for correction in (0, 1, 2):
ref = np.cov(m, ddof=correction)
res = cov(xp.asarray(m), correction=correction)
xp_assert_close(res, xp.asarray(ref))

def test_correction_float(self, xp: ModuleType):
# Float correction: reference computed by hand (numpy.cov rejects
# non-integer ddof; our generic path supports it).
rng = np.random.default_rng(20260417)
m = rng.random((3, 20))
n = m.shape[-1]
centered = m - m.mean(axis=-1, keepdims=True)
ref = centered @ centered.T / (n - 1.5)
res = cov(xp.asarray(m), correction=1.5)
xp_assert_close(res, xp.asarray(ref))

def test_axis(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((20, 3)) # observations on axis 0
ref = np.cov(m, rowvar=False)
res = cov(xp.asarray(m), axis=0)
xp_assert_close(res, xp.asarray(ref))
res_neg = cov(xp.asarray(m), axis=-2)
xp_assert_close(res_neg, xp.asarray(ref))

def test_frequency_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 10))
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
ref = np.cov(m, fweights=fw)
res = cov(xp.asarray(m), fweights=xp.asarray(fw))
xp_assert_close(res, xp.asarray(ref))

def test_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 10))
aw = rng.random(10)
ref = np.cov(m, aweights=aw)
res = cov(xp.asarray(m), aweights=xp.asarray(aw))
xp_assert_close(res, xp.asarray(ref))

def test_both_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 10))
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
aw = rng.random(10)
for correction in (0, 1, 2):
ref = np.cov(m, ddof=correction, fweights=fw, aweights=aw)
res = cov(
xp.asarray(m),
correction=correction,
fweights=xp.asarray(fw),
aweights=xp.asarray(aw),
)
xp_assert_close(res, xp.asarray(ref))

def test_batch_with_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
batch_shape = (2, 3)
n_var, n_obs = 3, 15
m = rng.random((*batch_shape, n_var, n_obs))
aw = rng.random(n_obs)
res = cov(xp.asarray(m), aweights=xp.asarray(aw))
ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))]
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))

def test_axis_with_weights(self, xp: ModuleType):
# axis=-2 (observations on first of 2D) combined with weights:
# verifies that moveaxis and weight alignment cooperate.
rng = np.random.default_rng(20260417)
m = rng.random((15, 3)) # observations on axis 0
aw = rng.random(15)
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1], dtype=np.int64)
ref = np.cov(m, rowvar=False, fweights=fw, aweights=aw)
res = cov(
xp.asarray(m),
axis=-2,
fweights=xp.asarray(fw),
aweights=xp.asarray(aw),
)
xp_assert_close(res, xp.asarray(ref))

def test_axis_out_of_bounds(self, xp: ModuleType):
m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
with pytest.raises(IndexError):
_ = cov(m, axis=5)



@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
class TestOneHot:
Expand Down