Source code for shared_lib_manager

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""The implementation of loading for packages that contain a reusable native library."""

from __future__ import annotations

import ctypes
import os
import platform
from pathlib import Path
from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
    from collections.abc import Iterable


# Once we require Python 3.10, switch to using a dataclass with kw_only=True
[docs] class PlatformLibrary: """A tuple containing the paths to a library on different platforms. Parameters ---------- Linux : os.PathLike The path to the library on Linux. Darwin : os.PathLike The path to the library on macOS. Windows : os.PathLike The path to the library on Windows. default : typing.Callable[[], os.PathLike] A callable that returns the default path to the library. This is used when the current platform is not found in the library paths. It may also be used by libraries that require a more general key for determining what path to use if the choice needs to be made at runtime based on additional factors. """ Linux: Path | None Darwin: Path | None Windows: Path | None def __init__( self, *, Darwin: os.PathLike | str | None = None, # noqa: N803 Linux: os.PathLike | str | None = None, # noqa: N803 Windows: os.PathLike | str | None = None, # noqa: N803 default: Callable[[], os.PathLike | str] | None = None, ): # public attributes should correspond to platform.system() return values: # https://docs.python.org/3/library/platform.html#platform.system # TODO: Determine if sys.platform is more appropriate # https://discuss.python.org/t/clarify-usage-of-platform-system/70900/4 if not all( isinstance(path, (os.PathLike, str)) or path is None for path in (Darwin, Linux, Windows) ): raise TypeError("Paths must be instances of pathlib.Path, str, or None.") self.Darwin = Path(Darwin) if Darwin else None self.Linux = Path(Linux) if Linux else None self.Windows = Path(Windows) if Windows else None if not all( p.is_absolute() for p in (self.Darwin, self.Linux, self.Windows) if p is not None ): raise ValueError("All paths must be absolute.") self.default = default
[docs] class LibraryLoader: """Loader for a set of native libraries associated with a module. Parameters ---------- libraries : dict[str, PlatformLibrary | tuple[os.PathLike | str, os.PathLike | str, os.PathLike | str]] A mapping from library names to the paths of the libraries on each platform. If a tuple is passed, it must be ordered as (Linux, Darwin, Windows). """ # noqa: E501 def __init__( self, libraries: dict[str, PlatformLibrary], ): platform_name = platform.system() self._libraries = {} for lib, path in libraries.items(): if not isinstance(path, PlatformLibrary): raise TypeError( f"Invalid path {path} for library {lib}. Expected a tuple or " "PlatformLibrary." ) try: self._libraries[lib] = getattr(path, platform_name) except AttributeError: if path.default is not None: self._libraries[lib] = Path(path.default()) else: raise ValueError( f"No library {lib} found for the current platform " f"{platform_name}. This is a bug in the wheel, please report " "to the maintainer." ) from None @staticmethod def _load(library_path: Path | str) -> None: """Load the library at the given path with RTLD_LOCAL.""" ctypes.CDLL(str(library_path), mode=ctypes.RTLD_LOCAL)
[docs] def load( self, libraries: Iterable[str] | None = None, *, prefer_system: bool = False ) -> None: """Load the native library and return the ctypes.CDLL object. Parameters ---------- libraries : typing.Iterable[str] | None The names of the libraries to load. If None, all libraries are loaded. prefer_system : bool Whether or not to try loading a system library before the local version. Default is False. """ # Always load the library in local mode. if libraries is None: libraries = self._libraries.keys() for library_name in libraries: try: library_path = self._libraries[library_name] except KeyError: raise ValueError( f"Library {library_name} not found in the package." ) from None if prefer_system or os.getenv( f"PREFER_{library_name.upper()}_SYSTEM_LIBRARY", "false" ).lower() not in ("false", 0): try: self._load(library_path.name) except OSError: self._load(library_path) else: self._load(library_path)