Skip to content

Commit 1e58129

Browse files
committed
ENH: add crude assert_allclose; use value testing in vecdot
1 parent a71b4c0 commit 1e58129

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

array_api_tests/pytest_helpers.py

+10
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"assert_0d_equals",
2929
"assert_fill",
3030
"assert_array_elements",
31+
"assert_allclose"
3132
]
3233

3334

@@ -599,3 +600,12 @@ def assert_array_elements(
599600
at_expected = expected[idx]
600601
msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
601602
assert at_out == at_expected, msg
603+
604+
605+
def assert_allclose(actual, desired, *, atol=1e-7, rtol=1e-7, equal_nan=True, msg_extra=""):
606+
if equal_nan:
607+
# XXX assert same position, mask away
608+
pass
609+
610+
delta = xp.abs(actual - desired)
611+
assert xp.all(delta < atol + xp.abs(actual)*rtol), msg_extra

array_api_tests/test_linalg.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
106106
msg_extra = f'{x_idxes = }, {res_idx = }'
107107
assert_equal(res_stack, decomp_res_stack, msg_extra)
108108
if true_val:
109-
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)
109+
expected = true_val(*x_stacks, **kw)
110+
if decomp_res_stack.dtype in dh.all_float_dtypes:
111+
ph.assert_allclose(decomp_res_stack, expected, msg_extra=msg_extra)
112+
else:
113+
assert_equal(decomp_res_stack, expected, msg_extra)
110114

111115

112116
def _test_namedtuple(res, fields, func_name):

0 commit comments

Comments
 (0)