diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1635bdd2a..27aa16acc 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: include: - - {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "Dask 2025.12.0"} + - {os: windows-latest, python: "3.11", dask-version: "2026.3.0", name: "Dask 2026.3.0"} - {os: windows-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} diff --git a/pyproject.toml b/pyproject.toml index 07ec8140b..04bbb2d44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,8 @@ dependencies = [ "annsel>=0.1.2", "click", "dask-image", - "dask>=2025.12.0,<2026.1.2", - "distributed<2026.1.2", + "dask>=2026.3.0", + "distributed>=2026.3.0", "datashader", "fsspec[s3,http]", "geopandas>=0.14", @@ -50,6 +50,7 @@ dependencies = [ "xarray>=2024.10.0", "xarray-spatial>=0.3.5", "zarr>=3.0.0", + "zarrs", ] [project.optional-dependencies] torch = [ @@ -62,6 +63,9 @@ extra = [ ] [dependency-groups] +sharding = [ + "zarrs", +] dev = [ "bump2version", ] diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 7ba66e710..1bb0483c9 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -131,6 +131,10 @@ "settings", ] +import zarr + +zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) + def __getattr__(name: str) -> Any: if name in _submodules: diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 739b225fe..d042fa6c3 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1108,6 +1108,7 @@ def write( update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, dict[str, Any] | Any] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1155,6 +1156,7 @@ def write( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. + """ from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import _parse_formats @@ -1173,6 +1175,13 @@ def write( store.close() for element_type, element_name, element in self.gen_elements(): + element_raster_write_kwargs = None + if element_type in ("images", "labels") and raster_write_kwargs: + if kwargs := raster_write_kwargs.get(element_name): + element_raster_write_kwargs = kwargs + elif not any(isinstance(x, dict) for x in raster_write_kwargs.values()): + element_raster_write_kwargs = raster_write_kwargs + self._write_element( element=element, zarr_container_path=file_path, @@ -1181,6 +1190,7 @@ def write( overwrite=False, parsed_formats=parsed, shapes_geometry_encoding=shapes_geometry_encoding, + element_raster_write_kwargs=element_raster_write_kwargs, ) if self.path != file_path and update_sdata_path: @@ -1198,6 +1208,7 @@ def _write_element( overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + element_raster_write_kwargs: dict[str, Any] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1231,6 +1242,7 @@ def _write_element( group=element_group, name=element_name, element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, ) elif element_type == "labels": write_labels( @@ -1238,6 +1250,7 @@ def _write_element( group=root_group, name=element_name, element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, ) elif element_type == "points": write_points( diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index a8b2ab2ce..a588ce092 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -148,13 +148,13 @@ def _prepare_storage_options( return None if isinstance(storage_options, dict): prepared = dict(storage_options) - if "chunks" in prepared: + if "chunks" in prepared and prepared["chunks"] is not None: prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) return prepared prepared_options = [dict(options) for options in storage_options] for options in prepared_options: - if "chunks" in options: + if "chunks" in options and options["chunks"] is not None: options["chunks"] = _normalize_explicit_chunks(options["chunks"]) return prepared_options @@ -289,6 +289,10 @@ def _write_raster( metadata Additional metadata for the raster element """ + from dataclasses import asdict + + from spatialdata import settings + if raster_type not in ["image", "labels"]: raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in @@ -305,6 +309,18 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + if isinstance(storage_options, dict): + storage_options = { + **{k: v for k, v in asdict(settings).items() if k in ("chunks", "shards")}, + **storage_options, + } + elif isinstance(storage_options, list): + storage_options = [ + {**{k: v for k, v in asdict(settings).items() if k in ("chunks", "shards")}, **x} for x in storage_options + ] + elif not storage_options: + storage_options = {k: v for k, v in asdict(settings).items() if k in ("chunks", "shards")} + if isinstance(raster_data, DataArray): _write_raster_dataarray( raster_type, diff --git a/src/spatialdata/config.py b/src/spatialdata/config.py index 35b96e5f7..ebb7a0163 100644 --- a/src/spatialdata/config.py +++ b/src/spatialdata/config.py @@ -1,8 +1,18 @@ from __future__ import annotations -from dataclasses import dataclass +import json +import os +from dataclasses import asdict, dataclass +from pathlib import Path from typing import Literal +from platformdirs import user_config_dir + + +def _config_path() -> Path: + """Return the platform-appropriate path to the user config file.""" + return Path(user_config_dir(appname="spatialdata")) / "settings.json" + @dataclass class Settings: @@ -10,6 +20,8 @@ class Settings: Attributes ---------- + custom_config_path + The path specified by the user of where to store the settings. shapes_geometry_encoding Default geometry encoding for GeoParquet files when writing shapes. Can be "WKB" (Well-Known Binary) or "geoarrow". @@ -18,13 +30,157 @@ class Settings: Chunk sizes bigger than this value (bytes) can trigger a compression error. See https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276 If detected during parsing/validation, a warning is raised. + chunks + The chunksize to use for chunking an array. Length of the tuple must match + the number of dimensions. + shards + The default shard size (zarr v3) to use when storing arrays. Length of the tuple + must match the number of dimensions. """ + custom_config_path: Path | None = None shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB" large_chunk_threshold_bytes: int = 2147483647 + chunks: tuple[int, ...] | None = None + shards: tuple[int, ...] | None = None + + def save(self, path: Path | str | None = None) -> None: + """Store current settings on disk. + + If Path is specified, it will store the config settings to this location. Otherwise, stores + the config in the default config directory for the given operating system. + + Parameters + ---------- + path + The path to use for storing settings if different from default. Must be + a json file. This will be stored in the global config as the custom_config_path. + + Returns + ------- + Path + The path the settings were written to. + """ + target = Path(path) if path else _config_path() + + if not str(target).endswith(".json"): + raise ValueError("Path must end with .json") + + if path is not None: + data = asdict(self) + data["custom_config_path"] = str(target) + with target.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + global_path = _config_path() + global_path.parent.mkdir(parents=True, exist_ok=True) + try: + with global_path.open(encoding="utf-8") as f: + global_data = json.load(f) + except (json.JSONDecodeError, OSError): + global_data = {} + global_data["custom_config_path"] = str(target) + with global_path.open("w", encoding="utf-8") as f: + json.dump(global_data, f, indent=2) + else: + target.parent.mkdir(parents=True, exist_ok=True) + data = asdict(self) + data["custom_config_path"] = str(data["custom_config_path"]) if data["custom_config_path"] else None + with target.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + @classmethod + def load(cls, path: Path | str | None = None) -> Settings: + """Load settings from disk. + + This method falls back to default settings if either there is no config at the + given path or there is a decoding error. Unknown or renamed keys in the file + are silently ignored, e.g. old config files will not cause errors. + + Parameters + ---------- + path + The path to the config file if different from default. If not specified, + the default location is used. + + Returns + ------- + Settings + A populated Settings instance. + """ + target = Path(path) if path else _config_path() + + if not target.exists(): + instance = cls() + instance.apply_env() + return instance + + try: + with target.open(encoding="utf-8") as f: + data = json.load(f) + except (json.JSONDecodeError, OSError): + instance = cls() + instance.apply_env() + return instance + + # This prevents fields from old config files to be used. + known_fields = {k: v for k, v in data.items() if k in cls.__dataclass_fields__} + instance = cls(**known_fields) + instance.apply_env() + return instance + + def reset(self) -> None: + """Inplace reset all settings to their built-in defaults (in memory only). + + Call 'save' method afterwards if you want the reset to be persisted. + """ + defaults = Settings() + for field_name in self.__dataclass_fields__: + setattr(self, field_name, getattr(defaults, field_name)) + + def apply_env(self) -> None: + """Apply environment variable overrides on top of the current state. + + Env vars take precedence over both the config file and any + in-session assignments. Useful in CI pipelines or HPC clusters + where you cannot edit the config file. + + Supported variables + ------------------- + SPATIALDATA_CUSTOM_CONFIG_PATH -> custom_config_path + SPATIALDATA_SHAPES_GEOMETRY_ENCODING → shapes_geometry_encoding + SPATIALDATA_LARGE_CHUNK_THRESHOLD_BYTES → large_chunk_threshold_bytes + SPATIALDATA_CHUNKS → chunks + SPATIALDATA_SHARDS → shards (integer or "none") + """ + _ENV: dict[str, tuple[str, type]] = { + "SPATIALDATA_CUSTOM_CONFIG_PATH": ("custom_config_path", Path), + "SPATIALDATA_SHAPES_GEOMETRY_ENCODING": ("shapes_geometry_encoding", str), + "SPATIALDATA_LARGE_CHUNK_THRESHOLD_BYTES": ("large_chunk_threshold_bytes", int), + "SPATIALDATA_CHUNKS": ("chunks", str), + "SPATIALDATA_SHARDS": ("shards", str), # handled specially below + } + for env_key, (field_name, cast) in _ENV.items(): + raw = os.environ.get(env_key) + if raw is None: + continue + if field_name == "shards": + setattr(self, field_name, None if raw.lower() in ("none", "") else int(raw)) + else: + setattr(self, field_name, cast(raw)) + + def __repr__(self) -> str: + fields = ", ".join(f"{k}={v!r}" for k, v in asdict(self).items()) + return f"Settings({fields})" + + @staticmethod + def config_path() -> Path: + """Return platform-specific path where settings are stored.""" + return _config_path() + -settings = Settings() +settings = Settings.load() # Backwards compatibility alias LARGE_CHUNK_THRESHOLD_BYTES = settings.large_chunk_threshold_bytes diff --git a/tests/conftest.py b/tests/conftest.py index c97939129..5a73b5b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -642,3 +642,14 @@ def complex_sdata() -> SpatialData: sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X)) return sdata + + +@pytest.fixture() +def settings_cls(tmp_path, monkeypatch): + """ + Provide setting class with default path redirected. + """ + from spatialdata.config import Settings + + monkeypatch.setattr("spatialdata.config._config_path", lambda: tmp_path / "default_settings.json") + return Settings diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 209a43046..f17222062 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -743,6 +743,103 @@ def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: assert list(image_group.keys()) == ["s0"] +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_image_sharding(tmp_path: Path, sdata_container_format: SpatialDataContainerFormatType) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=(1, 100, 200)) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + sdata = SpatialData(images={"image": image}) + path = tmp_path / "data.zarr" + + if sdata_container_format.zarr_format == 2: + with pytest.raises(ValueError, match="Zarr format 2 arrays can only"): + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": (1, 50, 100), "shards": (1, 100, 200)}, + ) + else: + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": (1, 50, 100), "shards": (1, 100, 200)}, + ) + + image_group = zarr.open_group(path / "images" / "image", mode="r") + arr = image_group["s0"] + + assert arr.chunks == (1, 50, 100) + assert arr.shards == (1, 100, 200) + + +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_multiscale_image_sharding( + tmp_path: Path, sdata_container_format: SpatialDataContainerFormatType +) -> None: + data = da.from_array(RNG.random((3, 1600, 2000)), chunks=(1, 100, 200)) + image = Image2DModel.parse(data, dims=("c", "y", "x"), scale_factors=[2]) + sdata = SpatialData(images={"image": image}) + path = tmp_path / "data.zarr" + + if sdata_container_format.zarr_format == 2: + with pytest.raises(ValueError, match="Zarr format 2 arrays can only"): + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": (1, 50, 100), "shards": (1, 100, 200)}, + ) + else: + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": (1, 50, 100), "shards": (1, 100, 200)}, + ) + + image_group = zarr.open_group(path / "images" / "image", mode="r") + arr1 = image_group["s0"] + + assert arr1.chunks == (1, 50, 100) + assert arr1.shards == (1, 100, 200) + + arr2 = image_group["s0"] + + assert arr2.chunks == (1, 50, 100) + assert arr2.shards == (1, 100, 200) + + +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_image_sharding_keyword(tmp_path: Path, sdata_container_format: SpatialDataContainerFormatType) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=(1, 100, 200)) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + image2 = Image2DModel.parse(data.copy(), dims=("c", "y", "x")) + sdata = SpatialData(images={"image": image, "other_image": image2}) + path = tmp_path / "data.zarr" + + if sdata_container_format.zarr_format == 2: + with pytest.raises(ValueError, match="Zarr format 2 arrays can only"): + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"image": {"chunks": (1, 50, 100), "shards": (1, 100, 200)}}, + ) + else: + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"image": {"chunks": (1, 50, 100), "shards": (1, 100, 200)}}, + ) + + image_group = zarr.open_group(path / "images" / "image", mode="r") + arr = image_group["s0"] + + assert arr.chunks == (1, 50, 100) + assert arr.shards == (1, 100, 200) + + other_group = zarr.open_group(path / "images" / "other_image", mode="r") + arr = other_group["s0"] + + assert arr.chunks == (1, 100, 200) + + @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py new file mode 100644 index 000000000..1d9e28cbd --- /dev/null +++ b/tests/utils/test_config.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +def _config_path_for(tmp_path: Path) -> Path: + return tmp_path / "settings.json" + + +class TestDefaults: + def test_default_settings(self, settings_cls): + s = settings_cls() + assert s.shapes_geometry_encoding == "WKB" + assert s.large_chunk_threshold_bytes == 2_147_483_647 + assert s.chunks is None + assert s.shards is None + assert s.custom_config_path is None + + def test_change_settings_default_path(self, settings_cls): + s = settings_cls() + s.shapes_geometry_encoding = "geoarrow" + s.large_chunk_threshold_bytes = 1_000_000_000 + s.chunks = (512, 512) + s.shards = (1024, 1024) + s.save() + s = settings_cls().load() + assert s.shapes_geometry_encoding == "geoarrow" + assert s.large_chunk_threshold_bytes == 1_000_000_000 + assert s.chunks == [512, 512] + assert s.shards == [1024, 1024] + assert s.custom_config_path is None + + def test_change_settings_custom_path(self, settings_cls, tmp_path): + os.environ["SPATIALDATA_SHAPES_GEOMETRY_ENCODING"] = "geoarrow" + target_path = tmp_path / "custom_settings.json" + s = settings_cls().load() + assert s.shapes_geometry_encoding == "geoarrow" + + # We set the value also using environment variables to test whether these properly overwrite + s.large_chunk_threshold_bytes = 1_000_000_000 + os.environ["SPATIALDATA_LARGE_CHUNK_THRESHOLD_BYTES"] = "1_111_111_111" + + s.chunks = (512, 512) + s.shards = (1024, 1024) + s.save(path=target_path) + s = settings_cls().load() + assert s.shapes_geometry_encoding == "geoarrow" + assert s.large_chunk_threshold_bytes == 1_111_111_111 + assert s.chunks is None + assert s.shards is None + assert s.custom_config_path == str(target_path) + + s.reset() + s.save() + assert s.custom_config_path is None # This returns False + s = settings_cls().load() + assert s.custom_config_path is None