@@ -71,13 +71,13 @@ def test_take(x, data):
71
71
)
72
72
def test_take_along_axis (x , data ):
73
73
# TODO
74
- # 1. negative axis
75
74
# 2. negative indices
76
75
# 3. different dtypes for indices
77
- axis = data .draw (st .integers (0 , max (x .ndim - 1 , 0 )), label = "axis" )
76
+ axis = data .draw (st .integers (- x . ndim , max (x .ndim - 1 , 0 )), label = "axis" )
78
77
len_axis = data .draw (st .integers (0 , 2 * x .shape [axis ]), label = "len_axis" )
79
78
80
- idx_shape = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :]
79
+ n_axis = axis + x .ndim if axis < 0 else axis
80
+ idx_shape = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :]
81
81
indices = data .draw (
82
82
hh .arrays (
83
83
shape = idx_shape ,
@@ -94,7 +94,7 @@ def test_take_along_axis(x, data):
94
94
ph .assert_shape (
95
95
"take_along_axis" ,
96
96
out_shape = out .shape ,
97
- expected = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :],
97
+ expected = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :],
98
98
kw = dict (
99
99
x = x ,
100
100
indices = indices ,
@@ -103,12 +103,11 @@ def test_take_along_axis(x, data):
103
103
)
104
104
105
105
# value test: notation is from `np.take_along_axis` docstring
106
- Ni , Nk = x .shape [:axis ], x .shape [axis + 1 :]
106
+ Ni , Nk = x .shape [:n_axis ], x .shape [n_axis + 1 :]
107
107
for ii in sh .ndindex (Ni ):
108
108
for kk in sh .ndindex (Nk ):
109
109
a_1d = x [ii + (slice (None ),) + kk ]
110
110
i_1d = indices [ii + (slice (None ),) + kk ]
111
111
o_1d = out [ii + (slice (None ),) + kk ]
112
112
for j in range (len_axis ):
113
113
assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
114
-
0 commit comments