Source code for pythonwrench.typing.checks

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import inspect
import logging
import sys
from numbers import Integral
from types import FunctionType, MethodType
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterable,
    Literal,
    Mapping,
    Sequence,
    Tuple,
    Type,
    TypedDict,
    Union,
)

import typing_extensions
from typing_extensions import (
    NotRequired,
    ParamSpec,
    Required,
    TypeGuard,
    TypeIs,
    TypeVar,
    get_args,
    get_origin,
)

from pythonwrench.typing.classes import (
    BuiltinCollection,
    BuiltinNumber,
    BuiltinScalar,
    DataclassInstance,
    NamedTupleInstance,
    NoneType,
)

T = TypeVar("T")
P = ParamSpec("P")

logger = logging.getLogger(__name__)


def check_args_types(fn: Callable[P, T]) -> Callable[P, T]:
    """Decorator to check argument types before call to a function.

    Example
    -------
    >>> import pythonwrench as pw
    >>> @pw.check_args_types
    >>> def f(a: int, b: str) -> str:
    >>>     return a * b
    >>> f(1, "a")  # pass check
    >>> f(1, 2)  # raises TypeError from decorator
    """
    if not isinstance(fn, (FunctionType, MethodType)):
        msg = f"Invalid argument type {type(fn)}. (expected function or method)"
        raise TypeError(msg)

    parameters = inspect.signature(fn).parameters
    annotations = {k: v.annotation for k, v in parameters.items()}
    argnames = list(annotations.keys())

    def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        num_positional = len(args)
        given_kwargs = dict(zip(argnames[:num_positional], args))
        given_kwargs.update(kwargs)

        msgs = []
        for i, (k, v) in enumerate(given_kwargs.items()):
            if isinstance_generic(v, annotations[k]):
                continue

            if i < num_positional:
                msg = f" - invalid argument n°{i + 1} with value {v!r}; expected an instance of {annotations[k]}."
            else:
                msg = f" - invalid argument '{k}' with value {v!r}; expected an instance of {annotations[k]}."
            msgs.append(msg)

        if len(msgs) > 0:
            msgs = [
                f"{fn.__name__}() has {len(msgs)}/{len(given_kwargs)} invalid argument(s):",
            ] + msgs
            msg = "\n".join(msgs)
            raise TypeError(msg)

        result = fn(*args, **kwargs)
        return result

    return _wrapper


def isinstance_generic(
    obj: Any,
    class_or_tuple: Union[Type[T], None, Tuple[Type[T], ...], Any],
    *,
    check_only_first: bool = False,
) -> TypeIs[T]:
    """Improved isinstance(...) function that supports parametrized Union, TypedDict, Literal, Mapping or Iterable.

    Args:
        obj: Object to check.
        class_or_tuple: Type to check. Can be a parametrized type from `typing`.
        check_only_first: If True, check only if first element when checking for Iterable[type]. defaults to False.

    Example 1
    ---------
    >>> isinstance_generic({"a": 1, "b": 2}, dict)
    ... True
    >>> isinstance_generic({"a": 1, "b": 2}, dict[str, int])
    ... True
    >>> isinstance_generic({"a": 1, "b": 2}, dict[str, str])
    ... False
    >>> from typing import Literal
    >>> isinstance_generic({"a": 1, "b": 2}, dict[str, Literal[1, 2]])
    ... True

    """
    if class_or_tuple is Any or class_or_tuple is typing_extensions.Any:
        return True
    if class_or_tuple is None:
        return obj is None
    if isinstance(class_or_tuple, tuple):
        return any(
            isinstance_generic(obj, target_type_i) for target_type_i in class_or_tuple
        )

    if is_typed_dict(class_or_tuple):
        return _isinstance_generic_typed_dict(obj, class_or_tuple)

    origin = get_origin(class_or_tuple)
    if origin is None:
        return isinstance(obj, class_or_tuple)  # type: ignore

    # Special case for empty tuple because get_args(Tuple[()]) returns () and not ((),) in python >= 3.11
    # More info at https://github.com/python/cpython/issues/91137
    if class_or_tuple == Tuple[()]:
        return obj == ()

    args = get_args(class_or_tuple)
    if origin is Callable:
        if len(args) == 0:
            return callable(obj)
        else:
            # TODO: impl
            msg = "Function `isinstance_generic` currently does not support parametrized Callable."
            raise NotImplementedError(msg)

    if len(args) == 0:
        return isinstance_generic(obj, origin)

    if origin is Union:
        return any(isinstance_generic(obj, arg) for arg in args)

    if origin is Literal:
        return obj in args

    if isinstance(obj, Generator):
        msg = f"Invalid argument type {type(obj)}. (cannot check elements in generator)"
        raise TypeError(msg)

    if issubclass(origin, Generator):
        msg = f"Invalid argument type {origin}. (cannot check generator type)"
        raise TypeError(msg)

    if issubclass(origin, Mapping):
        assert len(args) == 2, f"{args=}"
        if not isinstance_generic(obj, origin):
            return False

        return all(isinstance_generic(k, args[0]) for k in obj.keys()) and all(
            isinstance_generic(v, args[1]) for v in obj.values()
        )

    if issubclass(origin, Tuple):
        if not isinstance_generic(obj, origin):
            return False
        elif len(args) == 1 and args[0] == ():
            return len(obj) == 0
        elif len(args) == 2 and args[1] is ...:
            if check_only_first:
                args = (args[0],)
            else:
                args = tuple([args[0]] * len(obj))
        elif len(obj) != len(args):
            return False
        return all(isinstance_generic(xi, ti) for xi, ti in zip(obj, args))

    if issubclass(origin, Iterable):
        if not isinstance_generic(obj, origin):
            return False

        if check_only_first:
            return isinstance_generic(next(iter(obj)), args[0])
        else:
            return all(isinstance_generic(xi, args[0]) for xi in obj)

    msg = f"Unsupported type {class_or_tuple}. (expected unparametrized type or parametrized Union, TypedDict, Literal, Mapping or Iterable)"
    raise NotImplementedError(msg)


def _isinstance_generic_typed_dict(x: Any, target_type: type) -> bool:
    if not isinstance_generic(x, Dict[str, Any]):
        return False

    total: bool = target_type.__total__
    annotations = target_type.__annotations__

    required_annotations = {}
    optional_annotations = {}
    for k, v in annotations.items():
        origin = get_origin(v)
        if origin is Required:
            required_annotations[k] = v
        elif origin is NotRequired:
            optional_annotations[k] = v
        elif total:
            required_annotations[k] = v
        else:
            optional_annotations[k] = v

    if not set(required_annotations.keys()).issubset(x.keys()):
        return False

    annotations_set = set(required_annotations.keys()) | set(
        optional_annotations.keys()
    )
    if not annotations_set.issuperset(x.keys()):
        return False

    for k, v in required_annotations.items():
        origin = get_origin(v)
        if origin is Required:
            v = get_args(v)[0]

        if not isinstance_generic(x[k], v):
            return False

    for k, v in optional_annotations.items():
        if k not in x:
            continue
        origin = get_origin(v)
        if origin is NotRequired:
            v = get_args(v)[0]
        if not isinstance_generic(x[k], v):
            return False

    return True


def is_builtin_collection(x: Any, *, strict: bool = False) -> TypeIs[BuiltinCollection]:
    """Returns True if x is an instance of a builtin collection type (list, tuple, dict, set, frozenset).

    Args:
        x: Object to check.
        strict: If True, it will not consider custom subtypes of builtins as builtin collections. defaults to False.
    """
    if strict and not is_builtin_obj(x):
        return False
    return isinstance(x, (list, tuple, dict, set, frozenset))


[docs] def is_builtin_number(x: Any, *, strict: bool = False) -> TypeIs[BuiltinNumber]: """Returns True if x is an instance of a builtin number type (int, float, bool, complex). Args: x: Object to check. strict: If True, it will not consider custom subtypes of builtins as builtin numbers. defaults to False. """ if strict and not is_builtin_obj(x): return False return isinstance(x, (int, float, bool, complex))
def is_builtin_obj(x: Any) -> bool: """Returns True if object is an instance of a builtin object. Note: If the object is an instance of a custom subtype of a builtin object, this function returns False. """ return x.__class__.__module__ == "builtins" and not isinstance(x, type)
[docs] def is_builtin_scalar(x: Any, *, strict: bool = False) -> TypeIs[BuiltinScalar]: """Returns True if x is an instance of a builtin scalar type (int, float, bool, complex, NoneType, str, bytes). Args: x: Object to check. strict: If True, it will not consider subtypes of builtins as builtin scalars. defaults to False. """ if strict and not is_builtin_obj(x): return False return isinstance(x, (int, float, bool, complex, NoneType, str, bytes))
def is_dataclass_instance(x: Any) -> TypeIs[DataclassInstance]: """Returns True if argument is a dataclass. Unlike function `dataclasses.is_dataclass`, this function returns False for a dataclass type. """ return not isinstance(x, type) and isinstance_generic(x, DataclassInstance) def is_iterable_bool( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[bool]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[bool]) def is_iterable_bytes_or_list( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[Union[bytes, list]]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[Union[bytes, list]]) def is_iterable_float( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[float]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[float]) def is_iterable_int( x: Any, *, accept_bool: bool = True, accept_generator: bool = True, ) -> TypeIs[Iterable[int]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[int]) and ( accept_bool or not isinstance_generic(x, Iterable[bool]) ) def is_iterable_integral( x: Any, *, accept_generator: bool = True, ) -> TypeIs[Iterable[Integral]]: if not accept_generator and isinstance(x, Generator): return False return isinstance_generic(x, Iterable[Integral]) def is_iterable_str( x: Any, *, accept_str: bool = True, accept_generator: bool = True, ) -> TypeGuard[Iterable[str]]: if isinstance(x, str): return accept_str if isinstance(x, Generator): return accept_generator and all(isinstance(xi, str) for xi in x) return isinstance_generic(x, Iterable[str]) def is_namedtuple_instance(x: Any) -> TypeIs[NamedTupleInstance]: """Returns True if argument is a NamedTuple.""" return not isinstance(x, type) and isinstance_generic(x, NamedTupleInstance) def is_sequence_str( x: Any, *, accept_str: bool = True, ) -> TypeGuard[Sequence[str]]: return (accept_str and isinstance(x, str)) or ( not isinstance(x, str) and isinstance(x, Sequence) and all(isinstance(xi, str) for xi in x) ) def is_typed_dict(x: Any) -> TypeGuard[type]: if sys.version_info.major == 3 and sys.version_info.minor < 9: return x.__class__.__name__ == "_TypedDictMeta" else: return hasattr(x, "__orig_bases__") and TypedDict in x.__orig_bases__