Skip to content

Commit 5e14b53

Browse files
TYP: reject bool in the ord params of vector_norm and matrix_norm (#310)
* TYP: auto-plagiarize the optypean `Just*` types * TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` * TYP: remove accidental type alias * TYP: Tighten the `ord` param of `matrix_norm` Co-authored-by: Lucas Colley <[email protected]> --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 205c967 commit 5e14b53

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

Diff for: array_api_compat/common/_linalg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .._internal import get_xp
1414
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
15-
from ._typing import Array, DType, Namespace
15+
from ._typing import Array, DType, JustFloat, JustInt, Namespace
1616

1717

1818
# These are in the main NumPy namespace but not in numpy.linalg
@@ -139,7 +139,7 @@ def matrix_norm(
139139
xp: Namespace,
140140
*,
141141
keepdims: bool = False,
142-
ord: float | Literal["fro", "nuc"] | None = "fro",
142+
ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
143143
) -> Array:
144144
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
145145

@@ -155,7 +155,7 @@ def vector_norm(
155155
*,
156156
axis: int | tuple[int, ...] | None = None,
157157
keepdims: bool = False,
158-
ord: float = 2,
158+
ord: JustInt | JustFloat = 2,
159159
) -> Array:
160160
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
161161
# when axis=None and the input is 2-D, so to force a vector norm, we make

Diff for: array_api_compat/common/_typing.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
from collections.abc import Mapping
44
from types import ModuleType as Namespace
5-
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
5+
from typing import (
6+
TYPE_CHECKING,
7+
Literal,
8+
Protocol,
9+
TypeAlias,
10+
TypedDict,
11+
TypeVar,
12+
final,
13+
)
614

715
if TYPE_CHECKING:
816
from _typeshed import Incomplete
@@ -21,6 +29,37 @@
2129
_T_co = TypeVar("_T_co", covariant=True)
2230

2331

32+
# These "Just" types are equivalent to the `Just` type from the `optype` library,
33+
# apart from them not being `@runtime_checkable`.
34+
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
35+
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
36+
@final
37+
class JustInt(Protocol):
38+
@property
39+
def __class__(self, /) -> type[int]: ...
40+
@__class__.setter
41+
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
42+
43+
44+
@final
45+
class JustFloat(Protocol):
46+
@property
47+
def __class__(self, /) -> type[float]: ...
48+
@__class__.setter
49+
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
50+
51+
52+
@final
53+
class JustComplex(Protocol):
54+
@property
55+
def __class__(self, /) -> type[complex]: ...
56+
@__class__.setter
57+
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
58+
59+
60+
#
61+
62+
2463
class NestedSequence(Protocol[_T_co]):
2564
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
2665
def __len__(self, /) -> int: ...
@@ -140,6 +179,9 @@ class DTypesAll(DTypesBool, DTypesNumeric):
140179
"Device",
141180
"HasShape",
142181
"Namespace",
182+
"JustInt",
183+
"JustFloat",
184+
"JustComplex",
143185
"NestedSequence",
144186
"SupportsArrayNamespace",
145187
"SupportsBufferProtocol",

Diff for: array_api_compat/torch/linalg.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# These functions are in both the main and linalg namespaces
1717
from ._aliases import matmul, matrix_transpose, tensordot
1818
from ._typing import Array, DType
19+
from ..common._typing import JustInt, JustFloat
1920

2021
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2122
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -84,8 +85,8 @@ def vector_norm(
8485
*,
8586
axis: Optional[Union[int, Tuple[int, ...]]] = None,
8687
keepdims: bool = False,
87-
# float stands for inf | -inf, which are not valid for Literal
88-
ord: Union[int, float] = 2,
88+
# JustFloat stands for inf | -inf, which are not valid for Literal
89+
ord: JustInt | JustFloat = 2,
8990
**kwargs,
9091
) -> Array:
9192
# torch.vector_norm incorrectly treats axis=() the same as axis=None
@@ -115,3 +116,6 @@ def vector_norm(
115116
_all_ignore = ['torch_linalg', 'sum']
116117

117118
del linalg_all
119+
120+
def __dir__() -> list[str]:
121+
return __all__

0 commit comments

Comments
 (0)