Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
804 views
in Technique[技术] by (71.8m points)

matlab - How to compute sum of binomial more efficiently?

I must calculate an equation as follows:

enter image description here where k1,k2 are given. I am using MATLAB to compute P. I think I have a correct implementation for the above equation. However, my implementation is so slow. I think the issue is from binomial coefficient. From the equation, could I have an efficient way to speed up the time? Thank all.

For k1=150; k2=150; D=200;, it takes 11.6 seconds

function main
warning ('off');
  function test_binom()
      k1=150; k2=150; D=200; P=0;
      for i=0:D-1
          for j=0:i
              if (i-j>k2||j>k1) 
                  continue;
              end
              P=P+nchoosek(k1,j)*nchoosek(k2,i-j)/nchoosek((k1+k2),i);          
          end 
      end
  end
f = @()test_binom(); 
timeit(f)
end

Update: For measure time, I found that nchoosek is the reason for large computational time. Hence, I rewrite the function as follows

function re=choose(n, k)
    if (k == 0)
        re=1;
    else
        re=(n * choose(n - 1, k - 1)) / k;
    end
end

Now, the computational time is reduced as 0.25 second. Is has any better way?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

You can save results of nchoosek to a table to prevent repeated evaluation of the function, also an implementation of binomial coefficients provided:

%binomial coefficients
function nk=nchoosek2(n, k)
    if n-k > k
        nk = prod((k+1:n) .* prod((1:n-k).^ (-1/(n-k))));
    else
        nk = prod((n-k+1:n) .* prod((1:k).^ (-1/k)) ) ;
    end
end
%function to store and retrieve results of nchoosek to/from a table
function ret = choose (n,k, D, K1, K2)
    persistent binTable = zeros(max([D+1,K1+K2+1]) , D+1);
    if binTable(n+1,k+1) == 0
        binTable(n+1,k+1) = nchoosek2(n,k);
    end
    ret = binTable(n+1,k+1);
end

function P = tst()
    P=0;k1=150; k2=150; D=200; P=0;
    choose(1,0,D,k1,k2);
    for i = 0:D-1
        for j = j=max(i - k2 , 0):min (i,k1-1)
            P=P+choose(k1,j)*choose(k2,i-j)/choose((k1+k2),i);
        end
    end
end

Your code with nchoosek2 compared with this: online demo


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...