Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
299 views
in Technique[技术] by (71.8m points)

python - Where is my mistake in this complex numpy usage?

recently I am learning NumPy's usage and I met a question which gives me 3 array:

- q: A numpy array of shape (1, K) (queries)
- k: A numpy array of shape (N, K) (keys)
- v: A numpy array of shape (N, 1) (values)

and ask me to do sum_i exp(-||q-k_i||^2) * v[i]

My code is:

(np.exp(np.sum((np.tile(q, (np.shape(k)[0], 1)) - k)**2, axis = -1)**0.5 * -1).T.dot(v))[0]

But the value is not correct.

For readability, I explain my previous one-line code as follows: 1.

If anyone knows where is my mistake, please help me, thank you. Stuck on this for a long time.

np.tile(q, (np.shape(k)[0], 1) // this makes a N line copy of q, for following operation.

(np.tile(q, (np.shape(k)[0], 1)) - k) // this is just let q - k, making an elementwise, same-size, q-k matrix

np.sum((np.tile(q, (np.shape(k)[0], 1)) - k)**2, axis = -1)**0.5 gives a row-wise L2-norm, so currently we have a (N,1) matrix, each row is the norm of the original q-k[i]

np.exp(np.sum((np.tile(q, (np.shape(k)[0], 1)) - k)**2, axis = -1)**0.5 * -1).T Then we do exp of its -1 negative, and eventually, dot it with the v, to automatically get the sum.

I tried the np.exp(-np.linalg.norm((np.tile(q, (np.shape(k)[0], 1)) - k), axis=1, ord = 2)).T.dot(v)[0] aka linalg.norm method, but it does not change...

question from:https://stackoverflow.com/questions/66056124/where-is-my-mistake-in-this-complex-numpy-usage

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Looking at the formula you shared, it takes exponential of Square of norm(q-k). You missed the square term. Try this out.

    np.exp(np.sum((np.tile(q, (np.shape(k)[0], 1)) - k)**2, axis = -1) * -1).T.dot(v)

or

    np.exp(-np.linalg.norm((np.tile(q, (np.shape(k)[0], 1)) - k), axis=1, ord = 2)**2).T.dot(v)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...