Skip to content

TYP: reject bool in the ord params of vector_norm and matrix_norm #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
from ._typing import Array, DType, Namespace
from ._typing import Array, DType, JustFloat, JustInt, Namespace


# These are in the main NumPy namespace but not in numpy.linalg
Expand Down Expand Up @@ -139,7 +139,7 @@ def matrix_norm(
xp: Namespace,
*,
keepdims: bool = False,
ord: float | Literal["fro", "nuc"] | None = "fro",
ord: JustInt | JustFloat | Literal["fro", "nuc"] | None = "fro",
) -> Array:
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)

Expand All @@ -155,7 +155,7 @@ def vector_norm(
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
ord: float = 2,
ord: JustInt | JustFloat = 2,
) -> Array:
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
Expand Down
46 changes: 45 additions & 1 deletion array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from collections.abc import Mapping
from types import ModuleType as Namespace
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
from typing import (
TYPE_CHECKING,
Literal,
Protocol,
TypeAlias,
TypedDict,
TypeVar,
final,
)

if TYPE_CHECKING:
from _typeshed import Incomplete
Expand All @@ -21,6 +29,37 @@
_T_co = TypeVar("_T_co", covariant=True)


# These "Just" types are equivalent to the `Just` type from the `optype` library,
# apart from them not being `@runtime_checkable`.
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
@final
class JustInt(Protocol):
@property
def __class__(self, /) -> type[int]: ...
@__class__.setter
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


@final
class JustFloat(Protocol):
@property
def __class__(self, /) -> type[float]: ...
@__class__.setter
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


@final
class JustComplex(Protocol):
@property
def __class__(self, /) -> type[complex]: ...
@__class__.setter
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


#


class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
Expand Down Expand Up @@ -121,6 +160,8 @@ class DTypesAll(DTypesBool, DTypesNumeric):
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
DTypesAny: TypeAlias = Mapping[str, DType]

NormOrder: TypeAlias = JustFloat | Literal[-2, -1, 1, 2]


__all__ = [
"Array",
Expand All @@ -140,6 +181,9 @@ class DTypesAll(DTypesBool, DTypesNumeric):
"Device",
"HasShape",
"Namespace",
"JustInt",
"JustFloat",
"JustComplex",
"NestedSequence",
"SupportsArrayNamespace",
"SupportsBufferProtocol",
Expand Down
8 changes: 6 additions & 2 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
from ._typing import Array, DType
from ..common._typing import JustInt, JustFloat

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
Expand Down Expand Up @@ -84,8 +85,8 @@ def vector_norm(
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
# float stands for inf | -inf, which are not valid for Literal
ord: Union[int, float] = 2,
# JustFloat stands for inf | -inf, which are not valid for Literal
ord: JustInt | JustFloat = 2,
**kwargs,
) -> Array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
Expand Down Expand Up @@ -115,3 +116,6 @@ def vector_norm(
_all_ignore = ['torch_linalg', 'sum']

del linalg_all

def __dir__() -> list[str]:
return __all__
Loading