Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
548 views
in Technique[技术] by (71.8m points)

python - faster alternative to numpy.where?

I have a 3d array filled with integers from 0 to N. I need a list of the indices corresponding to where the array is equal 1, 2, 3, ... N. I can do it with np.where as follows:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

but this is quite slow. According to this question fast python numpy where functionality? it should be possible to speed up the index search quite a lot, but I haven't been able to transfer the methods proposed there to my problem of getting the actual indices. What would be the best way to speed up the above code?

As an add-on: I want to store the indices later, for which it makes sense to use np.ravel_multi_index to reduce the size from saving 3 indices to only 1, i.e. using:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

which is closer to e.g. Matlab's find function. Can this be directly incorporated in a solution that doesn't use np.where?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

I think that a standard vectorized approach to this problem would end up being very memory intensive – for int64 data, it would require O(8 * N * data.size) bytes, or ~22 gigs of memory for the example you gave above. I'm assuming that is not an option.

You might make some progress by using a sparse matrix to store the locations of the unique values. For example:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

This takes advantage of fast code within the sparse matrix constructor to organize the data in a useful way, constructing a sparse matrix where row i contains just the indices where the flattened data equals i.

To test it out, I'll also define a function that does your straightforward method:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

The two functions give the same results for the same input:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

And the sparse method is an order of magnitude faster than the simple method for your dataset:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

The other benefit of the sparse method is that the matrix M ends up being a very compact and efficient way to store all the relevant information for later use, as mentioned in the add-on part of your question. Hope that's useful!


Edit: I realized there was a bug in the initial version: it failed if any values in the range didn't appear in the data: that's now fixed above.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...