End-to-end support for coarser-than-daily count signals#794
End-to-end support for coarser-than-daily count signals#794cdc-mitzimorris merged 56 commits intomainfrom
Conversation
High-level comment: do we definitely need to enforce all weekly quantities sharing the same week? I agree that in many cases a user will want this, but it's not a given. For example, you could imagine two weekly aggregate observables, one reported in MMWR epiweeks, the other in isoweeks. Similarly, while matching temporal process weeks to observation weeks makes sense to me as a default, I don't think we should enforce it. |
|
@damonbayer and @dylanhmorris - I have addressed all comments.
ready for re-review. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 42 out of 42 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| n_steps = self._resolve_n_coarse(n_timepoints) | ||
| coarse = self.inner.sample( | ||
| n_timepoints=n_steps, | ||
| initial_value=initial_value, | ||
| n_processes=n_processes, | ||
| name_prefix=name_prefix, | ||
| ) |
There was a problem hiding this comment.
StepwiseTemporalProcess.sample accepts first_day_dow but does not forward it to the inner process. This breaks valid compositions like StepwiseTemporalProcess(inner=WeeklyTemporalProcess(...)) where the inner requires first_day_dow at sample time. Forward first_day_dow to inner.sample (it can be ignored by processes that don't use it) or explicitly forbid inners that require a calendar anchor with a clear error.
| if self.reporting_schedule == "regular": | ||
| if obs is None: | ||
| return | ||
| n_periods = self._n_periods(n_total, first_day_dow) | ||
| obs = jnp.asarray(obs) | ||
| if obs.ndim != 2: | ||
| raise ValueError( | ||
| f"Observation '{self.name}': regular-schedule obs must " | ||
| f"be 2D (n_periods, n_observed_subpops); got shape {obs.shape}" | ||
| ) | ||
| if obs.shape[0] != n_periods: | ||
| raise ValueError( | ||
| f"Observation '{self.name}': obs dimension 0 length " | ||
| f"{obs.shape[0]} must equal n_periods ({n_periods}). " | ||
| f"Pad with NaN for unobserved periods." | ||
| ) | ||
| if subpop_indices is not None: | ||
| n_observed = jnp.asarray(subpop_indices).shape[0] | ||
| if obs.shape[1] != n_observed: | ||
| raise ValueError( | ||
| f"Observation '{self.name}': obs dimension 1 length " | ||
| f"{obs.shape[1]} must equal len(subpop_indices) " | ||
| f"({n_observed})" | ||
| ) | ||
| return |
There was a problem hiding this comment.
SubpopulationCounts.validate_data allows obs to be provided with subpop_indices=None (regular and irregular schedules), but SubpopulationCounts.sample always raises when subpop_indices is missing. This can let validate_data pass while sampling fails later. Either require subpop_indices in validate_data whenever obs is not None, or define a default behavior (e.g., all subpops) and implement it consistently in sample/validate_data.
damonbayer
left a comment
There was a problem hiding this comment.
One small suggestion. Thanks @cdc-mitzimorris!
Overview
Adds support for count observations aggregated to a weekly grid while the renewal equation continues to be evaluated daily. Two design pieces:
End-to-end flow for a weekly signal
pyrenew.time.daily_to_weekly.Relationship to pyrenew-hew
The production pyrenew-hew model parameterizes$\mathcal{R}(t)$ weekly and aggregates daily predicted hospital admissions to a weekly grid. This PR brings the same capability into PyRenew while making parameter cadence a user choice rather than a fixed coupling:
StepwiseTemporalProcess(step_size=7, alignment="calendar_week", week_start_dow=...).Observation cadence and parameter cadence are independent design choices; the builder no longer enforces a pairing rule between them.
Reviewer guide
Review bottom-up through the dependency chain. Each unit's changes are self-contained.
1. Synthetic data refresh (120 → 126 days)
datagen_he_CA_126.py,synthetic_CA_126/*.csv,synthetic_data.py,test_datagen_he_CA_126.py,test_datasets_synthetic.py.true_parameters.jsonmatches the generating process.2. Observation base validators —
pyrenew/observation/base.py_validate_aggregation_params,_compute_period_offset,_validate_period_end_times._validate_shapes_matchreplaces_validate_obs_times_shape._validate_dow_effectnow uses the sharedrequire_shapehelper.(period_end_dow + 1 - first_day_dow) % 7and period-boundary alignment.3. Latent: temporal processes —
pyrenew/latent/temporal_processes.pyStepwiseTemporalProcess;step_sizeattr added to theTemporalProcessProtocol and to AR1 / DifferencedAR1 / RandomWalk."model_index"(default) starts blocks at model index 0;"calendar_week"aligns weekly blocks to a declaredweek_start_dow.pyrenew.time.weekly_to_daily; coarse trajectory recorded as{name_prefix}_coarse.first_day_dowthreaded through the Protocol so calendar-aligned wrappers can use the model-axis day-of-week; standard processes ignore it.4. Latent: shape contracts —
pyrenew/latent/{population,subpopulation}_infections.pysample()methods acceptfirst_day_dowand forward it to their temporal processes.pyrenew.arrayutils.require_shapehelper.5. Count observation path —
pyrenew/observation/count_observations.pyaggregation_period,reporting_schedule,period_end_dow._aggregatewrapspyrenew.time.daily_to_weekly.validate_data/samplefor regular vs. irregular schedules.SubpopulationCountscall sites updated;times→period_end_timesrename consistent.6. Measurement observations —
pyrenew/observation/measurement_observations.py7. Builder + model coherence —
pyrenew/model/{pyrenew_builder,multisignal_model}.pyMultiSignalModel.sample()acceptsfirst_day_dowand forwards it to the latent process._validate_coherenceenforces calendar-anchor and structural coherence:aggregation_period > 1must agree onperiod_end_dow.step_sizemust be a positive integer.week_start_dowconsistent with weeklyperiod_end_dow:period_end_dow == (week_start_dow + 6) % 7.first_day_dowrequired atvalidate_datawhen any obs hasaggregation_period > 1.8. Integration test —
test/integration/test_population_infections_he_weekly.pynumpyro.enable_x64()+set_host_device_count(4).9. Test fixtures + config —
test/conftest.py,pyproject.toml,_typos.tomlWrongShapeTemporalProcess,ConstantTemporalProcess,InvalidStepSizeTemporalProcess).integrationmarker added so-m "not integration"skips MCMC tests.10. Tutorial —
docs/tutorials/building_multisignal_models.qmdalignment="calendar_week",week_start_dow=6) with weekly observations (period_end_dow=5).Where to focus review attention
_compute_period_offsetand the period-boundary check in_validate_period_end_times— these govern correctness of weekly alignment.StepwiseTemporalProcesscalendar-week broadcasting —weekly_to_dailyis reused; sanity check withn_timepoints=17,first_day_dow=3,week_start_dow=6produces 3 coarse samples broadcast to[c0×3, c1×7, c2×7]._validate_coherence— each has pass + distinct-failure tests intest_pyrenew_builder.py.CountObservation._aggregateruns inside the numpyro-traced graph — likelihood scoring is at weekly scale, not post-hoc.Incidental fixes (not directly tied to #789)
Found while implementing aggregation; small enough that separating them would create churn.
_validate_index_arrayempty-array guard (observation/base.py) —jnp.any(indices < 0)on empty arrays returnedFalse; now returns early onsize == 0._validate_index_array/_validate_obs_densendim check — previously accepted non-1D arrays and relied on silent broadcasting; now rejects with a clear error.test_ar_process_asymptotics— the bound|long_ts[-1]| < 3 * noise_sdwas too tight (stationary SD strictly > innovation SD when autoregressive coefficients are non-zero) and latently flaky; replaced with closed-form stationary SD per order._validate_obs_times_shape→_validate_shapes_matchrename — prerequisite refactor; same shape-match logic needed for(obs, period_end_times)pairs._typos.toml: allowdows.pyproject.toml: registerintegrationpytest marker.