Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
544 views
in Technique[技术] by (71.8m points)

performance - Minimizing overhead due to the large number of Numpy dot calls

My problem is the following, I have an iterative algorithm such that at each iteration it needs to perform several matrix-matrix multiplications dot(A_i, B_i), for i = 1 ... k. Since these multiplications are being performed with Numpy's dot, I know they are calling BLAS-3 implementation, which is quite fast. The problem is that the number of calls is huge and it turned out to be a bottleneck in my program. I would like to minimize the overhead due all these calls by making less products but with bigger matrices.

For simplicity, consider that all matrices are n x n (usually n is not big, it ranges between 1 and 1000). One way around to my problem would be to consider the block diagonal matrix diag(A_i) and perform the product below.

diag_blk

This is just one call to the function dot but now the program wastes a lot of times performing multiplication with zeros. This idea doesn't seem to work but it gives the result [A_1 B_1, ..., A_k B_k], that is, all products stacked in a single big matrix.

My question is this, is there a way to compute [A_1 B_1, ..., A_k B_k] with a single function call? Or even more to the point, how can I compute these products faster than making a loop of Numpy dots?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

It depends on the size of the matrices

Edit

For larger nxn matrices (aprox. size 20) a BLAS call from compiled code is faster, for smaller matrices custom Numba or Cython Kernels are usually faster.

The following method generates custom dot- functions for given input shapes. With this method it is also possible to benefit from compiler related optimizations like loop unrolling, which are especially important for small matrices.

It has to be noted, that generating and compiling one kernel takes approx. 1s, therefore make sure to call the generator only if you really have to.

Generator function

def gen_dot_nm(x,y,z):
    #small kernels
    @nb.njit(fastmath=True,parallel=True)
    def dot_numba(A,B):
        """
        calculate dot product for (x,y)x(y,z)
        """
        assert A.shape[0]==B.shape[0]
        assert A.shape[2]==B.shape[1]

        assert A.shape[1]==x
        assert B.shape[1]==y
        assert B.shape[2]==z

        res=np.empty((A.shape[0],A.shape[1],B.shape[2]),dtype=A.dtype)
        for ii in nb.prange(A.shape[0]):
            for i in range(x):
                for j in range(z):
                    acc=0.
                    for k in range(y):
                        acc+=A[ii,i,k]*B[ii,k,j]
                    res[ii,i,j]=acc
        return res

    #large kernels
    @nb.njit(fastmath=True,parallel=True)
    def dot_BLAS(A,B):
        assert A.shape[0]==B.shape[0]
        assert A.shape[2]==B.shape[1]

        res=np.empty((A.shape[0],A.shape[1],B.shape[2]),dtype=A.dtype)
        for ii in nb.prange(A.shape[0]):
            res[ii]=np.dot(A[ii],B[ii])
        return res

    #At square matices above size 20
    #calling BLAS is faster
    if x>=20 or y>=20 or z>=20:
        return dot_BLAS
    else:
        return dot_numba

Usage example

A=np.random.rand(1000,2,2)
B=np.random.rand(1000,2,2)

dot22=gen_dot_nm(2,2,2)
X=dot22(A,B)
%timeit X3=dot22(A,B)
#5.94 μs ± 21.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) 

Old answer

Another alternative, but with more work to do, would be to use some special BLAS implementations, which creates custom kernels for very small matrices just in time and than calling this kernels from C.

Example

import numpy as np
import numba as nb

#Don't use this for larger submatrices
@nb.njit(fastmath=True,parallel=True)
def dot(A,B):
    assert A.shape[0]==B.shape[0]
    assert A.shape[2]==B.shape[1]

    res=np.empty((A.shape[0],A.shape[1],B.shape[2]),dtype=A.dtype)
    for ii in nb.prange(A.shape[0]):
        for i in range(A.shape[1]):
            for j in range(B.shape[2]):
                acc=0.
                for k in range(B.shape[1]):
                    acc+=A[ii,i,k]*B[ii,k,j]
                res[ii,i,j]=acc
    return res

@nb.njit(fastmath=True,parallel=True)
def dot_22(A,B):
    assert A.shape[0]==B.shape[0]
    assert A.shape[1]==2
    assert A.shape[2]==2
    assert B.shape[1]==2
    assert B.shape[2]==2

    res=np.empty((A.shape[0],A.shape[1],B.shape[2]),dtype=A.dtype)
    for ii in nb.prange(A.shape[0]):
        res[ii,0,0]=A[ii,0,0]*B[ii,0,0]+A[ii,0,1]*B[ii,1,0]
        res[ii,0,1]=A[ii,0,0]*B[ii,0,1]+A[ii,0,1]*B[ii,1,1]
        res[ii,1,0]=A[ii,1,0]*B[ii,0,0]+A[ii,1,1]*B[ii,1,0]
        res[ii,1,1]=A[ii,1,0]*B[ii,0,1]+A[ii,1,1]*B[ii,1,1]
    return res

Timings

A=np.random.rand(1000,2,2)
B=np.random.rand(1000,2,2)

X=A@B
X2=np.einsum("xik,xkj->xij",A,B)
X3=dot_22(A,B) #avoid measurig compilation overhead
X4=dot(A,B)    #avoid measurig compilation overhead

%timeit X=A@B
#262 μs ± 2.55 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.einsum("xik,xkj->xij",A,B,optimize=True)
#264 μs ± 3.22 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit X3=dot_22(A,B)
#5.68 μs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit X4=dot(A,B)
#9.79 μs ± 61.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...