-
Notifications
You must be signed in to change notification settings - Fork 18
ENH: cov: expose correction and weights parameters #690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0c57e2b
d9701e0
72e2d61
c0a20b0
c621a74
98f216a
06b4007
8b5c471
cb717b0
e34d415
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,7 +81,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array | |
| return _funcs.atleast_nd(x, ndim=ndim, xp=xp) | ||
|
|
||
|
|
||
| def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: | ||
| def cov( | ||
| m: Array, | ||
| /, | ||
| *, | ||
| axis: int = -1, | ||
| correction: int | float = 1, | ||
| fweights: Array | None = None, | ||
| aweights: Array | None = None, | ||
| xp: ModuleType | None = None, | ||
| ) -> Array: | ||
| """ | ||
| Estimate a covariance matrix (or a stack of covariance matrices). | ||
|
|
||
|
|
@@ -92,16 +101,37 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: | |
| :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance | ||
| of :math:`x_i`. | ||
|
|
||
| With the exception of supporting batch input, this provides a subset of | ||
| the functionality of ``numpy.cov``. | ||
| Extends ``numpy.cov`` with support for batch input and array-api | ||
| backends. Naming follows the array-api conventions used elsewhere in | ||
| this library (``axis``, ``correction``) rather than the numpy spellings | ||
| (``rowvar``, ``bias``, ``ddof``); see Notes for the mapping. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| m : array | ||
| An array of shape ``(..., N, M)`` whose innermost two dimensions | ||
| contain *M* observations of *N* variables. That is, | ||
| each row of `m` represents a variable, and each column a single | ||
| observation of all those variables. | ||
| contain *M* observations of *N* variables by default. The axis of | ||
| observations is controlled by `axis`. | ||
| axis : int, optional | ||
| Axis of `m` containing the observations. Default: ``-1`` (the last | ||
| axis), matching the array-api convention. Use ``axis=-2`` (or ``0`` | ||
| for 2-D input) to treat each column as a variable, which | ||
| corresponds to ``rowvar=False`` in ``numpy.cov``. | ||
| correction : int or float, optional | ||
| Degrees of freedom correction: normalization divides by | ||
| ``N - correction`` (for unweighted input). Default: ``1``, which | ||
| gives the unbiased estimate (matches ``numpy.cov`` default of | ||
| ``bias=False``). Set to ``0`` for the biased estimate (``N`` | ||
| normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to | ||
| ``correction`` in ``numpy.var``/``std`` and ``torch.cov``. | ||
| fweights : array, optional | ||
| 1-D array of integer frequency weights: the number of times each | ||
| observation is repeated. Same as ``fweights`` in | ||
| ``numpy.cov``/``torch.cov``. | ||
| aweights : array, optional | ||
| 1-D array of observation-vector weights (analytic weights). Larger | ||
| values mark more important observations. Same as ``aweights`` in | ||
| ``numpy.cov``/``torch.cov``. | ||
| xp : array_namespace, optional | ||
| The standard-compatible namespace for `m`. Default: infer. | ||
|
|
||
|
|
@@ -111,6 +141,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: | |
| An array having shape (..., N, N) whose innermost two dimensions represent | ||
| the covariance matrix of the variables. | ||
|
|
||
| Notes | ||
| ----- | ||
| Mapping from ``numpy.cov`` to this function:: | ||
|
|
||
| numpy.cov(m, rowvar=True) -> cov(m, axis=-1) # default | ||
| numpy.cov(m, rowvar=False) -> cov(m, axis=-2) | ||
| numpy.cov(m, bias=True) -> cov(m, correction=0) | ||
| numpy.cov(m, ddof=k) -> cov(m, correction=k) | ||
| numpy.cov(m, fweights=f) -> cov(m, fweights=f) | ||
| numpy.cov(m, aweights=a) -> cov(m, aweights=a) | ||
|
|
||
| Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective | ||
| degrees of freedom is only emitted on the unweighted path. The | ||
| weighted path omits the check so that lazy backends (e.g. Dask) can | ||
| stay lazy end-to-end; choose ``correction`` and weights such that the | ||
| effective normalizer is positive. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import array_api_strict as xp | ||
|
|
@@ -162,18 +209,54 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: | |
| """ | ||
|
|
||
| if xp is None: | ||
| xp = array_namespace(m) | ||
|
|
||
| if ( | ||
| is_numpy_namespace(xp) | ||
| or is_cupy_namespace(xp) | ||
| or is_torch_namespace(xp) | ||
| or is_dask_namespace(xp) | ||
| or is_jax_namespace(xp) | ||
| ) and m.ndim <= 2: | ||
| return xp.cov(m) | ||
|
|
||
| return _funcs.cov(m, xp=xp) | ||
| xp = array_namespace(m, fweights, aweights) | ||
|
|
||
| # Validate axis against m.ndim. | ||
| ndim = max(m.ndim, 1) | ||
| if not -ndim <= axis < ndim: | ||
| msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}" | ||
| raise IndexError(msg) | ||
|
|
||
| # Normalize: observations on the last axis. After this, every backend | ||
| # sees the same convention and we never need to deal with `rowvar`. | ||
| if m.ndim >= 2 and axis not in (-1, m.ndim - 1): | ||
| m = xp.moveaxis(m, axis, -1) | ||
|
|
||
| # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` | ||
| # requires integer `correction`. For non-integer-valued `correction`, | ||
| # fall through to the generic implementation. | ||
| integer_correction = isinstance(correction, int) or correction.is_integer() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we allow non integer corrections in the first place? Is this to allow people to pass
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, I follow the xp.var approach to allow int and float, https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html. Think about the case use, I don't know either. We can reduce the scope of the var and allow only integers. I think it is a good idea!
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe @lucascolley has an idea for use-cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If observations have weights, the unbiased correction is often not One of the reasons for not using
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. many thanks for the input @kgryte 🙏🏽 |
||
| has_weights = fweights is not None or aweights is not None | ||
|
|
||
| if m.ndim <= 2 and integer_correction: | ||
| if is_torch_namespace(xp): | ||
| fw = None if fweights is None else xp.asarray(fweights) | ||
| aw = None if aweights is None else xp.asarray(aweights) | ||
| return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw) | ||
| # `dask.array.cov` forces `.compute()` whenever weights are given: | ||
| # its internal `if fact <= 0` check on a lazy 0-D scalar triggers | ||
| # materialization. Route to the generic impl, which is fully lazy | ||
| # because it only does sum/matmul and skips that scalar check. | ||
| if ( | ||
| is_numpy_namespace(xp) | ||
| or is_cupy_namespace(xp) | ||
| or is_jax_namespace(xp) | ||
| or (is_dask_namespace(xp) and not has_weights) | ||
| ): | ||
| return xp.cov( | ||
| m, | ||
| ddof=int(correction), | ||
| fweights=fweights, | ||
| aweights=aweights, | ||
| ) | ||
|
|
||
| return _funcs.cov( | ||
| m, | ||
| correction=correction, | ||
| fweights=fweights, | ||
| aweights=aweights, | ||
| xp=xp, | ||
| ) | ||
|
|
||
|
|
||
| def create_diagonal( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -281,31 +281,79 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... | |
| return tuple(out) | ||
|
|
||
|
|
||
| def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 | ||
| def cov( | ||
| m: Array, | ||
| /, | ||
| *, | ||
| correction: int | float = 1, | ||
| fweights: Array | None = None, | ||
| aweights: Array | None = None, | ||
| xp: ModuleType, | ||
| ) -> Array: # numpydoc ignore=PR01,RT01 | ||
| """See docstring in array_api_extra._delegation.""" | ||
| m = xp.asarray(m, copy=True) | ||
| m = xp.asarray(m) | ||
|
Comment on lines
-286
to
+294
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we drop this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was mostly one small optimization that I made, as the new code for the covariance doesn't mutate in-place anymore. Like, if I understand correctly, we need to copy because of this line: But now, as we are doing: m_c = m - avg
m_w = m_c if w is None else m_c * w
m_cT = xp.matrix_transpose(m_c)
c = (m_w @ m_cT) / factI noticed this by accident, then I was testing the speed test, and I noticed some small regression. I think it's worth disabling copying. |
||
| dtype = ( | ||
| xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64) | ||
| ) | ||
|
|
||
| m = atleast_nd(m, ndim=2, xp=xp) | ||
| m = xp.astype(m, dtype) | ||
|
|
||
| avg = xp.mean(m, axis=-1, keepdims=True) | ||
| # Validate weight shapes (eager metadata, lazy-safe). Native backends | ||
| # validate themselves; this covers the generic path (array-api-strict, | ||
| # sparse, and the dask+weights fallback where the native check is | ||
| # bypassed to preserve laziness). | ||
| n_obs = m.shape[-1] | ||
| for name, w_in in (("fweights", fweights), ("aweights", aweights)): | ||
| if w_in is None: | ||
| continue | ||
| if w_in.ndim != 1: | ||
| msg = f"`{name}` must be 1-D, got ndim={w_in.ndim}" | ||
| raise ValueError(msg) | ||
| if w_in.shape[0] != n_obs: | ||
| msg = ( | ||
| f"`{name}` has length {w_in.shape[0]} but `m` has {n_obs} observations" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| fw = None | ||
| if fweights is not None: | ||
| fw = xp.astype(xp.asarray(fweights), dtype) | ||
| aw = None | ||
| if aweights is not None: | ||
| aw = xp.astype(xp.asarray(aweights), dtype) | ||
| if fw is None and aw is None: | ||
| w = None | ||
| elif fw is None: | ||
| w = aw | ||
| elif aw is None: | ||
| w = fw | ||
| else: | ||
| w = fw * aw | ||
|
|
||
| m_shape = eager_shape(m) | ||
| fact = m_shape[-1] - 1 | ||
|
|
||
| if fact <= 0: | ||
| warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) | ||
| fact = 0 | ||
|
|
||
| m -= avg | ||
| m_transpose = xp.matrix_transpose(m) | ||
| if xp.isdtype(m_transpose.dtype, "complex floating"): | ||
| m_transpose = xp.conj(m_transpose) | ||
| c = xp.matmul(m, m_transpose) | ||
| c /= fact | ||
| if w is None: | ||
| avg = xp.mean(m, axis=-1, keepdims=True) | ||
| fact = m_shape[-1] - correction | ||
| if fact <= 0: | ||
| warnings.warn( | ||
| "Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2 | ||
| ) | ||
| fact = 0 | ||
| else: | ||
| v1 = xp.sum(w, axis=-1) | ||
| avg = xp.sum(m * w, axis=-1, keepdims=True) / v1 | ||
| if aw is None: | ||
| fact = v1 - correction | ||
| else: | ||
| fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1 | ||
|
|
||
| m_c = m - avg | ||
| m_w = m_c if w is None else m_c * w | ||
| m_cT = xp.matrix_transpose(m_c) | ||
|
lucascolley marked this conversation as resolved.
|
||
| if xp.isdtype(m_cT.dtype, "complex floating"): | ||
| m_cT = xp.conj(m_cT) | ||
| c = (m_w @ m_cT) / fact | ||
| axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) | ||
| return xp.squeeze(c, axis=axes) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a thought, maybe some common logic can be extracted out from this and
array-api-extra/src/array_api_extra/_delegation.py
Lines 313 to 316 in af12cd5
that is for a tuple of axes though, so maybe not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't think of something to optimize this. The only thing that i could think here is the normalize_axis_index from numpy, but then we would need another PR to introduce here in the library ;p
https://numpy.org/doc/2.1/reference/generated/numpy.lib.array_utils.normalize_axis_index.html#numpy.lib.array_utils.normalize_axis_index