Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
227 views
in Technique[技术] by (71.8m points)

python - Efficiently count zero elements in numpy array?

I need to count the number of zero elements in numpy arrays. I'm aware of the numpy.count_nonzero function, but there appears to be no analog for counting zero elements.

My arrays are not very large (typically less than 1E5 elements) but the operation is performed several millions of times.

Of course I could use len(arr) - np.count_nonzero(arr), but I wonder if there's a more efficient way to do it.

Here's a MWE of how I do it currently:

import numpy as np
import timeit

arrs = []
for _ in range(1000):
    arrs.append(np.random.randint(-5, 5, 10000))


def func1():
    for arr in arrs:
        zero_els = len(arr) - np.count_nonzero(arr)


print(timeit.timeit(func1, number=10))
question from:https://stackoverflow.com/questions/42916330/efficiently-count-zero-elements-in-numpy-array

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

A 2x faster approach would be to just use np.count_nonzero() but with the condition as needed.

In [3]: arr
Out[3]: 
array([[1, 2, 0, 3],
      [3, 9, 0, 4]])

In [4]: np.count_nonzero(arr==0)
Out[4]: 2

In [5]:def func_cnt():
            for arr in arrs:
                zero_els = np.count_nonzero(arr==0)
                # here, it counts the frequency of zeroes actually

You can also use np.where() but it's slower than np.count_nonzero()

In [6]: np.where( arr == 0)
Out[6]: (array([0, 1]), array([2, 2]))

In [7]: len(np.where( arr == 0))
Out[7]: 2

Efficiency: (in descending order)

In [8]: %timeit func_cnt()
10 loops, best of 3: 29.2 ms per loop

In [9]: %timeit func1()
10 loops, best of 3: 46.5 ms per loop

In [10]: %timeit func_where()
10 loops, best of 3: 61.2 ms per loop

more speedups with accelerators

It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you've access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:

In [1]: import jax.numpy as jnp
In [2]: from jax import jit

# set up inputs
In [3]: arrs = []
In [4]: for _ in range(1000):
   ...:     arrs.append(np.random.randint(-5, 5, 10000))

# JIT'd function that performs the counting task
In [5]: @jit
   ...: def func_cnt():
   ...:     for arr in arrs:
   ...:         zero_els = jnp.count_nonzero(arr==0)

# efficiency test
In [8]: %timeit func_cnt()
15.6 μs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...