I'm new to Pyro and trying to get my first stochastic process model working. I adapted the code from here to suit my example problem which is simply two Gaussians with a discrete probability of the sample coming from one or the other.
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC
# Actual data sample
observations = torch.tensor(
[0.00528813, -0.00589001, -1.20608593, 0.00190794,
0.89052784, 0.66690464, 0.57295968, 0.02605967]
)
# Define the process
def model(observations):
a_prior = dist.Beta(2, 2)
a = pyro.sample("a", a_prior)
c = pyro.sample('c', dist.Bernoulli(a))
if c.item() == 1.0:
my_dist = dist.Normal(0.785, 1.0)
else:
my_dist = dist.Normal(0.0, 0.01)
for i, observation in enumerate(observations):
measurement = pyro.sample(f'obs_{i}', my_dist, obs=observation)
# Clear parameters
pyro.clear_param_store()
# Define the MCMC kernel function
my_kernel = HMC(model)
# Define the MCMC algorithm
my_mcmc = MCMC(my_kernel,
num_samples=5000,
warmup_steps=50)
# Run the algorithm, passing the observations
my_mcmc.run(observations)
The exception raised is:
<ipython-input-2-a668622a0fb9> in model(observations)
11 a = pyro.sample("a", a_prior)
12 c = pyro.sample('c', dist.Bernoulli(a))
---> 13 if c.item() == 1.0:
14 my_dist = dist.Normal(0.785, 1.0)
15 else:
ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
c dist |
value 2 |
I had a look at c
using the debugger and for some reason it has two elements the second time model() is called:
tensor([0., 1.])
What is causing this? I wanted it to be a simple scalar having the values 0 or 1.
As a further test, the condition statement works fine when taking samples in the normal way:
# Conditional switch test
a_prior = dist.Beta(2, 2)
a = pyro.sample("a", a_prior)
for i in range(5):
c = pyro.sample('c', dist.Bernoulli(a))
if c.item() == 1.0:
print(1, end=' ')
else:
print(0, end=' ')
# 0 0 1 0 0
question from:
https://stackoverflow.com/questions/65880002/pyro-samples-of-a-bernoulli-random-variable-have-more-than-one-element 与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…