Skip to content

Commit 27927b5

Browse files
committed
ENH: test take_along_axis default axis=-1
1 parent 0a181ed commit 27927b5

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

Diff for: array_api_tests/test_indexing_functions.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,30 @@ def test_take_along_axis(x, data):
7373
# TODO
7474
# 2. negative indices
7575
# 3. different dtypes for indices
76-
axis = data.draw(st.integers(-x.ndim, max(x.ndim - 1, 0)), label="axis")
77-
len_axis = data.draw(st.integers(0, 2*x.shape[axis]), label="len_axis")
76+
axis = data.draw(
77+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
78+
label="axis"
79+
)
80+
if axis is None:
81+
axis_kw = {}
82+
n_axis = x.ndim - 1
83+
else:
84+
axis_kw = {"axis": axis}
85+
n_axis = axis + x.ndim if axis < 0 else axis
7886

79-
n_axis = axis + x.ndim if axis < 0 else axis
87+
len_axis = data.draw(st.integers(0, 2*x.shape[n_axis]), label="len_axis")
8088
idx_shape = x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:]
8189
indices = data.draw(
8290
hh.arrays(
8391
shape=idx_shape,
8492
dtype=dh.default_int,
85-
elements={"min_value": 0, "max_value": x.shape[axis]-1}
93+
elements={"min_value": 0, "max_value": x.shape[n_axis]-1}
8694
),
8795
label="indices"
8896
)
8997
note(f"{indices=} {idx_shape=}")
9098

91-
out = xp.take_along_axis(x, indices, axis=axis)
99+
out = xp.take_along_axis(x, indices, **axis_kw)
92100

93101
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
94102
ph.assert_shape(

0 commit comments

Comments
 (0)