Skip to content

Commit accf8c2

Browse files
authored
Fix svd function return dtype (data-apis#619)
1 parent 4d1372b commit accf8c2

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

Diff for: src/array_api_stubs/_2021_12/linalg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def solve(x1: array, x2: array, /) -> array:
364364
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
365365
"""
366366

367-
def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
367+
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
368368
"""
369369
Returns a singular value decomposition A = USVh of a matrix (or a stack of matrices) ``x``, where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.
370370
@@ -379,7 +379,7 @@ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array,
379379
-------
380380
..
381381
NOTE: once complex numbers are supported, each square matrix must be Hermitian.
382-
out: Union[array, Tuple[array, ...]]
382+
out: Tuple[array, array, array]
383383
a namedtuple ``(U, S, Vh)`` whose
384384
385385
- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``.

Diff for: src/array_api_stubs/_2022_12/linalg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def solve(x1: array, x2: array, /) -> array:
609609
"""
610610

611611

612-
def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
612+
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
613613
r"""
614614
Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) ``x``.
615615
@@ -649,7 +649,7 @@ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array,
649649
650650
Returns
651651
-------
652-
out: Union[array, Tuple[array, ...]]
652+
out: Tuple[array, array, array]
653653
a namedtuple ``(U, S, Vh)`` whose
654654
655655
- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have the same data type as ``x``.

Diff for: src/array_api_stubs/_draft/linalg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def solve(x1: array, x2: array, /) -> array:
609609
"""
610610

611611

612-
def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
612+
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
613613
r"""
614614
Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) ``x``.
615615
@@ -649,7 +649,7 @@ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array,
649649
650650
Returns
651651
-------
652-
out: Union[array, Tuple[array, ...]]
652+
out: Tuple[array, array, array]
653653
a namedtuple ``(U, S, Vh)`` whose
654654
655655
- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have the same data type as ``x``.

0 commit comments

Comments
 (0)