From b0c88f4dfb8aac39e1471d8068fab796f3e986bd Mon Sep 17 00:00:00 2001 From: jorenham Date: Thu, 17 Apr 2025 22:34:39 +0200 Subject: [PATCH 1/4] TYP: auto-plagiarize the optypean `Just*` types --- array_api_compat/common/_typing.py | 46 +++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d7deade1..b87eabd7 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -2,7 +2,15 @@ from collections.abc import Mapping from types import ModuleType as Namespace -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Literal, + Protocol, + TypeAlias, + TypedDict, + TypeVar, + final, +) if TYPE_CHECKING: from _typeshed import Incomplete @@ -21,6 +29,37 @@ _T_co = TypeVar("_T_co", covariant=True) +# These "Just" types are equivalent to the `Just` type from the `optype` library, +# apart from them not being `@runtime_checkable`. +# - docs: https://github.com/jorenham/optype/blob/master/README.md#just +# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py +@final +class JustInt(Protocol): + @property + def __class__(self, /) -> type[int]: ... + @__class__.setter + def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustFloat(Protocol): + @property + def __class__(self, /) -> type[float]: ... + @__class__.setter + def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustComplex(Protocol): + @property + def __class__(self, /) -> type[complex]: ... + @__class__.setter + def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +# + + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... @@ -121,6 +160,8 @@ class DTypesAll(DTypesBool, DTypesNumeric): # `__array_namespace_info__.dtypes(kind=?)` (fallback) DTypesAny: TypeAlias = Mapping[str, DType] +NormOrder: TypeAlias = JustFloat | Literal[-2, -1, 1, 2] + __all__ = [ "Array", @@ -140,6 +181,9 @@ class DTypesAll(DTypesBool, DTypesNumeric): "Device", "HasShape", "Namespace", + "JustInt", + "JustFloat", + "JustComplex", "NestedSequence", "SupportsArrayNamespace", "SupportsBufferProtocol", From 5a1a00ba2297545f6a85ece5e08f2695978f6d31 Mon Sep 17 00:00:00 2001 From: jorenham Date: Thu, 17 Apr 2025 22:35:33 +0200 Subject: [PATCH 2/4] TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` --- array_api_compat/common/_linalg.py | 6 +++--- array_api_compat/torch/linalg.py | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7e002aed..eb4a5de2 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -12,7 +12,7 @@ from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot -from ._typing import Array, DType, Namespace +from ._typing import Array, DType, JustFloat, JustInt, Namespace # These are in the main NumPy namespace but not in numpy.linalg @@ -139,7 +139,7 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: float | Literal["fro", "nuc"] | None = "fro", + ord: JustInt | JustFloat | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) @@ -155,7 +155,7 @@ def vector_norm( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: float = 2, + ord: JustInt | JustFloat = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 1ff7319d..70d72405 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -16,6 +16,7 @@ # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot from ._typing import Array, DType +from ..common._typing import JustInt, JustFloat # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 @@ -84,8 +85,8 @@ def vector_norm( *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float] = 2, + # JustFloat stands for inf | -inf, which are not valid for Literal + ord: JustInt | JustFloat = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None @@ -115,3 +116,6 @@ def vector_norm( _all_ignore = ['torch_linalg', 'sum'] del linalg_all + +def __dir__() -> list[str]: + return __all__ From 920cb8059a477368a5e5c30a142167ee4b4818bc Mon Sep 17 00:00:00 2001 From: jorenham Date: Thu, 17 Apr 2025 22:38:19 +0200 Subject: [PATCH 3/4] TYP: remove accidental type alias --- array_api_compat/common/_typing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index b87eabd7..cd26feeb 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -160,8 +160,6 @@ class DTypesAll(DTypesBool, DTypesNumeric): # `__array_namespace_info__.dtypes(kind=?)` (fallback) DTypesAny: TypeAlias = Mapping[str, DType] -NormOrder: TypeAlias = JustFloat | Literal[-2, -1, 1, 2] - __all__ = [ "Array", From 9eb938a12e5a54ea26de0e84b9d51ecf0b697b1a Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Sat, 19 Apr 2025 15:45:22 +0200 Subject: [PATCH 4/4] TYP: Tighten the `ord` param of `matrix_norm` Co-authored-by: Lucas Colley --- array_api_compat/common/_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index eb4a5de2..7ad87a1b 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -139,7 +139,7 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: JustInt | JustFloat | Literal["fro", "nuc"] | None = "fro", + ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)