Python – Tricky numpy argmax on the last dimension of a 3-dimensional ndarray

Tricky numpy argmax on the last dimension of a 3-dimensional ndarray… here is a solution to the problem.

Tricky numpy argmax on the last dimension of a 3-dimensional ndarray

If there is an array of shapes (9,1,3).

array([[[  6,  12, 108]],

[[122, 112,  38]],

[[ 57, 101,  62]],

[[119,  76, 177]],

[[ 46,  62,   2]],

[[127,  61, 155]],

[[  5,   6, 151]],

[[  5,   8, 185]],

[[109, 167,  33]]])

I want to find the argmax index of the third dimension, which is 185 in this case, so the index is 7.

I

guess the solution is related to reshape, but I can’t understand it. Thanks for your help!

Solution

I’m not sure what’s tricky about it. However, one way to get the largest element index along the last axis is to use np.max and np.argmax like:

# find `max` element along last axis 
# and get the index using `argmax` where `arr` is your array
In [53]: np.argmax(np.max(arr, axis=2))
Out[53]: 7

Or, as @PaulPanzer suggested in his comments, you can use:

In [63]: np.unravel_index(np.argmax(arr), arr.shape)
Out[63]: (7, 0, 2)

In [64]: arr[(7, 0, 2)]
Out[64]: 185

Related Problems and Solutions