Skip to content

Commit 4371506

Browse files
committed
Merge branch 'main' into typ_v4
2 parents 49f9ba7 + 5e14b53 commit 4371506

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

Diff for: array_api_compat/common/_linalg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .._internal import get_xp
1414
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
15-
from ._typing import Array, DType, Namespace
15+
from ._typing import Array, DType, JustFloat, JustInt, Namespace
1616

1717

1818
# These are in the main NumPy namespace but not in numpy.linalg
@@ -139,7 +139,7 @@ def matrix_norm(
139139
xp: Namespace,
140140
*,
141141
keepdims: bool = False,
142-
ord: float | Literal["fro", "nuc"] | None = "fro",
142+
ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro",
143143
) -> Array:
144144
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
145145

@@ -155,7 +155,7 @@ def vector_norm(
155155
*,
156156
axis: int | tuple[int, ...] | None = None,
157157
keepdims: bool = False,
158-
ord: float = 2,
158+
ord: JustInt | JustFloat = 2,
159159
) -> Array:
160160
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
161161
# when axis=None and the input is 2-D, so to force a vector norm, we make

Diff for: array_api_compat/common/_typing.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from __future__ import annotations
22

33
from types import ModuleType as Namespace
4-
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
4+
from typing import (
5+
TYPE_CHECKING,
6+
Literal,
7+
Protocol,
8+
TypeAlias,
9+
TypedDict,
10+
TypeVar,
11+
final,
12+
)
513

614
if TYPE_CHECKING:
715
from _typeshed import Incomplete
@@ -20,6 +28,37 @@
2028
_T_co = TypeVar("_T_co", covariant=True)
2129

2230

31+
# These "Just" types are equivalent to the `Just` type from the `optype` library,
32+
# apart from them not being `@runtime_checkable`.
33+
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
34+
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
35+
@final
36+
class JustInt(Protocol):
37+
@property
38+
def __class__(self, /) -> type[int]: ...
39+
@__class__.setter
40+
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
41+
42+
43+
@final
44+
class JustFloat(Protocol):
45+
@property
46+
def __class__(self, /) -> type[float]: ...
47+
@__class__.setter
48+
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
49+
50+
51+
@final
52+
class JustComplex(Protocol):
53+
@property
54+
def __class__(self, /) -> type[complex]: ...
55+
@__class__.setter
56+
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
57+
58+
59+
#
60+
61+
2362
class NestedSequence(Protocol[_T_co]):
2463
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
2564
def __len__(self, /) -> int: ...
@@ -78,6 +117,9 @@ def shape(self, /) -> tuple[_T_co, ...]: ...
78117
"Device",
79118
"HasShape",
80119
"Namespace",
120+
"JustInt",
121+
"JustFloat",
122+
"JustComplex",
81123
"NestedSequence",
82124
"SupportsArrayNamespace",
83125
"SupportsBufferProtocol",

Diff for: array_api_compat/torch/linalg.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# These functions are in both the main and linalg namespaces
1515
from ._aliases import matmul, matrix_transpose, tensordot
1616
from ._typing import Array, DType
17+
from ..common._typing import JustInt, JustFloat
1718

1819
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
1920
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -82,8 +83,8 @@ def vector_norm(
8283
*,
8384
axis: int | tuple[int, ...] | None = None,
8485
keepdims: bool = False,
85-
# float stands for inf | -inf, which are not valid for Literal
86-
ord: float = 2,
86+
# JustFloat stands for inf | -inf, which are not valid for Literal
87+
ord: JustInt | JustFloat = 2,
8788
**kwargs: object,
8889
) -> Array:
8990
# torch.vector_norm incorrectly treats axis=() the same as axis=None
@@ -113,3 +114,6 @@ def vector_norm(
113114
_all_ignore = ['torch_linalg', 'sum']
114115

115116
del linalg_all
117+
118+
def __dir__() -> list[str]:
119+
return __all__

0 commit comments

Comments
 (0)