diff --git a/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py b/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py new file mode 100644 index 00000000000..a5d4d167d33 --- /dev/null +++ b/cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import ctypes +import functools +from collections.abc import Callable +from dataclasses import dataclass + +from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import ( + load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib, +) +from cuda.pathfinder._utils.platform_aware import IS_WINDOWS + + +class QueryDriverCudaVersionError(RuntimeError): + """Raised when ``query_driver_cuda_version()`` cannot determine the CUDA driver version.""" + + +@dataclass(frozen=True, slots=True) +class DriverCudaVersion: + """ + CUDA-facing driver version reported by ``cuDriverGetVersion()``. + + The name ``DriverCudaVersion`` is intentionally specific: this dataclass + models the version shown as ``CUDA Version`` in ``nvidia-smi``, not the + graphics driver release shown as ``Driver Version``. More specifically, + it reflects the CUDA user-mode driver (UMD) interface version reported by + ``cuDriverGetVersion()``, not the kernel-mode driver (KMD) package + version. + + Example ``nvidia-smi`` output:: + + +---------------------------------------------------------------------+ + | NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 | + +---------------------------------------------------------------------+ + + For the example above, ``DriverCudaVersion(encoded=13020, major=13, + minor=2)`` corresponds to ``CUDA Version: 13.2``. It does not correspond + to ``Driver Version: 595.58.03``. + """ + + encoded: int + major: int + minor: int + + +@functools.cache +def query_driver_cuda_version() -> DriverCudaVersion: + """Return the CUDA driver version parsed into its major/minor components.""" + try: + encoded = _query_driver_cuda_version_int() + return DriverCudaVersion( + encoded=encoded, + major=encoded // 1000, + minor=(encoded % 1000) // 10, + ) + except Exception as exc: + raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc + + +def _query_driver_cuda_version_int() -> int: + """Return the encoded CUDA driver version from ``cuDriverGetVersion()``.""" + loaded_cuda = _load_nvidia_dynamic_lib("cuda") + if IS_WINDOWS: + # `ctypes.WinDLL` exists on Windows at runtime. The ignore is only for + # Linux mypy runs, where the platform stubs do not define that attribute. + loader_cls: Callable[[str], ctypes.CDLL] = ctypes.WinDLL # type: ignore[attr-defined] + else: + loader_cls = ctypes.CDLL + driver_lib = loader_cls(loaded_cuda.abs_path) + cu_driver_get_version = driver_lib.cuDriverGetVersion + cu_driver_get_version.argtypes = [ctypes.POINTER(ctypes.c_int)] + cu_driver_get_version.restype = ctypes.c_int + version = ctypes.c_int() + status = cu_driver_get_version(ctypes.byref(version)) + if status != 0: + raise RuntimeError(f"Failed to query CUDA driver version via cuDriverGetVersion() (status={status}).") + return version.value diff --git a/cuda_pathfinder/tests/test_driver_lib_loading.py b/cuda_pathfinder/tests/test_driver_lib_loading.py index bf62a17d703..b97453c9b5a 100644 --- a/cuda_pathfinder/tests/test_driver_lib_loading.py +++ b/cuda_pathfinder/tests/test_driver_lib_loading.py @@ -25,6 +25,7 @@ _load_lib_no_cache, ) from cuda.pathfinder._dynamic_libs.subprocess_protocol import STATUS_NOT_FOUND, parse_dynamic_lib_subprocess_payload +from cuda.pathfinder._utils import driver_info from cuda.pathfinder._utils.platform_aware import IS_WINDOWS, quote_for_shell STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_LOAD_NVIDIA_DYNAMIC_LIB_STRICTNESS", "see_what_works") @@ -157,3 +158,23 @@ def raise_child_process_failed(): assert abs_path is not None info_summary_append(f"abs_path={quote_for_shell(abs_path)}") assert os.path.isfile(abs_path) + + +def test_real_query_driver_cuda_version(info_summary_append): + driver_info._load_nvidia_dynamic_lib.cache_clear() + driver_info.query_driver_cuda_version.cache_clear() + try: + version = driver_info.query_driver_cuda_version() + except driver_info.QueryDriverCudaVersionError as exc: + if STRICTNESS == "all_must_work": + raise + info_summary_append(f"driver version unavailable: {exc.__class__.__name__}: {exc}") + return + finally: + driver_info._load_nvidia_dynamic_lib.cache_clear() + driver_info.query_driver_cuda_version.cache_clear() + + info_summary_append(f"driver_version={version.major}.{version.minor} (encoded={version.encoded})") + assert version.encoded > 0 + assert version.major == version.encoded // 1000 + assert version.minor == (version.encoded % 1000) // 10 diff --git a/cuda_pathfinder/tests/test_utils_driver_info.py b/cuda_pathfinder/tests/test_utils_driver_info.py new file mode 100644 index 00000000000..21948dadafe --- /dev/null +++ b/cuda_pathfinder/tests/test_utils_driver_info.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import ctypes + +import pytest + +from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL +from cuda.pathfinder._utils import driver_info + + +@pytest.fixture(autouse=True) +def _clear_driver_cuda_version_query_cache(): + driver_info.query_driver_cuda_version.cache_clear() + yield + driver_info.query_driver_cuda_version.cache_clear() + + +class _FakeCuDriverGetVersion: + def __init__(self, *, status: int, version: int): + self.argtypes = None + self.restype = None + self._status = status + self._version = version + + def __call__(self, version_ptr) -> int: + ctypes.cast(version_ptr, ctypes.POINTER(ctypes.c_int)).contents.value = self._version + return self._status + + +class _FakeDriverLib: + def __init__(self, *, status: int, version: int): + self.cuDriverGetVersion = _FakeCuDriverGetVersion(status=status, version=version) + + +def _loaded_cuda(abs_path: str) -> LoadedDL: + return LoadedDL( + abs_path=abs_path, + was_already_loaded_from_elsewhere=False, + _handle_uint=0xBEEF, + found_via="system-search", + ) + + +def test_query_driver_cuda_version_uses_windll_on_windows(monkeypatch): + fake_driver_lib = _FakeDriverLib(status=0, version=12080) + loaded_paths: list[str] = [] + + monkeypatch.setattr(driver_info, "IS_WINDOWS", True) + monkeypatch.setattr( + driver_info, + "_load_nvidia_dynamic_lib", + lambda _libname: _loaded_cuda(r"C:\Windows\System32\nvcuda.dll"), + ) + + def fake_windll(abs_path: str): + loaded_paths.append(abs_path) + return fake_driver_lib + + monkeypatch.setattr(driver_info.ctypes, "WinDLL", fake_windll, raising=False) + + assert driver_info._query_driver_cuda_version_int() == 12080 + assert loaded_paths == [r"C:\Windows\System32\nvcuda.dll"] + + +def test_query_driver_cuda_version_returns_parsed_dataclass(monkeypatch): + monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", lambda: 12080) + + assert driver_info.query_driver_cuda_version() == driver_info.DriverCudaVersion( + encoded=12080, + major=12, + minor=8, + ) + + +def test_query_driver_cuda_version_wraps_internal_failures(monkeypatch): + root_cause = RuntimeError("low-level query failed") + + def fail_query_driver_cuda_version_int() -> int: + raise root_cause + + monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", fail_query_driver_cuda_version_int) + + with pytest.raises( + driver_info.QueryDriverCudaVersionError, + match="Failed to query the CUDA driver version", + ) as exc_info: + driver_info.query_driver_cuda_version() + + assert exc_info.value.__cause__ is root_cause + + +def test_query_driver_cuda_version_int_raises_when_cuda_call_fails(monkeypatch): + fake_driver_lib = _FakeDriverLib(status=1, version=0) + + monkeypatch.setattr(driver_info, "IS_WINDOWS", False) + monkeypatch.setattr(driver_info, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_cuda("/usr/lib/libcuda.so.1")) + monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_driver_lib) + + with pytest.raises(RuntimeError, match=r"cuDriverGetVersion\(\) \(status=1\)"): + driver_info._query_driver_cuda_version_int()