Skip to content

Commit 0a181ed

Browse files
committed
ENH: test test_along_axis with axis<0
1 parent dcb22c0 commit 0a181ed

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

array_api_tests/test_indexing_functions.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def test_take(x, data):
7171
)
7272
def test_take_along_axis(x, data):
7373
# TODO
74-
# 1. negative axis
7574
# 2. negative indices
7675
# 3. different dtypes for indices
77-
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
76+
axis = data.draw(st.integers(-x.ndim, max(x.ndim - 1, 0)), label="axis")
7877
len_axis = data.draw(st.integers(0, 2*x.shape[axis]), label="len_axis")
7978

80-
idx_shape = x.shape[:axis] + (len_axis,) + x.shape[axis+1:]
79+
n_axis = axis + x.ndim if axis < 0 else axis
80+
idx_shape = x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:]
8181
indices = data.draw(
8282
hh.arrays(
8383
shape=idx_shape,
@@ -94,7 +94,7 @@ def test_take_along_axis(x, data):
9494
ph.assert_shape(
9595
"take_along_axis",
9696
out_shape=out.shape,
97-
expected=x.shape[:axis] + (len_axis,) + x.shape[axis+1:],
97+
expected=x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:],
9898
kw=dict(
9999
x=x,
100100
indices=indices,
@@ -103,12 +103,11 @@ def test_take_along_axis(x, data):
103103
)
104104

105105
# value test: notation is from `np.take_along_axis` docstring
106-
Ni, Nk = x.shape[:axis], x.shape[axis+1:]
106+
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
107107
for ii in sh.ndindex(Ni):
108108
for kk in sh.ndindex(Nk):
109109
a_1d = x[ii + (slice(None),) + kk]
110110
i_1d = indices[ii + (slice(None),) + kk]
111111
o_1d = out[ii + (slice(None),) + kk]
112112
for j in range(len_axis):
113113
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
114-

0 commit comments

Comments
 (0)