Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
268 views
in Technique[技术] by (71.8m points)

python - Pyro: samples of a Bernoulli random variable have more than one element

I'm new to Pyro and trying to get my first stochastic process model working. I adapted the code from here to suit my example problem which is simply two Gaussians with a discrete probability of the sample coming from one or the other.

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC

# Actual data sample
observations = torch.tensor(
    [0.00528813, -0.00589001, -1.20608593, 0.00190794,
     0.89052784,  0.66690464,  0.57295968, 0.02605967]
)

# Define the process
def model(observations):
    
    a_prior = dist.Beta(2, 2)
    a = pyro.sample("a", a_prior)
    c = pyro.sample('c', dist.Bernoulli(a))
    if c.item() == 1.0:
        my_dist = dist.Normal(0.785, 1.0)
    else:
        my_dist = dist.Normal(0.0, 0.01)
    
    for i, observation in enumerate(observations):
        measurement = pyro.sample(f'obs_{i}', my_dist, obs=observation)

# Clear parameters
pyro.clear_param_store()

# Define the MCMC kernel function
my_kernel = HMC(model)

# Define the MCMC algorithm
my_mcmc = MCMC(my_kernel,
               num_samples=5000,
               warmup_steps=50)

# Run the algorithm, passing the observations 
my_mcmc.run(observations)

The exception raised is:

<ipython-input-2-a668622a0fb9> in model(observations)
     11     a = pyro.sample("a", a_prior)
     12     c = pyro.sample('c', dist.Bernoulli(a))
---> 13     if c.item() == 1.0:
     14         my_dist = dist.Normal(0.785, 1.0)
     15     else:

ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:    
 Param Sites:    
Sample Sites:    
       a dist   |
        value   |
       c dist   |
        value 2 |

I had a look at c using the debugger and for some reason it has two elements the second time model() is called:

tensor([0., 1.])

What is causing this? I wanted it to be a simple scalar having the values 0 or 1.

As a further test, the condition statement works fine when taking samples in the normal way:

# Conditional switch test
a_prior = dist.Beta(2, 2)
a = pyro.sample("a", a_prior)
for i in range(5):
    c = pyro.sample('c', dist.Bernoulli(a))
    if c.item() == 1.0:
        print(1, end=' ')
    else:
        print(0, end=' ')

# 0 0 1 0 0 
question from:https://stackoverflow.com/questions/65880002/pyro-samples-of-a-bernoulli-random-variable-have-more-than-one-element

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...