ENH: cov: expose correction and weights parameters#690
ENH: cov: expose correction and weights parameters#690bruAristimunha wants to merge 10 commits intodata-apis:mainfrom
Conversation
Resolves data-apis#688. Adds `axis`, `correction`, `frequency_weights`, and `weights` to `cov`, giving users control over the degrees-of-freedom correction and the observation-axis / weighted variants that `numpy.cov` and `torch.cov` already support. Naming follows array-api conventions (`axis`, `correction`) rather than numpy's (`rowvar`, `bias`, `ddof`); the docstring includes a one-to-one mapping. The delegation moves observations to the last axis via `xp.moveaxis`, collapsing `rowvar` out of the backend dispatch — only `ddof` vs `correction` differs between branches. Dask's native `cov` forces `.compute()` on a lazy scalar when any weights are given, so weighted dask inputs fall through to the generic implementation, which is fully lazy.
|
It looks like the The PR description mentions that other functions in this library already use |
| # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` | ||
| # requires integer `correction`. For non-integer-valued `correction`, | ||
| # fall through to the generic implementation. | ||
| integer_correction = isinstance(correction, int) or correction.is_integer() |
There was a problem hiding this comment.
Why do we allow non integer corrections in the first place? Is this to allow people to pass correction=1. instead of raising an error? Or do people really use corrections that are correction=1.234 (I'm not familiar with advanced uses)
There was a problem hiding this comment.
Here, I follow the xp.var approach to allow int and float, https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html.
Think about the case use, I don't know either. We can reduce the scope of the var and allow only integers. I think it is a good idea!
There was a problem hiding this comment.
If observations have weights, the unbiased correction is often not n-1. Instead, it depends on the sum of weights and their dispersion. Another instance where the correction is not an integer is for autocorrelated data.
One of the reasons for not using ddof was to get away from the implicit integer-encoded mental model of correction factors.
|
Hey @betatim! This was a little hard decision that I had to make, but I can be more strict with numpy if you prefer. I basically looked at what was already implemented on the API array and how they handle the parameter names that I was trying to implement. Like, for each parameter that I was trying to introduce, I checked how it was made in the past here from numpy to: the Basically, for the https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html There was a discussion on how to use correction instead of bias+ddof on these functions. Here was introduced data-apis/array-api#10, and then, later, they made some interesting discussions here: data-apis/array-api#695; it was @kgryte who led the discussion. For the case of the And for the frequency_weights and weights, it was my experience in Pyriemann that made the decisions. I think the only place that I remember using something similar was the statsmodels (freq_weights, var_weights) that uses https://www.statsmodels.org/stable/generated/statsmodels.genmod.generalized_linear_model.GLM.html#statsmodels.genmod.generalized_linear_model.GLM.freq_weights I think in scikit you guys use sample_weight more, but I can accommodate any request about this. |
betatim
left a comment
There was a problem hiding this comment.
What is your thinking on validating the weights passed in? Things like checking the shapes make sense, that they are all positive (is this actually required? how does it fit with being lazy?)
|
I liked this idea a lot @betatim! I think it will make the check in the library that use api array extra much lighter. |
83b7e1b to
d9701e0
Compare
|
FYI @qbarthelemy and @agramfort |
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
|
Thanks a lot for the detailed answer in #690 (comment) - I didn't realise there was precedent for using What is the "temporary deployed" thing that keeps happening? |
|
it is not me @betatim, i think it something that @lucascolley is pushing in pushing here: #699 |
|
Happy that you liked the response @betatim :) I think I addressed all the points from you and @qbarthelemy, can we merge? |
fixed in bd3652a |
lucascolley
left a comment
There was a problem hiding this comment.
I took an initial look, seems pretty good!
One high level comment @bruAristimunha — could you demonstrate that this works as expected when used in a branch of sklearn? You should be able to change https://github.com/scikit-learn/scikit-learn/blob/06aded051fe6c7c9970b7e13c3669f952a799831/maint_tools/vendor_array_api_extra.sh#L8-L9 to point to this branch and commit hash.
| m = xp.asarray(m, copy=True) | ||
| m = xp.asarray(m) |
There was a problem hiding this comment.
It was mostly one small optimization that I made, as the new code for the covariance doesn't mutate in-place anymore.
Like, if I understand correctly, we need to copy because of this line:
m -= avg
But now, as we are doing:
m_c = m - avg
m_w = m_c if w is None else m_c * w
m_cT = xp.matrix_transpose(m_c)
c = (m_w @ m_cT) / factI noticed this by accident, then I was testing the speed test, and I noticed some small regression. I think it's worth disabling copying.
| # Validate axis against m.ndim. | ||
| ndim = max(m.ndim, 1) | ||
| if not -ndim <= axis < ndim: | ||
| msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}" | ||
| raise IndexError(msg) |
There was a problem hiding this comment.
just a thought, maybe some common logic can be extracted out from this and
array-api-extra/src/array_api_extra/_delegation.py
Lines 313 to 316 in af12cd5
that is for a tuple of axes though, so maybe not
There was a problem hiding this comment.
I couldn't think of something to optimize this. The only thing that i could think here is the normalize_axis_index from numpy, but then we would need another PR to introduce here in the library ;p
| # Validate weight shapes (eager metadata, lazy-safe). Value-based | ||
| # checks (non-negative, integer dtype) are intentionally skipped so | ||
| # lazy backends don't trigger compute -- same tradeoff as dask.cov. | ||
| n_obs = m.shape[-1] | ||
| for name, w in (("fweights", fweights), ("aweights", aweights)): | ||
| if w is None: | ||
| continue | ||
| if w.ndim != 1: | ||
| msg = f"`{name}` must be 1-D, got ndim={w.ndim}" | ||
| raise ValueError(msg) | ||
| if w.shape[0] != n_obs: | ||
| msg = ( | ||
| f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations" | ||
| ) | ||
| raise ValueError(msg) |
There was a problem hiding this comment.
must this happen at the delegation layer? What happens if we just let the backends to which we delegate error out instead?
There was a problem hiding this comment.
I think all the back-end (numpy, torch, jax, dask), except the 0-D scalar fweights for dask, will raise, and should be fine.
I mostly push this to address the suggestion of @betatim (#690 (review)), but if you prefer, I can let the backend handle this, as they already validate.
There was a problem hiding this comment.
let's me validate only in the case of dask
There was a problem hiding this comment.
moving this for the _funcs.
|
hey @betatim, As you have the first covariance PR on scikit, can you help with this small test as requested by @lucascolley?
|
|
hey @lucascolley, I made in my branch that was built on top of @betatim's work for scikit-learn first covariance, you can check more here: scikit-learn/scikit-learn#33600 |
Resolves #688.
Summary
axis,correction,frequency_weights, andweightsparameters toxpx.cov, unlocking the degrees-of-freedom and weighted variants thatnumpy.covandtorch.covalready support.axis,correction) used elsewhere in this library rather than numpy's (rowvar,bias,ddof). The docstring includes a one-to-one mapping for users migrating fromnumpy.cov.Design
The delegation moves observations to the last axis via
xp.moveaxis, which collapsesrowvarout of backend dispatch entirely — onlyddof(numpy/cupy/dask/jax) vscorrection(torch) differs between branches.Fallbacks to the generic implementation (
_funcs.cov):m.ndim > 2(batched input, not supported by any native).correction(rejected bynumpy.cov'sddof).dask.array.covforces.compute()on a lazy 0-D scalar via its internalif fact <= 0check. The generic path stays fully lazy because its weighted branch doesn't comparefactto zero (noted in docstring).Weighted formula in
_funcs.covmatches numpy's (algebraically):c = (m_c · w) @ m_c.T / (v1 - correction · v2 / v1).Tests
New
TestCovcases validate againstnp.covas reference:test_correction(integer ddof)test_correction_float(generic-path-only, hand-computed reference)test_axis/test_axis_with_weights/test_axis_out_of_boundstest_frequency_weights/test_weights/test_both_weightstest_batch_with_weightsTest plan
pytest tests/test_funcs.py::TestCov— 126 passed across numpy, torch, jax, dask, array-api-strictpytest tests/test_funcs.pyfull — 4263 passed, 0 failedlefthook run pre-commit— ruff, numpydoc, mypy, pyright, typos all greenlazy_xp_function(cov)asserts 0.compute()calls, holds for weighted path via the fallback