Tricky numpy argmax on the last dimension of a 3-dimensional ndarray… here is a solution to the problem.
Tricky numpy argmax on the last dimension of a 3-dimensional ndarray
If there is an array of shapes (9,1,3).
array([[[ 6, 12, 108]],
[[122, 112, 38]],
[[ 57, 101, 62]],
[[119, 76, 177]],
[[ 46, 62, 2]],
[[127, 61, 155]],
[[ 5, 6, 151]],
[[ 5, 8, 185]],
[[109, 167, 33]]])
I want to find the argmax index of the third dimension, which is 185 in this case, so the index is 7.
I
guess the solution is related to reshape, but I can’t understand it. Thanks for your help!
Solution
I’m not sure what’s tricky about it. However, one way to get the largest element index along the last axis is to use np.max
and np.argmax
like:
# find `max` element along last axis
# and get the index using `argmax` where `arr` is your array
In [53]: np.argmax(np.max(arr, axis=2))
Out[53]: 7
Or, as @PaulPanzer suggested in his comments, you can use:
In [63]: np.unravel_index(np.argmax(arr), arr.shape)
Out[63]: (7, 0, 2)
In [64]: arr[(7, 0, 2)]
Out[64]: 185