Skip to content

ENH: cov: expose correction and weights parameters#690

Open
bruAristimunha wants to merge 10 commits intodata-apis:mainfrom
bruAristimunha:cov_parameters
Open

ENH: cov: expose correction and weights parameters#690
bruAristimunha wants to merge 10 commits intodata-apis:mainfrom
bruAristimunha:cov_parameters

Conversation

@bruAristimunha
Copy link
Copy Markdown

Resolves #688.

Summary

  • Adds axis, correction, frequency_weights, and weights parameters to xpx.cov, unlocking the degrees-of-freedom and weighted variants that numpy.cov and torch.cov already support.
  • Naming follows array-api conventions (axis, correction) used elsewhere in this library rather than numpy's (rowvar, bias, ddof). The docstring includes a one-to-one mapping for users migrating from numpy.cov.

Design

The delegation moves observations to the last axis via xp.moveaxis, which collapses rowvar out of backend dispatch entirely — only ddof (numpy/cupy/dask/jax) vs correction (torch) differs between branches.

Fallbacks to the generic implementation (_funcs.cov):

  • m.ndim > 2 (batched input, not supported by any native).
  • Non-integer correction (rejected by numpy.cov's ddof).
  • Dask with weights — dask.array.cov forces .compute() on a lazy 0-D scalar via its internal if fact <= 0 check. The generic path stays fully lazy because its weighted branch doesn't compare fact to zero (noted in docstring).

Weighted formula in _funcs.cov matches numpy's (algebraically): c = (m_c · w) @ m_c.T / (v1 - correction · v2 / v1).

Tests

New TestCov cases validate against np.cov as reference:

  • test_correction (integer ddof)
  • test_correction_float (generic-path-only, hand-computed reference)
  • test_axis / test_axis_with_weights / test_axis_out_of_bounds
  • test_frequency_weights / test_weights / test_both_weights
  • test_batch_with_weights

Test plan

  • pytest tests/test_funcs.py::TestCov — 126 passed across numpy, torch, jax, dask, array-api-strict
  • pytest tests/test_funcs.py full — 4263 passed, 0 failed
  • lefthook run pre-commit — ruff, numpydoc, mypy, pyright, typos all green
  • Dask laziness verified — lazy_xp_function(cov) asserts 0 .compute() calls, holds for weighted path via the fallback

Resolves data-apis#688. Adds `axis`, `correction`, `frequency_weights`, and
`weights` to `cov`, giving users control over the degrees-of-freedom
correction and the observation-axis / weighted variants that
`numpy.cov` and `torch.cov` already support.

Naming follows array-api conventions (`axis`, `correction`) rather
than numpy's (`rowvar`, `bias`, `ddof`); the docstring includes a
one-to-one mapping. The delegation moves observations to the last
axis via `xp.moveaxis`, collapsing `rowvar` out of the backend
dispatch — only `ddof` vs `correction` differs between branches.

Dask's native `cov` forces `.compute()` on a lazy scalar when any
weights are given, so weighted dask inputs fall through to the
generic implementation, which is fully lazy.
@betatim
Copy link
Copy Markdown
Member

betatim commented Apr 20, 2026

It looks like the cov you are adding follows the pytorch signature, can you explain a bit why you chose that? In my PR I thought following the Numpy API makes sense because it seems that most libraries use that.

The PR description mentions that other functions in this library already use correction and axis. Which is a good reason to also do it here? Interested in your thinking.

# `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov`
# requires integer `correction`. For non-integer-valued `correction`,
# fall through to the generic implementation.
integer_correction = isinstance(correction, int) or correction.is_integer()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we allow non integer corrections in the first place? Is this to allow people to pass correction=1. instead of raising an error? Or do people really use corrections that are correction=1.234 (I'm not familiar with advanced uses)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I follow the xp.var approach to allow int and float, https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html.

Think about the case use, I don't know either. We can reduce the scope of the var and allow only integers. I think it is a good idea!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @lucascolley has an idea for use-cases?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If observations have weights, the unbiased correction is often not n-1. Instead, it depends on the sum of weights and their dispersion. Another instance where the correction is not an integer is for autocorrelated data.

One of the reasons for not using ddof was to get away from the implicit integer-encoded mental model of correction factors.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many thanks for the input @kgryte 🙏🏽

Comment thread src/array_api_extra/_delegation.py
Comment thread src/array_api_extra/_lib/_funcs.py Outdated
@bruAristimunha
Copy link
Copy Markdown
Author

Hey @betatim!

This was a little hard decision that I had to make, but I can be more strict with numpy if you prefer.

I basically looked at what was already implemented on the API array and how they handle the parameter names that I was trying to implement.

Like, for each parameter that I was trying to introduce, I checked how it was made in the past here from numpy to: the bias, the rowvar, the ddof, the fweights, and the aweights.

Basically, for the bias, ddof to become correction, I notice that in the functions xp.var, xp.std, and think xp.sum, they change the default names to the array api specification name.

https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html
https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html

There was a discussion on how to use correction instead of bias+ddof on these functions. Here was introduced data-apis/array-api#10, and then, later, they made some interesting discussions here: data-apis/array-api#695; it was @kgryte who led the discussion.

For the case of the rowvar becoming the axis, I just follow the signature of the other functions. seems like the axis was how they followed.

And for the frequency_weights and weights, it was my experience in Pyriemann that made the decisions. I think the only place that I remember using something similar was the statsmodels (freq_weights, var_weights) that uses https://www.statsmodels.org/stable/generated/statsmodels.genmod.generalized_linear_model.GLM.html#statsmodels.genmod.generalized_linear_model.GLM.freq_weights

I think in scikit you guys use sample_weight more, but I can accommodate any request about this.

Copy link
Copy Markdown
Member

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your thinking on validating the weights passed in? Things like checking the shapes make sense, that they are all positive (is this actually required? how does it fit with being lazy?)

@bruAristimunha
Copy link
Copy Markdown
Author

I liked this idea a lot @betatim! I think it will make the check in the library that use api array extra much lighter.

@bruAristimunha
Copy link
Copy Markdown
Author

FYI @qbarthelemy and @agramfort

Comment thread src/array_api_extra/_lib/_funcs.py Outdated
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Comment thread src/array_api_extra/_delegation.py Outdated
bruAristimunha and others added 2 commits April 20, 2026 13:39
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
@betatim
Copy link
Copy Markdown
Member

betatim commented Apr 20, 2026

Thanks a lot for the detailed answer in #690 (comment) - I didn't realise there was precedent for using correction in functions like var. I think it makes sense to copy that and use correction for cov as well. Worth making the translation!

What is the "temporary deployed" thing that keeps happening?

@bruAristimunha
Copy link
Copy Markdown
Author

it is not me @betatim, i think it something that @lucascolley is pushing in pushing here: #699

@bruAristimunha
Copy link
Copy Markdown
Author

Happy that you liked the response @betatim :)

I think I addressed all the points from you and @qbarthelemy, can we merge?

@lucascolley
Copy link
Copy Markdown
Member

What is the "temporary deployed" thing that keeps happening?

fixed in bd3652a

@lucascolley lucascolley changed the title ENH: expose correction and weights parameters in cov ENH: cov: expose correction and weights parameters Apr 20, 2026
Copy link
Copy Markdown
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took an initial look, seems pretty good!

One high level comment @bruAristimunha — could you demonstrate that this works as expected when used in a branch of sklearn? You should be able to change https://github.com/scikit-learn/scikit-learn/blob/06aded051fe6c7c9970b7e13c3669f952a799831/maint_tools/vendor_array_api_extra.sh#L8-L9 to point to this branch and commit hash.

Comment on lines -286 to +294
m = xp.asarray(m, copy=True)
m = xp.asarray(m)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we drop this?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mostly one small optimization that I made, as the new code for the covariance doesn't mutate in-place anymore.

Like, if I understand correctly, we need to copy because of this line:

 m -= avg 

But now, as we are doing:

m_c = m - avg                                                                                                                               
m_w = m_c if w is None else m_c * w                                                                                                                                                        
m_cT = xp.matrix_transpose(m_c)
c = (m_w @ m_cT) / fact

I noticed this by accident, then I was testing the speed test, and I noticed some small regression. I think it's worth disabling copying.

Comment thread src/array_api_extra/_delegation.py Outdated
Comment on lines +214 to +218
# Validate axis against m.ndim.
ndim = max(m.ndim, 1)
if not -ndim <= axis < ndim:
msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}"
raise IndexError(msg)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a thought, maybe some common logic can be extracted out from this and

if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
err_msg = (
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
)

that is for a tuple of axes though, so maybe not

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of something to optimize this. The only thing that i could think here is the normalize_axis_index from numpy, but then we would need another PR to introduce here in the library ;p

https://numpy.org/doc/2.1/reference/generated/numpy.lib.array_utils.normalize_axis_index.html#numpy.lib.array_utils.normalize_axis_index

Comment thread src/array_api_extra/_delegation.py Outdated
Comment on lines +225 to +239
# Validate weight shapes (eager metadata, lazy-safe). Value-based
# checks (non-negative, integer dtype) are intentionally skipped so
# lazy backends don't trigger compute -- same tradeoff as dask.cov.
n_obs = m.shape[-1]
for name, w in (("fweights", fweights), ("aweights", aweights)):
if w is None:
continue
if w.ndim != 1:
msg = f"`{name}` must be 1-D, got ndim={w.ndim}"
raise ValueError(msg)
if w.shape[0] != n_obs:
msg = (
f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations"
)
raise ValueError(msg)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must this happen at the delegation layer? What happens if we just let the backends to which we delegate error out instead?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the back-end (numpy, torch, jax, dask), except the 0-D scalar fweights for dask, will raise, and should be fine.

I mostly push this to address the suggestion of @betatim (#690 (review)), but if you prefer, I can let the backend handle this, as they already validate.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's me validate only in the case of dask

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving this for the _funcs.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread src/array_api_extra/_lib/_funcs.py
Comment thread src/array_api_extra/_lib/_funcs.py Outdated
@bruAristimunha
Copy link
Copy Markdown
Author

hey @betatim,

As you have the first covariance PR on scikit, can you help with this small test as requested by @lucascolley?

One high level comment @bruAristimunha — could you demonstrate that this works as expected when used in a branch of sklearn? You should be able to change https://github.com/scikit-learn/scikit-learn/blob/06aded051fe6c7c9970b7e13c3669f952a799831/maint_tools/vendor_array_api_extra.sh#L8-L9 to point to this branch and commit hash.

@bruAristimunha
Copy link
Copy Markdown
Author

hey @lucascolley,

I made in my branch that was built on top of @betatim's work for scikit-learn first covariance, you can check more here: scikit-learn/scikit-learn#33600

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EHN: make covariance more flexible

5 participants